-
Notifications
You must be signed in to change notification settings - Fork 590
[PyTorch]Add Casting-Free FP8-Flow-MoE Blockwise Optimizations #2544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Quick question, does this work with mxfp8 or it only applies to fp8? Thanks. |
@rich-junwang , Thanks for asking. This PR primarily targets the standard FP8 blockwise recipe, and the current scaling-aware FP8 transpose implementation is specialized for blockwise=128 scaling. Among the five optimizations described in this PR: |
f4115c3 to
d69001f
Compare
1. add fp8 rowwise scaling-aware transpose op for wgrad columwise. 2. support Float8BlockwiseQTensor input in grouped_linear. 3. _rowwise_scale_inv is propagated with a COMPACT layout along the `dispatch → permute → GroupedLinear` path. Signed-off-by: xiaoxi-wangfj <690912414@qq.com> Co-authored-by: dantesuu@gmail.com Co-authored-by: xzhu@zhejianglab.org Co-authored-by: 123sssmmm@gmail.com
d69001f to
d8264ea
Compare
Greptile SummaryThis PR implements a casting-free FP8 MoE dataflow optimization by introducing scaling-aware FP8 transpose operations that eliminate unnecessary dequantization/requantization steps. Key Changes
Architecture IntegrationThe changes fit well into the existing FP8 infrastructure. The new path activates when blockwise-quantized tensors flow through Issues Found
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Input
participant GL as GroupedLinear
participant FB as Float8BlockwiseQTensor
participant TK as Triton Kernel
participant GEMM as GroupedGEMM
Note over User,GEMM: Forward Pass - FP8 Quantization Before Dispatch
User->>GL: forward(inp, m_splits)
GL->>GL: Check if inp is Float8BlockwiseQTensor
alt Input is Float8BlockwiseQTensor
GL->>FB: split_scaling_aware_fp8_transpose(m_splits)
FB->>FB: Convert GEMM_READY → COMPACT format
FB->>TK: blockwise_scaling_aware_fp8_transpose()
Note over TK: Row-wise FP8 → Column-wise FP8<br/>via exponent manipulation
TK->>TK: Extract sign, exp, mantissa
TK->>TK: Compute k = exp_target - exp_source
TK->>TK: Adjust exponent: exp_new = exp - k
TK->>TK: Transpose and write columnwise data
TK-->>FB: [rowwise_data, rowwise_scale_inv_t,<br/>columnwise_data, columnwise_scale_inv]
FB-->>GL: List of Float8BlockwiseQTensorStorage
else Input is standard tensor
GL->>GL: tex.split_quantize(inp_view, m_splits)
end
GL->>GEMM: general_grouped_gemm(weights, inputmats)
GEMM-->>GL: output
GL-->>User: result
Note over User,GEMM: Backward Pass - Scaling-Aware Transpose for Gradients
User->>GL: backward(grad_output)
GL->>GL: Check if grad_output is Float8BlockwiseQTensor
alt grad_output is Float8BlockwiseQTensor
GL->>FB: split_scaling_aware_fp8_transpose(m_splits)
FB->>TK: blockwise_scaling_aware_fp8_transpose()
TK-->>FB: transposed gradients
FB-->>GL: quantized grad tensors
else Standard grad_output
GL->>GL: tex.split_quantize()
end
GL->>GEMM: dgrad computation
GL->>GEMM: wgrad computation
GEMM-->>GL: gradients
GL-->>User: dgrad, wgrad
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (3)
-
transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py, line 100-101 (link)logic: potential exponent underflow not handled - when
exp_new < 0(not just< 1), the result could wrap around sinceexpis unsigned. consider clamping to 0 for negative values -
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py, line 453-456 (link)style: modifying
self._data_formatin-place during a method call can cause unexpected side effects if the tensor is reused. consider creating a new tensor or documenting this mutation clearly -
transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py, line 94-99 (link)style: verify that the scaling adjustment works correctly when
target_siandsidiffer by more than 127 in exponent (i.e., orders of magnitude apart). have you validated this scaling adjustment with extreme scale differences where target_si and si magnitudes differ significantly?
5 files reviewed, 3 comments
Description
This PR introduces blockwise, scaling-aware FP8 transpose optimizations for FP8 MoE that enable a casting-free, FP8-centric MoE dataflow in TransformerEngine by eliminating unnecessary cast and re-quantization steps, while maintaining numerical stability in existing FP8 training workflows.
This PR is designed to be used in conjunction with PR NVIDIA/Megatron-LM#2764
Further optimizations are introduced via two additional PRs:
Background / Motivation
The design and theoretical background of this PR are described in the paper:
FP8-Flow-MoE: A Casting-Free FP8 Recipe without Double Quantization Error
The follow figure illustrates the optimized MoE dataflow and highlights the key optimization points (marked as ①–⑤).
1. FP8 Quantization Before Dispatch (DeepEP → GroupedLinear)
Quantization is performed before DeepEP dispatch, and row-wise FP8 tensors are directly fed into GroupedLinear.
dispatch → permute → expert computationentirely in FP8Float8BlockwiseQTensoris propagated with a COMPACT layout (for_rowwise_scale_inv) along thedispatch → permute → GroupedLinearpath, avoiding layout-induced.T.contiguous()calls and reducing unnecessary memory copies.(Shown as marker ① in the figure)
2. Scaling-Aware FP8 Transpose for Wgrad
GroupedLinear requires:
To avoid
dequantize → transpose → requantize, this PR introducesscaling_aware_fp8_transpose, which:(Shown as marker ④ in the figure)
3. Fused Permute + Padding / Unpermute + Unpadding
We fuse two memory movement operators along the MoE path:
permute + padin the forward passunpermute + unpadin the backward passFor details of this optimization, please refer to PR #1921
(Shown as marker ② in the figure)
4. Fused Activation + Quantization
Activation and FP8 quantization are fused into a single kernel, Produces FP8 outputs directly, while enabling FP8 persistence
(Shown as marker ③ in the figure)
5. Add fine-grained recompute
moe_expertBecause the entire
dispatch → permute → GroupedLinearpath stays in FP8, we enable fine-grained recomputation at themoe_expertlevel:moelevel(Shown as marker ⑤ in the figure)
Performance Results
We evaluate FP8-Flow-MoE on DeepSeek-V3 (671B) to validate scalability and robustness under realistic large-scale training conditions.
Throughput
Measured throughput (TGS, tokens/GPU/s) under different expert parallelism (EP) on DeepSeek-V3 (671B) :
vs. BF16
+6% (EP8), +8% (EP16), +16% (EP32)
vs. TransformerEngine blockwise FP8 recipe
+3% (EP8), +8% (EP16), up to +21% (EP32)
Memory Efficiency
With AC = selective checkpointing and recompute-modules = moe_expert:
Numerical Accuracy
We trained for >200B tokens. The loss deviation of FP8-Flow-MoE stays within 0.19% compared to both BF16 baselines and TransformerEngine blockwise FP8 recipe, with no observed instability or divergence.
Limitations
Type of change
Changes
Please list the changes introduced in this PR:
fused_bias_swiglu.pyandfused_weighted_swiglu_quant.pyfused_a2a.pyFloat8BlockwiseQTensorinputs ingrouped_linear.pyscaling_aware_fp8_transposeoperator intriton/blockwise_scaling_aware_fp8_transpose.pyChecklist: