[JAX] Integrate BF16 Grouped GEMM with on-device group sizes#2680
[JAX] Integrate BF16 Grouped GEMM with on-device group sizes#2680jberchtold-nvidia wants to merge 16 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR integrates a new BF16 grouped GEMM implementation that is CUDA-graph safe by eliminating D2H memcpy and stream synchronization. The implementation uses on-device group size storage and offset computation. Key changes:
Backend selection logic: Previous review items addressed: Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[JAX grouped_gemm call] --> B{Can use CUDA-graphable?}
B -->|BF16, no scaling, no bias| C[GroupedGemmCudaGraphablePrimitive]
B -->|FP8/MXFP8 or has bias| D[GroupedGemmPrimitive legacy]
C --> E[Python: Convert group_sizes to int32]
E --> F[C++: GroupedGemmCudaGraphableFFI]
F --> G[Allocate int64_workspace for group_sizes]
F --> H[Call nvte_convert_int32_to_int64]
H --> I[CUDA kernel: int32 -> int64 conversion]
F --> J[Create JAXX_GroupedTensorWrapper]
J --> K[set_group_sizes_only with int64 sizes]
K --> L[nvte_grouped_gemm]
L --> M[Setup workspace calculation]
M --> N[CUDA kernel: setup_grouped_gemm_kernel]
N --> O[compute_grouped_tensor_offset for each tensor]
O --> P[Compute pointers and dimensions]
P --> Q[cuBLASLt grouped matmul API]
Q --> R[Return output]
D --> S[Legacy path with D2H sync]
S --> T[nvte_multi_tensor_gemm]
Last reviewed commit: 70d8f78 |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…void enabling JAX x64 globally Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…sion to a different suffix Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
a4b6769 to
a661e9e
Compare
for more information, see https://pre-commit.ci
| .Ret<Buffer_Type>() // dummy_output | ||
| .Attr<int64_t>("num_gemms")); | ||
|
|
||
| class JAXX_GroupedTensorWrapper { |
There was a problem hiding this comment.
There is a PyTorch refactor PR that also adds a similar GroupedTensorWrapper struct to TE/common. I will move to that in the future once it is merged, but I don't want to block this PR on that currently and the logic is pretty straightforward and should be a small change to switch.
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L0 jax |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
154f2be to
661a829
Compare
|
/te-ci L0 jax |
right cuBLAS version Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
7b125cc to
bd5e6fb
Compare
for more information, see https://pre-commit.ci
|
/te-ci |
| __forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorShapeInfo &meta, | ||
| size_t idx) { | ||
| if (meta.offsets) { | ||
| return meta.offsets[idx]; | ||
| } else if (meta.first_dims != nullptr || meta.last_dims != nullptr) { | ||
| // offset[i] = sum_{j < i} (first_dims[j] * last_dims[j]) | ||
| int64_t cumsum = 0; | ||
| for (size_t i = 0; i < idx; i++) { | ||
| int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first; | ||
| int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last; | ||
| cumsum += f * l; | ||
| } | ||
| return cumsum; | ||
| } else { | ||
| return static_cast<int64_t>(idx) * meta.uniform_first * meta.uniform_last; | ||
| } | ||
| } |
There was a problem hiding this comment.
O(n²) complexity in parallel kernel. Each of the n threads calls this function with a different idx, and for case 2 (per-tensor dims without explicit offsets), thread idx performs a sequential loop from 0 to idx-1. This creates O(1 + 2 + ... + n) = O(n²) total work across all threads.
For large numbers of groups, consider either:
- Computing offsets once on CPU and passing them explicitly
- Using a parallel prefix sum (scan) to compute cumulative offsets
- Documenting this limitation if group counts are expected to be small
Description
Integrate new BF16 grouped GEMM from TE common/cuBLASLt that supports on-device group sizes without a D2H memcpy and stream sync. This grouped GEMM is faster and CUDA-graph safe.
Also fixes #2659
Type of change
Changes
make_ragged_dot_clsfor easy integration into existing models. This will be most useful when quantization is supported and storing recipe state is requiredChecklist: