diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index b3e216dc4f..ac9322d651 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -440,6 +440,29 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, return heuristicResult.algo; } +// Device helper: compute the element offset for tensor `idx` given shape metadata. +// Three cases: +// 1. Explicit per-tensor offset array provided → use it directly. +// 2. Per-tensor first/last dims provided but no offsets → cumulative sum of (first*last) products. +// 3. Fully uniform shapes → idx * uniform_first * uniform_last. +__forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorShapeInfo &meta, + size_t idx) { + if (meta.offsets) { + return meta.offsets[idx]; + } else if (meta.first_dims != nullptr || meta.last_dims != nullptr) { + // offset[i] = sum_{j < i} (first_dims[j] * last_dims[j]) + int64_t cumsum = 0; + for (size_t i = 0; i < idx; i++) { + int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first; + int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last; + cumsum += f * l; + } + return cumsum; + } else { + return static_cast(idx) * meta.uniform_first * meta.uniform_last; + } +} + // Single kernel that sets up all GEMM parameters. // Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions, // but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. @@ -464,15 +487,11 @@ __global__ void setup_grouped_gemm_kernel( int64_t d_first = D_meta.first_dims ? D_meta.first_dims[idx] : D_meta.uniform_first; int64_t d_last = D_meta.last_dims ? D_meta.last_dims[idx] : D_meta.uniform_last; - // Compute offsets (from array or compute from uniform dims) - int64_t a_offset = - A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); - int64_t b_offset = - B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); - int64_t c_offset = - C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); - int64_t d_offset = - D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); + // Compute offsets (from explicit array, cumulative from per-tensor dims, or uniform) + int64_t a_offset = compute_grouped_tensor_offset(A_meta, idx); + int64_t b_offset = compute_grouped_tensor_offset(B_meta, idx); + int64_t c_offset = compute_grouped_tensor_offset(C_meta, idx); + int64_t d_offset = compute_grouped_tensor_offset(D_meta, idx); // Compute data pointers A_ptrs[idx] = a_base + a_offset * a_elem_size; @@ -487,9 +506,8 @@ __global__ void setup_grouped_gemm_kernel( a_cols[idx] = static_cast(a_first); b_rows[idx] = static_cast(b_last); b_cols[idx] = static_cast(b_first); - // For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N). - d_rows[idx] = static_cast(d_first); - d_cols[idx] = static_cast(d_last); + d_rows[idx] = static_cast(d_last); + d_cols[idx] = static_cast(d_first); // Fill alpha/beta pointers (per-matrix) alpha_ptrs[idx] = alpha_ptr + idx; @@ -535,6 +553,10 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { } // namespace +size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors) { + return grouped_gemm_setup_workspace_size(num_tensors); +} + void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, @@ -642,4 +664,27 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); } +size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors) { + NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); + return 0; +} + #endif // CUBLAS_VERSION >= 130200 + +namespace { + +__global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) dst[idx] = static_cast(src[idx]); +} + +} // namespace + +void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) { + if (n == 0) return; + const int threads = 256; + const int blocks = static_cast((n + threads - 1) / threads); + convert_int32_to_int64_kernel<<>>(src, dst, n); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 7403448722..ec370d46d1 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -329,6 +329,31 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * - Shape compatibility: if transa=false, transb=false: * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) */ +/*! \brief Return the required size in bytes for the setup workspace of grouped GEMM. + * + * The setup workspace stores pointer arrays and per-matrix dimension arrays used + * by the grouped GEMM kernel. Its size depends only on the number of tensors (GEMMs) + * in the group and is independent of matrix dimensions. + * + * Pass the result as the size of the workspace_setup tensor in nvte_grouped_gemm. + * + * \param[in] num_tensors Number of tensors (GEMMs) in the group. + * \return Required size in bytes for workspace_setup. + */ +size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors); + +/*! \brief Convert a device array of int32 values to int64 values. + * + * Useful for preparing group_sizes for nvte_grouped_gemm when the caller + * holds int32 sizes and needs int64 values on the device. + * + * \param[in] src Device pointer to source int32 array. + * \param[out] dst Device pointer to destination int64 array. + * \param[in] n Number of elements. + * \param[in] stream CUDA stream. + */ +void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream); + void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index a34cb030bf..283b96d995 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -24,6 +24,7 @@ get_device_compute_capability, initialize_cgemm_communicator, get_cgemm_num_max_streams, + get_grouped_gemm_setup_workspace_size, ) from .base import BasePrimitive, register_primitive @@ -591,7 +592,7 @@ def lowering( assert_cublas_requirements( scaling_mode, lhs_contracting_size, - "LHS", + f"LHS {lhs_aval.shape} with contracting dims {lhs_cdims}", ) rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) rhs_contracting_size = ( @@ -602,7 +603,7 @@ def lowering( assert_cublas_requirements( scaling_mode, rhs_contracting_size, - "RHS", + f"RHS {rhs_aval.shape} with contracting dims {rhs_cdims}", ) args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) @@ -1421,9 +1422,188 @@ def impl( register_primitive(GroupedGemmCopySizesPrimitive) +class GroupedGemmCudaGraphablePrimitive(BasePrimitive): + """ + Primitive for grouped GEMM using nvte_grouped_gemm (cuda-graphable, BF16 only). + """ + + name = "te_grouped_gemm_cuda_graphable_ffi" + multiple_results = True + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + lhs_data_aval, + lhs_scale_inv_aval, + rhs_data_aval, + rhs_scale_inv_aval, + bias_aval, + group_sizes_aval, + alpha, + beta, + *, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + use_async_d2h_group_sizes, + ): + """ + Grouped GEMM operation (cuda-graphable via nvte_grouped_gemm). + + Args: + lhs_data: Left-hand side input matrix data, 1D flattened array + lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array + rhs_data: Right-hand side input matrix data, 1D flattened array + rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array + bias: Bias matrix of shape (G, N) + group_sizes: 1D int32 array containing the sizes of each group + alpha: Per-group alpha scaling factors (float32) + beta: Per-group beta scaling factors (float32) + M: Number of rows in the output matrix + N: Number of columns in the output matrix + K: Number of columns in the left-hand side matrix + lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed + rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed + scaling_mode: Scaling mode for the GEMM operations + out_dtype: Data type of the output tensors + has_bias: Boolean indicating if bias tensors are provided + is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation + where both lhs and rhs are 2D matrices and output is (G, M, N) + + Returns: + A jnp.ndarray containing the result of the grouped GEMM operation + """ + del lhs_data_aval, rhs_data_aval, bias_aval, alpha, beta + del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes + # TODO(Phuong): move some shape checks from Cpp to here + workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams + workspace_alignment_padding = 256 + tensor_scaling_sinv_aligment = 16 + mxfp8_scaling_sinv_alignment_padding = 256 + # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not + # necessarily 256 bytes aligned, we add some padding to ensure alignment. + workspace_size += workspace_alignment_padding + if scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING.value, + ScalingMode.CURRENT_TENSOR_SCALING.value, + ): + # For tensor scaling, each matrix has a single scale value, but it + # needs to be aligned to 16 bytes for CUDA 12.9.1 and later. + workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment + workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + # We also pad scale_inv swizzle buffers size for 256 bytes alignment. + workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + + workspace_size += get_grouped_gemm_setup_workspace_size(group_sizes_aval.size) + workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) + + # Temporary buffer for int32 → int64 conversion of group_sizes on device. + int64_workspace_size = group_sizes_aval.size * jnp.dtype(jnp.int64).itemsize + int64_workspace_aval = jax.core.ShapedArray(shape=(int64_workspace_size,), dtype=jnp.uint8) + + out_shape = (M, N) + if is_grouped_dense_wgrad: + num_tensors = group_sizes_aval.size + out_shape = (num_tensors, M, N) + out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + return (out_aval, workspace_aval, int64_workspace_aval) + + @staticmethod + def outer_abstract(*args, **kwargs): + (out_aval, _, _) = GroupedGemmCudaGraphablePrimitive.abstract(*args, **kwargs) + return (out_aval,) + + @staticmethod + def lowering( + ctx, + *args, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + use_async_d2h_group_sizes, + ): + del out_dtype + return jax.ffi.ffi_lowering(GroupedGemmCudaGraphablePrimitive.name)( + ctx, + *args, + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, + has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, + ) + + @staticmethod + def impl( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + alpha, + beta, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + use_async_d2h_group_sizes, + ): + assert GroupedGemmCudaGraphablePrimitive.inner_primitive is not None + (out, _, _) = GroupedGemmCudaGraphablePrimitive.inner_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + alpha, + beta, + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode, + out_dtype=out_dtype, + has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, + ) + return (out,) + + +register_primitive(GroupedGemmCudaGraphablePrimitive) + + class GroupedGemmPrimitive(BasePrimitive): """ - Primitive for grouped GEMM + Primitive for grouped GEMM using nvte_multi_tensor_gemm (supports all scaling modes). """ name = "te_grouped_gemm_ffi" @@ -1903,6 +2083,28 @@ def grouped_gemm_copy_group_sizes( return out +def _can_use_cuda_graphable_grouped_gemm( + scaling_mode: ScalingMode, + dtype: jnp.dtype, + has_bias: bool, +) -> bool: + """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters.""" + # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy + # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay + # feature-compatible with the main branch. + # Bias can be supported in a kernel or in pure-JAX in the future. + + try: + get_grouped_gemm_setup_workspace_size(1) + except RuntimeError as e: + if "cublas" in str(e).lower(): + # If the workspace size function is not available, it means the cuda-graphable implementation is not available. + return False + raise e + + return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias + + def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], @@ -1937,8 +2139,6 @@ def grouped_gemm( lhs: [M, K] or [K, N] rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] """ - # TODO(Phuong): implement the group_offset - group_offset = group_offset or jnp.zeros((1,), jnp.int32) # TODO(Phuong): implement the precision del precision @@ -2074,29 +2274,65 @@ def grouped_gemm( else: assert group_sizes.size == rhs_shape[0] - assert group_offset.size == 1 - has_bias = bias is not None assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias - (out,) = GroupedGemmPrimitive.outer_primitive.bind( - lhs_data, - lhs_scale_inv, - rhs_data, - rhs_scale_inv, - bias, - group_sizes, - group_offset, - M=M, - N=N, - K=K_lhs, - lhs_is_trans=lhs_is_trans, - rhs_is_trans=rhs_is_trans, - scaling_mode=scaling_mode.value, - out_dtype=out_dtype, - has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, - use_async_d2h_group_sizes=use_async_d2h_group_sizes, + use_cuda_graphable = _can_use_cuda_graphable_grouped_gemm( + scaling_mode, lhs_data.dtype, has_bias ) + + if use_cuda_graphable: + assert group_offset is None, ( + "group_offset is not supported in the cuda graphable path and is instead computed" + " internally assuming contiguous grouping. Any padding is included in the group_sizes" + " and padded with zeros to not affect the result of the MoE block." + ) + group_sizes = group_sizes.astype(jnp.int32) + num_gemms = group_sizes.shape[0] + alpha = jnp.ones((num_gemms,), jnp.float32) + beta = jnp.zeros((num_gemms,), jnp.float32) + (out,) = GroupedGemmCudaGraphablePrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + alpha, + beta, + M=M, + N=N, + K=K_lhs, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, + out_dtype=out_dtype, + has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, + ) + else: + # TODO(Phuong): implement the group_offset + group_offset = group_offset or jnp.zeros((1,), jnp.int32) + assert group_offset.size == 1 + (out,) = GroupedGemmPrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M=M, + N=N, + K=K_lhs, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, + out_dtype=out_dtype, + has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, + ) return out diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 1fcecb0e96..bf4e833c89 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -97,7 +97,9 @@ def abstract( dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] out_shape = x_aval.shape - assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert ( + scale_aval is None or scale_aval.dtype == jnp.float32 + ), f"scale must be float32 but received {scale_aval}" if stochastic_rounding: assert ScalingMode( scaling_mode @@ -1213,7 +1215,7 @@ def grouped_quantize( assert n_groups == len( quantizer.quantizers ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}" - scale = jnp.empty((n_groups,), jnp.float32) + scale = jnp.ones((n_groups,), jnp.float32) if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: for i, quantizer_i in enumerate(quantizer.quantizers): diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1c0bc52b88..bc5bf32ae5 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -138,6 +138,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmCudaGraphableHandler); // Amax XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 4303682bfb..73c4a5ede1 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -409,6 +409,384 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGro .Ret() // dummy_output .Attr("num_gemms")); +class JAXX_GroupedTensorWrapper { + public: + JAXX_GroupedTensorWrapper() = delete; + JAXX_GroupedTensorWrapper(JAXX_Scaling_Mode scaling_mode, size_t num_tensors, + NVTEShape const &dataShape); + JAXX_GroupedTensorWrapper(JAXX_GroupedTensorWrapper const &) = delete; + JAXX_GroupedTensorWrapper &operator=(JAXX_GroupedTensorWrapper const &) = delete; + JAXX_GroupedTensorWrapper(JAXX_GroupedTensorWrapper &&other) noexcept + : m_data_shape(other.m_data_shape), + m_grouped_tensor(other.m_grouped_tensor), + m_data_tensor(other.m_data_tensor), + m_scale_inv_tensor(other.m_scale_inv_tensor), + m_sizes_tensor(other.m_sizes_tensor), + m_offsets_tensor(other.m_offsets_tensor) { + other.m_grouped_tensor = nullptr; + } + JAXX_GroupedTensorWrapper &operator=(JAXX_GroupedTensorWrapper &&) = delete; + ~JAXX_GroupedTensorWrapper(); + + void set_rowwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, + NVTEGroupedTensorParam group_sizes_param_name); + // Set only group sizes (no offsets); the setup kernel will compute offsets from sizes. + void set_group_sizes_only(const int64_t *sizes_ptr, size_t num_tensors, + NVTEGroupedTensorParam group_sizes_param_name); + + operator NVTEGroupedTensor() const { return m_grouped_tensor; } + NVTEGroupedTensor const &get_grouped_tensor() const; + + private: + NVTEShape m_data_shape{}; + NVTEGroupedTensor m_grouped_tensor{}; + + // Internal tensors. These need to be kept alive as long as the grouped tensor is alive. + NVTEBasicTensor m_data_tensor{}; + NVTEBasicTensor m_scale_inv_tensor{}; + + NVTEBasicTensor m_sizes_tensor{}; + NVTEBasicTensor m_offsets_tensor{}; +}; + +JAXX_GroupedTensorWrapper::JAXX_GroupedTensorWrapper(JAXX_Scaling_Mode scaling_mode, + size_t num_tensors, + NVTEShape const &dataShape) { + m_data_shape = dataShape; + m_grouped_tensor = + nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, dataShape); +} + +JAXX_GroupedTensorWrapper::~JAXX_GroupedTensorWrapper() { + if (m_grouped_tensor != nullptr) { + nvte_destroy_grouped_tensor(m_grouped_tensor); + } +} + +void JAXX_GroupedTensorWrapper::set_rowwise(Buffer_Type const &data, + std::optional const &scale_inv) { + NVTEDType data_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())); + m_data_tensor = + NVTEBasicTensor{reinterpret_cast(data.untyped_data()), data_dtype, m_data_shape}; + + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedRowwiseData, &m_data_tensor); + + if (scale_inv.has_value()) { + NVTEDType scale_inv_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())); + NVTEShape logical_scale_shape{}; + if (scale_inv->dimensions().size() == 1) { + logical_scale_shape.ndim = 1; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + } else if (scale_inv->dimensions().size() == 2) { + logical_scale_shape.ndim = 2; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + logical_scale_shape.data[1] = scale_inv->dimensions()[1]; + } else { + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", + scale_inv->dimensions().size()); + } + m_scale_inv_tensor = NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), + scale_inv_dtype, logical_scale_shape}; + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedRowwiseScaleInv, + &m_scale_inv_tensor); + } +} + +void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const &group_sizes, + Buffer_Type const &group_offsets, + NVTEGroupedTensorParam group_sizes_param_name) { + NVTEDType sizes_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(group_sizes.element_type())); + NVTEDType offsets_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(group_offsets.element_type())); + + NVTE_CHECK(sizes_dtype == NVTEDType::kNVTEInt64, "group_sizes must be of type int64."); + NVTE_CHECK(offsets_dtype == NVTEDType::kNVTEInt64, "group_offsets must be of type int64."); + + size_t num_tensors = group_sizes.dimensions()[0]; + NVTE_CHECK(group_sizes.dimensions().size() == 1, + "group_sizes must be a 1D tensor with length equal to the number of tensors."); + NVTE_CHECK(group_offsets.dimensions().size() == 1, + "group_offsets must be a 1D tensor with length equal to the number of tensors."); + NVTE_CHECK(group_offsets.dimensions()[0] == num_tensors, + "group_sizes and group_offsets must have the same number of elements."); + + NVTEShape shape{}; + shape.ndim = 1; + shape.data[0] = num_tensors; + + m_sizes_tensor = NVTEBasicTensor{reinterpret_cast(group_sizes.untyped_data()), + NVTEDType::kNVTEInt64, shape}; + m_offsets_tensor = NVTEBasicTensor{reinterpret_cast(group_offsets.untyped_data()), + NVTEDType::kNVTEInt64, shape}; + + nvte_set_grouped_tensor_param(&m_grouped_tensor, group_sizes_param_name, &m_sizes_tensor); + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedTensorOffsets, &m_offsets_tensor); +} + +void JAXX_GroupedTensorWrapper::set_group_sizes_only( + const int64_t *sizes_ptr, size_t num_tensors, NVTEGroupedTensorParam group_sizes_param_name) { + NVTEShape shape{}; + shape.ndim = 1; + shape.data[0] = num_tensors; + m_sizes_tensor = NVTEBasicTensor{reinterpret_cast(const_cast(sizes_ptr)), + NVTEDType::kNVTEInt64, shape}; + nvte_set_grouped_tensor_param(&m_grouped_tensor, group_sizes_param_name, &m_sizes_tensor); + // Intentionally no offset tensor: offsets will be computed by the setup kernel. +} + +NVTEGroupedTensor const &JAXX_GroupedTensorWrapper::get_grouped_tensor() const { + return m_grouped_tensor; +} + +JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, + std::optional scale_inv, + JAXX_Scaling_Mode scaling_mode, size_t num_tensors, + NVTEShape const &dataShape) { + JAXX_GroupedTensorWrapper grouped_tensor_wrapper(scaling_mode, num_tensors, dataShape); + if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING) { + scale_inv = std::nullopt; + } + grouped_tensor_wrapper.set_rowwise(data, scale_inv); + + return std::move(grouped_tensor_wrapper); +} + +Error_Type GroupedGemmCudaGraphableFFI( + cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, + Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type alpha, + Buffer_Type beta, Result_Type output, Result_Type workspace, Result_Type int64_workspace, + size_t m, size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, + JAXX_Scaling_Mode scaling_mode, bool has_bias, bool is_grouped_dense_wgrad, + bool use_async_d2h_group_sizes) { + // Notes on matrix layouts and transpose: + // Jax uses row-major data_layout, on entering this function, each input matrix pair: + // A: row-major [m, k] for N - [k, m] for T + // B: row-major [k, n] for N - [n, k] for T + // on exiting this function, JAX expect: + // C: row-major with size [m, n]. + // cuBLAS uses column-major data_layout, in this view, each input matrix pair: + // A: column-major with size [k, m] for T - [m, k] for N + // B: column-major with size [n, k] for T - [k, n] for N + // + // If we call cuBLAS GEMM for A * B, the output will be: + // C: column-major with size [m, n] --> row-major with size [n, m]. + // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. + + int num_streams = nvte_get_num_compute_streams(); + + // Inputs + auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); + auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); + auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); + auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); + auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); + auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); + auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); + auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); + auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + + NVTE_CHECK(group_sizes.dimensions().size() == 1); + size_t num_gemms = group_sizes.dimensions()[0]; + + // Convert int32 group_sizes to int64 into the dedicated output buffer. + NVTE_CHECK(group_sizes.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(group_sizes.untyped_data()), + int64_sizes_ptr, num_gemms, stream); + + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, + "Only non-quantized grouped GEMM is supported in current implementation."); + + // It is weird that TE/Common GEMM only use colwise for MXFP8 + const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); + const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; + const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; + const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; + + // Outputs + auto out_ptr = reinterpret_cast(output->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); + auto workspace_total_size = product(workspace->dimensions()); + + auto lhs_sinv_size = product(lhs_sinv.dimensions()); + auto rhs_sinv_size = product(rhs_sinv.dimensions()); + const size_t workspace_alignment_padding = 256; + const size_t tensor_scaling_sinv_aligment = 16; + const size_t mxfp8_scaling_sinv_alignment_padding = 256; + auto workspace_size = workspace_total_size - workspace_alignment_padding; + if (is_mxfp8_scaling) { + // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. + workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); + } else if (is_tensor_scaling) { + // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned + // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. + workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); + } + workspace_size = workspace_size / num_streams; + auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; + swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); + auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; + swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); + auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned + auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; + + size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); + size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); + size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); + size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); + size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); + size_t out_dtype_bytes = te_dtype_bytes(out_dtype); + + NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); + NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, + "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); + + size_t expected_lhs_size = m * k; + size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t actual_lhs_size = product(lhs_data.dimensions()); + size_t actual_rhs_size = product(rhs_data.dimensions()); + size_t actual_out_size = product(output->dimensions()); + NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", + expected_lhs_size, ", got ", actual_lhs_size); + if (!is_grouped_dense_wgrad) { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, + "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, + " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, + " * ", n, " = ", expected_out_size, ", got ", actual_out_size); + } else { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, + " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, + "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, + " = ", expected_out_size, ", got ", actual_out_size); + } + + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + bool grad = false; + bool accumulate = false; + bool use_split_accumulator = false; + auto bias_shape = std::vector{has_bias ? n : 0}; + const int arch = cuda::sm_arch(); + + if (arch < 100 && is_fp8_gemm) { + NVTE_CHECK(!lhs_is_trans && rhs_is_trans, + "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", + "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); + } + + const size_t workspace_setup_size = nvte_grouped_gemm_setup_workspace_size(num_gemms); + TensorWrapper workspace_setup(workspace_ptr, std::vector{workspace_setup_size}, + DType::kByte); + TensorWrapper workspace_cublas(workspace_ptr + workspace_setup_size, + std::vector{workspace_size}, DType::kByte); + + TensorWrapper alpha_tensor(static_cast(alpha.untyped_data()), + std::vector{num_gemms}, + convert_ffi_datatype_to_te_dtype(alpha.element_type())); + TensorWrapper beta_tensor(static_cast(beta.untyped_data()), + std::vector{num_gemms}, + convert_ffi_datatype_to_te_dtype(beta.element_type())); + + if (is_grouped_dense_wgrad) { + NVTE_CHECK(lhs_is_trans && !rhs_is_trans, + "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); + + //// RHS + NVTEShape rhsShape{.data = {k, n}, .ndim = 2}; + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); + rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + + //// LHS + NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; + lhs_is_trans = true; + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + + //// OUTPUT + NVTEShape outShape{.data = {num_gemms * m, n}, .ndim = 2}; + auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, + num_gemms, outShape); + + nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, + alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), + workspace_cublas.data(), + nullptr, // config (use defaults) + stream); + + return ffi_with_cuda_error_check(); + } + + // Nominal case for FWD or DGRAD + + //// RHS + NVTEShape rhsShape{.data = {num_gemms * k, n}, .ndim = 2}; + if (rhs_is_trans) { + rhsShape.data[0] = num_gemms * n; + rhsShape.data[1] = k; + } + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); + + //// LHS + NVTEShape lhsShape{.data = {m, k}, .ndim = 2}; + if (lhs_is_trans) { + std::swap(lhsShape.data[0], lhsShape.data[1]); + } + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, + lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); + + //// OUTPUT + NVTEShape outShape{.data = {m, n}, .ndim = 2}; + auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, + num_gemms, outShape); + out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + + nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, + alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), + workspace_cublas.data(), + nullptr, // config (use defaults) + stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmCudaGraphableHandler, GroupedGemmCudaGraphableFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs_data + .Arg() // lhs_sinv + .Arg() // rhs_data + .Arg() // rhs_sinv + .Arg() // bias + .Arg() // group_sizes (int32) + .Arg() // alpha + .Arg() // beta + .Ret() // output + .Ret() // workspace + .Ret() // int64_workspace + .Attr("M") + .Attr("N") + .Attr("K") + .Attr("lhs_is_trans") + .Attr("rhs_is_trans") + .Attr("scaling_mode") + .Attr("has_bias") + .Attr("is_grouped_dense_wgrad") + .Attr("use_async_d2h_group_sizes"), + FFI_CudaGraph_Traits); + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 71de897d9b..87acb287a7 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -7,6 +7,7 @@ #include "../extensions.h" #include "cgemm_helper.h" #include "common/util/cuda_runtime.h" +#include "transformer_engine/gemm.h" namespace transformer_engine { namespace jax { @@ -75,6 +76,9 @@ pybind11::dict Registrations() { dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); + dict["te_grouped_gemm_cuda_graphable_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GroupedGemmCudaGraphableHandler)); // Amax dict["te_rht_amax_ffi"] = pybind11::dict( @@ -105,6 +109,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); + m.def("get_grouped_gemm_setup_workspace_size", &nvte_grouped_gemm_setup_workspace_size); pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index dd7d2a47ba..98b043ef35 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -4,7 +4,11 @@ """Transformer Engine bindings for JAX""" from .module import DenseGeneral, LayerNorm from .module import LayerNormDenseGeneral, LayerNormMLP -from .module import wrap_function_in_te_state_module, make_dot_general_cls +from .module import ( + wrap_function_in_te_state_module, + make_dot_general_cls, + make_ragged_dot_cls, +) from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -16,6 +20,7 @@ "LayerNormMLP", "wrap_function_in_te_state_module", "make_dot_general_cls", + "make_ragged_dot_cls", "extend_logical_axis_rules", "DotProductAttention", "MultiHeadAttention", diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 3d82d8f0b4..1758e338c4 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -17,7 +17,7 @@ from jax.ad_checkpoint import checkpoint_name -from ..dense import dense +from ..dense import dense, grouped_dense from ..layernorm import canonicalize_norm_type from ..layernorm import layernorm @@ -377,6 +377,7 @@ def generate_quantizer_set( variable_collection: str = None, quantization_checkpoint_name: Optional[str] = None, fp8_recipe=None, + n_groups: int = None, ): """ Generate a set of FP8 meta for a GEMM. @@ -409,6 +410,7 @@ def generate_quantizer_set( fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set, checkpoint_name=quantization_checkpoint_name, + n_groups=n_groups, ) return quantizer_set @@ -1379,12 +1381,13 @@ def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] class TEWrapper(te.flax.module.TransformerEngineBase): """Wrapper Flax module for TransformerEngine quantization support.""" - def generate_quantizer_set(self, postfix: str = ""): + def generate_quantizer_set(self, postfix: str = "", n_groups: int = None): OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" return super().generate_quantizer_set( postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, fp8_recipe=quantization_recipe, + n_groups=n_groups, ) @nn.compact @@ -1438,3 +1441,26 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): ) return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") + + +def make_ragged_dot_cls(quantization_recipe): + """Creates a ragged dot (grouped GEMM) class for use with TE state module.""" + assert quantization_recipe is None, "Ragged dot grouped GEMM does not support quantization yet" + + def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): + del kwargs # Unused + num_groups = group_sizes.shape[0] + quantizer_set = generate_quantizer_set(n_groups=num_groups) + + out = grouped_dense( + x, + kernel, + group_sizes=group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + ) + return out + + return wrap_function_in_te_state_module( + te_grouped_dot_general, quantization_recipe, "ragged_dot" + )()