Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 57 additions & 12 deletions transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,29 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle,
return heuristicResult.algo;
}

// Device helper: compute the element offset for tensor `idx` given shape metadata.
// Three cases:
// 1. Explicit per-tensor offset array provided → use it directly.
// 2. Per-tensor first/last dims provided but no offsets → cumulative sum of (first*last) products.
// 3. Fully uniform shapes → idx * uniform_first * uniform_last.
__forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorShapeInfo &meta,
size_t idx) {
if (meta.offsets) {
return meta.offsets[idx];
} else if (meta.first_dims != nullptr || meta.last_dims != nullptr) {
// offset[i] = sum_{j < i} (first_dims[j] * last_dims[j])
int64_t cumsum = 0;
for (size_t i = 0; i < idx; i++) {
int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first;
int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last;
cumsum += f * l;
}
return cumsum;
} else {
return static_cast<int64_t>(idx) * meta.uniform_first * meta.uniform_last;
}
}
Comment on lines +448 to +464
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

O(n²) complexity in parallel kernel. Each of the n threads calls this function with a different idx, and for case 2 (per-tensor dims without explicit offsets), thread idx performs a sequential loop from 0 to idx-1. This creates O(1 + 2 + ... + n) = O(n²) total work across all threads.

For large numbers of groups, consider either:

  1. Computing offsets once on CPU and passing them explicitly
  2. Using a parallel prefix sum (scan) to compute cumulative offsets
  3. Documenting this limitation if group counts are expected to be small


// Single kernel that sets up all GEMM parameters.
// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions,
// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes.
Expand All @@ -464,15 +487,11 @@ __global__ void setup_grouped_gemm_kernel(
int64_t d_first = D_meta.first_dims ? D_meta.first_dims[idx] : D_meta.uniform_first;
int64_t d_last = D_meta.last_dims ? D_meta.last_dims[idx] : D_meta.uniform_last;

// Compute offsets (from array or compute from uniform dims)
int64_t a_offset =
A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last);
int64_t b_offset =
B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last);
int64_t c_offset =
C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last);
int64_t d_offset =
D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last);
// Compute offsets (from explicit array, cumulative from per-tensor dims, or uniform)
int64_t a_offset = compute_grouped_tensor_offset(A_meta, idx);
int64_t b_offset = compute_grouped_tensor_offset(B_meta, idx);
int64_t c_offset = compute_grouped_tensor_offset(C_meta, idx);
int64_t d_offset = compute_grouped_tensor_offset(D_meta, idx);

// Compute data pointers
A_ptrs[idx] = a_base + a_offset * a_elem_size;
Expand All @@ -487,9 +506,8 @@ __global__ void setup_grouped_gemm_kernel(
a_cols[idx] = static_cast<int>(a_first);
b_rows[idx] = static_cast<int>(b_last);
b_cols[idx] = static_cast<int>(b_first);
// For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N).
d_rows[idx] = static_cast<int>(d_first);
d_cols[idx] = static_cast<int>(d_last);
d_rows[idx] = static_cast<int>(d_last);
d_cols[idx] = static_cast<int>(d_first);

// Fill alpha/beta pointers (per-matrix)
alpha_ptrs[idx] = alpha_ptr + idx;
Expand Down Expand Up @@ -535,6 +553,10 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) {

} // namespace

size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors) {
return grouped_gemm_setup_workspace_size(num_tensors);
}

void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
Expand Down Expand Up @@ -642,4 +664,27 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT
CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer.");
}

size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors) {
NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ",
CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer.");
return 0;
}

#endif // CUBLAS_VERSION >= 130200

namespace {

__global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, size_t n) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) dst[idx] = static_cast<int64_t>(src[idx]);
}

} // namespace

void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) {
if (n == 0) return;
const int threads = 256;
const int blocks = static_cast<int>((n + threads - 1) / threads);
convert_int32_to_int64_kernel<<<blocks, threads, 0, stream>>>(src, dst, n);
NVTE_CHECK_CUDA(cudaGetLastError());
}
25 changes: 25 additions & 0 deletions transformer_engine/common/include/transformer_engine/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,31 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
* - Shape compatibility: if transa=false, transb=false:
* - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i])
*/
/*! \brief Return the required size in bytes for the setup workspace of grouped GEMM.
*
* The setup workspace stores pointer arrays and per-matrix dimension arrays used
* by the grouped GEMM kernel. Its size depends only on the number of tensors (GEMMs)
* in the group and is independent of matrix dimensions.
*
* Pass the result as the size of the workspace_setup tensor in nvte_grouped_gemm.
*
* \param[in] num_tensors Number of tensors (GEMMs) in the group.
* \return Required size in bytes for workspace_setup.
*/
size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors);

/*! \brief Convert a device array of int32 values to int64 values.
*
* Useful for preparing group_sizes for nvte_grouped_gemm when the caller
* holds int32 sizes and needs int64 values on the device.
*
* \param[in] src Device pointer to source int32 array.
* \param[out] dst Device pointer to destination int64 array.
* \param[in] n Number of elements.
* \param[in] stream CUDA stream.
*/
void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream);

void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
Expand Down
Loading
Loading