Skip to content

Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644

Open
zianglih wants to merge 46 commits intoNVIDIA:mainfrom
zianglih:keep-bwd
Open

Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644
zianglih wants to merge 46 commits intoNVIDIA:mainfrom
zianglih:keep-bwd

Conversation

@zianglih
Copy link

@zianglih zianglih commented Feb 3, 2026

Description

@HumansAnd

Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.

Add NVTE_BACKWARD_MODE=default|unquant|dequant env var:

  • default: existing default quantization behavior
  • unquant: quantized fprop + high precision wgrad & dgrad using unquantized activation and weight
  • dequant: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized value

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 3, 2026

Greptile Summary

This PR introduces NVTE_BACKWARD_MODE environment variable to control backward pass precision in FP8 training. The implementation adds three modes: default (quantized backward), unquant (saves original high-precision operands), and dequant (dequantizes saved quantized operands during backward).

Key changes:

  • Added backward_mode field to all recipe classes with proper validation
  • DelayedScaling only supports default mode (enforced via assertion)
  • LayerNormMLP only supports default mode (enforced via assertion)
  • Modified saving logic to preserve original vs quantized tensors based on mode
  • Comprehensive test coverage added with 1446 lines of tests
  • Properly disables FP8 quantization flags and quantizers for non-default modes
  • Special handling for empty grouped splits in dequant mode
  • Fuser tracks backward_mode to trigger rebuilds when mode changes

Minor issues found:

  • Formatting: Missing space in LayerNormMLP assertion message
  • Documentation: Misleading comment about DelayedScaling compatibility
  • Documentation: Missing explanation for optimize_for_gemm disabling with MXFP8/NVFP4

The implementation is solid with extensive test coverage. Most critical issues were already identified in previous review threads.

Confidence Score: 4/5

  • Safe to merge with minor documentation/style improvements recommended
  • Comprehensive implementation with extensive test coverage (1446 lines). Core logic is sound with proper tensor saving/restoration, quantization control, and edge case handling. Only minor style and documentation issues found. Previous review threads already identified and discussed the major design constraints (DelayedScaling/LayerNormMLP limitations).
  • No files require special attention - the minor style/documentation issues can be addressed in follow-up or accepted as-is

Important Files Changed

Filename Overview
transformer_engine/common/recipe/init.py Adds backward_mode field to all recipe classes with validation; DelayedScaling enforces default mode only
transformer_engine/pytorch/ops/basic/basic_linear.py Implements backward_mode logic for saving original vs quantized tensors; disables columnwise usage for non-default modes
transformer_engine/pytorch/module/linear.py Forces save_original_input=True for unquant mode; dequantizes operands in backward pass; comment about "ignored" is misleading
transformer_engine/pytorch/module/layernorm_linear.py Saves high-precision ln_out_hp for unquant mode; dequantizes in backward for dequant mode; properly handles saved tensors
transformer_engine/pytorch/module/grouped_linear.py Implements dequant with special handling for empty splits; forces save_original_input=True for unquant mode
tests/pytorch/test_backward_mode.py Comprehensive test coverage for backward modes across multiple modules, recipes, and edge cases; includes dequant reference validation

Last reviewed commit: 4f87a26

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zianglih
Copy link
Author

zianglih commented Feb 3, 2026

I'll work on potential unit test breakage.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?

not ctx.use_bias
and not ctx.requires_wgrad
and ctx.grad_output_quantizer is not None
and use_fp8_bwd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

recipe = cls.get_fp8_recipe()
if recipe is not None and recipe.delayed():
# Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to assert an error for delayed scaling? Okay with both.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems redundant too if we skip quant in grad_output_preprocess

pre-commit-ci bot and others added 21 commits February 24, 2026 15:28
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih
Copy link
Author

I have finished the refactor. All new unit tests passed on B200. No new failing tests in the entire pytorch test suite compared to main.

+ python3 -m pytest --tb=auto --junitxml=/sgl-workspace/logs/te_pr_full_pytorch_unittest_20260224_185629/xml/pytest_test_backward_mode.xml /root/TransformerEngine-pr/tests/pytorch/test_backward_mode.py
============================= test session starts ==============================
platform linux -- Python 3.12.3, pytest-8.2.1, pluggy-1.6.0
rootdir: /root/TransformerEngine-pr
configfile: pyproject.toml
plugins: hydra-core-1.3.2, anyio-4.12.1, typeguard-4.4.4
collected 1034 items

tests/pytorch/test_backward_mode.py ........ssssssss.sss...s...s..ss...s [  3%]
...s..ss...s...s..ss...s...s..ssssssssss.sss...s...s..ss...s...s..ss...s [ 10%]
...s..ss...s...s..ssssssssss.sss..........s...........s...........s..... [ 17%]
......s.ssssssss.sss..........s...........s...........s...........s.ssss [ 24%]
ssss.sss...s...s..ss...s...s..ss...s...s..ss...s...s..ssssssssss.sss...s [ 31%]
...s..ss...s...s..ss...s...s..ss...s...s..ssssssssss.sss...s...s..ss...s [ 38%]
...s..ss...s...s..ss...s...s..ssssssssss.sss...s...s..ss...s...s..ss...s [ 45%]
...s..ss...s...s..ssssssssss.sss..........s...........s...........s..... [ 52%]
......s.ssssssss.sss..........s...........s...........s...........s.ssss [ 59%]
ssss.sss...s...s..ss...s...s..ss...s...s..ss...s...s..ssssssssss.sss...s [ 66%]
...s..ss...s...s..ss...s...s..ss...s...s..ss...s...s...s...s.sss.sss...s [ 73%]
...s...s...s.sss.sss...s...s...s...s.sss.sss...s...s...s...s.sss.sss...s [ 80%]
...s...s...s.sss.sss...s...s...s...s.sss.sss.sss.sss..ss..ss.sss.sss..s. [ 87%]
..s..sss.sss..ss..ss.sss.sss..s...s...ss..ss..s...s...ss..ss..s...s...s. [ 94%]
..s...s...s...s...s...ss..s...ss..ss..s...ss.....s.......s....           [100%]

- generated xml file: /sgl-workspace/logs/te_pr_full_pytorch_unittest_20260224_185629/xml/pytest_test_backward_mode.xml -
======================= 632 passed, 402 skipped in 8.56s =======================

@zianglih zianglih marked this pull request as ready for review February 25, 2026 03:55
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Comment on lines +242 to +244
assert backward_mode == "default", (
"NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP."
"Replace LayerNormMLP with LayerNormLinear + Linear to enable unquant/dequant backward."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing space between sentences in assertion message

Suggested change
assert backward_mode == "default", (
"NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP."
"Replace LayerNormMLP with LayerNormLinear + Linear to enable unquant/dequant backward."
assert backward_mode == "default", (
"NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP. "
"Replace LayerNormMLP with LayerNormLinear + Linear to enable unquant/dequant backward."
)

Comment on lines +136 to +137
if backward_mode == "unquant":
# Note, NVTE_BACKWARD_MODE=unquant is ignored when delayed scaling is used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

misleading comment - suggests the mode is "ignored" but DelayedScaling.__post_init__ will actually raise an assertion error

consider rephrasing to: "Note: DelayedScaling recipe does not support NVTE_BACKWARD_MODE=unquant (will raise assertion error)"

Comment on lines +1542 to +1546
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()):
input_quantizer.optimize_for_gemm = False
if grad_output_quantizer is not None:
grad_output_quantizer.optimize_for_gemm = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing comment explaining why optimize_for_gemm must be disabled for MXFP8/NVFP4 in dequant mode

add brief explanation of the constraint

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants