From 7c46453afb2f786ccb3b889df204b8e3f76aefde Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 13 Feb 2026 09:22:44 -0800 Subject: [PATCH 01/14] Grouped GEMM Signed-off-by: Jeremy Berchtold --- .../common/gemm/cublaslt_gemm.cu | 8 +- .../common/gemm/cublaslt_grouped_gemm.cu | 5 +- transformer_engine/jax/cpp_extensions/gemm.py | 115 +++-- .../jax/cpp_extensions/quantization.py | 10 +- .../jax/csrc/extensions/gemm.cpp | 474 ++++++++---------- transformer_engine/jax/flax/__init__.py | 7 +- transformer_engine/jax/flax/module.py | 27 +- transformer_engine/jax/quantize/quantizer.py | 2 +- 8 files changed, 346 insertions(+), 302 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index c58c3cb47a..241e30764a 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -154,8 +154,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.lda % 16 == 0, - "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + // NVTE_CHECK(ret.lda % 16 == 0, + // "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. @@ -245,8 +245,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.ldb % 16 == 0, - "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); + // NVTE_CHECK(ret.ldb % 16 == 0, + // "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { if (is_B_transposed) { diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index b3e216dc4f..a2434419dc 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -487,9 +487,8 @@ __global__ void setup_grouped_gemm_kernel( a_cols[idx] = static_cast(a_first); b_rows[idx] = static_cast(b_last); b_cols[idx] = static_cast(b_first); - // For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N). - d_rows[idx] = static_cast(d_first); - d_cols[idx] = static_cast(d_last); + d_rows[idx] = static_cast(d_last); + d_cols[idx] = static_cast(d_first); // Fill alpha/beta pointers (per-matrix) alpha_ptrs[idx] = alpha_ptr + idx; diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 71f133bfc4..0949d5e462 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -583,27 +583,27 @@ def lowering( ) lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed) - lhs_contracting_size = ( - reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) - if lhs_transposed - else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) - ) - assert_cublas_requirements( - scaling_mode, - lhs_contracting_size, - "LHS", - ) - rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) - rhs_contracting_size = ( - reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) - if rhs_transposed - else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) - ) - assert_cublas_requirements( - scaling_mode, - rhs_contracting_size, - "RHS", - ) + # lhs_contracting_size = ( + # reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) + # if lhs_transposed + # else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) + # ) + # assert_cublas_requirements( + # scaling_mode, + # lhs_contracting_size, + # f"LHS {lhs_aval.shape} with contracting dims {lhs_cdims}", + # ) + # rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) + # rhs_contracting_size = ( + # reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) + # if rhs_transposed + # else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) + # ) + # assert_cublas_requirements( + # scaling_mode, + # rhs_contracting_size, + # f"RHS {rhs_aval.shape} with contracting dims {rhs_cdims}", + # ) args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) kwargs = { @@ -936,7 +936,15 @@ def _parse_operand_output_specs( # Non-contracting dims of RHS always needs to be gathered along the FSDP axis rhs_non_cspecs = tuple( - None if spec is not None and spec == gsr.fsdp_resource else spec + ( + None + if spec is not None + and ( + spec == gsr.fsdp_resource + or (isinstance(spec, tuple) and gsr.fsdp_resource in spec) + ) + else spec + ) for spec in rhs_non_cspecs ) @@ -1420,7 +1428,7 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + impl_static_args = (10, 11, 12, 13, 14, 15, 16, 17, 18, 19) inner_primitive = None outer_primitive = None @@ -1432,7 +1440,10 @@ def abstract( rhs_scale_inv_aval, bias_aval, group_sizes_aval, - group_offset_aval, + group_offset_lhs_aval, + group_offset_out_aval, + alpha, + beta, *, M, N, @@ -1470,7 +1481,7 @@ def abstract( Returns: A jnp.ndarray containing the result of the grouped GEMM operation """ - del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval + del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_out_aval del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams @@ -1492,11 +1503,16 @@ def abstract( # We also pad scale_inv swizzle buffers size for 256 bytes alignment. workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + + workspace_size += ( + 1024 * 1024 + ) # HACK: properly make a workspace_setup buffer in addition to the workspace_cublas buffer workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) out_shape = (M, N) if is_grouped_dense_wgrad: - out_shape = (group_sizes_aval.size, M, N) + num_tensors = group_sizes_aval.size + out_shape = (num_tensors, M, N) out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) return (out_aval, workspace_aval) @@ -1543,7 +1559,10 @@ def impl( rhs_scale_inv, bias, group_sizes, - group_offset, + group_offset_lhs, + group_offset_out, + alpha, + beta, M, N, K, @@ -1563,7 +1582,10 @@ def impl( rhs_scale_inv, bias, group_sizes, - group_offset, + group_offset_lhs, + group_offset_out, + alpha, + beta, M=M, N=N, K=K, @@ -1929,8 +1951,9 @@ def grouped_gemm( lhs: [M, K] or [K, N] rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] """ - # TODO(Phuong): implement the group_offset - group_offset = group_offset or jnp.zeros((1,), jnp.int32) + + assert group_offset is None, "group_offset is not yet implemented" + assert jax.config.jax_enable_x64, "Grouped GEMM currently requires jax_enable_x64 to be True for correct behavior" # TODO(Phuong): implement the precision del precision @@ -2066,12 +2089,35 @@ def grouped_gemm( else: assert group_sizes.size == rhs_shape[0] - assert group_offset.size == 1 - has_bias = bias is not None assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias + # TODO(jberchtold): move the int64 and offset computation to C++ side in a kernel to avoid needing JAX to support int64 + group_sizes = group_sizes.astype(jnp.int64) + # Compute group_offset as cumulative sum of group_sizes, starting with 0 + group_offset = jnp.concatenate( + [jnp.array([0], dtype=jnp.int64), jnp.cumsum(group_sizes, dtype=jnp.int64)[:-1]] + ) + if is_grouped_dense_wgrad: + group_offset_lhs = ( + group_offset * M + ) # Offset is by number of elements total, not number of rows + # HACK: this _out is really the rhs in this case + group_offset_out = ( + group_offset * N + ) # Offset is by number of elements total, not number of rows + else: + group_offset_lhs = ( + group_offset * K_lhs + ) # Offset is by number of elements total, not number of rows + group_offset_out = ( + group_offset * N + ) # Offset is by number of elements total, not number of rows + + num_gemms = group_sizes.shape[0] # Due to interlaced zeros to support int64 + alpha = jnp.ones((num_gemms,), jnp.float32) + beta = jnp.zeros((num_gemms,), jnp.float32) (out,) = GroupedGemmPrimitive.outer_primitive.bind( lhs_data, lhs_scale_inv, @@ -2079,7 +2125,10 @@ def grouped_gemm( rhs_scale_inv, bias, group_sizes, - group_offset, + group_offset_lhs, + group_offset_out, + alpha, + beta, M=M, N=N, K=K_lhs, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 1fcecb0e96..f851bbebc1 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -20,7 +20,6 @@ from .base import BasePrimitive, register_primitive from .misc import ( get_padded_spec, - check_valid_batch_dims, te_dtype_to_jax_dtype, jax_dtype_to_te_dtype, multidim_transpose, @@ -97,7 +96,9 @@ def abstract( dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] out_shape = x_aval.shape - assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert ( + scale_aval is None or scale_aval.dtype == jnp.float32 + ), f"scale must be float32 but received {scale_aval}" if stochastic_rounding: assert ScalingMode( scaling_mode @@ -1213,7 +1214,7 @@ def grouped_quantize( assert n_groups == len( quantizer.quantizers ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}" - scale = jnp.empty((n_groups,), jnp.float32) + scale = jnp.ones((n_groups,), jnp.float32) if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: for i, quantizer_i in enumerate(quantizer.quantizers): @@ -1249,7 +1250,8 @@ def grouped_quantize( ) = GroupedQuantizePrimitive.outer_primitive.bind( x, scale, - group_sizes, + # TODO(jberchtold): Remove this int32 cast once GMM does not require JAX int64 dtype + group_sizes.astype(jnp.int32), out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, q_layout=q_layout, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 4303682bfb..5a17c38d37 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -409,12 +409,146 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGro .Ret() // dummy_output .Attr("num_gemms")); +class JAXX_GroupedTensorWrapper { + public: + JAXX_GroupedTensorWrapper() = delete; + JAXX_GroupedTensorWrapper(JAXX_Scaling_Mode scaling_mode, size_t num_tensors, + NVTEShape const &dataShape); + JAXX_GroupedTensorWrapper(JAXX_GroupedTensorWrapper const &) = delete; + JAXX_GroupedTensorWrapper &operator=(JAXX_GroupedTensorWrapper const &) = delete; + JAXX_GroupedTensorWrapper(JAXX_GroupedTensorWrapper &&other) noexcept + : m_data_shape(other.m_data_shape), + m_grouped_tensor(other.m_grouped_tensor), + m_data_tensor(other.m_data_tensor), + m_scale_inv_tensor(other.m_scale_inv_tensor), + m_sizes_tensor(other.m_sizes_tensor), + m_offsets_tensor(other.m_offsets_tensor) { + other.m_grouped_tensor = nullptr; + } + JAXX_GroupedTensorWrapper &operator=(JAXX_GroupedTensorWrapper &&) = delete; + ~JAXX_GroupedTensorWrapper(); + + void set_rowwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, + NVTEGroupedTensorParam group_sizes_param_name); + + operator NVTEGroupedTensor() const { return m_grouped_tensor; } + NVTEGroupedTensor const &get_grouped_tensor() const; + + private: + NVTEShape m_data_shape{}; + NVTEGroupedTensor m_grouped_tensor{}; + + // Internal tensors. These need to be kept alive as long as the grouped tensor is alive. + NVTEBasicTensor m_data_tensor{}; + NVTEBasicTensor m_scale_inv_tensor{}; + + NVTEBasicTensor m_sizes_tensor{}; + NVTEBasicTensor m_offsets_tensor{}; +}; + +JAXX_GroupedTensorWrapper::JAXX_GroupedTensorWrapper(JAXX_Scaling_Mode scaling_mode, + size_t num_tensors, + NVTEShape const &dataShape) { + m_data_shape = dataShape; + m_grouped_tensor = + nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, dataShape); +} + +JAXX_GroupedTensorWrapper::~JAXX_GroupedTensorWrapper() { + if (m_grouped_tensor != nullptr) { + nvte_destroy_grouped_tensor(m_grouped_tensor); + } +} + +void JAXX_GroupedTensorWrapper::set_rowwise(Buffer_Type const &data, + std::optional const &scale_inv) { + NVTEDType data_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())); + m_data_tensor = + NVTEBasicTensor{reinterpret_cast(data.untyped_data()), data_dtype, m_data_shape}; + + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedRowwiseData, &m_data_tensor); + + if (scale_inv.has_value()) { + NVTEDType scale_inv_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())); + NVTEShape logical_scale_shape{}; + if (scale_inv->dimensions().size() == 1) { + logical_scale_shape.ndim = 1; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + } else if (scale_inv->dimensions().size() == 2) { + logical_scale_shape.ndim = 2; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + logical_scale_shape.data[1] = scale_inv->dimensions()[1]; + } else { + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", + scale_inv->dimensions().size()); + } + m_scale_inv_tensor = NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), + scale_inv_dtype, logical_scale_shape}; + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedRowwiseScaleInv, + &m_scale_inv_tensor); + } +} + +void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const &group_sizes, + Buffer_Type const &group_offsets, + NVTEGroupedTensorParam group_sizes_param_name) { + NVTEDType sizes_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(group_sizes.element_type())); + NVTEDType offsets_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(group_offsets.element_type())); + + NVTE_CHECK(sizes_dtype == NVTEDType::kNVTEInt64, "group_sizes must be of type int64."); + NVTE_CHECK(offsets_dtype == NVTEDType::kNVTEInt64, "group_offsets must be of type int64."); + + size_t num_tensors = group_sizes.dimensions()[0]; + NVTE_CHECK(group_sizes.dimensions().size() == 1, + "group_sizes must be a 1D tensor with length equal to the number of tensors."); + NVTE_CHECK(group_offsets.dimensions().size() == 1, + "group_offsets must be a 1D tensor with length equal to the number of tensors."); + NVTE_CHECK(group_offsets.dimensions()[0] == num_tensors, + "group_sizes and group_offsets must have the same number of elements."); + + NVTEShape shape{}; + shape.ndim = 1; + shape.data[0] = num_tensors; + + m_sizes_tensor = NVTEBasicTensor{reinterpret_cast(group_sizes.untyped_data()), + NVTEDType::kNVTEInt64, shape}; + m_offsets_tensor = NVTEBasicTensor{reinterpret_cast(group_offsets.untyped_data()), + NVTEDType::kNVTEInt64, shape}; + + nvte_set_grouped_tensor_param(&m_grouped_tensor, group_sizes_param_name, &m_sizes_tensor); + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedTensorOffsets, &m_offsets_tensor); +} + +NVTEGroupedTensor const &JAXX_GroupedTensorWrapper::get_grouped_tensor() const { + return m_grouped_tensor; +} + +JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, + std::optional scale_inv, + JAXX_Scaling_Mode scaling_mode, size_t num_tensors, + NVTEShape const &dataShape) { + JAXX_GroupedTensorWrapper grouped_tensor_wrapper(scaling_mode, num_tensors, dataShape); + if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING) { + scale_inv = std::nullopt; + } + grouped_tensor_wrapper.set_rowwise(data, scale_inv); + + return std::move(grouped_tensor_wrapper); +} + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, - Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, - bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { + Buffer_Type group_sizes, Buffer_Type group_offset_lhs, + Buffer_Type group_offset_out, Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type workspace, size_t m, size_t n, size_t k, + bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, + bool has_bias, bool is_grouped_dense_wgrad, + bool use_async_d2h_group_sizes) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -446,6 +580,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK(group_sizes.dimensions().size() == 1); size_t num_gemms = group_sizes.dimensions()[0]; + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); + // It is weird that TE/Common GEMM only use colwise for MXFP8 const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || @@ -491,22 +627,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - if (is_tensor_scaling) { - size_t dpitch = tensor_scaling_sinv_aligment; - size_t spitch = lhs_sinv_dtype_bytes; - size_t width = lhs_sinv_dtype_bytes; - size_t height = lhs_sinv_size; - cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream); - spitch = rhs_sinv_dtype_bytes; - width = rhs_sinv_dtype_bytes; - height = rhs_sinv_size; - cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream); - lhs_sinv_ptr = lhs_scatter_aligned_ptr; - rhs_sinv_ptr = rhs_scatter_aligned_ptr; - } - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); @@ -533,29 +653,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type " = ", expected_out_size, ", got ", actual_out_size); } - size_t dim_list_bytes = sizeof(int32_t) * num_gemms; - std::vector dim_list_host(num_gemms); - size_t host_num_gemms = 0; - if (use_async_d2h_group_sizes) { - host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); - NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, - " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); - } else { - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - } - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); - } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); - } - auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); bool grad = false; bool accumulate = false; @@ -569,221 +666,86 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); } - // These lists are to keep the TensorWrapper objects alive - std::vector lhs_wrapper_list; - std::vector rhs_wrapper_list; - std::vector lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling - std::vector rhs_swizzle_wrapper_list; - std::vector bias_wrapper_list; - std::vector pre_gelu_wrapper_list; - std::vector out_wrapper_list; - std::vector workspace_wrapper_list; - - // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM - std::vector lhs_list; - std::vector rhs_list; - std::vector lhs_swizzle_list; - std::vector rhs_swizzle_list; - std::vector bias_list; - std::vector pre_gelu_list; - std::vector out_list; - std::vector workspace_list; - - size_t lhs_sinv_total_size = 0; - size_t rhs_sinv_total_size = 0; - - std::vector zero_out_dptr_list; - std::vector zero_out_size_list; - - for (size_t i = 0; i < num_gemms; i++) { - // Matrix data shapes - size_t m_i = dim_list_host[i]; - auto lhs_shape_i = std::vector{m_i, k}; - auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; - auto out_shape_i = std::vector{m_i, n}; - if (is_grouped_dense_wgrad) { - size_t k_i = dim_list_host[i]; - lhs_shape_i[0] = lhs_is_trans ? k_i : m; - lhs_shape_i[1] = lhs_is_trans ? m : k_i; - rhs_shape_i[0] = rhs_is_trans ? n : k_i; - rhs_shape_i[1] = rhs_is_trans ? k_i : n; - out_shape_i[0] = m; - out_shape_i[1] = n; - } - - size_t lhs_size = lhs_shape_i[0] * lhs_shape_i[1]; - size_t rhs_size = rhs_shape_i[0] * rhs_shape_i[1]; - size_t out_size = out_shape_i[0] * out_shape_i[1]; - bool is_empty_gemm = lhs_size == 0 || rhs_size == 0; - if (is_empty_gemm && out_size > 0) { - zero_out_dptr_list.push_back(out_ptr); - zero_out_size_list.push_back(out_size * out_dtype_bytes); - } - - // Set matrix data pointers - auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto out_i = TensorWrapper(static_cast(out_ptr), out_shape_i, out_dtype); - void *lhs_vptr = static_cast(lhs_ptr); - void *rhs_vptr = static_cast(rhs_ptr); - if (rhs_use_colwise) // MatA to enter cuBLAS - rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - else - rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - if (lhs_use_colwise) // MatB to enter cuBLAS - lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - else - lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - - // Set scale_inv shapes and pointers - void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); - void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); - size_t lhs_sinv_size_i = 0; - size_t rhs_sinv_size_i = 0; - if (is_tensor_scaling) { - auto tensor_scaling_sinv_shape = std::vector{1}; - // If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers - if (!is_empty_gemm) { - lhs_sinv_size_i = tensor_scaling_sinv_aligment / lhs_sinv_dtype_bytes; - rhs_sinv_size_i = tensor_scaling_sinv_aligment / rhs_sinv_dtype_bytes; - } - if (rhs_use_colwise) // MatA to enter cuBLAS - rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); - else - rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); - if (lhs_use_colwise) // MatB to enter cuBLAS - lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); - else - lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); - } else if (is_mxfp8_scaling) { - auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - void *swizzled_lhs_sinv_vptr = static_cast(swizzled_lhs_sinv_ptr); - void *swizzled_rhs_sinv_vptr = static_cast(swizzled_rhs_sinv_ptr); - - // {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i - // point to swizzled scale_inv data (store on workspace, only used for GEMM). - // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers - auto lhs_sinv_shape_i = - get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); - auto rhs_sinv_shape_i = - get_block_scale_shape(scaling_mode, rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); - lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; - rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; - if (lhs_use_colwise) { - lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - } else { - lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - } - lhs_i.set_with_gemm_swizzled_scales(true); - if (rhs_use_colwise) { - rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - } else { - rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - } - rhs_i.set_with_gemm_swizzled_scales(true); - - if (!is_empty_gemm) { - lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i)); - rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i)); - lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data()); - rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data()); - } - } else { - NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, - "Unsupported scaling mode: ", static_cast(scaling_mode)); - } - - auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); - - // Update pointer for the next GEMM pair - lhs_ptr += lhs_size * lhs_dtype_bytes; - rhs_ptr += rhs_size * rhs_dtype_bytes; - out_ptr += out_size * out_dtype_bytes; - if (is_fp8_gemm) { - lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; - rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; - lhs_sinv_total_size += lhs_sinv_size_i; - rhs_sinv_total_size += rhs_sinv_size_i; - if (is_mxfp8_scaling) { - swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; - swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; - } - } - if (has_bias) bias_ptr += n * bias_dtype_bytes; - - // Move objects to the lists to keep them alive - if (is_empty_gemm) continue; - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); - out_wrapper_list.push_back(std::move(out_i)); - bias_wrapper_list.push_back(std::move(bias_i)); - pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); - - lhs_list.push_back(lhs_wrapper_list.back().data()); - rhs_list.push_back(rhs_wrapper_list.back().data()); - bias_list.push_back(bias_wrapper_list.back().data()); - pre_gelu_list.push_back(pre_gelu_wrapper_list.back().data()); - out_list.push_back(out_wrapper_list.back().data()); - } - - auto workspace_shape = std::vector{workspace_size}; - for (int i = 0; i < num_streams; i++) { - auto workspace_i = - TensorWrapper(static_cast(workspace_ptr), workspace_shape, DType::kByte); - workspace_wrapper_list.push_back(std::move(workspace_i)); - workspace_list.push_back(workspace_wrapper_list.back().data()); - workspace_ptr += workspace_size; - } - - if (is_fp8_gemm) { - if (is_tensor_scaling) { - lhs_sinv_size *= tensor_scaling_sinv_aligment; - rhs_sinv_size *= tensor_scaling_sinv_aligment; - } - NVTE_CHECK(lhs_sinv_total_size <= lhs_sinv_size, "Actual total lhs_sinv size ", - lhs_sinv_total_size, " exceeds estimated upper bound ", lhs_sinv_size); - NVTE_CHECK(rhs_sinv_total_size <= rhs_sinv_size, "Actual total rhs_sinv size ", - rhs_sinv_total_size, " exceeds estimated upper bound ", rhs_sinv_size); + constexpr size_t workspace_setup_size = 1024 * 1024; // HACK: dummy workspace for setup + TensorWrapper workspace_setup(workspace_ptr, std::vector{workspace_setup_size}, + DType::kByte); + TensorWrapper workspace_cublas(workspace_ptr + workspace_setup_size, + std::vector{workspace_size}, DType::kByte); + + TensorWrapper alpha_tensor(static_cast(alpha.untyped_data()), + std::vector{num_gemms}, + convert_ffi_datatype_to_te_dtype(alpha.element_type())); + TensorWrapper beta_tensor(static_cast(beta.untyped_data()), + std::vector{num_gemms}, + convert_ffi_datatype_to_te_dtype(beta.element_type())); + + if (is_grouped_dense_wgrad) { + NVTE_CHECK(lhs_is_trans && !rhs_is_trans, + "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); + + //// RHS + NVTEShape rhsShape{.data = {k, n}, .ndim = 2}; + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); + rhs_tensor.set_group_info(group_sizes, group_offset_out, kNVTEGroupedFirstDims); + + //// LHS + NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; + lhs_is_trans = true; + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); + lhs_tensor.set_group_info(group_sizes, group_offset_lhs, kNVTEGroupedFirstDims); + + //// OUTPUT + NVTEShape outShape{.data = {num_gemms * m, n}, .ndim = 2}; + auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, + num_gemms, outShape); + + // Output needs to be zeroed in case any group sizes have size zero, meaning the expert weight isn't used in the fwd, meaning the corresponding output gradient should be zero. But using the grouped GEMM, the output buffer contains uninitialized data. + // TODO(jberchtold): make this memset smaller by only zeroing the expert weights that correspond to groups with size zero. + cudaMemsetAsync(output->untyped_data(), 0, output->size_bytes(), stream); + + nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, + alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), + workspace_cublas.data(), + nullptr, // config (use defaults) + stream); + + return ffi_with_cuda_error_check(); } - size_t num_non_empty_gemms = lhs_list.size(); + // Nominal case for FWD or DGRAD - if (is_mxfp8_scaling) { - for (int i = 0; i < num_non_empty_gemms; i++) { - // The i-th GEMM will use the (i % num_streams)-th stream to compute, - // use the same stream to swizzle the scaling factors to make sure that - // the swizzling is done before the GEMM computation starts. - int stream_id = i % num_streams; - cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); - nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); - } + //// RHS + NVTEShape rhsShape{.data = {num_gemms * k, n}, .ndim = 2}; + if (rhs_is_trans) { + rhsShape.data[0] = num_gemms * n; + rhsShape.data[1] = k; } + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM - size_t num_zero_outs = zero_out_dptr_list.size(); - for (int i = 0; i < num_zero_outs; i++) { - int stream_id = i % num_streams; - cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - void *dptr = zero_out_dptr_list[i]; - size_t count = zero_out_size_list[i]; - NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); + //// LHS + NVTEShape lhsShape{.data = {m, k}, .ndim = 2}; + if (lhs_is_trans) { + std::swap(lhsShape.data[0], lhsShape.data[1]); } - - nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, - grad, workspace_list.data(), accumulate, use_split_accumulator, - num_math_sm, stream); + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); + lhs_tensor.set_group_info(group_sizes, group_offset_lhs, + lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); + + //// OUTPUT + NVTEShape outShape{.data = {m, n}, .ndim = 2}; + auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, + num_gemms, outShape); + out_tensor.set_group_info(group_sizes, group_offset_out, kNVTEGroupedFirstDims); + + // This memset is required because the group sizes may not fill the full buffer since we overallocate for the worst case. However, in theory unused space on the grouped axis should not be utilizied downstream, but it seems like somehow it is utilized. + // TODO(jberchtold): try removing this + cudaMemsetAsync(output->untyped_data(), 0, output->size_bytes(), stream); + + nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, + alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), + workspace_cublas.data(), + nullptr, // config (use defaults) + stream); return ffi_with_cuda_error_check(); } @@ -797,7 +759,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Arg() // rhs_sinv .Arg() // bias .Arg() // group_sizes - .Arg() // group_offset + .Arg() // group_offset_lhs + .Arg() // group_offset_out + .Arg() // alpha + .Arg() // beta .Ret() // output .Ret() // workspace .Attr("M") @@ -808,7 +773,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode") .Attr("has_bias") .Attr("is_grouped_dense_wgrad") - .Attr("use_async_d2h_group_sizes")); + .Attr("use_async_d2h_group_sizes"), + FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index dd7d2a47ba..98b043ef35 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -4,7 +4,11 @@ """Transformer Engine bindings for JAX""" from .module import DenseGeneral, LayerNorm from .module import LayerNormDenseGeneral, LayerNormMLP -from .module import wrap_function_in_te_state_module, make_dot_general_cls +from .module import ( + wrap_function_in_te_state_module, + make_dot_general_cls, + make_ragged_dot_cls, +) from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -16,6 +20,7 @@ "LayerNormMLP", "wrap_function_in_te_state_module", "make_dot_general_cls", + "make_ragged_dot_cls", "extend_logical_axis_rules", "DotProductAttention", "MultiHeadAttention", diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 3d82d8f0b4..8fba2b1853 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -17,7 +17,7 @@ from jax.ad_checkpoint import checkpoint_name -from ..dense import dense +from ..dense import dense, grouped_dense from ..layernorm import canonicalize_norm_type from ..layernorm import layernorm @@ -377,6 +377,7 @@ def generate_quantizer_set( variable_collection: str = None, quantization_checkpoint_name: Optional[str] = None, fp8_recipe=None, + n_groups: int = None, ): """ Generate a set of FP8 meta for a GEMM. @@ -409,6 +410,7 @@ def generate_quantizer_set( fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set, checkpoint_name=quantization_checkpoint_name, + n_groups=n_groups, ) return quantizer_set @@ -1379,12 +1381,13 @@ def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] class TEWrapper(te.flax.module.TransformerEngineBase): """Wrapper Flax module for TransformerEngine quantization support.""" - def generate_quantizer_set(self, postfix: str = ""): + def generate_quantizer_set(self, postfix: str = "", n_groups: int = None): OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" return super().generate_quantizer_set( postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, fp8_recipe=quantization_recipe, + n_groups=n_groups, ) @nn.compact @@ -1438,3 +1441,23 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): ) return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") + + +def make_ragged_dot_cls(quantization_recipe): + assert quantization_recipe is None, "Ragged dot grouped GEMM does not support quantization yet" + def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): + num_groups = group_sizes.shape[0] + quantizer_set = generate_quantizer_set(n_groups=num_groups) + + out = grouped_dense( + x, + kernel, + group_sizes=group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set + ) + return out + + return wrap_function_in_te_state_module( + te_grouped_dot_general, quantization_recipe, "ragged_dot" + )() diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index f5ca6aeaed..1923932692 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -68,7 +68,7 @@ def compute_scale_from_amax( sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}" - return sf + return sf.astype(jnp.float32) @register_pytree_node_class From 5a968454f01aa5fa2f3f0c45682cc76079cc38ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:42:29 +0000 Subject: [PATCH 02/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 4 +++- transformer_engine/jax/csrc/extensions/gemm.cpp | 5 +++-- transformer_engine/jax/flax/module.py | 3 ++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0949d5e462..b3267bf182 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1953,7 +1953,9 @@ def grouped_gemm( """ assert group_offset is None, "group_offset is not yet implemented" - assert jax.config.jax_enable_x64, "Grouped GEMM currently requires jax_enable_x64 to be True for correct behavior" + assert ( + jax.config.jax_enable_x64 + ), "Grouped GEMM currently requires jax_enable_x64 to be True for correct behavior" # TODO(Phuong): implement the precision del precision diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 5a17c38d37..1725309869 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -580,7 +580,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK(group_sizes.dimensions().size() == 1); size_t num_gemms = group_sizes.dimensions()[0]; - NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, + "Only non-quantized grouped GEMM is supported in current implementation."); // It is weird that TE/Common GEMM only use colwise for MXFP8 const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); @@ -774,7 +775,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("has_bias") .Attr("is_grouped_dense_wgrad") .Attr("use_async_d2h_group_sizes"), - FFI_CudaGraph_Traits); + FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8fba2b1853..a661f30356 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1445,6 +1445,7 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): def make_ragged_dot_cls(quantization_recipe): assert quantization_recipe is None, "Ragged dot grouped GEMM does not support quantization yet" + def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): num_groups = group_sizes.shape[0] quantizer_set = generate_quantizer_set(n_groups=num_groups) @@ -1454,7 +1455,7 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa kernel, group_sizes=group_sizes, contracting_dims=((1,), (1,)), - quantizer_set=quantizer_set + quantizer_set=quantizer_set, ) return out From 49b45fa22f58a4af655aa2567b25e918a45cea5c Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 24 Feb 2026 10:38:43 -0800 Subject: [PATCH 03/14] disable cuda-graph for GMM Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/csrc/extensions/gemm.cpp | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 1725309869..63b7417a7b 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -680,6 +680,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); + fprintf(stderr, "Before GEMM:\n"); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + // fflush(stderr); + if (is_grouped_dense_wgrad) { NVTE_CHECK(lhs_is_trans && !rhs_is_trans, "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); @@ -710,6 +714,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type nullptr, // config (use defaults) stream); + fprintf(stderr, "After GEMM:\n"); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + // fflush(stderr); + return ffi_with_cuda_error_check(); } @@ -748,6 +756,11 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type nullptr, // config (use defaults) stream); + + fprintf(stderr, "After GEMM:\n"); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + // fflush(stderr); + return ffi_with_cuda_error_check(); } @@ -774,8 +787,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode") .Attr("has_bias") .Attr("is_grouped_dense_wgrad") - .Attr("use_async_d2h_group_sizes"), - FFI_CudaGraph_Traits); + .Attr("use_async_d2h_group_sizes")/*, + FFI_CudaGraph_Traits*/); } // namespace jax } // namespace transformer_engine From 593a790b640f9634a5644bf32aeb07c3f3a5251f Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 24 Feb 2026 11:23:30 -0800 Subject: [PATCH 04/14] proper workspace size Signed-off-by: Jeremy Berchtold --- .../common/gemm/cublaslt_grouped_gemm.cu | 4 ++++ .../common/include/transformer_engine/gemm.h | 13 +++++++++++++ transformer_engine/jax/cpp_extensions/gemm.py | 10 +++++++--- transformer_engine/jax/csrc/extensions/gemm.cpp | 2 +- 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index a2434419dc..3f1710b784 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -534,6 +534,10 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { } // namespace +size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors) { + return grouped_gemm_setup_workspace_size(num_tensors); +} + void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 7403448722..805a2f8834 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -329,6 +329,19 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * - Shape compatibility: if transa=false, transb=false: * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) */ +/*! \brief Return the required size in bytes for the setup workspace of grouped GEMM. + * + * The setup workspace stores pointer arrays and per-matrix dimension arrays used + * by the grouped GEMM kernel. Its size depends only on the number of tensors (GEMMs) + * in the group and is independent of matrix dimensions. + * + * Pass the result as the size of the workspace_setup tensor in nvte_grouped_gemm. + * + * \param[in] num_tensors Number of tensors (GEMMs) in the group. + * \return Required size in bytes for workspace_setup. + */ +size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors); + void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index b3267bf182..60dac6a806 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1504,9 +1504,13 @@ def abstract( workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding - workspace_size += ( - 1024 * 1024 - ) # HACK: properly make a workspace_setup buffer in addition to the workspace_cublas buffer + # Setup workspace: 6 pointer arrays (void*, 8 bytes each) + 6 int arrays (4 bytes each), + # aligned to 256 bytes. Layout: [A_ptrs, B_ptrs, C_ptrs, D_ptrs, alpha_ptrs, beta_ptrs, + # a_rows, a_cols, b_rows, b_cols, d_rows, d_cols] + num_tensors = group_sizes_aval.size + setup_workspace_size = 6 * num_tensors * 8 + 6 * num_tensors * 4 # 72 * num_tensors bytes + setup_workspace_size = ((setup_workspace_size + 255) // 256) * 256 # align to 256 bytes + workspace_size += setup_workspace_size workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) out_shape = (M, N) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 63b7417a7b..ba00dd570d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -667,7 +667,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); } - constexpr size_t workspace_setup_size = 1024 * 1024; // HACK: dummy workspace for setup + const size_t workspace_setup_size = nvte_grouped_gemm_setup_workspace_size(num_gemms); TensorWrapper workspace_setup(workspace_ptr, std::vector{workspace_setup_size}, DType::kByte); TensorWrapper workspace_cublas(workspace_ptr + workspace_setup_size, From ae344615fabe6b8c3578596a323828d1a62553de Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 24 Feb 2026 11:34:28 -0800 Subject: [PATCH 05/14] remove duplicate workspace size logic in Python gemm.py Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 9 ++------- transformer_engine/jax/csrc/extensions/pybind.cpp | 2 ++ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 60dac6a806..ee465b9417 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -24,6 +24,7 @@ get_device_compute_capability, initialize_cgemm_communicator, get_cgemm_num_max_streams, + get_grouped_gemm_setup_workspace_size, ) from .base import BasePrimitive, register_primitive @@ -1504,13 +1505,7 @@ def abstract( workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding - # Setup workspace: 6 pointer arrays (void*, 8 bytes each) + 6 int arrays (4 bytes each), - # aligned to 256 bytes. Layout: [A_ptrs, B_ptrs, C_ptrs, D_ptrs, alpha_ptrs, beta_ptrs, - # a_rows, a_cols, b_rows, b_cols, d_rows, d_cols] - num_tensors = group_sizes_aval.size - setup_workspace_size = 6 * num_tensors * 8 + 6 * num_tensors * 4 # 72 * num_tensors bytes - setup_workspace_size = ((setup_workspace_size + 255) // 256) * 256 # align to 256 bytes - workspace_size += setup_workspace_size + workspace_size += get_grouped_gemm_setup_workspace_size(group_sizes_aval.size) workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) out_shape = (M, N) diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 71de897d9b..d4bb3bab00 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -7,6 +7,7 @@ #include "../extensions.h" #include "cgemm_helper.h" #include "common/util/cuda_runtime.h" +#include "transformer_engine/gemm.h" namespace transformer_engine { namespace jax { @@ -105,6 +106,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); + m.def("get_grouped_gemm_setup_workspace_size", &nvte_grouped_gemm_setup_workspace_size); pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) From 7e99c649a151152b6fc28ab499a335185e5717a1 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 24 Feb 2026 12:07:50 -0800 Subject: [PATCH 06/14] use group_sizes as int32 and handle int64 and offsets inside FFI to avoid enabling JAX x64 globally Signed-off-by: Jeremy Berchtold --- .../common/gemm/cublaslt_grouped_gemm.cu | 51 ++++++++++++++---- .../common/include/transformer_engine/gemm.h | 12 +++++ transformer_engine/jax/cpp_extensions/gemm.py | 53 +++++-------------- .../jax/csrc/extensions/gemm.cpp | 43 +++++++++++---- 4 files changed, 100 insertions(+), 59 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 3f1710b784..292f3d6a21 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -440,6 +440,29 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, return heuristicResult.algo; } +// Device helper: compute the element offset for tensor `idx` given shape metadata. +// Three cases: +// 1. Explicit per-tensor offset array provided → use it directly. +// 2. Per-tensor first/last dims provided but no offsets → cumulative sum of (first*last) products. +// 3. Fully uniform shapes → idx * uniform_first * uniform_last. +__forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorShapeInfo &meta, + size_t idx) { + if (meta.offsets) { + return meta.offsets[idx]; + } else if (meta.first_dims != nullptr || meta.last_dims != nullptr) { + // offset[i] = sum_{j < i} (first_dims[j] * last_dims[j]) + int64_t cumsum = 0; + for (size_t i = 0; i < idx; i++) { + int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first; + int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last; + cumsum += f * l; + } + return cumsum; + } else { + return static_cast(idx) * meta.uniform_first * meta.uniform_last; + } +} + // Single kernel that sets up all GEMM parameters. // Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions, // but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. @@ -464,15 +487,11 @@ __global__ void setup_grouped_gemm_kernel( int64_t d_first = D_meta.first_dims ? D_meta.first_dims[idx] : D_meta.uniform_first; int64_t d_last = D_meta.last_dims ? D_meta.last_dims[idx] : D_meta.uniform_last; - // Compute offsets (from array or compute from uniform dims) - int64_t a_offset = - A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); - int64_t b_offset = - B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); - int64_t c_offset = - C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); - int64_t d_offset = - D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); + // Compute offsets (from explicit array, cumulative from per-tensor dims, or uniform) + int64_t a_offset = compute_grouped_tensor_offset(A_meta, idx); + int64_t b_offset = compute_grouped_tensor_offset(B_meta, idx); + int64_t c_offset = compute_grouped_tensor_offset(C_meta, idx); + int64_t d_offset = compute_grouped_tensor_offset(D_meta, idx); // Compute data pointers A_ptrs[idx] = a_base + a_offset * a_elem_size; @@ -532,12 +551,26 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); } +__global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) dst[idx] = static_cast(src[idx]); +} + } // namespace size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors) { return grouped_gemm_setup_workspace_size(num_tensors); } +void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, + cudaStream_t stream) { + if (n == 0) return; + const int threads = 256; + const int blocks = static_cast((n + threads - 1) / threads); + convert_int32_to_int64_kernel<<>>(src, dst, n); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 805a2f8834..ec370d46d1 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -342,6 +342,18 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor */ size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors); +/*! \brief Convert a device array of int32 values to int64 values. + * + * Useful for preparing group_sizes for nvte_grouped_gemm when the caller + * holds int32 sizes and needs int64 values on the device. + * + * \param[in] src Device pointer to source int32 array. + * \param[out] dst Device pointer to destination int64 array. + * \param[in] n Number of elements. + * \param[in] stream CUDA stream. + */ +void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream); + void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ee465b9417..36ba54c79f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1429,7 +1429,7 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (10, 11, 12, 13, 14, 15, 16, 17, 18, 19) + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17) inner_primitive = None outer_primitive = None @@ -1441,8 +1441,6 @@ def abstract( rhs_scale_inv_aval, bias_aval, group_sizes_aval, - group_offset_lhs_aval, - group_offset_out_aval, alpha, beta, *, @@ -1482,7 +1480,7 @@ def abstract( Returns: A jnp.ndarray containing the result of the grouped GEMM operation """ - del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_out_aval + del lhs_data_aval, rhs_data_aval, bias_aval del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams @@ -1508,16 +1506,22 @@ def abstract( workspace_size += get_grouped_gemm_setup_workspace_size(group_sizes_aval.size) workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) + # Temporary buffer for int32 → int64 conversion of group_sizes on device. + int64_workspace_size = group_sizes_aval.size * jnp.dtype(jnp.int64).itemsize + int64_workspace_aval = jax.core.ShapedArray( + shape=(int64_workspace_size,), dtype=jnp.uint8 + ) + out_shape = (M, N) if is_grouped_dense_wgrad: num_tensors = group_sizes_aval.size out_shape = (num_tensors, M, N) out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) - return (out_aval, workspace_aval) + return (out_aval, workspace_aval, int64_workspace_aval) @staticmethod def outer_abstract(*args, **kwargs): - (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) + (out_aval, _, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) return (out_aval,) @staticmethod @@ -1558,8 +1562,6 @@ def impl( rhs_scale_inv, bias, group_sizes, - group_offset_lhs, - group_offset_out, alpha, beta, M, @@ -1574,15 +1576,13 @@ def impl( use_async_d2h_group_sizes, ): assert GroupedGemmPrimitive.inner_primitive is not None - (out, _) = GroupedGemmPrimitive.inner_primitive.bind( + (out, _, _) = GroupedGemmPrimitive.inner_primitive.bind( lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, - group_offset_lhs, - group_offset_out, alpha, beta, M=M, @@ -1952,9 +1952,6 @@ def grouped_gemm( """ assert group_offset is None, "group_offset is not yet implemented" - assert ( - jax.config.jax_enable_x64 - ), "Grouped GEMM currently requires jax_enable_x64 to be True for correct behavior" # TODO(Phuong): implement the precision del precision @@ -2094,29 +2091,9 @@ def grouped_gemm( assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias - # TODO(jberchtold): move the int64 and offset computation to C++ side in a kernel to avoid needing JAX to support int64 - group_sizes = group_sizes.astype(jnp.int64) - # Compute group_offset as cumulative sum of group_sizes, starting with 0 - group_offset = jnp.concatenate( - [jnp.array([0], dtype=jnp.int64), jnp.cumsum(group_sizes, dtype=jnp.int64)[:-1]] - ) - if is_grouped_dense_wgrad: - group_offset_lhs = ( - group_offset * M - ) # Offset is by number of elements total, not number of rows - # HACK: this _out is really the rhs in this case - group_offset_out = ( - group_offset * N - ) # Offset is by number of elements total, not number of rows - else: - group_offset_lhs = ( - group_offset * K_lhs - ) # Offset is by number of elements total, not number of rows - group_offset_out = ( - group_offset * N - ) # Offset is by number of elements total, not number of rows - - num_gemms = group_sizes.shape[0] # Due to interlaced zeros to support int64 + group_sizes = group_sizes.astype(jnp.int32) + + num_gemms = group_sizes.shape[0] alpha = jnp.ones((num_gemms,), jnp.float32) beta = jnp.zeros((num_gemms,), jnp.float32) (out,) = GroupedGemmPrimitive.outer_primitive.bind( @@ -2126,8 +2103,6 @@ def grouped_gemm( rhs_scale_inv, bias, group_sizes, - group_offset_lhs, - group_offset_out, alpha, beta, M=M, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index ba00dd570d..5be78af53b 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -431,6 +431,9 @@ class JAXX_GroupedTensorWrapper { void set_rowwise(Buffer_Type const &data, std::optional const &scale_inv); void set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name); + // Set only group sizes (no offsets); the setup kernel will compute offsets from sizes. + void set_group_sizes_only(const int64_t *sizes_ptr, size_t num_tensors, + NVTEGroupedTensorParam group_sizes_param_name); operator NVTEGroupedTensor() const { return m_grouped_tensor; } NVTEGroupedTensor const &get_grouped_tensor() const; @@ -524,6 +527,18 @@ void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const &group_sizes, nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedTensorOffsets, &m_offsets_tensor); } +void JAXX_GroupedTensorWrapper::set_group_sizes_only( + const int64_t *sizes_ptr, size_t num_tensors, + NVTEGroupedTensorParam group_sizes_param_name) { + NVTEShape shape{}; + shape.ndim = 1; + shape.data[0] = num_tensors; + m_sizes_tensor = NVTEBasicTensor{reinterpret_cast(const_cast(sizes_ptr)), + NVTEDType::kNVTEInt64, shape}; + nvte_set_grouped_tensor_param(&m_grouped_tensor, group_sizes_param_name, &m_sizes_tensor); + // Intentionally no offset tensor: offsets will be computed by the setup kernel. +} + NVTEGroupedTensor const &JAXX_GroupedTensorWrapper::get_grouped_tensor() const { return m_grouped_tensor; } @@ -543,9 +558,9 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type group_offset_lhs, - Buffer_Type group_offset_out, Buffer_Type alpha, Buffer_Type beta, - Result_Type output, Result_Type workspace, size_t m, size_t n, size_t k, + Buffer_Type group_sizes, Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type workspace, Result_Type int64_workspace, + size_t m, size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { @@ -580,6 +595,13 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK(group_sizes.dimensions().size() == 1); size_t num_gemms = group_sizes.dimensions()[0]; + // Convert int32 group_sizes to int64 into the dedicated output buffer. + NVTE_CHECK(group_sizes.element_type() == xla::ffi::DataType::S32, + "group_sizes must be int32."); + auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(group_sizes.untyped_data()), + int64_sizes_ptr, num_gemms, stream); + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); @@ -691,13 +713,13 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type //// RHS NVTEShape rhsShape{.data = {k, n}, .ndim = 2}; auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - rhs_tensor.set_group_info(group_sizes, group_offset_out, kNVTEGroupedFirstDims); + rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); //// LHS NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; lhs_is_trans = true; auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - lhs_tensor.set_group_info(group_sizes, group_offset_lhs, kNVTEGroupedFirstDims); + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); //// OUTPUT NVTEShape outShape{.data = {num_gemms * m, n}, .ndim = 2}; @@ -737,14 +759,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type std::swap(lhsShape.data[0], lhsShape.data[1]); } auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - lhs_tensor.set_group_info(group_sizes, group_offset_lhs, - lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, + lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); //// OUTPUT NVTEShape outShape{.data = {m, n}, .ndim = 2}; auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); - out_tensor.set_group_info(group_sizes, group_offset_out, kNVTEGroupedFirstDims); + out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); // This memset is required because the group sizes may not fill the full buffer since we overallocate for the worst case. However, in theory unused space on the grouped axis should not be utilizied downstream, but it seems like somehow it is utilized. // TODO(jberchtold): try removing this @@ -772,13 +794,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Arg() // rhs_data .Arg() // rhs_sinv .Arg() // bias - .Arg() // group_sizes - .Arg() // group_offset_lhs - .Arg() // group_offset_out + .Arg() // group_sizes (int32) .Arg() // alpha .Arg() // beta .Ret() // output .Ret() // workspace + .Ret() // int64_workspace .Attr("M") .Attr("N") .Attr("K") From a661e9e5fd13af8f0737912bc5d81a4638f9020b Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 24 Feb 2026 13:06:49 -0800 Subject: [PATCH 07/14] restore previous non-cuda-graphable grouped GEMM FFI and move new version to a different suffix Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 270 +++++++++-- transformer_engine/jax/csrc/extensions.h | 1 + .../jax/csrc/extensions/gemm.cpp | 438 +++++++++++++++++- .../jax/csrc/extensions/pybind.cpp | 3 + 4 files changed, 652 insertions(+), 60 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 36ba54c79f..708c91ded8 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1422,12 +1422,12 @@ def impl( register_primitive(GroupedGemmCopySizesPrimitive) -class GroupedGemmPrimitive(BasePrimitive): +class GroupedGemmCudaGraphablePrimitive(BasePrimitive): """ - Primitive for grouped GEMM + Primitive for grouped GEMM using nvte_grouped_gemm (cuda-graphable, BF16 only). """ - name = "te_grouped_gemm_ffi" + name = "te_grouped_gemm_cuda_graphable_ffi" multiple_results = True impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17) inner_primitive = None @@ -1456,7 +1456,7 @@ def abstract( use_async_d2h_group_sizes, ): """ - Grouped GEMM operation. + Grouped GEMM operation (cuda-graphable via nvte_grouped_gemm). Args: lhs_data: Left-hand side input matrix data, 1D flattened array @@ -1464,8 +1464,9 @@ def abstract( rhs_data: Right-hand side input matrix data, 1D flattened array rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array bias: Bias matrix of shape (G, N) - group_sizes: 1D array containing the sizes of each group - group_offset: 1D array containing offsets for each group (not yet implemented) + group_sizes: 1D int32 array containing the sizes of each group + alpha: Per-group alpha scaling factors (float32) + beta: Per-group beta scaling factors (float32) M: Number of rows in the output matrix N: Number of columns in the output matrix K: Number of columns in the left-hand side matrix @@ -1521,7 +1522,7 @@ def abstract( @staticmethod def outer_abstract(*args, **kwargs): - (out_aval, _, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) + (out_aval, _, _) = GroupedGemmCudaGraphablePrimitive.abstract(*args, **kwargs) return (out_aval,) @staticmethod @@ -1540,7 +1541,7 @@ def lowering( use_async_d2h_group_sizes, ): del out_dtype - return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( + return jax.ffi.ffi_lowering(GroupedGemmCudaGraphablePrimitive.name)( ctx, *args, M=M, @@ -1575,8 +1576,8 @@ def impl( is_grouped_dense_wgrad, use_async_d2h_group_sizes, ): - assert GroupedGemmPrimitive.inner_primitive is not None - (out, _, _) = GroupedGemmPrimitive.inner_primitive.bind( + assert GroupedGemmCudaGraphablePrimitive.inner_primitive is not None + (out, _, _) = GroupedGemmCudaGraphablePrimitive.inner_primitive.bind( lhs_data, lhs_scale_inv, rhs_data, @@ -1599,6 +1600,174 @@ def impl( return (out,) +register_primitive(GroupedGemmCudaGraphablePrimitive) + + +class GroupedGemmPrimitive(BasePrimitive): + """ + Primitive for grouped GEMM using nvte_multi_tensor_gemm (supports all scaling modes). + """ + + name = "te_grouped_gemm_ffi" + multiple_results = True + impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + lhs_data_aval, + lhs_scale_inv_aval, + rhs_data_aval, + rhs_scale_inv_aval, + bias_aval, + group_sizes_aval, + group_offset_aval, + *, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + use_async_d2h_group_sizes, + ): + """ + Grouped GEMM operation. + + Args: + lhs_data: Left-hand side input matrix data, 1D flattened array + lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array + rhs_data: Right-hand side input matrix data, 1D flattened array + rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array + bias: Bias matrix of shape (G, N) + group_sizes: 1D array containing the sizes of each group + group_offset: 1D array containing offsets for each group (not yet implemented) + M: Number of rows in the output matrix + N: Number of columns in the output matrix + K: Number of columns in the left-hand side matrix + lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed + rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed + scaling_mode: Scaling mode for the GEMM operations + out_dtype: Data type of the output tensors + has_bias: Boolean indicating if bias tensors are provided + is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation + where both lhs and rhs are 2D matrices and output is (G, M, N) + + Returns: + A jnp.ndarray containing the result of the grouped GEMM operation + """ + del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval + del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes + # TODO(Phuong): move some shape checks from Cpp to here + workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams + workspace_alignment_padding = 256 + tensor_scaling_sinv_aligment = 16 + mxfp8_scaling_sinv_alignment_padding = 256 + # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not + # necessarily 256 bytes aligned, we add some padding to ensure alignment. + workspace_size += workspace_alignment_padding + if scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING.value, + ScalingMode.CURRENT_TENSOR_SCALING.value, + ): + # For tensor scaling, each matrix has a single scale value, but it + # needs to be aligned to 16 bytes for CUDA 12.9.1 and later. + workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment + workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + # We also pad scale_inv swizzle buffers size for 256 bytes alignment. + workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) + + out_shape = (M, N) + if is_grouped_dense_wgrad: + out_shape = (group_sizes_aval.size, M, N) + out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + return (out_aval, workspace_aval) + + @staticmethod + def outer_abstract(*args, **kwargs): + (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) + return (out_aval,) + + @staticmethod + def lowering( + ctx, + *args, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + use_async_d2h_group_sizes, + ): + del out_dtype + return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( + ctx, + *args, + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, + has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, + ) + + @staticmethod + def impl( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + use_async_d2h_group_sizes, + ): + assert GroupedGemmPrimitive.inner_primitive is not None + (out, _) = GroupedGemmPrimitive.inner_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode, + out_dtype=out_dtype, + has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, + ) + return (out,) + + register_primitive(GroupedGemmPrimitive) @@ -1951,8 +2120,6 @@ def grouped_gemm( rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] """ - assert group_offset is None, "group_offset is not yet implemented" - # TODO(Phuong): implement the precision del precision @@ -2091,29 +2258,60 @@ def grouped_gemm( assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias - group_sizes = group_sizes.astype(jnp.int32) - - num_gemms = group_sizes.shape[0] - alpha = jnp.ones((num_gemms,), jnp.float32) - beta = jnp.zeros((num_gemms,), jnp.float32) - (out,) = GroupedGemmPrimitive.outer_primitive.bind( - lhs_data, - lhs_scale_inv, - rhs_data, - rhs_scale_inv, - bias, - group_sizes, - alpha, - beta, - M=M, - N=N, - K=K_lhs, - lhs_is_trans=lhs_is_trans, - rhs_is_trans=rhs_is_trans, - scaling_mode=scaling_mode.value, - out_dtype=out_dtype, - has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, - use_async_d2h_group_sizes=use_async_d2h_group_sizes, + # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy + # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay + # feature-compatible with the main branch. + _use_cuda_graphable = ( + scaling_mode == ScalingMode.NO_SCALING and lhs_data.dtype == jnp.bfloat16 ) + + if _use_cuda_graphable: + assert group_offset is None, "group_offset is not supported in the cuda graphable path and is instead computed internally assuming contiguous grouping. Any padding is included in the group_sizes and padded with zeros to not affect the result of the MoE block." + group_sizes = group_sizes.astype(jnp.int32) + num_gemms = group_sizes.shape[0] + alpha = jnp.ones((num_gemms,), jnp.float32) + beta = jnp.zeros((num_gemms,), jnp.float32) + (out,) = GroupedGemmCudaGraphablePrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + alpha, + beta, + M=M, + N=N, + K=K_lhs, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, + out_dtype=out_dtype, + has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, + ) + else: + # TODO(Phuong): implement the group_offset + group_offset = group_offset or jnp.zeros((1,), jnp.int32) + assert group_offset.size == 1 + (out,) = GroupedGemmPrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M=M, + N=N, + K=K_lhs, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, + out_dtype=out_dtype, + has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, + ) return out diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1c0bc52b88..bc5bf32ae5 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -138,6 +138,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmCudaGraphableHandler); // Amax XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 5be78af53b..508b9aeda9 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -556,14 +556,16 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } -Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, - Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type alpha, Buffer_Type beta, - Result_Type output, Result_Type workspace, Result_Type int64_workspace, - size_t m, size_t n, size_t k, - bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, - bool has_bias, bool is_grouped_dense_wgrad, - bool use_async_d2h_group_sizes) { +Error_Type GroupedGemmCudaGraphableFFI(cudaStream_t stream, Buffer_Type lhs_data, + Buffer_Type lhs_sinv, Buffer_Type rhs_data, + Buffer_Type rhs_sinv, Buffer_Type bias, + Buffer_Type group_sizes, Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type workspace, + Result_Type int64_workspace, size_t m, size_t n, size_t k, + bool lhs_is_trans, bool rhs_is_trans, + JAXX_Scaling_Mode scaling_mode, bool has_bias, + bool is_grouped_dense_wgrad, + bool use_async_d2h_group_sizes) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -702,10 +704,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - fprintf(stderr, "Before GEMM:\n"); - NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); - // fflush(stderr); - if (is_grouped_dense_wgrad) { NVTE_CHECK(lhs_is_trans && !rhs_is_trans, "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); @@ -736,10 +734,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type nullptr, // config (use defaults) stream); - fprintf(stderr, "After GEMM:\n"); - NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); - // fflush(stderr); - return ffi_with_cuda_error_check(); } @@ -778,15 +772,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type nullptr, // config (use defaults) stream); - - fprintf(stderr, "After GEMM:\n"); - NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); - // fflush(stderr); - return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmCudaGraphableHandler, GroupedGemmCudaGraphableFFI, FFI::Bind() .Ctx() // stream .Arg() // lhs_data @@ -808,8 +797,409 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode") .Attr("has_bias") .Attr("is_grouped_dense_wgrad") - .Attr("use_async_d2h_group_sizes")/*, - FFI_CudaGraph_Traits*/); + .Attr("use_async_d2h_group_sizes"), + FFI_CudaGraph_Traits); + +Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, + Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, + Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, + Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, + bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, + bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { + // Notes on matrix layouts and transpose: + // Jax uses row-major data_layout, on entering this function, each input matrix pair: + // A: row-major [m, k] for N - [k, m] for T + // B: row-major [k, n] for N - [n, k] for T + // on exiting this function, JAX expect: + // C: row-major with size [m, n]. + // cuBLAS uses column-major data_layout, in this view, each input matrix pair: + // A: column-major with size [k, m] for T - [m, k] for N + // B: column-major with size [n, k] for T - [k, n] for N + // + // If we call cuBLAS GEMM for A * B, the output will be: + // C: column-major with size [m, n] --> row-major with size [n, m]. + // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. + + int num_streams = nvte_get_num_compute_streams(); + + // Inputs + auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); + auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); + auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); + auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); + auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); + auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); + auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); + auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); + auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + + NVTE_CHECK(group_sizes.dimensions().size() == 1); + size_t num_gemms = group_sizes.dimensions()[0]; + + // It is weird that TE/Common GEMM only use colwise for MXFP8 + const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); + const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; + const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; + const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; + + // Outputs + auto out_ptr = reinterpret_cast(output->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); + auto workspace_total_size = product(workspace->dimensions()); + + auto lhs_sinv_size = product(lhs_sinv.dimensions()); + auto rhs_sinv_size = product(rhs_sinv.dimensions()); + const size_t workspace_alignment_padding = 256; + const size_t tensor_scaling_sinv_aligment = 16; + const size_t mxfp8_scaling_sinv_alignment_padding = 256; + auto workspace_size = workspace_total_size - workspace_alignment_padding; + if (is_mxfp8_scaling) { + // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. + workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); + } else if (is_tensor_scaling) { + // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned + // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. + workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); + } + workspace_size = workspace_size / num_streams; + auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; + swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); + auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; + swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); + auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned + auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; + + size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); + size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); + size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); + size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); + size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); + size_t out_dtype_bytes = te_dtype_bytes(out_dtype); + + if (is_tensor_scaling) { + size_t dpitch = tensor_scaling_sinv_aligment; + size_t spitch = lhs_sinv_dtype_bytes; + size_t width = lhs_sinv_dtype_bytes; + size_t height = lhs_sinv_size; + cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height, + cudaMemcpyDeviceToDevice, stream); + spitch = rhs_sinv_dtype_bytes; + width = rhs_sinv_dtype_bytes; + height = rhs_sinv_size; + cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height, + cudaMemcpyDeviceToDevice, stream); + lhs_sinv_ptr = lhs_scatter_aligned_ptr; + rhs_sinv_ptr = rhs_scatter_aligned_ptr; + } + + NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); + NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, + "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); + + size_t expected_lhs_size = m * k; + size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t actual_lhs_size = product(lhs_data.dimensions()); + size_t actual_rhs_size = product(rhs_data.dimensions()); + size_t actual_out_size = product(output->dimensions()); + NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", + expected_lhs_size, ", got ", actual_lhs_size); + if (!is_grouped_dense_wgrad) { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, + "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, + " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, + " * ", n, " = ", expected_out_size, ", got ", actual_out_size); + } else { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, + " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, + "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, + " = ", expected_out_size, ", got ", actual_out_size); + } + + size_t dim_list_bytes = sizeof(int32_t) * num_gemms; + std::vector dim_list_host(num_gemms); + size_t host_num_gemms = 0; + if (use_async_d2h_group_sizes) { + host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); + } else { + auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + } + size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + if (!is_grouped_dense_wgrad) { + NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + ", got sum(group_sizes)=", sum_group_sizes); + } else { + NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + ", got sum(group_sizes)=", sum_group_sizes); + } + + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + bool grad = false; + bool accumulate = false; + bool use_split_accumulator = false; + auto bias_shape = std::vector{has_bias ? n : 0}; + const int arch = cuda::sm_arch(); + + if (arch < 100 && is_fp8_gemm) { + NVTE_CHECK(!lhs_is_trans && rhs_is_trans, + "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", + "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); + } + + // These lists are to keep the TensorWrapper objects alive + std::vector lhs_wrapper_list; + std::vector rhs_wrapper_list; + std::vector lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling + std::vector rhs_swizzle_wrapper_list; + std::vector bias_wrapper_list; + std::vector pre_gelu_wrapper_list; + std::vector out_wrapper_list; + std::vector workspace_wrapper_list; + + // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM + std::vector lhs_list; + std::vector rhs_list; + std::vector lhs_swizzle_list; + std::vector rhs_swizzle_list; + std::vector bias_list; + std::vector pre_gelu_list; + std::vector out_list; + std::vector workspace_list; + + size_t lhs_sinv_total_size = 0; + size_t rhs_sinv_total_size = 0; + + std::vector zero_out_dptr_list; + std::vector zero_out_size_list; + + for (size_t i = 0; i < num_gemms; i++) { + // Matrix data shapes + size_t m_i = dim_list_host[i]; + auto lhs_shape_i = std::vector{m_i, k}; + auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; + auto out_shape_i = std::vector{m_i, n}; + if (is_grouped_dense_wgrad) { + size_t k_i = dim_list_host[i]; + lhs_shape_i[0] = lhs_is_trans ? k_i : m; + lhs_shape_i[1] = lhs_is_trans ? m : k_i; + rhs_shape_i[0] = rhs_is_trans ? n : k_i; + rhs_shape_i[1] = rhs_is_trans ? k_i : n; + out_shape_i[0] = m; + out_shape_i[1] = n; + } + + size_t lhs_size = lhs_shape_i[0] * lhs_shape_i[1]; + size_t rhs_size = rhs_shape_i[0] * rhs_shape_i[1]; + size_t out_size = out_shape_i[0] * out_shape_i[1]; + bool is_empty_gemm = lhs_size == 0 || rhs_size == 0; + if (is_empty_gemm && out_size > 0) { + zero_out_dptr_list.push_back(out_ptr); + zero_out_size_list.push_back(out_size * out_dtype_bytes); + } + + // Set matrix data pointers + auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto out_i = TensorWrapper(static_cast(out_ptr), out_shape_i, out_dtype); + void *lhs_vptr = static_cast(lhs_ptr); + void *rhs_vptr = static_cast(rhs_ptr); + if (rhs_use_colwise) // MatA to enter cuBLAS + rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); + else + rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); + if (lhs_use_colwise) // MatB to enter cuBLAS + lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); + else + lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); + + // Set scale_inv shapes and pointers + void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); + void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); + size_t lhs_sinv_size_i = 0; + size_t rhs_sinv_size_i = 0; + if (is_tensor_scaling) { + auto tensor_scaling_sinv_shape = std::vector{1}; + // If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers + if (!is_empty_gemm) { + lhs_sinv_size_i = tensor_scaling_sinv_aligment / lhs_sinv_dtype_bytes; + rhs_sinv_size_i = tensor_scaling_sinv_aligment / rhs_sinv_dtype_bytes; + } + if (rhs_use_colwise) // MatA to enter cuBLAS + rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); + else + rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); + if (lhs_use_colwise) // MatB to enter cuBLAS + lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); + else + lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); + } else if (is_mxfp8_scaling) { + auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + void *swizzled_lhs_sinv_vptr = static_cast(swizzled_lhs_sinv_ptr); + void *swizzled_rhs_sinv_vptr = static_cast(swizzled_rhs_sinv_ptr); + + // {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i + // point to swizzled scale_inv data (store on workspace, only used for GEMM). + // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers + auto lhs_sinv_shape_i = + get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); + auto rhs_sinv_shape_i = + get_block_scale_shape(scaling_mode, rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); + lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; + rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; + if (lhs_use_colwise) { + lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); + lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + } else { + lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); + lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + } + lhs_i.set_with_gemm_swizzled_scales(true); + if (rhs_use_colwise) { + rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); + rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + } else { + rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); + rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + } + rhs_i.set_with_gemm_swizzled_scales(true); + + if (!is_empty_gemm) { + lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i)); + rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i)); + lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data()); + rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data()); + } + } else { + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, + "Unsupported scaling mode: ", static_cast(scaling_mode)); + } + + auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); + + // Update pointer for the next GEMM pair + lhs_ptr += lhs_size * lhs_dtype_bytes; + rhs_ptr += rhs_size * rhs_dtype_bytes; + out_ptr += out_size * out_dtype_bytes; + if (is_fp8_gemm) { + lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; + rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; + lhs_sinv_total_size += lhs_sinv_size_i; + rhs_sinv_total_size += rhs_sinv_size_i; + if (is_mxfp8_scaling) { + swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; + swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; + } + } + if (has_bias) bias_ptr += n * bias_dtype_bytes; + + // Move objects to the lists to keep them alive + if (is_empty_gemm) continue; + lhs_wrapper_list.push_back(std::move(lhs_i)); + rhs_wrapper_list.push_back(std::move(rhs_i)); + out_wrapper_list.push_back(std::move(out_i)); + bias_wrapper_list.push_back(std::move(bias_i)); + pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); + + lhs_list.push_back(lhs_wrapper_list.back().data()); + rhs_list.push_back(rhs_wrapper_list.back().data()); + bias_list.push_back(bias_wrapper_list.back().data()); + pre_gelu_list.push_back(pre_gelu_wrapper_list.back().data()); + out_list.push_back(out_wrapper_list.back().data()); + } + + auto workspace_shape = std::vector{workspace_size}; + for (int i = 0; i < num_streams; i++) { + auto workspace_i = + TensorWrapper(static_cast(workspace_ptr), workspace_shape, DType::kByte); + workspace_wrapper_list.push_back(std::move(workspace_i)); + workspace_list.push_back(workspace_wrapper_list.back().data()); + workspace_ptr += workspace_size; + } + + if (is_fp8_gemm) { + if (is_tensor_scaling) { + lhs_sinv_size *= tensor_scaling_sinv_aligment; + rhs_sinv_size *= tensor_scaling_sinv_aligment; + } + NVTE_CHECK(lhs_sinv_total_size <= lhs_sinv_size, "Actual total lhs_sinv size ", + lhs_sinv_total_size, " exceeds estimated upper bound ", lhs_sinv_size); + NVTE_CHECK(rhs_sinv_total_size <= rhs_sinv_size, "Actual total rhs_sinv size ", + rhs_sinv_total_size, " exceeds estimated upper bound ", rhs_sinv_size); + } + + size_t num_non_empty_gemms = lhs_list.size(); + + if (is_mxfp8_scaling) { + for (int i = 0; i < num_non_empty_gemms; i++) { + // The i-th GEMM will use the (i % num_streams)-th stream to compute, + // use the same stream to swizzle the scaling factors to make sure that + // the swizzling is done before the GEMM computation starts. + int stream_id = i % num_streams; + cudaStream_t stream_i = nvte_get_compute_stream(stream_id); + nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); + nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); + } + } + + // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM + size_t num_zero_outs = zero_out_dptr_list.size(); + for (int i = 0; i < num_zero_outs; i++) { + int stream_id = i % num_streams; + cudaStream_t stream_i = nvte_get_compute_stream(stream_id); + void *dptr = zero_out_dptr_list[i]; + size_t count = zero_out_size_list[i]; + NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); + } + + nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), + pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, + grad, workspace_list.data(), accumulate, use_split_accumulator, + num_math_sm, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs_data + .Arg() // lhs_sinv + .Arg() // rhs_data + .Arg() // rhs_sinv + .Arg() // bias + .Arg() // group_sizes + .Arg() // group_offset + .Ret() // output + .Ret() // workspace + .Attr("M") + .Attr("N") + .Attr("K") + .Attr("lhs_is_trans") + .Attr("rhs_is_trans") + .Attr("scaling_mode") + .Attr("has_bias") + .Attr("is_grouped_dense_wgrad") + .Attr("use_async_d2h_group_sizes")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index d4bb3bab00..87acb287a7 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -76,6 +76,9 @@ pybind11::dict Registrations() { dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); + dict["te_grouped_gemm_cuda_graphable_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GroupedGemmCudaGraphableHandler)); // Amax dict["te_rht_amax_ffi"] = pybind11::dict( From 6fd7f169ddc95d2feb9c3d54942f1f636e7c7839 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Feb 2026 21:11:51 +0000 Subject: [PATCH 08/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_grouped_gemm.cu | 5 ++-- transformer_engine/jax/cpp_extensions/gemm.py | 14 +++++------ .../jax/csrc/extensions/gemm.cpp | 23 ++++++++----------- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 292f3d6a21..7b248aae7d 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -446,7 +446,7 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, // 2. Per-tensor first/last dims provided but no offsets → cumulative sum of (first*last) products. // 3. Fully uniform shapes → idx * uniform_first * uniform_last. __forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorShapeInfo &meta, - size_t idx) { + size_t idx) { if (meta.offsets) { return meta.offsets[idx]; } else if (meta.first_dims != nullptr || meta.last_dims != nullptr) { @@ -562,8 +562,7 @@ size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors) { return grouped_gemm_setup_workspace_size(num_tensors); } -void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, - cudaStream_t stream) { +void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) { if (n == 0) return; const int threads = 256; const int blocks = static_cast((n + threads - 1) / threads); diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 708c91ded8..69793b758b 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1509,9 +1509,7 @@ def abstract( # Temporary buffer for int32 → int64 conversion of group_sizes on device. int64_workspace_size = group_sizes_aval.size * jnp.dtype(jnp.int64).itemsize - int64_workspace_aval = jax.core.ShapedArray( - shape=(int64_workspace_size,), dtype=jnp.uint8 - ) + int64_workspace_aval = jax.core.ShapedArray(shape=(int64_workspace_size,), dtype=jnp.uint8) out_shape = (M, N) if is_grouped_dense_wgrad: @@ -2261,12 +2259,14 @@ def grouped_gemm( # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay # feature-compatible with the main branch. - _use_cuda_graphable = ( - scaling_mode == ScalingMode.NO_SCALING and lhs_data.dtype == jnp.bfloat16 - ) + _use_cuda_graphable = scaling_mode == ScalingMode.NO_SCALING and lhs_data.dtype == jnp.bfloat16 if _use_cuda_graphable: - assert group_offset is None, "group_offset is not supported in the cuda graphable path and is instead computed internally assuming contiguous grouping. Any padding is included in the group_sizes and padded with zeros to not affect the result of the MoE block." + assert group_offset is None, ( + "group_offset is not supported in the cuda graphable path and is instead computed" + " internally assuming contiguous grouping. Any padding is included in the group_sizes" + " and padded with zeros to not affect the result of the MoE block." + ) group_sizes = group_sizes.astype(jnp.int32) num_gemms = group_sizes.shape[0] alpha = jnp.ones((num_gemms,), jnp.float32) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 508b9aeda9..9e27ca5c2f 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -528,8 +528,7 @@ void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const &group_sizes, } void JAXX_GroupedTensorWrapper::set_group_sizes_only( - const int64_t *sizes_ptr, size_t num_tensors, - NVTEGroupedTensorParam group_sizes_param_name) { + const int64_t *sizes_ptr, size_t num_tensors, NVTEGroupedTensorParam group_sizes_param_name) { NVTEShape shape{}; shape.ndim = 1; shape.data[0] = num_tensors; @@ -556,16 +555,13 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } -Error_Type GroupedGemmCudaGraphableFFI(cudaStream_t stream, Buffer_Type lhs_data, - Buffer_Type lhs_sinv, Buffer_Type rhs_data, - Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type alpha, Buffer_Type beta, - Result_Type output, Result_Type workspace, - Result_Type int64_workspace, size_t m, size_t n, size_t k, - bool lhs_is_trans, bool rhs_is_trans, - JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad, - bool use_async_d2h_group_sizes) { +Error_Type GroupedGemmCudaGraphableFFI( + cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, + Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type alpha, + Buffer_Type beta, Result_Type output, Result_Type workspace, Result_Type int64_workspace, + size_t m, size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, + JAXX_Scaling_Mode scaling_mode, bool has_bias, bool is_grouped_dense_wgrad, + bool use_async_d2h_group_sizes) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -598,8 +594,7 @@ Error_Type GroupedGemmCudaGraphableFFI(cudaStream_t stream, Buffer_Type lhs_data size_t num_gemms = group_sizes.dimensions()[0]; // Convert int32 group_sizes to int64 into the dedicated output buffer. - NVTE_CHECK(group_sizes.element_type() == xla::ffi::DataType::S32, - "group_sizes must be int32."); + NVTE_CHECK(group_sizes.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); nvte_convert_int32_to_int64(reinterpret_cast(group_sizes.untyped_data()), int64_sizes_ptr, num_gemms, stream); From 0d5837daf195ff15826e5f7a236aa897b25ed176 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 24 Feb 2026 13:37:07 -0800 Subject: [PATCH 09/14] cleanup and lint fixes Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 9 +++++---- transformer_engine/jax/cpp_extensions/quantization.py | 4 ++-- transformer_engine/jax/csrc/extensions/gemm.cpp | 8 -------- transformer_engine/jax/flax/module.py | 2 ++ transformer_engine/jax/quantize/quantizer.py | 2 +- 5 files changed, 10 insertions(+), 15 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 69793b758b..032e809d88 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -583,7 +583,7 @@ def lowering( (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) ) - lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed) + _lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed) # lhs_contracting_size = ( # reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) # if lhs_transposed @@ -1481,7 +1481,7 @@ def abstract( Returns: A jnp.ndarray containing the result of the grouped GEMM operation """ - del lhs_data_aval, rhs_data_aval, bias_aval + del lhs_data_aval, rhs_data_aval, bias_aval, alpha, beta del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams @@ -2259,9 +2259,10 @@ def grouped_gemm( # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay # feature-compatible with the main branch. - _use_cuda_graphable = scaling_mode == ScalingMode.NO_SCALING and lhs_data.dtype == jnp.bfloat16 + # Bias can be supported in a kernel or in pure-JAX in the future. + use_cuda_graphable = scaling_mode == ScalingMode.NO_SCALING and lhs_data.dtype == jnp.bfloat16 and not has_bias - if _use_cuda_graphable: + if use_cuda_graphable: assert group_offset is None, ( "group_offset is not supported in the cuda graphable path and is instead computed" " internally assuming contiguous grouping. Any padding is included in the group_sizes" diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index f851bbebc1..bf4e833c89 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -20,6 +20,7 @@ from .base import BasePrimitive, register_primitive from .misc import ( get_padded_spec, + check_valid_batch_dims, te_dtype_to_jax_dtype, jax_dtype_to_te_dtype, multidim_transpose, @@ -1250,8 +1251,7 @@ def grouped_quantize( ) = GroupedQuantizePrimitive.outer_primitive.bind( x, scale, - # TODO(jberchtold): Remove this int32 cast once GMM does not require JAX int64 dtype - group_sizes.astype(jnp.int32), + group_sizes, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, q_layout=q_layout, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 9e27ca5c2f..73c4a5ede1 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -719,10 +719,6 @@ Error_Type GroupedGemmCudaGraphableFFI( auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); - // Output needs to be zeroed in case any group sizes have size zero, meaning the expert weight isn't used in the fwd, meaning the corresponding output gradient should be zero. But using the grouped GEMM, the output buffer contains uninitialized data. - // TODO(jberchtold): make this memset smaller by only zeroing the expert weights that correspond to groups with size zero. - cudaMemsetAsync(output->untyped_data(), 0, output->size_bytes(), stream); - nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), workspace_cublas.data(), @@ -757,10 +753,6 @@ Error_Type GroupedGemmCudaGraphableFFI( num_gemms, outShape); out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - // This memset is required because the group sizes may not fill the full buffer since we overallocate for the worst case. However, in theory unused space on the grouped axis should not be utilizied downstream, but it seems like somehow it is utilized. - // TODO(jberchtold): try removing this - cudaMemsetAsync(output->untyped_data(), 0, output->size_bytes(), stream); - nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), workspace_cublas.data(), diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index a661f30356..1758e338c4 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1444,9 +1444,11 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): def make_ragged_dot_cls(quantization_recipe): + """Creates a ragged dot (grouped GEMM) class for use with TE state module.""" assert quantization_recipe is None, "Ragged dot grouped GEMM does not support quantization yet" def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): + del kwargs # Unused num_groups = group_sizes.shape[0] quantizer_set = generate_quantizer_set(n_groups=num_groups) diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 1923932692..f5ca6aeaed 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -68,7 +68,7 @@ def compute_scale_from_amax( sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}" - return sf.astype(jnp.float32) + return sf @register_pytree_node_class From d3ee0fcfb1b8a6c0cb07225bd0ffb32c51d42bfa Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 24 Feb 2026 13:48:41 -0800 Subject: [PATCH 10/14] re-add cublas alignment checks Signed-off-by: Jeremy Berchtold --- .../common/gemm/cublaslt_gemm.cu | 8 ++-- transformer_engine/jax/cpp_extensions/gemm.py | 44 +++++++++---------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 241e30764a..a6eef503f7 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -154,8 +154,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - // NVTE_CHECK(ret.lda % 16 == 0, - // "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + NVTE_CHECK(ret.lda % 16 == 0, + "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. @@ -245,8 +245,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - // NVTE_CHECK(ret.ldb % 16 == 0, - // "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); + NVTE_CHECK(ret.ldb % 16 == 0, + "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { if (is_B_transposed) { diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 032e809d88..b50567aec9 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -583,28 +583,28 @@ def lowering( (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) ) - _lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed) - # lhs_contracting_size = ( - # reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) - # if lhs_transposed - # else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) - # ) - # assert_cublas_requirements( - # scaling_mode, - # lhs_contracting_size, - # f"LHS {lhs_aval.shape} with contracting dims {lhs_cdims}", - # ) - # rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) - # rhs_contracting_size = ( - # reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) - # if rhs_transposed - # else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) - # ) - # assert_cublas_requirements( - # scaling_mode, - # rhs_contracting_size, - # f"RHS {rhs_aval.shape} with contracting dims {rhs_cdims}", - # ) + lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed) + lhs_contracting_size = ( + reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) + if lhs_transposed + else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) + ) + assert_cublas_requirements( + scaling_mode, + lhs_contracting_size, + f"LHS {lhs_aval.shape} with contracting dims {lhs_cdims}", + ) + rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) + rhs_contracting_size = ( + reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) + if rhs_transposed + else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) + ) + assert_cublas_requirements( + scaling_mode, + rhs_contracting_size, + f"RHS {rhs_aval.shape} with contracting dims {rhs_cdims}", + ) args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) kwargs = { From 64406481cfb189906adda24a69443877179225c5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Feb 2026 21:49:40 +0000 Subject: [PATCH 11/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/cublaslt_gemm.cu | 2 +- transformer_engine/jax/cpp_extensions/gemm.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index a6eef503f7..c58c3cb47a 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -155,7 +155,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage NVTE_CHECK(ret.lda % 16 == 0, - "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index b50567aec9..567219b6ad 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2260,7 +2260,9 @@ def grouped_gemm( # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay # feature-compatible with the main branch. # Bias can be supported in a kernel or in pure-JAX in the future. - use_cuda_graphable = scaling_mode == ScalingMode.NO_SCALING and lhs_data.dtype == jnp.bfloat16 and not has_bias + use_cuda_graphable = ( + scaling_mode == ScalingMode.NO_SCALING and lhs_data.dtype == jnp.bfloat16 and not has_bias + ) if use_cuda_graphable: assert group_offset is None, ( From 661a8291e33d3c2ea8eef2dc57a3167bced0c3af Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 24 Feb 2026 16:00:09 -0800 Subject: [PATCH 12/14] fix symbol export when building with older cublas Signed-off-by: Jeremy Berchtold --- .../common/gemm/cublaslt_grouped_gemm.cu | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 7b248aae7d..4a1ae992d3 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -551,25 +551,12 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); } -__global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, size_t n) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) dst[idx] = static_cast(src[idx]); -} - } // namespace size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors) { return grouped_gemm_setup_workspace_size(num_tensors); } -void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) { - if (n == 0) return; - const int threads = 256; - const int blocks = static_cast((n + threads - 1) / threads); - convert_int32_to_int64_kernel<<>>(src, dst, n); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, @@ -677,4 +664,28 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); } +size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors) { + NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); + return 0; +} + + #endif // CUBLAS_VERSION >= 130200 + +namespace { + +__global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) dst[idx] = static_cast(src[idx]); +} + +} // namespace + +void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) { + if (n == 0) return; + const int threads = 256; + const int blocks = static_cast((n + threads - 1) / threads); + convert_int32_to_int64_kernel<<>>(src, dst, n); + NVTE_CHECK_CUDA(cudaGetLastError()); +} From bd5e6fb026371c223d9e31448c7d49794e318bd8 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 25 Feb 2026 09:48:05 -0800 Subject: [PATCH 13/14] Fix backend selection depending on whether TE was compiled with the right cuBLAS version Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 567219b6ad..dfac350860 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2082,6 +2082,30 @@ def grouped_gemm_copy_group_sizes( ) return out +def _can_use_cuda_graphable_grouped_gemm( + scaling_mode: ScalingMode, + dtype: jnp.dtype, + has_bias: bool, +) -> bool: + """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters. + + """ + # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy + # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay + # feature-compatible with the main branch. + # Bias can be supported in a kernel or in pure-JAX in the future. + + try: + get_grouped_gemm_setup_workspace_size(1) + except RuntimeError as e: + if "cublas" in str(e).lower(): + # If the workspace size function is not available, it means the cuda-graphable implementation is not available. + return False + raise e + + return ( + scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias + ) def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], @@ -2256,13 +2280,7 @@ def grouped_gemm( assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias - # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy - # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay - # feature-compatible with the main branch. - # Bias can be supported in a kernel or in pure-JAX in the future. - use_cuda_graphable = ( - scaling_mode == ScalingMode.NO_SCALING and lhs_data.dtype == jnp.bfloat16 and not has_bias - ) + use_cuda_graphable = _can_use_cuda_graphable_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) if use_cuda_graphable: assert group_offset is None, ( From 60d5c4220adee68f095f355bad3f007a50c95ba4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Feb 2026 17:49:24 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_grouped_gemm.cu | 1 - transformer_engine/jax/cpp_extensions/gemm.py | 14 +++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 4a1ae992d3..ac9322d651 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -670,7 +670,6 @@ size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors) { return 0; } - #endif // CUBLAS_VERSION >= 130200 namespace { diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index dfac350860..283b96d995 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2082,14 +2082,13 @@ def grouped_gemm_copy_group_sizes( ) return out + def _can_use_cuda_graphable_grouped_gemm( scaling_mode: ScalingMode, dtype: jnp.dtype, has_bias: bool, ) -> bool: - """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters. - - """ + """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters.""" # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay # feature-compatible with the main branch. @@ -2103,9 +2102,8 @@ def _can_use_cuda_graphable_grouped_gemm( return False raise e - return ( - scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias - ) + return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias + def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], @@ -2280,7 +2278,9 @@ def grouped_gemm( assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias - use_cuda_graphable = _can_use_cuda_graphable_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) + use_cuda_graphable = _can_use_cuda_graphable_grouped_gemm( + scaling_mode, lhs_data.dtype, has_bias + ) if use_cuda_graphable: assert group_offset is None, (