From 500113dd038f2bb01ef9371b2416f3bb644ac6e4 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 5 Jan 2026 23:32:23 -0800 Subject: [PATCH 1/3] fix Signed-off-by: Zhongbo Zhu --- transformer_engine/common/common.h | 4 ++- ...cast_col_hadamard_transform_cast_fusion.cu | 29 +++++++++++++------ .../transformer_engine/transformer_engine.h | 8 +++++ .../common/transformer_engine.cpp | 6 ++++ .../pytorch/csrc/extensions/cast.cpp | 7 +++++ 5 files changed, 44 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 0bc9536844..c1b7a5accb 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -395,6 +395,7 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; + NVTETensor tile_scheduler_workspace = nullptr; static constexpr size_t attr_sizes[] = { sizeof(bool), // force_pow_2_scales @@ -404,7 +405,8 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(bool), // nvfp4_2d_quantization sizeof(bool), // stochastic_rounding - sizeof(bool) // use_fast_math + sizeof(bool), // use_fast_math + sizeof(NVTETensor) // tile_scheduler_workspace }; }; diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 8b077f6f1f..cc0195042f 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -1125,8 +1125,9 @@ template (tile_scheduler_workspace), 0, + sizeof(uint32_t), stream)); // Launch kernel cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream}; @@ -1308,8 +1308,6 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz tile_scheduler_workspace, mma, rng_state); NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed."); - - NVTE_CHECK_CUDA(cudaFreeAsync(tile_scheduler_workspace, stream)); } } // namespace @@ -1399,6 +1397,17 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector(rng_state_tensor.data.dptr); } + uint32_t *tile_scheduler_workspace = nullptr; + NVTE_CHECK(quant_config.tile_scheduler_workspace != nullptr, + "Tile scheduler workspace must be provided."); + Tensor &tile_scheduler_workspace_tensor = + *convertNVTETensorCheck(quant_config.tile_scheduler_workspace); + NVTE_CHECK(tile_scheduler_workspace_tensor.dtype() == DType::kInt32 && + tile_scheduler_workspace_tensor.data.shape == std::vector{1}, + "Tile scheduler workspace must be a tensor with shape [1] and dtype int32."); + tile_scheduler_workspace = + reinterpret_cast(tile_scheduler_workspace_tensor.data.dptr); + // Template arguments using TA = cute::bfloat16_t; using TB = cute::bfloat16_t; @@ -1461,7 +1470,9 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector(rowwise_data_base_ptr), /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), /*args=*/kernel_args, - /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*rng_state=*/rng_state, + /*tile_scheduler_workspace=*/tile_scheduler_workspace, + /*sm_count=*/sm_count, /*stream=*/stream, /*k_tile_size=*/k_tile_size); } else { NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 7fc9d78980..b67ddba802 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -343,6 +343,8 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, + /*! Tile scheduler workspace (NVTETensor with 1 uint32_t element) */ + kNVTEQuantizationConfigTileSchedulerWorkspace = 8, kNVTEQuantizationConfigNumAttributes }; @@ -1009,6 +1011,12 @@ class QuantizationConfigWrapper { &use_fast_math, sizeof(bool)); } + /*! \brief Set tile scheduler workspace */ + void set_tile_scheduler_workspace(NVTETensor tile_scheduler_workspace) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigTileSchedulerWorkspace, + &tile_scheduler_workspace, sizeof(NVTETensor)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 370d9723cf..b15280ad34 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -902,6 +902,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: std::memcpy(buf, &config_.use_fast_math, attr_size); break; + case kNVTEQuantizationConfigTileSchedulerWorkspace: + std::memcpy(buf, &config_.tile_scheduler_workspace, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -949,6 +952,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: std::memcpy(&config_.use_fast_math, buf, attr_size); break; + case kNVTEQuantizationConfigTileSchedulerWorkspace: + std::memcpy(&config_.tile_scheduler_workspace, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 3bbc99b444..101448287b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -872,6 +872,13 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix); if (all_aligned_token_dim) { + // allocate a tile scheduler workspace + auto tile_scheduler_workspace_torch = + at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32)); + auto nvte_tile_scheduler_workspace = + makeTransformerEngineTensor(tile_scheduler_workspace_torch); + // assign the workspace tensor + quant_config_list[0].set_tile_scheduler_workspace(nvte_tile_scheduler_workspace.data()); // call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose nvte_group_hadamard_transform_cast_fusion( input.data(), reinterpret_cast(nvte_tensor_output_list.data()), From 80543a6f357e3551e05775a794e206d9aa84bb1c Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 6 Jan 2026 15:10:42 -0800 Subject: [PATCH 2/3] resolve review comments Signed-off-by: Zhongbo Zhu --- transformer_engine/common/common.h | 4 +--- ...cast_col_hadamard_transform_cast_fusion.cu | 22 +++++++++---------- .../transformer_engine/hadamard_transform.h | 3 ++- .../transformer_engine/transformer_engine.h | 8 ------- .../common/transformer_engine.cpp | 6 ----- .../pytorch/csrc/extensions/cast.cpp | 6 ++--- 6 files changed, 16 insertions(+), 33 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index c1b7a5accb..0bc9536844 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -395,7 +395,6 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; - NVTETensor tile_scheduler_workspace = nullptr; static constexpr size_t attr_sizes[] = { sizeof(bool), // force_pow_2_scales @@ -405,8 +404,7 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(bool), // nvfp4_2d_quantization sizeof(bool), // stochastic_rounding - sizeof(bool), // use_fast_math - sizeof(NVTETensor) // tile_scheduler_workspace + sizeof(bool) // use_fast_math }; }; diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index cc0195042f..1ef1f81e82 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -1316,7 +1316,8 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector &output_list, const size_t *split_sections, size_t num_tensors, const Tensor &hadamard_matrix_, - QuantizationConfig &quant_config, cudaStream_t stream) { + QuantizationConfig &quant_config, Tensor &quant_workspace, + cudaStream_t stream) { NVTE_API_CALL(group_hadamard_transform_cast_fusion); using transformer_engine::detail::kMaxTensorsPerKernel; @@ -1398,15 +1399,10 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector{1}, - "Tile scheduler workspace must be a tensor with shape [1] and dtype int32."); - tile_scheduler_workspace = - reinterpret_cast(tile_scheduler_workspace_tensor.data.dptr); + NVTE_CHECK(quant_workspace.data.dptr != nullptr, "Quantization workspace must be provided."); + NVTE_CHECK(quant_workspace.data.buffer_size_bytes() >= sizeof(uint32_t), + "Quantization workspace must be at least 4 bytes."); + tile_scheduler_workspace = reinterpret_cast(quant_workspace.data.dptr); // Template arguments using TA = cute::bfloat16_t; @@ -1489,7 +1485,7 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso const size_t *split_sections, const size_t num_tensors, const NVTEQuantizationConfig quant_config, - cudaStream_t stream) { + NVTETensor quant_workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_hadamard_transform_cast_fusion); using namespace transformer_engine; NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); @@ -1500,6 +1496,8 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso output_list[i] = convertNVTETensorCheck(outputs[i]); } + Tensor *quant_workspace_tensor = convertNVTETensorCheck(quant_workspace); + QuantizationConfig quant_config_cpp; if (quant_config != nullptr) { quant_config_cpp = *reinterpret_cast(quant_config); @@ -1508,5 +1506,5 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso // Call the multi-tensor Hadamard transform amax implementation. group_hadamard_transform_cast_fusion(*input_tensor, output_list, split_sections, num_tensors, *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, - stream); + *quant_workspace_tensor, stream); } diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index b6e9719aad..7f3f281473 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -115,13 +115,14 @@ void nvte_group_hadamard_transform_cast_fusion_columnwise( * \param[in] split_sections Array specifying splits in dimension 0 for each output tensor. * \param[in] num_tensors Number of output tensors, must be > 0. * \param[in] quant_config Quantization configuration. + * \param[in] quant_workspace Quantization workspace, a device buffer that can be opaque * \param[in] stream CUDA stream used for the operation. */ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor* outputs, const NVTETensor hadamard_matrix, const size_t* split_sections, size_t num_tensors, const NVTEQuantizationConfig quant_config, - cudaStream_t stream); + NVTETensor quant_workspace, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index b67ddba802..7fc9d78980 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -343,8 +343,6 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, - /*! Tile scheduler workspace (NVTETensor with 1 uint32_t element) */ - kNVTEQuantizationConfigTileSchedulerWorkspace = 8, kNVTEQuantizationConfigNumAttributes }; @@ -1011,12 +1009,6 @@ class QuantizationConfigWrapper { &use_fast_math, sizeof(bool)); } - /*! \brief Set tile scheduler workspace */ - void set_tile_scheduler_workspace(NVTETensor tile_scheduler_workspace) { - nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigTileSchedulerWorkspace, - &tile_scheduler_workspace, sizeof(NVTETensor)); - } - private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index b15280ad34..370d9723cf 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -902,9 +902,6 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: std::memcpy(buf, &config_.use_fast_math, attr_size); break; - case kNVTEQuantizationConfigTileSchedulerWorkspace: - std::memcpy(buf, &config_.tile_scheduler_workspace, attr_size); - break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -952,9 +949,6 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: std::memcpy(&config_.use_fast_math, buf, attr_size); break; - case kNVTEQuantizationConfigTileSchedulerWorkspace: - std::memcpy(&config_.tile_scheduler_workspace, buf, attr_size); - break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 101448287b..60dc26e0a9 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -877,12 +877,12 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32)); auto nvte_tile_scheduler_workspace = makeTransformerEngineTensor(tile_scheduler_workspace_torch); - // assign the workspace tensor - quant_config_list[0].set_tile_scheduler_workspace(nvte_tile_scheduler_workspace.data()); // call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose + // Note that the workspace field can be opaque, this will leave room for furture extensions nvte_group_hadamard_transform_cast_fusion( input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], stream); + rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], + nvte_tile_scheduler_workspace.data(), stream); } else { // Separate quantization for rowwise usage and columnwise usage // Rowwise quantization fusion with grouped version From eba28af26f54bcba2144aa3ab1cc1650c272735e Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 6 Jan 2026 15:44:26 -0800 Subject: [PATCH 3/3] Comment tweaks Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../common/include/transformer_engine/hadamard_transform.h | 2 +- transformer_engine/pytorch/csrc/extensions/cast.cpp | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index 7f3f281473..13103cc388 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -115,7 +115,7 @@ void nvte_group_hadamard_transform_cast_fusion_columnwise( * \param[in] split_sections Array specifying splits in dimension 0 for each output tensor. * \param[in] num_tensors Number of output tensors, must be > 0. * \param[in] quant_config Quantization configuration. - * \param[in] quant_workspace Quantization workspace, a device buffer that can be opaque + * \param[in] quant_workspace Workspace buffer. Must be at least 4 bytes. * \param[in] stream CUDA stream used for the operation. */ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor* outputs, diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 60dc26e0a9..4e5e5223f7 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -878,7 +878,6 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, auto nvte_tile_scheduler_workspace = makeTransformerEngineTensor(tile_scheduler_workspace_torch); // call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose - // Note that the workspace field can be opaque, this will leave room for furture extensions nvte_group_hadamard_transform_cast_fusion( input.data(), reinterpret_cast(nvte_tensor_output_list.data()), rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0],