Skip to content

[JAX] Integrate BF16 Grouped GEMM with on-device group sizes#2680

Open
jberchtold-nvidia wants to merge 16 commits intoNVIDIA:mainfrom
jberchtold-nvidia:gmm
Open

[JAX] Integrate BF16 Grouped GEMM with on-device group sizes#2680
jberchtold-nvidia wants to merge 16 commits intoNVIDIA:mainfrom
jberchtold-nvidia:gmm

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Feb 13, 2026

Description

Integrate new BF16 grouped GEMM from TE common/cuBLASLt that supports on-device group sizes without a D2H memcpy and stream sync. This grouped GEMM is faster and CUDA-graph safe.

Also fixes #2659

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added a new cuda-graph-safe grouped GEMM to TE/JAX that is automatically used as the backend when the input data is bf16 (no scaling recipe) and not bias is required.
  • Exposed a new make_ragged_dot_cls for easy integration into existing models. This will be most useful when quantization is supported and storing recipe state is required

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft February 13, 2026 17:42
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 13, 2026

Greptile Summary

This PR integrates a new BF16 grouped GEMM implementation that is CUDA-graph safe by eliminating D2H memcpy and stream synchronization. The implementation uses on-device group size storage and offset computation.

Key changes:

  • Added nvte_grouped_gemm backend with on-device offset computation from group sizes
  • Created new GroupedGemmCudaGraphablePrimitive that automatically activates for BF16 non-quantized inputs without bias
  • Implemented int32→int64 conversion kernel to handle JAX's default int32 while supporting int64 internally
  • Exposed make_ragged_dot_cls API for Flax integration with grouped dense operations
  • Fixed dimension ordering bug in setup kernel (d_rows/d_cols were swapped)

Backend selection logic:
The new CUDA-graphable path is used when: dtype is bfloat16, no scaling/quantization, and no bias. All other cases fall back to the existing nvte_multi_tensor_gemm path, preserving backward compatibility for FP8/MXFP8 workloads.

Previous review items addressed:
All previously raised concerns about workspace allocation, output zeroing, offset computation in Python, alignment checks, and quantization support have been acknowledged in the review threads.

Confidence Score: 4/5

  • Safe to merge with one performance consideration for large group counts
  • The implementation is well-structured with proper fallback logic and comprehensive error checking. The O(n²) offset computation in the CUDA kernel could impact performance with very large numbers of groups, but for typical MoE workloads this should be acceptable. Previous review concerns have been noted and the feature is appropriately scoped to BF16 non-quantized cases.
  • Pay attention to cublaslt_grouped_gemm.cu - the offset computation has quadratic complexity that may need optimization for large group counts

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Adds BF16 grouped GEMM support with on-device offset computation, int32->int64 conversion utility, and workspace size calculation. Contains O(n²) complexity issue in offset computation kernel.
transformer_engine/jax/csrc/extensions/gemm.cpp Implements new CUDA-graphable grouped GEMM FFI handler with proper tensor wrapper abstractions. Well-structured but contains complex workspace calculations.
transformer_engine/jax/cpp_extensions/gemm.py Adds new CUDA-graphable grouped GEMM primitive alongside existing implementation. Implements smart backend selection based on dtype/scaling mode. Clean integration.
transformer_engine/jax/flax/module.py Adds make_ragged_dot_cls for grouped GEMM integration and updates quantizer set generation to support groups. Properly guards against unsupported quantization.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[JAX grouped_gemm call] --> B{Can use CUDA-graphable?}
    B -->|BF16, no scaling, no bias| C[GroupedGemmCudaGraphablePrimitive]
    B -->|FP8/MXFP8 or has bias| D[GroupedGemmPrimitive legacy]
    
    C --> E[Python: Convert group_sizes to int32]
    E --> F[C++: GroupedGemmCudaGraphableFFI]
    F --> G[Allocate int64_workspace for group_sizes]
    F --> H[Call nvte_convert_int32_to_int64]
    H --> I[CUDA kernel: int32 -> int64 conversion]
    
    F --> J[Create JAXX_GroupedTensorWrapper]
    J --> K[set_group_sizes_only with int64 sizes]
    K --> L[nvte_grouped_gemm]
    
    L --> M[Setup workspace calculation]
    M --> N[CUDA kernel: setup_grouped_gemm_kernel]
    N --> O[compute_grouped_tensor_offset for each tensor]
    O --> P[Compute pointers and dimensions]
    
    P --> Q[cuBLASLt grouped matmul API]
    Q --> R[Return output]
    
    D --> S[Legacy path with D2H sync]
    S --> T[nvte_multi_tensor_gemm]
Loading

Last reviewed commit: 70d8f78

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

jberchtold-nvidia and others added 7 commits February 24, 2026 13:10
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…void enabling JAX x64 globally

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…sion to a different suffix

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
.Ret<Buffer_Type>() // dummy_output
.Attr<int64_t>("num_gemms"));

class JAXX_GroupedTensorWrapper {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a PyTorch refactor PR that also adds a similar GroupedTensorWrapper struct to TE/common. I will move to that in the future once it is merged, but I don't want to block this PR on that currently and the logic is pretty straightforward and should be a small change to switch.

jberchtold-nvidia and others added 3 commits February 24, 2026 13:37
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 jax

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 jax

right cuBLAS version

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

@jberchtold-nvidia jberchtold-nvidia marked this pull request as ready for review February 25, 2026 21:32
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +448 to +464
__forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorShapeInfo &meta,
size_t idx) {
if (meta.offsets) {
return meta.offsets[idx];
} else if (meta.first_dims != nullptr || meta.last_dims != nullptr) {
// offset[i] = sum_{j < i} (first_dims[j] * last_dims[j])
int64_t cumsum = 0;
for (size_t i = 0; i < idx; i++) {
int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first;
int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last;
cumsum += f * l;
}
return cumsum;
} else {
return static_cast<int64_t>(idx) * meta.uniform_first * meta.uniform_last;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

O(n²) complexity in parallel kernel. Each of the n threads calls this function with a different idx, and for case 2 (per-tensor dims without explicit offsets), thread idx performs a sequential loop from 0 to idx-1. This creates O(1 + 2 + ... + n) = O(n²) total work across all threads.

For large numbers of groups, consider either:

  1. Computing offsets once on CPU and passing them explicitly
  2. Using a parallel prefix sum (scan) to compute cumulative offsets
  3. Documenting this limitation if group counts are expected to be small

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Core] TE common nvte_grouped_gemm treats output layout as column-wise instead of rowwise

1 participant