diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index 0f3d2cbbf0..7ddbb18077 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -22,7 +22,7 @@ ) from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE -from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 +from transformer_engine.pytorch.custom_recipes import quantization_ref_nvfp4 from transformer_engine.pytorch.custom_recipes import utils from run_layer_with_overlap import _compare_tensors @@ -56,39 +56,36 @@ def get_nvfp4_quantizer_factory(): enabled. Returns: - A factory function that takes a role string and returns a quantizer instance + A factory function that takes a QuantizerRole and returns a quantizer instance """ def factory(role): - if role == "linear_input": - return quantization_nvfp4.NVFP4QuantizerRef( + if role.tensor_type == "input": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, - with_rht=True, # RHT enabled for input + with_rht=True, ) - elif role == "linear_weight": - return quantization_nvfp4.NVFP4QuantizerRef( + elif role.tensor_type == "weight": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(16, 16), # 2D quantization for weight + quant_tile_shape=(16, 16), pow_2_scales=False, with_rht=False, ) - elif role == "linear_output": - # Output quantization not used + elif role.tensor_type == "output": return None - elif role == "linear_grad_output": - return quantization_nvfp4.NVFP4QuantizerRef( + elif role.tensor_type == "grad_output": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, - with_rht=True, # RHT enabled for grad_output + with_rht=True, ) - elif role == "linear_grad_input": - # Grad input quantization not used + elif role.tensor_type == "grad_input": return None else: - # For any other roles, return None return None return factory diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 911b7660dc..8afb103056 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -8,7 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 01a4a01205..019c5bd566 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -13,7 +13,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py index a96fea3af0..d727433a28 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -6,7 +6,7 @@ import torch import transformer_engine.pytorch as te from transformer_engine.common import recipe -from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 +from transformer_engine.pytorch.custom_recipes import quantization_ref_nvfp4 from transformer_engine.pytorch.custom_recipes import utils @@ -76,39 +76,36 @@ def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bo with_2d_quantization: Whether to use 2D quantization (16x16 tiles for weights) Returns: - A factory function that takes a role string and returns a quantizer instance + A factory function that takes a QuantizerRole and returns a quantizer instance """ def factory(role): - if role == "linear_input": - return quantization_nvfp4.NVFP4QuantizerRef( + if role.tensor_type == "input": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=with_rht, ) - elif role == "linear_weight": - return quantization_nvfp4.NVFP4QuantizerRef( + elif role.tensor_type == "weight": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16), pow_2_scales=False, with_rht=False, ) - elif role == "linear_output": - # Output quantization not used + elif role.tensor_type == "output": return None - elif role == "linear_grad_output": - return quantization_nvfp4.NVFP4QuantizerRef( + elif role.tensor_type == "grad_output": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=with_rht, ) - elif role == "linear_grad_input": - # Grad input quantization not used + elif role.tensor_type == "grad_input": return None else: - # For any other roles, return None return None return factory diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 80ccb2f23d..99f0f5cdd6 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -7,7 +7,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.pytorch.constants import TE_DType diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 98be9a4f54..a9178b25d8 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -12,7 +12,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 4de49115b3..612e254bf9 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -16,8 +16,15 @@ GroupedLinear, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.quantization import QuantizerRole import transformer_engine.pytorch.ops as te_ops -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import ( +from transformer_engine.pytorch.custom_recipes.quantization_recipes_base import ( + current_scaling_quantizer_factory, + mxfp8_quantizer_factory, + float8_block_scaling_quantizer_factory, + nvfp4_quantizer_factory, +) +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import ( nvfp4_ref_rht_2d_quantizer_factory, ) @@ -90,9 +97,9 @@ def test_custom_recipe_sanity(module_type): # Single factory: map roles to quantizers def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if role.tensor_type in ("grad_output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -127,9 +134,9 @@ def test_custom_recipe_grouped_linear_sanity(): inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if role.tensor_type in ("grad_output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -189,9 +196,9 @@ def test_custom_recipe_matches_current_scaling(): # Custom: single factory returning quantizers per role to match Float8CurrentScaling def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if role.tensor_type in ("grad_output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -199,7 +206,12 @@ def quantizer_factory(role): with autocast(enabled=True, recipe=custom_recipe): out_custom = model_custom(inp_custom) - # Assert dtypes for custom quantizers match reference mapping + # Assert dtypes for custom quantizers match reference mapping. + # The output (fwd) and grad_input (bwd) slots receive role=None + # (unknown consumer) and get E4M3 from our factory. The reference + # recipe uses E4M3 for fwd output and E5M2 for bwd grad_input, + # but these quantizers are typically unused so the mismatch doesn't + # affect GEMM results. cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] cus_fwd_w = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] cus_fwd_out = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] @@ -209,7 +221,7 @@ def quantizer_factory(role): assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3 assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3 assert cus_bwd_go.dtype == tex.DType.kFloat8E5M2 - assert cus_bwd_gi.dtype == tex.DType.kFloat8E5M2 + assert cus_bwd_gi.dtype == tex.DType.kFloat8E4M3 # role=None fallback loss_custom = (out_custom.float() * scale.view(1, -1)).sum() loss_custom.backward() @@ -246,9 +258,9 @@ def test_custom_recipe_ops_linear_2_1_layout(): inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if role.tensor_type in ("grad_output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -276,39 +288,40 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): op = Linear(in_features, out_features, params_dtype=torch.bfloat16) inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) - # Counters per role + # Counters per tensor_type. The output (fwd) and grad_input (bwd) + # slots have role=None by default (unknown consumer), so we count + # those separately. counts = { - "linear_input": 0, - "linear_weight": 0, - "linear_output": 0, - "linear_grad_output": 0, - "linear_grad_input": 0, + "input": 0, + "weight": 0, + "grad_output": 0, + None: 0, } def quantizer_factory(role): - if role in counts: - counts[role] += 1 - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: + counts[None] += 1 return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) - if role in ("linear_grad_output", "linear_grad_input"): + assert isinstance(role, QuantizerRole), f"Expected QuantizerRole, got {type(role)}" + assert role.module_type == "linear" + if role.tensor_type in counts: + counts[role.tensor_type] += 1 + if role.tensor_type == "grad_output": return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda")) return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) custom = recipe.CustomRecipe(qfactory=quantizer_factory) - # Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory), - # and backward to build 2 quantizers (cycled from 1 factory). with autocast(enabled=True, recipe=custom): out = op(inp) loss = out.float().sum() loss.backward() - # Single GEMM: forward should request input, weight, output; backward grad_output, grad_input - assert counts["linear_input"] == 1 - assert counts["linear_weight"] == 1 - assert counts["linear_output"] == 1 - assert counts["linear_grad_output"] == 1 - assert counts["linear_grad_input"] == 1 + # Forward: input, weight, output(None); backward: grad_output, grad_input(None) + assert counts["input"] == 1 + assert counts["weight"] == 1 + assert counts["grad_output"] == 1 + assert counts[None] == 2, f"Expected 2 None roles (output + grad_input), got {counts[None]}" def test_factories_return_distinct_instances_and_buffers(): @@ -330,3 +343,331 @@ def factory(): # Mutating one should not affect the other q1.scale.fill_(123.0) assert not torch.equal(q1.scale, q2.scale) + + +def _run_linear_fwd_bwd(model, inp, recipe): + """Run forward + backward with a given recipe and return (output, inp.grad, param grads).""" + with autocast(enabled=True, recipe=recipe): + out = model(inp) + loss = out.float().sum() + loss.backward() + param_grads = {n: p.grad.clone() for n, p in model.named_parameters() if p.grad is not None} + return out.clone(), inp.grad.clone(), param_grads + + +def _make_pair(in_features=128, out_features=128, batch=32, seed=42): + """Create a pair of identical Linear models and matching inputs.""" + torch.manual_seed(seed) + model_ref = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + model_cus = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + model_cus.load_state_dict(model_ref.state_dict()) + + base_inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16) + inp_ref = base_inp.clone().detach().requires_grad_(True) + inp_cus = base_inp.clone().detach().requires_grad_(True) + return model_ref, model_cus, inp_ref, inp_cus + + +def _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus): + """Assert exact match of outputs and all gradients.""" + assert torch.allclose( + out_ref, out_cus, rtol=0.0, atol=0.0 + ), f"Forward mismatch: max diff = {(out_ref - out_cus).abs().max()}" + assert torch.allclose( + grad_ref, grad_cus, rtol=0.0, atol=0.0 + ), f"Input grad mismatch: max diff = {(grad_ref - grad_cus).abs().max()}" + for name in pgrads_ref: + assert torch.allclose(pgrads_ref[name], pgrads_cus[name], rtol=0.0, atol=0.0), ( + f"Param grad '{name}' mismatch: max diff = " + f"{(pgrads_ref[name] - pgrads_cus[name]).abs().max()}" + ) + + +def test_factory_matches_current_scaling(): + """current_scaling_quantizer_factory should produce bit-identical results + to the built-in Float8CurrentScaling recipe.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.Float8CurrentScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=current_scaling_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_mxfp8(): + """mxfp8_quantizer_factory should produce bit-identical results + to the built-in MXFP8BlockScaling recipe.""" + available, reason = te.is_mxfp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"MXFP8 unsupported: {reason}") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.MXFP8BlockScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=mxfp8_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_block_scaling(): + """float8_block_scaling_quantizer_factory should produce bit-identical results + to the built-in Float8BlockScaling recipe.""" + available = te.is_fp8_block_scaling_available() + if not torch.cuda.is_available() or not available: + pytest.skip("Float8 block scaling unsupported on this device") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.Float8BlockScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=float8_block_scaling_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_nvfp4(): + """nvfp4_quantizer_factory should produce bit-identical results + to the built-in NVFP4BlockScaling recipe.""" + available = te.is_nvfp4_available() + if not torch.cuda.is_available() or not available: + pytest.skip("NVFP4 unsupported on this device") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.NVFP4BlockScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=nvfp4_quantizer_factory) + ) + + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_custom_recipe_quantization_targets(): + """Validate fine-grained per-module quantization targeting via QuantizerRole. + + Four transformer layers, each assembled at a different abstraction level. + The default recipe is NVFP4; specific modules are overridden: + + Layer 0 - ``TransformerLayer`` (name="tl0") -> all MXFP8 + Layer 1 - ``TransformerLayer`` (name="tl1") -> NVFP4 (default), + except fc2 overridden to MXFP8 + Layer 2 - ``MultiheadAttention`` + ``LayerNormMLP`` + (name prefix "tl2") -> NVFP4 (default), + except qkv and fc1 overridden to Float8 block-scaling + Layer 3 - Individual blocks (name prefix "tl3") -> NVFP4 (default), + except proj overridden to Float8 current-scaling + + The test validates that: + * The factory receives QuantizerRole objects with correct names + * Different quantizer types are dispatched per module + * Forward + backward complete successfully through all four layers + """ + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + if not te.is_mxfp8_available(): + pytest.skip("MXFP8 unsupported on this device") + if not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + if not te.is_fp8_block_scaling_available(): + pytest.skip("Float8 block scaling unsupported on this device") + + torch.manual_seed(42) + + H = 64 # hidden_size + FFN = 64 # ffn_hidden_size + NH = 4 # num_heads + KV = H // NH # kv_channels + B = 4 # batch + S = 8 # seq_len + common = dict(params_dtype=torch.bfloat16, bias=False) + + # Layer 0: TransformerLayer -> MXFP8 + tl0 = te.TransformerLayer( + H, + FFN, + NH, + hidden_dropout=0.0, + attention_dropout=0.0, + name="tl0", + **common, + ).cuda() + + # Layer 1: TransformerLayer -> NVFP4 default, fc2 overridden to MXFP8 + tl1 = te.TransformerLayer( + H, + FFN, + NH, + hidden_dropout=0.0, + attention_dropout=0.0, + name="tl1", + **common, + ).cuda() + + # Layer 2: MHA + LayerNormMLP -> NVFP4 default, qkv and fc1 to block-scaling + tl2_mha = te.MultiheadAttention( + H, + NH, + KV, + attention_dropout=0.0, + input_layernorm=True, + return_bias=True, + name="tl2.self_attention", + **common, + ).cuda() + tl2_mlp = LayerNormMLP(H, FFN, name="tl2.layernorm_mlp", **common).cuda() + + # Layer 3: Individual blocks with DPA -> NVFP4 default, proj to current-scaling + tl3_qkv = LayerNormLinear(H, 3 * H, name="tl3.qkv", **common).cuda() + tl3_dpa = te.DotProductAttention(NH, KV, attention_dropout=0.0, name="tl3.core_attention") + tl3_proj = Linear(H, H, name="tl3.proj", **common).cuda() + tl3_fc1 = LayerNormLinear(H, FFN, name="tl3.fc1", **common).cuda() + tl3_fc2 = Linear(FFN, H, name="tl3.fc2", **common).cuda() + + # ------------------------------------------------------------------ + # Recording + dispatching factory + # ------------------------------------------------------------------ + recorded_roles = [] + + def targeting_factory(role): + recorded_roles.append(role) + + if role is None: + return nvfp4_quantizer_factory(role) + + assert isinstance(role, QuantizerRole), f"Expected QuantizerRole, got {type(role)}" + + # Layer 0 (tl0.*): all MXFP8 + if role.name.startswith("tl0"): + return mxfp8_quantizer_factory(role) + + # Layer 1 (tl1.*): NVFP4 default, but fc2 overridden to MXFP8 + if role.name == "tl1.layernorm_mlp.fc2": + return mxfp8_quantizer_factory(role) + + # Layer 2: block scaling for qkv and fc1, rest falls through to default + if role.name == "tl2.self_attention.layernorm_linear_qkv": + return float8_block_scaling_quantizer_factory(role) + if role.name == "tl2.layernorm_mlp.fc1": + return float8_block_scaling_quantizer_factory(role) + + # Layer 3: current-scaling for proj, rest falls through to default + if role.name == "tl3.proj": + return current_scaling_quantizer_factory(role) + + # Default: NVFP4 + return nvfp4_quantizer_factory(role) + + custom_recipe = recipe.CustomRecipe(qfactory=targeting_factory) + + # ------------------------------------------------------------------ + # Forward + backward + # ------------------------------------------------------------------ + inp = torch.randn(S, B, H, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + with autocast(enabled=True, recipe=custom_recipe): + # Layer 0 & 1: TransformerLayer + h = tl1(tl0(inp)) + + # Layer 2: MHA + residual + LayerNormMLP + residual + attn_out, _ = tl2_mha(h) + h = h + attn_out + h = h + tl2_mlp(h) + + # Layer 3: individual blocks with DPA + residual = h + qkv = tl3_qkv(h).view(S, B, 3, NH, KV) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + attn = tl3_dpa(q, k, v).view(S, B, H) + h = residual + tl3_proj(attn) + residual = h + h = residual + tl3_fc2(torch.nn.functional.gelu(tl3_fc1(h))) + + loss = h.float().sum() + loss.backward() + + # ------------------------------------------------------------------ + # Assertions + # ------------------------------------------------------------------ + + assert inp.grad is not None, "Input gradient is None" + + # -- Name propagation check -- + # The factory dispatches on role.name, so if a TE module fails to propagate + # names (e.g. TransformerLayer -> MHA -> LayerNormLinear) the factory would + # silently fall through to the default recipe. The quantizer-type assertions + # below would catch that too, but checking names explicitly gives a clearer + # error message pointing at the broken name rather than a wrong quantizer type. + role_names = {r.name for r in recorded_roles if r is not None} + + def _tl_names(prefix): + """Expected role names for a standard TransformerLayer with given prefix.""" + return { + f"{prefix}.self_attention.layernorm_linear_qkv", + f"{prefix}.self_attention.proj", + f"{prefix}.layernorm_mlp.fc1", + f"{prefix}.layernorm_mlp.fc2", + } + + all_expected = ( + _tl_names("tl0") + | _tl_names("tl1") + | _tl_names("tl2") + | {"tl3.qkv", "tl3.proj", "tl3.fc1", "tl3.fc2"} + ) + missing = all_expected - role_names + assert not missing, ( + f"Expected module names not seen in QuantizerRole.name: {missing}\n" + f"Recorded names: {sorted(role_names)}" + ) + + for r in recorded_roles: + if r is not None and r.module_type: + assert r.module_type == "linear", f"Unexpected module_type={r.module_type} for role {r}" + + # -- Quantizer-type checks -- + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer + + def _check_q(mod, expected_cls, label=""): + q = mod.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + assert isinstance(q, expected_cls), ( + f"{mod.name}{' (' + label + ')' if label else ''}: " + f"expected {expected_cls.__name__}, got {type(q).__name__}" + ) + + # Layer 0: all MXFP8 + _check_q(tl0.self_attention.layernorm_qkv, MXFP8Quantizer) + _check_q(tl0.self_attention.proj, MXFP8Quantizer) + + # Layer 1: NVFP4 default, fc2 overridden to MXFP8 + _check_q(tl1.self_attention.layernorm_qkv, NVFP4Quantizer, "default") + _check_q(tl1.self_attention.proj, NVFP4Quantizer, "default") + assert any( + r is not None and r.name == "tl1.layernorm_mlp.fc2" and r.tensor_type == "input" + for r in recorded_roles + ), "tl1.layernorm_mlp.fc2 input role not recorded" + + # Layer 2: block-scaling on qkv and fc1, NVFP4 on proj and fc2 + _check_q(tl2_mha.layernorm_qkv, Float8BlockQuantizer) + _check_q(tl2_mha.proj, NVFP4Quantizer, "default") + + # Layer 3: current-scaling on proj, NVFP4 on everything else + _check_q(tl3_proj, Float8CurrentScalingQuantizer) + for mod in [tl3_qkv, tl3_fc1, tl3_fc2]: + _check_q(mod, NVFP4Quantizer, "default") diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 99ab9c4984..3b964a5af9 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -14,7 +14,7 @@ from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.custom_recipes.quantization import MMParams -from transformer_engine.pytorch.custom_recipes.quantization_current_scaling import ( +from transformer_engine.pytorch.custom_recipes.quantization_ref_current_scaling import ( CurrentScalingQuantizerRef, ) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 18577b0eb4..1db8110ae3 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -499,19 +499,24 @@ class CustomRecipe(Recipe): Parameters ---------- qfactory : Callable - Factory callable that returns a quantizer instance for a - given semantic tensor role. - The callable is typically invoked as:: + Factory callable that returns a quantizer instance for a given `QuantizerRole`. + The callable is invoked as:: qfactory( - role: str, - ) + role: QuantizerRole, + ) -> Optional[Quantizer] - Where `role` is one of the following strings for e.g. te.Linear - (stable public contract): + `QuantizerRole` is a frozen dataclass with the following fields: - - forward: "linear_input", "linear_weight", "linear_output" - - backward: "linear_grad_output", "linear_grad_input" + - `module_type` (str): module type (empty string when not set), e.g. + `"linear"`, `"grouped_linear"`, `"dpa"`. + - `tensor_type` (str): what tensor is being quantized (empty + string when not set), e.g. `"input"`, `"weight"`, `"grad_output"`. + - `name` (str): caller-provided module instance name (empty + string when not set), e.g. `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`. + + See `transformer_engine.pytorch.quantization.QuantizerRole` + for full documentation. """ qfactory: Callable[..., Any] diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 5e1eb6954b..4880959546 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -48,6 +48,7 @@ from transformer_engine.pytorch.quantization import is_fp8_block_scaling_available from transformer_engine.pytorch.quantization import is_nvfp4_available from transformer_engine.pytorch.quantization import get_default_recipe +from transformer_engine.pytorch.quantization import QuantizerRole from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import is_bf16_available diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 64db4646f6..875ba21fab 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -284,6 +284,8 @@ class DotProductAttention(TransformerEngineBaseModule): `_). :math:`\text{max_logit} = \max(S)`, where :math:`S = \text{mask}(Q \cdot K^T \cdot \text{softmax_scale} + \text{bias})` of shape ``[b, h, s_q, s_kv]``, and :math:`\text{max_logit}` is of shape ``[h]``. + name : Optional[str], default = None + module instance name. Parallelism parameters ---------------------- @@ -343,8 +345,9 @@ def __init__( softmax_scale: Optional[float] = None, softmax_type: str = "vanilla", return_max_logit: Optional[bool] = False, + name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name=name) self.logger = logging.getLogger("DotProductAttention") self.logger.setLevel(attn_log._log_level) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 01c4955d78..78a1c4dde6 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -8,7 +8,7 @@ from typing import Callable, List, Optional, Tuple, Union import torch -from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.quantization import FP8GlobalStateManager, QuantizerRole from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm @@ -461,6 +461,7 @@ def __init__( layer_number=self.layer_number, attention_type=self.attention_type, softmax_type=self.softmax_type, + name=name + ".core_attention" if name is not None else None, ) # Linear @@ -478,6 +479,42 @@ def __init__( **common_gemm_kwargs, ) + def _update_output_quantizer_roles( + self, + qkv_fp8_output: bool, + proj_fp8_grad: bool, + ) -> None: + """Set output / grad-input quantizer roles on QKV and proj linears. + + When the QKV linear's output feeds directly into DPA (``fp8_mha``), + the role is switched from the default linear-consumer assumption to + DPA-consumer roles. Otherwise roles are reset to ``None`` so the + modules fall back to their defaults. + """ + dpa_name = self.core_attention.name or "" + qkv_output_role = ( + QuantizerRole(module_type="dpa", tensor_type="qkv", name=dpa_name) + if qkv_fp8_output + else None + ) + proj_grad_input_role = ( + QuantizerRole(module_type="dpa", tensor_type="do", name=dpa_name) + if proj_fp8_grad + else None + ) + if self.attention_type == "self": + if self.input_layernorm: + self.layernorm_qkv.output_quantizer_role = qkv_output_role + else: + self.qkv.output_quantizer_role = qkv_output_role + elif self.attention_type == "cross": + if self.input_layernorm: + self.layernorm_query.output_quantizer_role = qkv_output_role + else: + self.query_layer.output_quantizer_role = qkv_output_role + self.key_value.output_quantizer_role = qkv_output_role + self.proj.grad_input_quantizer_role = proj_grad_input_role + def _create_qk_norm_modules( self, qk_norm_type: Optional[str], @@ -795,6 +832,8 @@ def forward( # Proj Gemm: match DPA output except for Float8CurrentScaling proj_fp8_grad = dpa_fp8_output and not float8_current_scaling + self._update_output_quantizer_roles(qkv_fp8_output, proj_fp8_grad) + layernorm_output = None if self.attention_type == "self": # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] diff --git a/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py b/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py new file mode 100644 index 0000000000..5c563da8a9 --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py @@ -0,0 +1,106 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Quantizer factory examples. + +Demonstrates how to use the ``CustomRecipe`` + ``qfactory`` interface to apply +*different* quantization recipes to different module/tensor types/instances within the same model. + +Usage:: + + from transformer_engine.common.recipe import CustomRecipe + from transformer_engine.pytorch.quantization import autocast + from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( + nvfp4_linear_mxfp8_grouped_linear_factory, + ) + + recipe = CustomRecipe(qfactory=nvfp4_linear_mxfp8_grouped_linear_factory) + with autocast(recipe=recipe): + output = model(input) +""" + +from __future__ import annotations + +from typing import Optional + +import transformer_engine_torch as tex + +from transformer_engine.pytorch.quantization import QuantizerRole + + +def nvfp4_linear_mxfp8_grouped_linear_factory( + role: Optional[QuantizerRole], +): + """Quantizer factory: NVFP4 for ``Linear``, MXFP8 for ``GroupedLinear``. + + Dispatch logic: + * ``role.module_type == "grouped_linear"`` -> MXFP8 (E4M3, block-32) + * everything else (``"linear"`` or unknown) -> NVFP4 (E2M1) + + NVFP4 settings follow the built-in ``NVFP4BlockScaling`` defaults: + * Weights: 2D quantization (16x16), no RHT, no stochastic rounding + * Inputs: 1D quantization, RHT enabled, no stochastic rounding + * Grads: 1D quantization, RHT enabled, stochastic rounding enabled + """ + is_grouped_linear = role is not None and role.module_type == "grouped_linear" + + if is_grouped_linear: + return _make_mxfp8_quantizer() + + return _make_nvfp4_quantizer(role) + + +def _make_mxfp8_quantizer(): + """Return an MXFP8 quantizer with default settings (E4M3, block-32, E8M0 scales).""" + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + return MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + ) + + +def _make_nvfp4_quantizer(role: Optional[QuantizerRole]): + """Return an NVFP4 quantizer configured per tensor role. + + Mirrors :class:`NVFP4BlockScaling` recipe defaults. + """ + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + is_linear = role is not None and role.module_type == "linear" + is_weight = is_linear and role.tensor_type == "weight" + is_grad = is_linear and role.tensor_type == "grad_output" + + if is_weight: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=True, + stochastic_rounding=False, + with_random_sign_mask=True, + ) + + if is_grad: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=True, + with_random_sign_mask=True, + ) + + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=True, + ) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py b/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py new file mode 100644 index 0000000000..0823a4c7bf --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py @@ -0,0 +1,160 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Quantizer factory examples using real silicon quantizers. + +Each factory below replicates the behaviour of built-in TE recipe but via the +``CustomRecipe`` + ``qfactory`` interface. This is useful when you want to +start from a known-good recipe and then selectively override quantizer settings +for specific layers / tensor types. + +Usage (any factory):: + + from transformer_engine.common.recipe import CustomRecipe + from transformer_engine.pytorch.quantization import autocast + from transformer_engine.pytorch.custom_recipes.quantization_recipes_base import ( + nvfp4_quantizer_factory, + ) + + recipe = CustomRecipe(qfactory=nvfp4_quantizer_factory) + with autocast(recipe=recipe): + output = model(input) +""" + +from __future__ import annotations + +from typing import Optional + +import torch +import transformer_engine_torch as tex + +from transformer_engine.pytorch.quantization import QuantizerRole + + +def current_scaling_quantizer_factory( + role: Optional[QuantizerRole], +) -> "Float8CurrentScalingQuantizer": + """Factory that mirrors :class:`Float8CurrentScaling` recipe defaults. + + * Forward tensors (input, weight) → E4M3 + * Backward tensors (grad_output) → E5M2 + """ + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + ) + + is_backward = role is not None and role.tensor_type == "grad_output" + fp8_dtype = tex.DType.kFloat8E5M2 if is_backward else tex.DType.kFloat8E4M3 + + return Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, + device=torch.device("cuda"), + force_pow_2_scales=False, # constrain scale to powers of 2 + amax_epsilon=0.0, # clamp amax from below to avoid div-by-zero + ) + + +def mxfp8_quantizer_factory( + role: Optional[QuantizerRole], +) -> "MXFP8Quantizer": + """Factory that mirrors :class:`MXFP8BlockScaling` recipe defaults. + + * E4M3 by default for all tensors + * Block size 32, power-of-2 (E8M0) scales + """ + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + return MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + ) + + +def float8_block_scaling_quantizer_factory( + role: Optional[QuantizerRole], +) -> "Float8BlockQuantizer": + """Factory that mirrors :class:`Float8BlockScaling` recipe defaults. + + * E4M3 by default for all tensors + * Weights use 2D block scaling, everything else uses 1D + * Power-of-2 scales by default + """ + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + ) + + is_weight = ( + role is not None + and role.module_type in ("linear", "grouped_linear") + and role.tensor_type == "weight" + ) + block_scaling_dim = 2 if is_weight else 1 + + return Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, # clamp amax from below to avoid div-by-zero + force_pow_2_scales=True, + block_scaling_dim=block_scaling_dim, # 1 = 1D (1×128), 2 = 2D (128×128) + ) + + +def nvfp4_quantizer_factory( + role: Optional[QuantizerRole], +) -> "NVFP4Quantizer": + """Factory that mirrors :class:`NVFP4BlockScaling` recipe defaults. + + * All tensors quantized to E2M1 (FP4) + * Weights: 2D quantization (16x16 blocks), no RHT, no stochastic rounding + * Inputs: 1D quantization, RHT enabled, no stochastic rounding + * Grads: 1D quantization, RHT enabled, stochastic rounding enabled + + Quantizer knobs: + fp4_dtype - E2M1 (only supported format) + with_rht - randomized Hadamard transform (smooths outliers) + with_post_rht_amax - recompute amax after RHT (should match with_rht) + with_2d_quantization - 16x16 2D blocks (vs 1x16 1D) + stochastic_rounding - probabilistic rounding to reduce quant bias + with_random_sign_mask - random sign flip in the Hadamard matrix + """ + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + is_linear = role is not None and role.module_type in ("linear", "grouped_linear") + is_weight = is_linear and role.tensor_type == "weight" + is_grad = is_linear and role.tensor_type == "grad_output" + + if is_weight: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=True, + stochastic_rounding=False, + with_random_sign_mask=True, + ) + + if is_grad: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=True, + with_random_sign_mask=True, + ) + + # For input and unknown roles + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=True, + ) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py similarity index 98% rename from transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py rename to transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py index 5bdc537e4b..0034b739cb 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py @@ -18,17 +18,18 @@ def current_scaling_ref_quantizer_factory(role): """Factory function for current scaling reference quantizer. - Usage with CustomRecipe and autocast: + Receives a :class:`~transformer_engine.pytorch.quantization.QuantizerRole`. + + Backward tensors use E5M2, everything else uses E4M3. + + Usage with CustomRecipe and autocast:: + custom_recipe = recipe.CustomRecipe(qfactory=current_scaling_ref_quantizer_factory) with autocast(recipe=custom_recipe): output = model(input) """ - if role in ("linear_input", "linear_weight"): - dtype = torch.float8_e4m3fn - elif role in ("linear_output", "linear_grad_output"): - dtype = torch.float8_e5m2 - else: - return None + is_backward = role is not None and role.tensor_type == "grad_output" + dtype = torch.float8_e5m2 if is_backward else torch.float8_e4m3fn return CurrentScalingQuantizerRef( dtype=dtype, rowwise=True, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py similarity index 98% rename from transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py rename to transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index d00d0c8b94..0b29977fb4 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -18,33 +18,32 @@ def nvfp4_ref_rht_2d_quantizer_factory(role): """ Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights). - Usage with CustomRecipe and autocast: + Receives a :class:`~transformer_engine.pytorch.quantization.QuantizerRole`. + + Usage with CustomRecipe and autocast:: + custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory) - with autocast(fp8_recipe=custom_recipe): + with autocast(recipe=custom_recipe): output = model(input) """ - if role == "linear_input": - return NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, - ) - if role == "linear_weight": + is_weight_tensor_in_gemm = ( + role is not None + and role.module_type in ("linear", "grouped_linear") + and role.tensor_type == "weight" + ) + if is_weight_tensor_in_gemm: # 2D quantization for weights in GEMM-based modules return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16), pow_2_scales=False, with_rht=False, ) - if role == "linear_grad_output": - return NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, - ) - return None + return NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ) def cast_to_fp4x2(x): diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 4858383c26..1c5176834d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -28,6 +28,7 @@ Float8BlockScalingRecipeState, NVFP4BlockScalingRecipeState, FP8GlobalStateManager, + QuantizerRole, RecipeState, ) from ..distributed import ( @@ -631,6 +632,8 @@ def __init__(self, name: Optional[str] = None) -> None: self.activation_dtype: Optional[torch.dtype] = None self.wgrad_accumulation_and_reduce_hooks = [] self.wgrad_store = None + self._output_quantizer_role: Optional[QuantizerRole] = None + self._grad_input_quantizer_role: Optional[QuantizerRole] = None if not TEDebugState.debug_enabled: TEDebugState.initialize() @@ -651,6 +654,72 @@ def module_setattr(self, name: str, value: Any) -> None: """ super().__setattr__(name, value) + @property + def output_quantizer_role(self) -> Optional[QuantizerRole]: + """Caller-configurable :class:`QuantizerRole` for the forward output quantizer. + + When set, overrides the default role used by :meth:`get_quantizer_roles` + for the forward-pass output quantizer slot. Setting this after + quantizers have been created forces their recreation on the next + forward pass. + + See also :attr:`grad_input_quantizer_role` for the backward-pass + counterpart. + """ + return self._output_quantizer_role + + @output_quantizer_role.setter + def output_quantizer_role(self, role: Optional[QuantizerRole]) -> None: + if role == self._output_quantizer_role: + return + self._output_quantizer_role = role + if self.fp8_meta_tensors_initialized: + self.fp8_meta_tensors_initialized = False + + @property + def grad_input_quantizer_role(self) -> Optional[QuantizerRole]: + """Caller-configurable :class:`QuantizerRole` for the grad-input quantizer. + + Backward-pass counterpart of :attr:`output_quantizer_role`. + """ + return self._grad_input_quantizer_role + + @grad_input_quantizer_role.setter + def grad_input_quantizer_role(self, role: Optional[QuantizerRole]) -> None: + if role == self._grad_input_quantizer_role: + return + self._grad_input_quantizer_role = role + if self.fp8_meta_tensors_initialized: + self.fp8_meta_tensors_initialized = False + + def _warn_missing_output_quantizer_role( + self, + fp8_output: bool, + fp8_grad: bool, + ) -> None: + """Warn when quantized output is requested but no consumer role is set. + + Only relevant for ``CustomRecipe`` where the ``qfactory`` dispatches + on roles. Built-in recipes ignore role metadata. + """ + recipe = FP8GlobalStateManager.get_fp8_recipe() + if not recipe.custom(): + return + if fp8_output and self._output_quantizer_role is None: + warnings.warn( + f"{type(self).__name__}: fp8_output=True but " + "output_quantizer_role is not set. The CustomRecipe qfactory " + "will receive None for the output quantizer role.", + stacklevel=3, + ) + if fp8_grad and self._grad_input_quantizer_role is None: + warnings.warn( + f"{type(self).__name__}: fp8_grad=True but " + "grad_input_quantizer_role is not set. The CustomRecipe " + "qfactory will receive None for the grad-input quantizer role.", + stacklevel=3, + ) + def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ Delayed scaling only. @@ -732,15 +801,34 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 # Initialize recipe state and quantizers + roles = self.get_quantizer_roles(fwd=fwd, num_quantizers=num_fp8_tensors) + if roles is not None: + assert ( + len(roles) == num_fp8_tensors + ), f"Recipe roles must match number of quantizers ({len(roles)=} vs {num_fp8_tensors=})" recipe_state = RecipeState.create( recipe, mode=("forward" if fwd else "backward"), num_quantizers=num_fp8_tensors, + roles=roles, ) self.fp8_meta[fp8_meta_tensor_key] = recipe_state self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """Return an ordered list of :class:`QuantizerRole` for quantizers. + + The returned list must have length `num_quantizers`. + Returning `None` means "no explicit roles". + """ + return None + def _update_weight_quantizers(self) -> None: """Update the quantizers for the weight tensors.""" weight_tensors = self._get_weight_tensors() diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b381073d78..e5a09241a1 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -23,7 +23,7 @@ _2X_ACC_WGRAD, ) from ._common import WeightGradStore -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( divide, cast_if_needed, @@ -744,6 +744,33 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``GroupedLinear``. + + For grouped GEMMs we repeat the same pattern for each GEMM in + order. The output (fwd) and grad-input (bwd) slots default to + ``None`` (unknown consumer). Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. + """ + name = self.name or "" + if fwd: + base = [ + QuantizerRole(module_type="grouped_linear", tensor_type="input", name=name), + QuantizerRole(module_type="grouped_linear", tensor_type="weight", name=name), + self._output_quantizer_role, + ] + else: + base = [ + QuantizerRole(module_type="grouped_linear", tensor_type="grad_output", name=name), + self._grad_input_quantizer_role, + ] + return [base[i % len(base)] for i in range(num_quantizers)] + def make_grouped_weights(self, defer_init=False) -> None: """ Convert parameters into a GroupedTensor and re-register them as parameters. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 27632db15b..d55c4c49c1 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -26,7 +26,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( assert_dim_for_fp8_exec, assert_dim_for_all_gather, @@ -1413,6 +1413,32 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: elif recipe.nvfp4(): self._customize_quantizers_nvfp4(fwd, recipe) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``LayerNormLinear``. + + The output (fwd) and grad-input (bwd) slots default to ``None`` + (unknown consumer). Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. + """ + name = self.name or "" + if fwd: + base = [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + self._output_quantizer_role, + ] + else: + base = [ + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), + self._grad_input_quantizer_role, + ] + return [base[i % len(base)] for i in range(num_quantizers)] + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1611,6 +1637,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): if not self.fp8: return [None] * 6 + + self._warn_missing_output_quantizer_role(fp8_output, fp8_grad) + grad_input_quantizer = None grad_weight_quantizer = None grad_output_quantizer = None diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b8823e46ca..c6c8e6758d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -27,7 +27,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -1980,6 +1980,44 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: elif recipe.nvfp4(): self._customize_quantizers_nvfp4(fwd, recipe) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``LayerNormMLP``. + + Each internal GEMM (fc1, fc2) gets a distinct name suffix so that + custom-recipe factories can target them individually. + + The module's final output (fc2 fwd) and final grad (fc1 bwd) + slots default to ``None`` (unknown consumer). Set + :attr:`output_quantizer_role` / :attr:`grad_input_quantizer_role` + to provide consumer identity. Internal boundaries use fixed + roles with known consumer identity. + """ + base_name = self.name or "" + fc1_name = f"{base_name}.fc1" if base_name else "fc1" + fc2_name = f"{base_name}.fc2" if base_name else "fc2" + if fwd: + base = [ + QuantizerRole(module_type="linear", tensor_type="input", name=fc1_name), + QuantizerRole(module_type="linear", tensor_type="weight", name=fc1_name), + QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name), + QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name), + QuantizerRole(module_type="linear", tensor_type="weight", name=fc2_name), + self._output_quantizer_role, + ] + else: + base = [ + QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name), + self._grad_input_quantizer_role, + QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc2_name), + QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name), + ] + return [base[i % len(base)] for i in range(num_quantizers)] + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -2187,6 +2225,9 @@ def forward( return out def _get_quantizers(self, fp8_output, is_grad_enabled): + if self.fp8: + self._warn_missing_output_quantizer_role(fp8_output, False) + ( fc1_input_quantizer, fc1_output_quantizer, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a55429d33d..adf44208e5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -25,7 +25,7 @@ _2X_ACC_WGRAD, ) from ._common import noop_cat, WeightGradStore -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( cast_if_needed, clear_tensor_data, @@ -1309,6 +1309,32 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``Linear``. + + The output (fwd) and grad-input (bwd) slots default to ``None`` + (unknown consumer). Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. + """ + name = self.name or "" + if fwd: + base = [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + self._output_quantizer_role, + ] + else: + base = [ + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), + self._grad_input_quantizer_role, + ] + return [base[i % len(base)] for i in range(num_quantizers)] + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) @@ -1479,6 +1505,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): if not self.fp8: return [None] * 6 + + self._warn_missing_output_quantizer_role(fp8_output, fp8_grad) + grad_input_quantizer = None grad_weight_quantizer = None grad_output_quantizer = None diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 48376a297f..dfb1b7f741 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -19,7 +19,7 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) -from ...quantization import FP8GlobalStateManager, Recipe +from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe from ...module.base import ( _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -270,6 +270,21 @@ def num_quantizers(self, mode: str) -> int: return 1 return 0 + def get_quantizer_roles(self, mode: str) -> Optional[list[QuantizerRole]]: + name = getattr(self, "name", "") or "" + if mode == "forward": + # BasicLinear owns input and weight quantizers. + # Output quantizer is provided by the next op (as its input quantizer). + return [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + ] + if mode == "backward": + # BasicLinear owns grad_output quantizer. + # Grad_input quantizer is provided by the previous op (as its grad_output quantizer). + return [QuantizerRole(module_type="linear", tensor_type="grad_output", name=name)] + return None + def reset_parameters(self) -> None: """Initialize parameter buffers and values""" diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 54b3f00117..3a59e3b229 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -16,6 +16,7 @@ from transformer_engine.common.recipe import Recipe from ..quantization import ( FP8GlobalStateManager, + QuantizerRole, RecipeState, autocast, ) @@ -209,6 +210,15 @@ def num_quantizers( """ return 0 + def get_quantizer_roles(self, mode: str) -> Optional[list[QuantizerRole]]: + """Return an ordered list of :class:`QuantizerRole` for quantizers. + + The returned list must be aligned with the internal quantizer ordering and + must have length ``num_quantizers(mode)`` for supported modes. + Returning ``None`` means "no explicit roles". + """ + return None + def get_input_quantizer(self) -> Optional[Quantizer]: if self.num_quantizers("forward") > 0: return self.get_quantizer("forward", 0) @@ -268,10 +278,17 @@ def reset_recipe_state( ) # Construct quantization recipe state + roles = self.get_quantizer_roles(mode) + if roles is not None: + assert len(roles) == num_quantizers, ( + "Recipe roles must match number of quantizers " + f"({len(roles)=} vs {num_quantizers=})" + ) recipe_state = RecipeState.create( recipe, mode=mode, num_quantizers=num_quantizers, + roles=roles, ) fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( forward=(mode == "forward"), diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..ff3c0b9a1c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import dataclasses import itertools import functools import warnings @@ -41,9 +42,53 @@ "is_nvfp4_available", "get_default_recipe", "get_align_size_for_quantization", + "QuantizerRole", ] +@dataclasses.dataclass(frozen=True) +class QuantizerRole: + """Identity of a tensor slot requesting a quantizer. + + TE modules populate all fields they know about. + User factories inspect only the fields they care about. + + .. warning:: + **EXPERIMENTAL**: QuantizerRole is experimental, still under active development, + and the API is subject to change without notice. Use at your own risk. + + Fields + ------ + module_type : str + Module type that emits this role, e.g. `"linear"`, `"grouped_linear"`, `"dpa"`. + Empty string when not provided. + tensor_type : str + What tensor is being quantized, in the module's own vocabulary. + Linear modules: `"input"`, `"weight"`, `"grad_output"`, etc. + DPA: `"qkv"`, `"s"`, etc. + Empty string when not provided. + name : str + Caller-provided module instance name (e.g. set by the training + framework), e.g. + `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`, `"linear_39"`. + Empty string when not provided. + """ + + module_type: str = "" + tensor_type: str = "" + name: str = "" + + def __str__(self) -> str: + parts = [] + if self.module_type: + parts.append(f"module_type={self.module_type}") + if self.tensor_type: + parts.append(f"tensor_type={self.tensor_type}") + if self.name: + parts.append(f"name={self.name}") + return "|".join(parts) if parts else "QuantizerRole()" + + @functools.lru_cache(maxsize=None) def check_fp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" @@ -992,6 +1037,7 @@ def create( mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, + roles: Optional[list[QuantizerRole]] = None, ) -> RecipeState: """Factory method to create the state for a quantization recipe @@ -1005,6 +1051,8 @@ def create( Number of quantizers to create state for. device: torch.device, default = default CUDA device Device for quantized tensors. + roles: list of QuantizerRole, optional + Semantic roles for each quantizer slot. Returns ------- @@ -1028,12 +1076,15 @@ def create( cls = CustomRecipeState else: raise ValueError(f"{recipe.__class__.__name__} is not supported") - return cls( + state = cls( recipe, mode=mode, num_quantizers=num_quantizers, device=device, ) + # Optional QuantizerRole objects + state.roles = roles + return state @abc.abstractmethod def make_quantizers(self) -> list: @@ -1381,26 +1432,24 @@ def __init__( def make_quantizers(self) -> list: qfactory = self.recipe.qfactory - out = [] - # TODO(negvet): make_quantizers() should take roles from the operation - # Hardcode linear-specific roles for now - roles: List[str] - if self.mode == "forward": - roles = [ - ("linear_input", "linear_weight", "linear_output")[i % 3] - for i in range(self.num_quantizers) - ] - elif self.mode == "backward": - roles = [ - ("linear_grad_output", "linear_grad_input")[i % 2] - for i in range(self.num_quantizers) - ] - else: - roles = ["unknown"] * self.num_quantizers + roles: List[QuantizerRole] = getattr(self, "roles", None) + if roles is None: + warnings.warn( + "CustomRecipeState: no QuantizerRole list provided by the module/op. " + "Falling back to bare QuantizerRole() defaults. " + "Override get_quantizer_roles() to provide meaningful roles.", + stacklevel=2, + ) + roles = [QuantizerRole() for _ in range(self.num_quantizers)] + if len(roles) != self.num_quantizers: + raise ValueError( + "CustomRecipeState requires roles to match num_quantizers " + f"({len(roles)=} vs {self.num_quantizers=})" + ) + out = [] for i in range(self.num_quantizers): - # Get quantizer from the user defined factory quantizer = qfactory(roles[i]) out.append(quantizer) return out