Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
cd8b8ad
Enable semantic roles emitted by module/op and comsumed by custom rec…
negvet Jan 23, 2026
fddeba4
Update quantization factories
negvet Jan 23, 2026
82b84ff
Fix tests
negvet Jan 23, 2026
4346231
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 23, 2026
a81f54a
Swap tensor:module
negvet Jan 27, 2026
700ea04
Better naming
negvet Jan 27, 2026
d7ca20b
Introduce QuantizerRole frozen data class instead of a string
negvet Feb 17, 2026
ed59556
Shrink module_type vocabulary
negvet Feb 19, 2026
ade46a6
Merge branch 'main' into semantic_quantizer_roles
negvet Feb 20, 2026
b1a4aed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2026
6e1ee37
Fix numerics exact test
negvet Feb 20, 2026
b9753f2
Set defaults, make custom recipe forward compatible
negvet Feb 20, 2026
ad67247
remove position from QuantizerRole
negvet Feb 20, 2026
e6be76a
Set good defaults
negvet Feb 20, 2026
a86fdad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2026
d323f66
Resolve naming: make every module/op distinguishable via name
negvet Feb 24, 2026
c9eae0f
Configure output/grad_input roles, defaults to None
negvet Feb 24, 2026
ea3c135
Remove is_gemm()
negvet Feb 24, 2026
aaf980f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2026
aad3512
Enable base recipes via CustomRecipe and quantization factories
Feb 25, 2026
8d7c91f
Add factory example - NVFP4 for Linear, MXFP8 for GroupedLinear
Feb 25, 2026
736cd72
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 25, 2026
ddf727c
Merge branch 'main' into semantic_quantizer_roles
negvet Feb 25, 2026
b6bfdf8
Fix custom recipe test
Feb 25, 2026
41656ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 25, 2026
cca370c
Test fine-grained quantization targets
negvet Feb 26, 2026
343f653
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions tests/pytorch/distributed/run_numerics_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.custom_recipes import quantization_nvfp4
from transformer_engine.pytorch.custom_recipes import quantization_ref_nvfp4
from transformer_engine.pytorch.custom_recipes import utils
from run_layer_with_overlap import _compare_tensors

Expand Down Expand Up @@ -56,39 +56,36 @@ def get_nvfp4_quantizer_factory():
enabled.

Returns:
A factory function that takes a role string and returns a quantizer instance
A factory function that takes a QuantizerRole and returns a quantizer instance
"""

def factory(role):
if role == "linear_input":
return quantization_nvfp4.NVFP4QuantizerRef(
if role.tensor_type == "input":
return quantization_ref_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True, # RHT enabled for input
with_rht=True,
)
elif role == "linear_weight":
return quantization_nvfp4.NVFP4QuantizerRef(
elif role.tensor_type == "weight":
return quantization_ref_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16), # 2D quantization for weight
quant_tile_shape=(16, 16),
pow_2_scales=False,
with_rht=False,
)
elif role == "linear_output":
# Output quantization not used
elif role.tensor_type == "output":
return None
elif role == "linear_grad_output":
return quantization_nvfp4.NVFP4QuantizerRef(
elif role.tensor_type == "grad_output":
return quantization_ref_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True, # RHT enabled for grad_output
with_rht=True,
)
elif role == "linear_grad_input":
# Grad input quantization not used
elif role.tensor_type == "grad_input":
return None
else:
# For any other roles, return None
return None

return factory
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils


Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/nvfp4/test_nvfp4_group_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling
Expand Down
23 changes: 10 additions & 13 deletions tests/pytorch/nvfp4/test_nvfp4_module_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.custom_recipes import quantization_nvfp4
from transformer_engine.pytorch.custom_recipes import quantization_ref_nvfp4
from transformer_engine.pytorch.custom_recipes import utils


Expand Down Expand Up @@ -76,39 +76,36 @@ def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bo
with_2d_quantization: Whether to use 2D quantization (16x16 tiles for weights)

Returns:
A factory function that takes a role string and returns a quantizer instance
A factory function that takes a QuantizerRole and returns a quantizer instance
"""

def factory(role):
if role == "linear_input":
return quantization_nvfp4.NVFP4QuantizerRef(
if role.tensor_type == "input":
return quantization_ref_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=with_rht,
)
elif role == "linear_weight":
return quantization_nvfp4.NVFP4QuantizerRef(
elif role.tensor_type == "weight":
return quantization_ref_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16),
pow_2_scales=False,
with_rht=False,
)
elif role == "linear_output":
# Output quantization not used
elif role.tensor_type == "output":
return None
elif role == "linear_grad_output":
return quantization_nvfp4.NVFP4QuantizerRef(
elif role.tensor_type == "grad_output":
return quantization_ref_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=with_rht,
)
elif role == "linear_grad_input":
# Grad input quantization not used
elif role.tensor_type == "grad_input":
return None
else:
# For any other roles, return None
return None

return factory
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling
Expand Down
Loading