Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644
Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644zianglih wants to merge 46 commits intoNVIDIA:mainfrom
NVTE_BACKWARD_MODE=default|unquant|dequant#2644Conversation
Greptile SummaryThis PR introduces Key changes:
Minor issues found:
The implementation is solid with extensive test coverage. Most critical issues were already identified in previous review threads. Confidence Score: 4/5
Important Files Changed
Last reviewed commit: 4f87a26 |
|
I'll work on potential unit test breakage. |
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?
| not ctx.use_bias | ||
| and not ctx.requires_wgrad | ||
| and ctx.grad_output_quantizer is not None | ||
| and use_fp8_bwd |
| recipe = cls.get_fp8_recipe() | ||
| if recipe is not None and recipe.delayed(): | ||
| # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used | ||
| return False |
There was a problem hiding this comment.
Maybe it's better to assert an error for delayed scaling? Okay with both.
There was a problem hiding this comment.
I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this seems redundant too if we skip quant in grad_output_preprocess
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
I have finished the refactor. All new unit tests passed on B200. No new failing tests in the entire pytorch test suite compared to main. |
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
| assert backward_mode == "default", ( | ||
| "NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP." | ||
| "Replace LayerNormMLP with LayerNormLinear + Linear to enable unquant/dequant backward." |
There was a problem hiding this comment.
missing space between sentences in assertion message
| assert backward_mode == "default", ( | |
| "NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP." | |
| "Replace LayerNormMLP with LayerNormLinear + Linear to enable unquant/dequant backward." | |
| assert backward_mode == "default", ( | |
| "NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP. " | |
| "Replace LayerNormMLP with LayerNormLinear + Linear to enable unquant/dequant backward." | |
| ) |
| if backward_mode == "unquant": | ||
| # Note, NVTE_BACKWARD_MODE=unquant is ignored when delayed scaling is used. |
There was a problem hiding this comment.
misleading comment - suggests the mode is "ignored" but DelayedScaling.__post_init__ will actually raise an assertion error
consider rephrasing to: "Note: DelayedScaling recipe does not support NVTE_BACKWARD_MODE=unquant (will raise assertion error)"
| fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() | ||
| if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): | ||
| input_quantizer.optimize_for_gemm = False | ||
| if grad_output_quantizer is not None: | ||
| grad_output_quantizer.optimize_for_gemm = False |
There was a problem hiding this comment.
missing comment explaining why optimize_for_gemm must be disabled for MXFP8/NVFP4 in dequant mode
add brief explanation of the constraint
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
@HumansAnd
Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.Add
NVTE_BACKWARD_MODE=default|unquant|dequantenv var:default: existing default quantization behaviorunquant: quantized fprop + high precision wgrad & dgrad using unquantized activation and weightdequant: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized valueType of change
Changes
Please list the changes introduced in this PR:
Checklist: