From 08244c74ccbd081a023ecec82e0844176e556898 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 24 Feb 2026 15:33:32 +0000 Subject: [PATCH 1/2] init Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 49 + transformer_engine/pytorch/distributed.py | 18 +- transformer_engine/pytorch/module/_common.py | 3 +- transformer_engine/pytorch/module/base.py | 279 ++-- .../pytorch/module/fx_serialize.py | 190 +++ transformer_engine/pytorch/module/linear.py | 1151 +++++++++++++---- .../pytorch/quantized_tensor.py | 42 +- .../pytorch/tensor/float8_tensor.py | 68 +- .../pytorch/tensor/mxfp8_tensor.py | 35 + 9 files changed, 1446 insertions(+), 389 deletions(-) create mode 100644 tests/pytorch/test_torch_compile.py create mode 100644 transformer_engine/pytorch/module/fx_serialize.py diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py new file mode 100644 index 0000000000..29b2ca0a34 --- /dev/null +++ b/tests/pytorch/test_torch_compile.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +from contextlib import nullcontext + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling + + +@pytest.mark.skipif(torch.__version__ < "2", reason="torch.compile not available") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for TE Linear") +@pytest.mark.parametrize( + "use_fp8,with_backward", + [ + (False, False), + (False, True), + (True, False), + (True, True), + ], + ids=["fp16_fwd", "fp16_fwd_bwd", "fp8_fwd", "fp8_fwd_bwd"], +) +def test_te_linear_fullgraph_compile(use_fp8, with_backward): + if use_fp8: + fp8_available, reason = te.is_fp8_available(return_reason=True) + if not fp8_available: + pytest.skip(reason) + + model = te.Linear(128, 64, device="cuda").to(dtype=torch.bfloat16) + for param in model.parameters(): + param.requires_grad_(False) + x = torch.randn(16, 128, device="cuda", dtype=torch.bfloat16, requires_grad=with_backward) + + fp8_recipe = DelayedScaling() if use_fp8 else None + maybe_fp8 = te.autocast(enabled=True, recipe=fp8_recipe) if use_fp8 else nullcontext() + + with maybe_fp8: + if use_fp8: + model.init_fp8_metadata() + compiled_model = torch.compile(model, fullgraph=True) + out = compiled_model(x) + assert out.shape == (16, 64) + if with_backward: + out.sum().backward() + + if with_backward: + assert x.grad is not None diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index f269e21b8c..26a18c5842 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1938,19 +1938,29 @@ def _fsdp_scatter_tensors( fsdp_group: dist_group_type, *tensors: torch.Tensor, ): - shapes = [] + shapes = _collect_fsdp_tensor_shapes(*tensors) if fsdp_group is not None: for t in tensors: if isinstance(t, torch.Tensor): targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t] for target in targets: - shapes.append(target.data.shape) safely_set_viewless_tensor_data( target, split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True), ) - else: - shapes.append(None) + return shapes + + +def _collect_fsdp_tensor_shapes(*tensors: torch.Tensor) -> List[Optional[Tuple[int, ...]]]: + """Collect tensor data shapes in the same order used by FSDP scatter/gather helpers.""" + shapes: List[Optional[Tuple[int, ...]]] = [] + for t in tensors: + if isinstance(t, torch.Tensor): + targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t] + for target in targets: + shapes.append(target.data.shape) + else: + shapes.append(None) return shapes diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index bf5a230e84..617a7b531b 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -9,6 +9,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union import torch +from torch._opaque_base import OpaqueBase from .. import cpp_extensions as tex from ..constants import TE_DType @@ -195,7 +196,7 @@ def __post_init__(self): self.init_fn = get_default_init_method() -class WeightGradStore: +class WeightGradStore(OpaqueBase): """ A class to manage weight gradient storage and computation in Transformer modules. This class enables split backward propagation for better memory efficiency. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 09b12afa21..c90a7153f7 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -57,6 +57,8 @@ from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled +from torch._opaque_base import OpaqueBase +from torch._library.opaque_object import register_opaque_type __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"] @@ -603,6 +605,150 @@ def fill_userbuffers_buffer_for_all_gather( raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})") +class WeightWorkspace(OpaqueBase): + """Opaque reference object holding the FP8 weight workspace cache. + + This object is registered as a PyTorch opaque reference type so that + it can be passed through ``torch.library.custom_op`` schemas. Inside + the custom-op forward the cached quantized weight is retrieved (and + updated when necessary) via :meth:`get_weight_workspace`. + """ + + def __init__(self) -> None: + self._workspaces: Dict[str, QuantizedTensor] = {} + + # ------------------------------------------------------------------ + # get_weight_workspace – moved from TransformerEngineBaseModule + # ------------------------------------------------------------------ + def get_weight_workspace( + self, + *, + tensor: Optional[torch.Tensor] = None, + quantizer: Optional[Quantizer] = None, + cache_name: Optional[str] = None, + update_workspace: bool = True, + skip_update_flag: Optional[torch.Tensor] = None, + fsdp_group: Optional[dist_group_type] = None, + workspace_dtype: Optional[torch.dtype] = None, + ) -> QuantizedTensor: + """Get workspace buffer for weights and maybe update its values + + The workspace buffer may be cached for future function calls. + + Parameters + ---------- + tensor : torch.Tensor, optional + Values to copy into workspace. Required if the workspace + is being constructed or updated. + quantizer: Quantizer, optional + Quantizer used to cast the weights. Required if the + workspace is being constructed or updated. + cache_name: str, optional + Key for caching. + update_workspace: bool, default = True + Update workspace with values from `tensor`. + skip_update_flag: torch.Tensor, optional + GPU flag to skip updating the workspace. Take precedence + over `update_workspace` if provided. + fsdp_group: bool, default = None + FSDP process group that the weights are distributed over. + workspace_dtype: torch.dtype, default = None + If weight workspace contains high-precision tensor - for example + for debug quantization, this is dtype of the tensor. + """ + + # Handle case where weights are already quantized + # Note: Make sure weights have required usages, but do not + # destroy unnecessary usages since they may be used later. + if isinstance(tensor, QuantizedTensor): + update_rowwise_usage = True if quantizer.rowwise_usage else None + update_columnwise_usage = True if quantizer.columnwise_usage else None + tensor.update_usage( + rowwise_usage=update_rowwise_usage, + columnwise_usage=update_columnwise_usage, + ) + return tensor + + # Try getting workspace from cache + out = None + if cache_name is not None: + out = self._workspaces.get(cache_name, None) + + # Reset cache if workspace is invalid + if out is not None and quantizer is not None: + reset_cache = False + if isinstance(out, Float8TensorStorage): + if ( + not is_non_tn_fp8_gemm_supported() + and quantizer.columnwise_usage + and out._transpose is None + ): + reset_cache = True + elif isinstance(out, MXFP8TensorStorage): + if quantizer.rowwise_usage and out._rowwise_data is None: + reset_cache = True + elif quantizer.columnwise_usage and out._columnwise_data is None: + reset_cache = True + elif isinstance(out, NVFP4TensorStorage): + if quantizer.rowwise_usage and out._rowwise_data is None: + reset_cache = True + elif quantizer.columnwise_usage and out._columnwise_data is None: + reset_cache = True + if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): + reset_cache = True + if reset_cache: + out = None + del self._workspaces[cache_name] + + # Gather cached Fp8 workspace if it's distributed + # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work + # for models initialized with Fp8 primary weights. + if ( + out is not None + and tensor is not None + and fsdp_group is not None + and out.data.shape != tensor.data.shape + ): + _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) + + # Construct workspace if needed + if out is None: + if tensor is None or quantizer is None: + raise ValueError( + "tensor and quantizer kwargs must be provided to construct FP8 workspace" + ) + + if cache_name is not None: + # Ensure the tensor in the cache is an instance of torch.Tensor, + # as it persists beyond a single forward pass. + # Setting internal=True would cause the data to be removed in prepare_for_saving(...). + quantizer_internal = quantizer.internal + quantizer.internal = False + out = quantizer.quantize(tensor, dtype=workspace_dtype) + if cache_name is not None: + quantizer.internal = quantizer_internal + + # Update cache + if cache_name is not None: + self._workspaces[cache_name] = out + return out + + # Update workspace if needed + if skip_update_flag is not None: + update_workspace = True + if update_workspace: + if tensor is None: + raise ValueError("tensor kwarg must be provided to update FP8 workspace") + if hasattr(out, "quantize_"): + out.quantize_(tensor, noop_flag=skip_update_flag) + else: + tex.quantize(tensor, quantizer, out, skip_update_flag) + return out + + +register_opaque_type(WeightWorkspace, typ="reference") + + class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" @@ -627,7 +773,8 @@ def __init__(self, name: Optional[str] = None) -> None: self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() self.fsdp_wrapped = False self.fsdp_group = None - self._fp8_workspaces: Dict[str, QuantizedTensor] = {} + self.weight_workspace = WeightWorkspace() + self._fp8_workspaces = self.weight_workspace._workspaces # backward-compat alias self.activation_dtype: Optional[torch.dtype] = None self.wgrad_accumulation_and_reduce_hooks = [] self.wgrad_store = None @@ -1034,6 +1181,7 @@ def prepare_forward( num_gemms: int = 1, allow_non_contiguous: bool = False, allow_different_data_and_param_types: bool = False, + defer_fp8_global_buffer_update: bool = False, ) -> torch.Tensor: """Checks and prepares for FWD execution.""" self.fast_setattr( @@ -1063,7 +1211,7 @@ def prepare_forward( "necessary when using sequence parallelism with FP8." ) - if not FP8GlobalStateManager.fp8_graph_capturing(): + if not defer_fp8_global_buffer_update and not FP8GlobalStateManager.fp8_graph_capturing(): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) # Activation recomputation is used and this is the first forward phase. @@ -1343,130 +1491,9 @@ def clear(self): def forward(self): """Needs override.""" - def get_weight_workspace( - self, - *, - tensor: Optional[torch.Tensor] = None, - quantizer: Optional[Quantizer] = None, - cache_name: Optional[str] = None, - update_workspace: bool = True, - skip_update_flag: Optional[torch.Tensor] = None, - fsdp_group: Optional[dist_group_type] = None, - workspace_dtype: Optional[torch.dtype] = None, - ) -> QuantizedTensor: - """Get workspace buffer for weights and maybe update its values - - The workspace buffer may be cached for future function calls. - - Parameters - ---------- - tensor : torch.Tensor, optional - Values to copy into workspace. Required if the workspace - is being constructed or updated. - quantizer: Quantizer, optional - Quantizer used to cast the weights. Required if the - workspace is being constructed or updated. - cache_name: str, optional - Key for caching. - update_workspace: bool, default = True - Update workspace with values from `tensor`. - skip_update_flag: torch.Tensor, optional - GPU flag to skip updating the workspace. Take precedence - over `update_workspace` if provided. - fsdp_group: bool, default = None - FSDP process group that the weights are distributed over. - workspace_dtype: torch.dtype, default = None - If weight workspace contains high-precision tensor - for example - for debug quantization, this is dtype of the tensor. - """ - - # Handle case where weights are already quantized - # Note: Make sure weights have required usages, but do not - # destroy unnecessary usages since they may be used later. - if isinstance(tensor, QuantizedTensor): - update_rowwise_usage = True if quantizer.rowwise_usage else None - update_columnwise_usage = True if quantizer.columnwise_usage else None - tensor.update_usage( - rowwise_usage=update_rowwise_usage, - columnwise_usage=update_columnwise_usage, - ) - return tensor - - # Try getting workspace from cache - out = None - if cache_name is not None: - out = self._fp8_workspaces.get(cache_name, None) - - # Reset cache if workspace is invalid - if out is not None and quantizer is not None: - reset_cache = False - if isinstance(out, Float8TensorStorage): - if ( - not is_non_tn_fp8_gemm_supported() - and quantizer.columnwise_usage - and out._transpose is None - ): - reset_cache = True - elif isinstance(out, MXFP8TensorStorage): - if quantizer.rowwise_usage and out._rowwise_data is None: - reset_cache = True - elif quantizer.columnwise_usage and out._columnwise_data is None: - reset_cache = True - elif isinstance(out, NVFP4TensorStorage): - if quantizer.rowwise_usage and out._rowwise_data is None: - reset_cache = True - elif quantizer.columnwise_usage and out._columnwise_data is None: - reset_cache = True - if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): - reset_cache = True - if reset_cache: - out = None - del self._fp8_workspaces[cache_name] - - # Gather cached Fp8 workspace if it's distributed - # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work - # for models initialized with Fp8 primary weights. - if ( - out is not None - and tensor is not None - and fsdp_group is not None - and out.data.shape != tensor.data.shape - ): - _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) - - # Construct workspace if needed - if out is None: - if tensor is None or quantizer is None: - raise ValueError( - "tensor and quantizer kwargs must be provided to construct FP8 workspace" - ) - - if cache_name is not None: - # Ensure the tensor in the cache is an instance of torch.Tensor, - # as it persists beyond a single forward pass. - # Setting internal=True would cause the data to be removed in prepare_for_saving(...). - quantizer_internal = quantizer.internal - quantizer.internal = False - out = quantizer.quantize(tensor, dtype=workspace_dtype) - if cache_name is not None: - quantizer.internal = quantizer_internal - - # Update cache - if cache_name is not None: - self._fp8_workspaces[cache_name] = out - return out - - # Update workspace if needed - if skip_update_flag is not None: - update_workspace = True - if update_workspace: - if tensor is None: - raise ValueError("tensor kwarg must be provided to update FP8 workspace") - if hasattr(out, "quantize_"): - out.quantize_(tensor, noop_flag=skip_update_flag) - else: - tex.quantize(tensor, quantizer, out, skip_update_flag) - return out + def get_weight_workspace(self, **kwargs) -> QuantizedTensor: + """Delegate to self.weight_workspace.get_weight_workspace(...).""" + return self.weight_workspace.get_weight_workspace(**kwargs) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs diff --git a/transformer_engine/pytorch/module/fx_serialize.py b/transformer_engine/pytorch/module/fx_serialize.py new file mode 100644 index 0000000000..6fb6398828 --- /dev/null +++ b/transformer_engine/pytorch/module/fx_serialize.py @@ -0,0 +1,190 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Utilities for converting runtime values into FX-evaluable expressions.""" + +from __future__ import annotations + +import enum +import pickle +from functools import singledispatch +from typing import Dict, Tuple + +import torch +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage + + +class _SerializeContext: + def __init__(self) -> None: + self.globals: Dict[str, object] = {} + + def add_global(self, name: str, value: object) -> None: + existing = self.globals.get(name) + if existing is not None and existing is not value: + raise RuntimeError(f"FX serializer global name collision for '{name}'") + self.globals[name] = value + +@singledispatch +def _convert(value: object, ctx: _SerializeContext) -> str: + try: + payload = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL).hex() + except Exception as err: + raise TypeError( + f"Unsupported value for FX serialization: {type(value)!r}. " + "Register a dedicated converter for this subtype." + ) from err + ctx.add_global("pickle", pickle) + return f"pickle.loads(bytes.fromhex({payload!r}))" + + +@_convert.register(type(None)) +def _(value: None, ctx: _SerializeContext) -> str: + del value, ctx + return "None" + + +@_convert.register(bool) +def _(value: bool, ctx: _SerializeContext) -> str: + del ctx + return "True" if value else "False" + + +@_convert.register(int) +def _(value: int, ctx: _SerializeContext) -> str: + del ctx + return repr(value) + + +@_convert.register(float) +def _(value: float, ctx: _SerializeContext) -> str: + del ctx + return repr(value) + + +@_convert.register(str) +def _(value: str, ctx: _SerializeContext) -> str: + del ctx + return repr(value) + + +@_convert.register(tuple) +def _(value: tuple, ctx: _SerializeContext) -> str: + items = [_convert(item, ctx) for item in value] + if len(items) == 1: + return f"({items[0]},)" + return f"({', '.join(items)})" + + +@_convert.register(list) +def _(value: list, ctx: _SerializeContext) -> str: + items = [_convert(item, ctx) for item in value] + return f"[{', '.join(items)}]" + + +@_convert.register(dict) +def _(value: dict, ctx: _SerializeContext) -> str: + items = [f"{_convert(k, ctx)}: {_convert(v, ctx)}" for k, v in value.items()] + return f"{{{', '.join(items)}}}" + + +@_convert.register(torch.dtype) +def _(value: torch.dtype, ctx: _SerializeContext) -> str: + ctx.add_global("torch", torch) + return str(value) + + +@_convert.register(torch.device) +def _(value: torch.device, ctx: _SerializeContext) -> str: + ctx.add_global("torch", torch) + return f"torch.device({str(value)!r})" + + +@_convert.register(torch.Size) +def _(value: torch.Size, ctx: _SerializeContext) -> str: + ctx.add_global("torch", torch) + items = [_convert(v, ctx) for v in tuple(value)] + return f"torch.Size([{', '.join(items)}])" + + +@_convert.register(enum.Enum) +def _(value: enum.Enum, ctx: _SerializeContext) -> str: + enum_cls = type(value) + ctx.add_global(enum_cls.__name__, enum_cls) + return f"{enum_cls.__name__}.{value.name}" + + +def _convert_or_none(value: object, ctx: _SerializeContext) -> str: + try: + return _convert(value, ctx) + except TypeError: + return "None" + + +@_convert.register(QuantizedTensorStorage) +def _(value: QuantizedTensorStorage, ctx: _SerializeContext) -> str: + cls = type(value) + cls_name = cls.__name__ + ctx.add_global(cls_name, cls) + + if cls_name == "Float8TensorStorage": + return ( + f"{cls_name}(" + f"data={_convert(getattr(value, '_data'), ctx)}, " + f"fp8_scale_inv={_convert(getattr(value, '_scale_inv'), ctx)}, " + f"fp8_dtype={_convert(getattr(value, '_fp8_dtype'), ctx)}, " + f"data_transpose={_convert(getattr(value, '_transpose'), ctx)}, " + f"quantizer={_convert_or_none(getattr(value, '_quantizer'), ctx)}" + f")" + ) + + if cls_name == "MXFP8TensorStorage": + return ( + f"{cls_name}(" + f"rowwise_data={_convert(getattr(value, '_rowwise_data'), ctx)}, " + f"rowwise_scale_inv={_convert(getattr(value, '_rowwise_scale_inv'), ctx)}, " + f"columnwise_data={_convert(getattr(value, '_columnwise_data'), ctx)}, " + f"columnwise_scale_inv={_convert(getattr(value, '_columnwise_scale_inv'), ctx)}, " + f"fp8_dtype={_convert(getattr(value, '_fp8_dtype'), ctx)}, " + f"quantizer={_convert_or_none(getattr(value, '_quantizer'), ctx)}, " + f"with_gemm_swizzled_scales={_convert(getattr(value, '_with_gemm_swizzled_scales'), ctx)}" + f")" + ) + + if cls_name == "Float8BlockwiseQTensorStorage": + return ( + f"{cls_name}(" + f"rowwise_data={_convert(getattr(value, '_rowwise_data'), ctx)}, " + f"rowwise_scale_inv={_convert(getattr(value, '_rowwise_scale_inv'), ctx)}, " + f"columnwise_data={_convert(getattr(value, '_columnwise_data'), ctx)}, " + f"columnwise_scale_inv={_convert(getattr(value, '_columnwise_scale_inv'), ctx)}, " + f"fp8_dtype={_convert(getattr(value, '_fp8_dtype'), ctx)}, " + f"quantizer={_convert_or_none(getattr(value, '_quantizer'), ctx)}, " + f"is_2D_scaled={_convert(getattr(value, '_is_2D_scaled'), ctx)}" + f")" + ) + + if cls_name == "NVFP4TensorStorage": + return ( + f"{cls_name}(" + f"rowwise_data={_convert(getattr(value, '_rowwise_data'), ctx)}, " + f"rowwise_scale_inv={_convert(getattr(value, '_rowwise_scale_inv'), ctx)}, " + f"columnwise_data={_convert(getattr(value, '_columnwise_data'), ctx)}, " + f"columnwise_scale_inv={_convert(getattr(value, '_columnwise_scale_inv'), ctx)}, " + f"amax_rowwise={_convert(getattr(value, '_amax_rowwise'), ctx)}, " + f"amax_columnwise={_convert(getattr(value, '_amax_columnwise'), ctx)}, " + f"fp4_dtype={_convert(getattr(value, '_fp4_dtype'), ctx)}, " + f"quantizer={_convert_or_none(getattr(value, '_quantizer'), ctx)}, " + f"with_gemm_swizzled_scales={_convert(getattr(value, '_with_gemm_swizzled_scales'), ctx)}" + f")" + ) + + # Fall back to generic object serializer for unknown storage subclasses. + return _convert.dispatch(object)(value, ctx) + + +def convert_to_fx(value: object) -> Tuple[str, Dict[str, object]]: + """Build FX expression + globals that reconstruct value from scratch.""" + ctx = _SerializeContext() + expr = _convert(value, ctx) + return expr, ctx.globals diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..720b4fe24e 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -20,10 +20,12 @@ get_dummy_wgrad, get_ub, TransformerEngineBaseModule, + WeightWorkspace, _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from .fx_serialize import convert_to_fx from ._common import noop_cat, WeightGradStore from ..quantization import FP8GlobalStateManager from ..utils import ( @@ -48,6 +50,7 @@ gather_along_first_dim, is_fp8_activation_recompute_enabled, in_fp8_activation_recompute_phase, + _collect_fsdp_tensor_shapes, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -75,36 +78,155 @@ mark_activation_offload, ) from ...debug.pytorch.debug_state import TEDebugState +from torch._opaque_base import OpaqueBase +from torch._library.opaque_object import MemberType, is_opaque_type, register_opaque_type __all__ = ["Linear"] -class _Linear(torch.autograd.Function): +class _LinearNonTensorArgs(OpaqueBase): + """Opaque value object wrapping all non-tensor arguments for the TE linear custom op. + + This allows passing complex Python objects (quantizers, process groups, + weight workspaces, etc.) through the ``torch.library.custom_op`` schema + which only supports tensors and simple scalars natively. + """ + + def __init__(self, args: Tuple) -> None: + self.args = args + + def __eq__(self, other) -> bool: + if not isinstance(other, _LinearNonTensorArgs): + return NotImplemented + self_expr, _ = convert_to_fx(self.args) + other_expr, _ = convert_to_fx(other.args) + return self_expr == other_expr + + def __hash__(self) -> int: + args_expr, _ = convert_to_fx(self.args) + return hash(args_expr) + + def __fx_repr__(self) -> Tuple[str, Dict[str, object]]: + args_expr, args_globals = convert_to_fx(self.args) + return ( + f"_LinearNonTensorArgs({args_expr})", + {"_LinearNonTensorArgs": _LinearNonTensorArgs, **args_globals}, + ) + + +class _LinearQuantizerArg(OpaqueBase): + """Reference wrapper for passing module to custom ops.""" + + def __init__(self, module: "Linear") -> None: + self.module = module + + +class _LinearBackwardStateArg(OpaqueBase): + """Opaque value object for backward non-tensor arguments.""" + + def __init__(self, args: Tuple) -> None: + self.args = args + + def __eq__(self, other) -> bool: + if not isinstance(other, _LinearBackwardStateArg): + return NotImplemented + self_expr, _ = convert_to_fx(self.args) + other_expr, _ = convert_to_fx(other.args) + return self_expr == other_expr + + def __hash__(self) -> int: + try: + args_expr, _ = convert_to_fx(self.args) + return hash(args_expr) + except TypeError: + return id(self) + + def __fx_repr__(self) -> Tuple[str, Dict[str, object]]: + args_expr, args_globals = convert_to_fx(self.args) + return ( + f"_LinearBackwardStateArg({args_expr})", + {"_LinearBackwardStateArg": _LinearBackwardStateArg, **args_globals}, + ) + + +register_opaque_type(_LinearNonTensorArgs, typ="value") +register_opaque_type(WeightGradStore, typ="reference") +register_opaque_type( + _LinearQuantizerArg, + typ="reference", + members={"module": MemberType.USE_REAL}, +) +register_opaque_type( + _LinearBackwardStateArg, + typ="value", +) +if not is_opaque_type(Float8Quantizer): + register_opaque_type(Float8Quantizer, typ="reference") +if not is_opaque_type(Float8CurrentScalingQuantizer): + register_opaque_type(Float8CurrentScalingQuantizer, typ="reference") +if not is_opaque_type(MXFP8Quantizer): + register_opaque_type(MXFP8Quantizer, typ="reference") + + +class _Linear: """Linear semi-top level module Calls custom cuda extensions. """ + @staticmethod + def _reconstruct_owns_input( + inp: torch.Tensor, + weight: torch.Tensor, + is_grad_enabled: bool, + fp8: bool, + debug: bool, + input_quantizer: Optional[Quantizer], + weight_quantizer: Optional[Quantizer], + parallel_mode: Optional[str], + sequence_parallel: bool, + ub_overlap_ag_fprop: bool, + activation_dtype: torch.dtype, + save_original_input: bool, + ) -> bool: + """Reconstruct whether forward owned saved input storage.""" + backward_needs_input = is_grad_enabled and weight.requires_grad + if not backward_needs_input: + # Forward saved_inputmat is None in this case. + return True + if save_original_input: + return False + + with_input_all_gather_nccl = ( + parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop + ) + if fp8 or debug: + if isinstance(inp, QuantizedTensorStorage): + return False + if with_input_all_gather_nccl or ub_overlap_ag_fprop: + custom = is_custom(input_quantizer) or is_custom(weight_quantizer) + return not custom + return True + return inp.dtype != activation_dtype + @staticmethod def forward( - ctx, weight: torch.Tensor, inp: torch.Tensor, bias: Optional[torch.Tensor], - non_tensor_args: Tuple, - ) -> torch.Tensor: + weight_workspace: WeightWorkspace, + wgrad_store: WeightGradStore, + args: _LinearNonTensorArgs, + module_arg: _LinearQuantizerArg, + ) -> List[torch.Tensor]: # pylint: disable=missing-function-docstring + non_tensor_args = args.args + module = module_arg.module + recipe = module.fp8_meta.get("recipe") ( is_first_microbatch, fp8, fp8_calibration, - wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, fuse_wgrad_accumulation, cpu_offloading, tp_group, @@ -122,13 +244,34 @@ def forward( ub_bulk_wgrad, ub_name, fp8_output, # pylint: disable=unused-variable + fp8_grad, fsdp_group, - module, - skip_fp8_weight_update, symmetric_ar_type, save_original_input, debug, ) = non_tensor_args + skip_fp8_weight_update = None + if skip_fp8_weight_update is None and FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if skip_fp8_weight_update is not None: + is_first_microbatch = False + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = _Linear._select_quantizers( + module=module, + fp8_output=fp8_output, + fp8_grad=fp8_grad, + is_grad_enabled=is_grad_enabled, + debug=debug, + ) + if fp8 and recipe is not None and recipe.delayed(): + if not FP8GlobalStateManager.fp8_graph_capturing(): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(module.fp8_meta) # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -248,25 +391,21 @@ def forward( # ------------------------------------------------------ # Prepare weight tensor + # Quantized workspace is prepared lazily inside custom op. # ------------------------------------------------------ - weightmat = weight if fp8 or debug: - # Configure quantizer - # No need to set the quantizer states if weight is already quantized if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): columnwise_usage = is_grad_enabled and inp.requires_grad if not columnwise_usage: columnwise_usage = ( - is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() + is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() ) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) elif isinstance(weight, QuantizedTensor): - # If weight is already quantized, no need to set quantizer states weight_quantizer = weight._quantizer - # Get quantized weight + update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( + weightmat = weight_workspace.get_weight_workspace( tensor=weight, quantizer=weight_quantizer, cache_name=(None if is_first_microbatch is None else "weight"), @@ -276,9 +415,8 @@ def forward( workspace_dtype=activation_dtype, ) weightmat.update_usage(rowwise_usage=True) - else: - weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP + weightmat = cast_if_needed(weight, activation_dtype) # Cast for AMP # ------------------------------------------------------ # Weight tensor is ready for GEMM... # ------------------------------------------------------ @@ -376,9 +514,7 @@ def forward( if save_original_input: inputmat = inp - ctx.weight_quantizer = weight_quantizer - - ctx.backward_input_needs_gather = ( + backward_input_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) @@ -389,7 +525,7 @@ def forward( and isinstance(inputmat, QuantizedTensorStorage) ): if ( - ctx.backward_input_needs_gather + backward_input_needs_gather and weight_quantizer.supports_only_rowwise_all_gather() ): # All-gather is not supported with FP8 column-wise data @@ -401,7 +537,8 @@ def forward( # Cached input tensor saved_inputmat = None if backward_needs_input: - saved_inputmat = inputmat + # Keep the original activation for backward in non-FP8 mode. + saved_inputmat = inp if not (fp8 or debug) else inputmat if cpu_offloading and saved_inputmat is not None: mark_activation_offload(saved_inputmat) @@ -409,8 +546,7 @@ def forward( # Scatter intermediate/activation tensors saved for the backward pass # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights nvtx_range_push(f"{nvtx_label}.fsdp_scatter") - ctx.fsdp_group = fsdp_group - ctx.fsdp_shapes = _fsdp_scatter_tensors( + fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, saved_inputmat, weightmat if fp8 and not isinstance(weight, QuantizedTensorStorage) else None, @@ -418,110 +554,411 @@ def forward( nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - - if ctx.grad_added_to_main_grad: - # If you are passing torch.nn.Parameter through the Torch hooks, you will - # get back torch.Tensor. Torch rips off the Parameter wrapper. - # You need to preserve the weight object to have all the attributes user - # sets for the weights. Because of this, it is not recommended to offload - # weights if weights are externally touched outside this module - ctx.weight_object = weight - mark_not_offload(weight, weightmat, bias) - # TODO(ksivamani): Check memory usage - tensors_to_save, tensor_objects = prepare_for_saving( - saved_inputmat, - weightmat, - weight, - bias, + saved_inputmat_to_save = ( + saved_inputmat if isinstance(saved_inputmat, QuantizedTensorStorage) else None + ) + weightmat_to_save = weightmat if isinstance(weightmat, QuantizedTensorStorage) else None + tensors_to_save, _ = prepare_for_saving( + saved_inputmat_to_save, + weightmat_to_save, ) - ctx.save_for_backward(*tensors_to_save) - ctx.tensor_objects = tensor_objects - - ctx.activation_dtype = activation_dtype - ctx.fp8 = fp8 - ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.input_quantizer = input_quantizer - ctx.grad_input_quantizer = grad_input_quantizer - ctx.grad_weight_quantizer = grad_weight_quantizer - ctx.grad_output_quantizer = grad_output_quantizer - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - if fuse_wgrad_accumulation and weight.requires_grad: - # This check is needed to ensure that main_grad is not created - # during the forward pass when using MCore FSDP as it creates - # the main_grad buffer lazily before backprop - if hasattr(weight, "__fsdp_param__"): - # MCore FSDP creates main_grad lazily before backward - ctx.main_grad_func = weight.get_main_grad - else: - ctx.main_grad_func = lambda: weight.main_grad - - ctx.debug = debug - ctx.custom = custom - ctx.cpu_offloading = cpu_offloading - ctx.is_first_microbatch = is_first_microbatch - ctx.use_bias = bias is not None - ctx.sequence_parallel = sequence_parallel - ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp.shape - ctx.parallel_mode = parallel_mode - ctx.tp_group = tp_group - ctx.ub_overlap_ag = ub_overlap_ag_dgrad - ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - ctx.ub_bulk_dgrad = ub_bulk_dgrad - ctx.ub_bulk_wgrad = ub_bulk_wgrad - ctx.ub_name = ub_name - ctx.tp_size = tp_size - ctx.requires_dgrad = inp.requires_grad - ctx.requires_wgrad = weight.requires_grad - ctx.reduce_and_update_bwd_fp8_tensors = False - - ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and requires_grad(inp, weight, bias): - _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE - ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() - if in_fp8_activation_recompute_phase(): - FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module - ctx.wgrad_store = wgrad_store # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ - return out + if is_grad_enabled: + tensors_to_save = _Linear._replace_none_saved_tensors(tensors_to_save, out) + return [out, *tensors_to_save] + return [out] @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - # pylint: disable=missing-function-docstring + def setup_context(ctx, inputs, output): + """Save state for backward pass. + + Called automatically after forward() when using apply(). + Receives the inputs and output of forward() and sets up ctx + for the backward pass. All non-tensor state is re-derived from + ``inputs`` (which includes ``non_tensor_args``). + """ + ( + weight, + inp, + bias, + weight_workspace, + wgrad_store, + args, + module_arg, + ) = inputs + module = module_arg.module + recipe = module.fp8_meta.get("recipe") + non_tensor_args = args.args + + # Unpack non_tensor_args (same order as Linear.forward packs them) + ( + is_first_microbatch, + fp8, + fp8_calibration, + fuse_wgrad_accumulation, + cpu_offloading, + tp_group, + tp_size, + sequence_parallel, + tensor_parallel, + activation_dtype, + parallel_mode, + is_grad_enabled, + ub_overlap_rs_fprop, + ub_overlap_ag_dgrad, + ub_overlap_ag_fprop, + ub_overlap_rs_dgrad, + ub_bulk_dgrad, + ub_bulk_wgrad, + ub_name, + fp8_output, # pylint: disable=unused-variable + fp8_grad, + fsdp_group, + symmetric_ar_type, # pylint: disable=unused-variable + save_original_input, # pylint: disable=unused-variable + debug, + ) = non_tensor_args + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = _Linear._select_quantizers( + module=module, + fp8_output=fp8_output, + fp8_grad=fp8_grad, + is_grad_enabled=is_grad_enabled, + debug=debug, + ) + + tensors_to_save = output[1:] + if tensors_to_save: + ctx.save_for_backward(*tensors_to_save) + ( + _tensors_to_save_ctx, + tensor_objects, + saved_inputmat, + weightmat, + ) = _Linear._build_saved_tensors_for_context( + inp=inp, + weight=weight, + bias=bias, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + fp8=fp8, + debug=debug, + save_original_input=save_original_input, + is_grad_enabled=is_grad_enabled, + activation_dtype=activation_dtype, + ) + ctx.tensor_objects = tensor_objects + ctx.fsdp_shapes = _collect_fsdp_tensor_shapes( + saved_inputmat, + weightmat if fp8 and not isinstance(weight, QuantizedTensorStorage) else None, + ) + + ctx.backward_input_needs_gather = ( + weight.requires_grad and parallel_mode == "column" and sequence_parallel + ) + + ctx.fsdp_group = fsdp_group + + ctx.activation_dtype = activation_dtype + ctx.fp8 = fp8 + ctx.fp8_recipe = recipe if fp8 else None + ctx.inp_ref = inp + ctx.weight_ref = weight + ctx.bias_ref = bias + ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + + if fuse_wgrad_accumulation and weight.requires_grad: + # This check is needed to ensure that main_grad is not created + # during the forward pass when using MCore FSDP as it creates + # the main_grad buffer lazily before backprop + if hasattr(weight, "__fsdp_param__"): + # MCore FSDP creates main_grad lazily before backward + ctx.main_grad_func = weight.get_main_grad + else: + ctx.main_grad_func = lambda: weight.main_grad + + ctx.debug = debug + ctx.custom = is_custom(input_quantizer) or is_custom(weight_quantizer) + ctx.cpu_offloading = cpu_offloading + + if cpu_offloading: + ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + if ctx.grad_added_to_main_grad: + ctx.weight_object = weight + ctx.is_first_microbatch = is_first_microbatch + ctx.sequence_parallel = sequence_parallel + ctx.tensor_parallel = tensor_parallel + ctx.tp_group = tp_group + + # Apply debug modifications to UB flags (matching forward behavior) + if debug: + ub_overlap_rs_dgrad = False + ub_bulk_wgrad = False + ub_bulk_dgrad = False + + ub_overlap_ag = ub_overlap_ag_dgrad + requires_dgrad = inp.requires_grad + requires_wgrad = weight.requires_grad + reduce_and_update_bwd_fp8_tensors = False + if fp8 and requires_grad(inp, weight, bias): + _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + if in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + owns_input = _Linear._reconstruct_owns_input( + inp=inp, + weight=weight, + is_grad_enabled=is_grad_enabled, + fp8=fp8, + debug=debug, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + ub_overlap_ag_fprop=ub_overlap_ag_fprop, + activation_dtype=activation_dtype, + save_original_input=save_original_input, + ) + ctx.wgrad_store = wgrad_store + ctx.module_ref_backward = module_arg + main_grad_func = ctx.main_grad_func if hasattr(ctx, "main_grad_func") else None + grad_added_to_main_grad = ( + ctx.grad_added_to_main_grad if hasattr(ctx, "grad_added_to_main_grad") else False + ) + weight_object = ctx.weight_object if hasattr(ctx, "weight_object") else None + ctx.args_value_backward = _LinearBackwardStateArg( + ( + list(inp.shape), # inp_shape + bias is not None, # use_bias + requires_dgrad, + requires_wgrad, + parallel_mode, + ub_overlap_ag, + ub_overlap_rs_dgrad, + ub_bulk_dgrad, + ub_bulk_wgrad, + ub_name, + tp_size, + reduce_and_update_bwd_fp8_tensors, + owns_input, + tensor_objects, + ctx.fsdp_shapes, + ctx.backward_input_needs_gather, + fsdp_group, + activation_dtype, + fp8, + recipe if fp8 else None, + fp8_output, + fp8_grad, + fuse_wgrad_accumulation, + debug, + cpu_offloading, + grad_added_to_main_grad, + weight_object, + is_first_microbatch, + sequence_parallel, + tp_group, + main_grad_func, + wgrad_store, + ) + ) + + @staticmethod + def _build_saved_tensors_for_context( + *, + inp: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + input_quantizer: Optional[Quantizer], + weight_quantizer: Optional[Quantizer], + fp8: bool, + debug: bool, + save_original_input: bool, + is_grad_enabled: bool, + activation_dtype: torch.dtype, + weightmat_override: Optional[torch.Tensor] = None, + ): + """Build placeholders and tensor object metadata for backward restore.""" + backward_needs_input = is_grad_enabled and weight.requires_grad + saved_inputmat = None + if backward_needs_input: + if save_original_input or not (fp8 or debug) or input_quantizer is None: + saved_inputmat = inp + else: + saved_inputmat = input_quantizer.make_empty_like_layout(inp, internal=True) + + if weightmat_override is not None: + weightmat = weightmat_override + elif fp8 or debug: + if isinstance(weight, QuantizedTensor): + weightmat = weight + elif weight_quantizer is not None: + weightmat = weight_quantizer.make_empty_like_layout(weight, internal=True) + else: + weightmat = weight + else: + weightmat = cast_if_needed(weight, activation_dtype) + + saved_inputmat_to_save = ( + saved_inputmat if isinstance(saved_inputmat, QuantizedTensorStorage) else None + ) + weightmat_to_save = weightmat if isinstance(weightmat, QuantizedTensorStorage) else None + tensors_to_save, tensor_objects = prepare_for_saving( + saved_inputmat_to_save, + weightmat_to_save, + ) + return tensors_to_save, tensor_objects, saved_inputmat, weightmat + + @staticmethod + def _select_quantizers( + *, + module: "Linear", + fp8_output: bool, + fp8_grad: bool, + is_grad_enabled: bool, + debug: bool, + ): + quantizers = ( + module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if not debug + else module._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) + ) + if debug and module.no_debug_features_active(quantizers): + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + return quantizers + + @staticmethod + def _replace_none_saved_tensors( + tensors_to_save: List[Optional[torch.Tensor]], + like_tensor: torch.Tensor, + ) -> List[torch.Tensor]: + """custom_op Tensor[] outputs cannot contain None or non-tensor values.""" + return [ + t if isinstance(t, torch.Tensor) else like_tensor.new_empty((0,)) + for t in tensors_to_save + ] + + @staticmethod + def _backward_impl( + ctx, + grad_output, + args_value_backward, + module_ref_backward, + saved_tensors_override=None, + inp_override: Optional[torch.Tensor] = None, + weight_override: Optional[torch.Tensor] = None, + bias_override: Optional[torch.Tensor] = None, + ): + # pylint: disable=missing-function-docstring + real_ctx = ctx + ( + inp_shape, + use_bias, + requires_dgrad, + requires_wgrad, + parallel_mode, + ub_overlap_ag, + ub_overlap_rs_dgrad, + ub_bulk_dgrad, + ub_bulk_wgrad, + ub_name, + tp_size, + reduce_and_update_bwd_fp8_tensors, + owns_input, + tensor_objects, + fsdp_shapes, + backward_input_needs_gather, + fsdp_group, + activation_dtype, + fp8, + fp8_recipe, + fp8_output, + fp8_grad, + fuse_wgrad_accumulation, + debug, + cpu_offloading, + grad_added_to_main_grad, + weight_object, + is_first_microbatch, + sequence_parallel, + tp_group, + main_grad_func, + wgrad_store, + ) = args_value_backward.args + module = module_ref_backward.module + ( + input_quantizer, + weight_quantizer, + _output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = _Linear._select_quantizers( + module=module, + fp8_output=fp8_output, + fp8_grad=fp8_grad, + is_grad_enabled=True, + debug=debug, + ) + custom = is_custom(input_quantizer) or is_custom(weight_quantizer) # NVTX label for profiling nvtx_label = "transformer_engine._Linear.backward" - if ctx.ub_name is not None: - nvtx_label = f"{nvtx_label}.{ctx.ub_name}" + if ub_name is not None: + nvtx_label = f"{nvtx_label}.{ub_name}" with get_nvtx_range_context("_Linear_backward"): - saved_tensors = ctx.saved_tensors - inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking - restore_from_saved(ctx.tensor_objects, saved_tensors) - ) - - # Delete the references to tensor objects once they've been consumed - # by the `restore_from_saved` method to construct back the actual tensors. - ctx.tensor_objects = None + inp_shape = tuple(inp_shape) + + saved_tensors = real_ctx.saved_tensors if saved_tensors_override is None else saved_tensors_override + restored_objects = restore_from_saved(tensor_objects, saved_tensors) + if len(tensor_objects) == 2: + inputmat_saved, weight_fp8_saved = restored_objects + inputmat = ( + inputmat_saved + if isinstance(tensor_objects[0], QuantizedTensorStorage) + else inp_override + ) + weight_fp8 = ( + weight_fp8_saved + if isinstance(tensor_objects[1], QuantizedTensorStorage) + else weight_override + ) + weight = weight_override + bias = bias_override + else: + inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking + restored_objects + ) + if isinstance(weight, QuantizedTensor): + weight_quantizer = weight._quantizer + if not requires_wgrad: + inputmat = None + if not use_bias: + bias = None # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ( - ctx.main_grad_func() - if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad + main_grad_func() + if weight is not None and fuse_wgrad_accumulation and requires_wgrad and main_grad_func else None ) - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: - weight = ctx.weight_object - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + if cpu_offloading and grad_added_to_main_grad: + weight = weight_object + if requires_wgrad and fuse_wgrad_accumulation: weight.main_grad = main_grad # Gather intermediate/activation tensors if needed @@ -529,39 +966,39 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # shards/unshards the base weights so we don't do it ourselves nvtx_range_push(f"{nvtx_label}.fsdp_gather") _fsdp_gather_tensors( - ctx.fsdp_group, - ctx.fsdp_shapes, + fsdp_group, + fsdp_shapes, inputmat, weight_fp8, ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") # Configure Userbuffers communication (comm+GEMM overlap) - ctx.ub_obj_gradout = None + ub_obj_gradout = None ub_obj_dgrad = None ub_obj_wgrad = None ub_type_dgrad = None ub_type_wgrad = None - dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] - if ctx.ub_overlap_ag: + dgrad_shape = [reduce(multiply_op, inp_shape[:-1]), inp_shape[-1]] + if ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) - ub_obj_dgrad = ctx.ub_obj_gradout + ub_obj_gradout = get_ub(ub_name + "_dgrad", fp8) + ub_obj_dgrad = ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG - elif ctx.ub_overlap_rs_dgrad: + elif ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) - ub_obj_dgrad = ctx.ub_obj_gradout + ub_obj_gradout = get_ub(ub_name + "_dgrad", fp8) + ub_obj_dgrad = ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: - if ctx.ub_bulk_dgrad: + if ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) - ub_obj_dgrad = ctx.ub_obj_gradout + ub_obj_gradout = get_ub(ub_name + "_dgrad", fp8) + ub_obj_dgrad = ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG - if ctx.ub_bulk_wgrad: + if ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_wgrad = get_ub(ub_name + "_wgrad", fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -575,10 +1012,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None: - quantizer = ctx.grad_output_quantizer + if grad_output_quantizer is not None: + quantizer = grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) - if ctx.ub_overlap_ag: + if ub_overlap_ag: # Userbuffers only supports communication for one # tensor usage at a time. Configure quantizer with # usage for only dgrad GEMM. @@ -591,22 +1028,30 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # NOTE: For `ctx.bias is True`, selected quantize kernel errors with # `cast_kernels.cuh:1322 in function fp8_quantize_arch_l_100: Not implemented scaling mode or fusion: NVTE_DELAYED_TENSOR_SCALING or IS_DBIAS=true on GPU with compute capability < 10.0.` if ( - not ctx.use_bias - and not ctx.requires_wgrad - and ctx.grad_output_quantizer is not None + not use_bias + and not requires_wgrad + and grad_output_quantizer is not None ): - ctx.grad_output_quantizer.set_usage(columnwise=False) + grad_output_quantizer.set_usage(columnwise=False) # Prepare grad output tensor nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") + grad_output_ctx = type("_LinearGradOutputCtx", (), {})() + grad_output_ctx.sequence_parallel = sequence_parallel + grad_output_ctx.fp8 = fp8 + grad_output_ctx.debug = debug + grad_output_ctx.ub_overlap_ag = ub_overlap_ag + grad_output_ctx.tp_group = tp_group + grad_output_ctx.ub_obj_gradout = ub_obj_gradout + grad_output_ctx.use_bias = use_bias ( grad_output, grad_bias, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, + grad_output_ctx, grad_output, - ctx.parallel_mode == "row", - ctx.grad_output_quantizer, + parallel_mode == "row", + grad_output_quantizer, ) nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") @@ -622,53 +1067,53 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # -------------------------------------------------- inputmat_total = None inputmat_total_work = None - if ctx.requires_wgrad: - if ctx.fp8 or ctx.debug: + if requires_wgrad: + if fp8 or debug: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass - elif ctx.debug or ctx.custom: + elif debug or custom: # Debug quantizer will be applied immediately before wgrad GEMM pass else: # Quantize input tensor - quantizer = ctx.input_quantizer + quantizer = input_quantizer if quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data quantizer.set_usage( rowwise=True, - columnwise=not ctx.backward_input_needs_gather, + columnwise=not backward_input_needs_gather, ) else: quantizer.set_usage(rowwise=False, columnwise=True) inputmat = quantizer(inputmat) else: if isinstance(inputmat, QuantizedTensorStorage): - inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) + inputmat = inputmat.dequantize(dtype=activation_dtype) else: - inputmat = cast_if_needed(inputmat, ctx.activation_dtype) - if ctx.backward_input_needs_gather: + inputmat = cast_if_needed(inputmat, activation_dtype) + if backward_input_needs_gather: quantizer = None - if ctx.fp8 or ctx.debug: - quantizer = ctx.input_quantizer + if fp8 or debug: + quantizer = input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) - if ctx.ub_bulk_dgrad: + if ub_bulk_dgrad: inputmat_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_dgrad, inputmat, quantizer, - ctx.tp_group, + tp_group, ) else: nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat, - ctx.tp_group, + tp_group, async_op=True, quantizer=quantizer, ) @@ -685,35 +1130,35 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad = None dgrad_work = None - if ctx.requires_dgrad: + if requires_dgrad: # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance( + if weight_quantizer is not None and isinstance( weight_fp8, QuantizedTensorStorage ): weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: - recipe = ctx.fp8_recipe + if fp8: + recipe = fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None: - ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + if grad_input_quantizer is not None: + grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter gemm_out = None reduce_scatter_out = None - if ctx.ub_overlap_rs_dgrad: + if ub_overlap_rs_dgrad: reduce_scatter_out = torch.empty( - dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device + dgrad_shape, dtype=activation_dtype, device=grad_output_arg.device ) - elif ctx.ub_bulk_wgrad: + elif ub_bulk_wgrad: gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False) # dgrad GEMM @@ -725,34 +1170,34 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer, + quantization_params=grad_input_quantizer, out=gemm_out, - out_dtype=ctx.activation_dtype, + out_dtype=activation_dtype, use_split_accumulator=use_split_accumulator, ub=ub_obj_dgrad, ub_type=ub_type_dgrad, extra_output=reduce_scatter_out, - bulk_overlap=ctx.ub_bulk_dgrad, + bulk_overlap=ub_bulk_dgrad, ) nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") # Prepare grad input tensor # Note: Perform tensor-parallel communication - if ctx.ub_overlap_rs_dgrad: + if ub_overlap_rs_dgrad: dgrad = reduce_scatter_out - elif ctx.ub_bulk_wgrad: + elif ub_bulk_wgrad: dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) - elif ctx.parallel_mode == "column" and ctx.tp_size > 1: + elif parallel_mode == "column" and tp_size > 1: nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") dgrad = gemm_out - if ctx.sequence_parallel: + if sequence_parallel: dgrad, dgrad_work = reduce_scatter_along_first_dim( dgrad, - ctx.tp_group, + tp_group, async_op=True, ) else: - dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + dgrad, dgrad_work = allreduce(dgrad, tp_group, async_op=True) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") else: dgrad = gemm_out @@ -766,7 +1211,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # -------------------------------------------------- wgrad = None - if ctx.requires_wgrad: + if requires_wgrad: # Prepare input tensor # Note: Synchronize tensor-parallel communication and @@ -774,17 +1219,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if ctx.fp8 or ctx.debug: + if fp8 or debug: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: - ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) - inputmat_total = ctx.input_quantizer(inputmat_total) + input_quantizer.set_usage(rowwise=False, columnwise=True) + inputmat_total = input_quantizer(inputmat_total) # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + if ub_overlap_ag and isinstance(grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -796,9 +1241,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_overlap_wgrad = get_ub(ub_name + "_wgrad", fp8) - ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) + grad_output_quantizer.set_usage(rowwise=False, columnwise=True) # We use the send stream to copy into the userbuffers. # This is the same stream that we will use to access the data in the AG, @@ -807,8 +1252,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_overlap_wgrad, grad_output_arg, - ctx.grad_output_quantizer, - ctx.tp_group, + grad_output_quantizer, + tp_group, ) # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm @@ -816,56 +1261,56 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if ctx.fp8 or ctx.debug: + if fp8 or debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: - ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output = ctx.grad_output_quantizer(grad_output) + grad_output_quantizer.set_usage(rowwise=False, columnwise=True) + grad_output = grad_output_quantizer(grad_output) # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: - recipe = ctx.fp8_recipe + if fp8: + recipe = fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad - if ctx.is_first_microbatch is not None: + if is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch + fuse_wgrad_accumulation and not is_first_microbatch ) else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + accumulate_wgrad_into_param_main_grad = fuse_wgrad_accumulation # Output buffer for overlapping FP8 grad input # reduce-scatter with wgrad GEMM reduce_scatter_out = None - if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + if ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): reduce_scatter_out = torch.empty( - dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device + dgrad_shape, dtype=activation_dtype, device=grad_output_arg.device ) # Arguments to include in wgrad GEMM closure wgrad_gemm_kwargs = { "out_dtype": ( - main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype + main_grad.dtype if fuse_wgrad_accumulation else activation_dtype ), - "quantization_params": ctx.grad_weight_quantizer, + "quantization_params": grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) else False ), "layout": "NT", - "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "out": main_grad if fuse_wgrad_accumulation else None, + "bias": (bias if (grad_bias is None and not fp8) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, "ub_type": ub_type_wgrad, "extra_output": reduce_scatter_out, - "bulk_overlap": ctx.ub_bulk_wgrad, + "bulk_overlap": ub_bulk_wgrad, } def wgrad_gemm( @@ -886,7 +1331,7 @@ def wgrad_gemm( return dw, db # Choose whether to call wgrad GEMM now or delay - if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): + if wgrad_store is not None and wgrad_store.delay_wgrad_compute(): if ( wgrad_gemm_kwargs["ub"] is not None or wgrad_gemm_kwargs["ub_type"] is not None @@ -897,7 +1342,7 @@ def wgrad_gemm( "Delayed weight grad computation is not supported " "with Userbuffers (tensor-parallel communication overlapping)" ) - ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm) + wgrad_store.put([inputmat_total, grad_output], wgrad_gemm) else: # Call wgrad GEMM now @@ -909,18 +1354,18 @@ def wgrad_gemm( del grad_bias_ # Deallocate tensors if permitted - if ctx.owns_input: + if owns_input: # Input tensor is internal clear_tensor_data(inputmat_total) - elif ctx.backward_input_needs_gather: + elif backward_input_needs_gather: # Gathered input tensor is internal clear_tensor_data(inputmat_total) - if ctx.parallel_mode == "row" and ctx.sequence_parallel: + if parallel_mode == "row" and sequence_parallel: # Gathered grad output tensor is internal clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM - if ctx.ub_bulk_wgrad: + if ub_bulk_wgrad: if ub_obj_wgrad.is_fp8_ubuf(): dgrad = reduce_scatter_out else: @@ -931,7 +1376,7 @@ def wgrad_gemm( # -------------------------------------------------- # Don't return grad bias if not needed - if not ctx.use_bias: + if not use_bias: grad_bias = None # Make sure all tensor-parallel communication is finished @@ -942,10 +1387,10 @@ def wgrad_gemm( dgrad_work.wait() dgrad_work = None - if ctx.requires_wgrad: + if requires_wgrad: # Handle custom DDP from mcore. if ( - ctx.fuse_wgrad_accumulation + fuse_wgrad_accumulation and weight is not None and hasattr(weight, "grad_added_to_main_grad") ): @@ -961,26 +1406,252 @@ def wgrad_gemm( list(weight.main_grad.shape), weight.dtype, ) - elif ctx.fuse_wgrad_accumulation: + elif fuse_wgrad_accumulation: wgrad = None else: wgrad = None # Update FP8 scaling factors if needed - if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + if reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") # Scatter fp8 weight buffers - if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage): - _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) + if fp8 and not isinstance(weight, QuantizedTensorStorage): + _fsdp_scatter_tensors(fsdp_group, weight_fp8) return ( wgrad, - dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, + dgrad.view(inp_shape) if requires_dgrad else None, grad_bias, + ) + + @staticmethod + def backward(ctx, *grads): + grad_output = grads[0] + if isinstance(grad_output, (list, tuple)): + grad_output = grad_output[0] + ( + _inp_shape, + use_bias, + requires_dgrad, + requires_wgrad, + _parallel_mode, + _ub_overlap_ag, + _ub_overlap_rs_dgrad, + _ub_bulk_dgrad, + _ub_bulk_wgrad, + _ub_name, + _tp_size, + _reduce_and_update_bwd_fp8_tensors, + _owns_input, + _tensor_objects, + _fsdp_shapes, + _backward_input_needs_gather, + _fsdp_group, + _activation_dtype, + _fp8, + _fp8_recipe, + _fp8_output, + _fp8_grad, + fuse_wgrad_accumulation, + _debug, + _cpu_offloading, + _grad_added_to_main_grad, + _weight_object, + _is_first_microbatch, + _sequence_parallel, + _tp_group, + _main_grad_func, + _wgrad_store, + ) = ctx.args_value_backward.args + + wgrad, dgrad, grad_bias = _te_linear_backward( + grad_output, + list(ctx.saved_tensors), + ctx.args_value_backward, + ctx.module_ref_backward, + ctx.inp_ref, + ctx.weight_ref, + ctx.bias_ref, + ) + return ( + wgrad if requires_wgrad and not fuse_wgrad_accumulation else None, + dgrad if requires_dgrad else None, + grad_bias if use_bias else None, + None, + None, None, + None, + ) + + +# Register _Linear.forward as a custom PyTorch operator +_te_linear_forward = torch.library.custom_op( + "te::linear_forward", mutates_args=() +)(_Linear.forward) + + +@_te_linear_forward.register_fake +def _( + weight: torch.Tensor, + inp: torch.Tensor, + bias: Optional[torch.Tensor], + weight_workspace: WeightWorkspace, + wgrad_store: WeightGradStore, + args: _LinearNonTensorArgs, + module_arg: _LinearQuantizerArg, +) -> List[torch.Tensor]: + # pylint: disable=unused-argument + del weight_workspace, wgrad_store + + out_shape = list(inp.shape) + out_shape[-1] = weight.shape[0] + out = inp.new_empty(out_shape) + ( + _is_first_microbatch, + fp8, + _fp8_calibration, + _fuse_wgrad_accumulation, + _cpu_offloading, + _tp_group, + _tp_size, + _sequence_parallel, + _tensor_parallel, + _activation_dtype, + _parallel_mode, + is_grad_enabled, + _ub_overlap_rs_fprop, + _ub_overlap_ag_dgrad, + _ub_overlap_ag_fprop, + _ub_overlap_rs_dgrad, + _ub_bulk_dgrad, + _ub_bulk_wgrad, + _ub_name, + _fp8_output, + _fp8_grad, + _fsdp_group, + _symmetric_ar_type, + _save_original_input, + debug, + ) = args.args + if is_grad_enabled: + module = module_arg.module + ( + _is_first_microbatch, + fp8, + _fp8_calibration, + _fuse_wgrad_accumulation, + _cpu_offloading, + _tp_group, + _tp_size, + _sequence_parallel, + _tensor_parallel, + activation_dtype, + _parallel_mode, + _is_grad_enabled, + _ub_overlap_rs_fprop, + _ub_overlap_ag_dgrad, + _ub_overlap_ag_fprop, + _ub_overlap_rs_dgrad, + _ub_bulk_dgrad, + _ub_bulk_wgrad, + _ub_name, + fp8_output, + fp8_grad, + _fsdp_group, + _symmetric_ar_type, + save_original_input, + debug, + ) = args.args + ( + input_quantizer, + weight_quantizer, + _output_quantizer, + _grad_input_quantizer, + _grad_weight_quantizer, + _grad_output_quantizer, + ) = _Linear._select_quantizers( + module=module, + fp8_output=fp8_output, + fp8_grad=fp8_grad, + is_grad_enabled=True, + debug=debug, + ) + tensors_to_save, _, _, _ = _Linear._build_saved_tensors_for_context( + inp=inp, + weight=weight, + bias=bias, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + fp8=fp8, + debug=debug, + save_original_input=save_original_input, + is_grad_enabled=True, + activation_dtype=activation_dtype, ) + tensors_to_save = _Linear._replace_none_saved_tensors(tensors_to_save, out) + return [out, *tensors_to_save] + return [out] + + +def _te_linear_backward_impl( + grad_output: torch.Tensor, + saved_tensors: List[torch.Tensor], + args_value_backward: _LinearBackwardStateArg, + module_ref_backward: _LinearQuantizerArg, + inp_ref: Optional[torch.Tensor], + weight_ref: Optional[torch.Tensor], + bias_ref: Optional[torch.Tensor], +) -> List[torch.Tensor]: + class _NoSavedCtx: + saved_tensors = () + + grads = _Linear._backward_impl( + _NoSavedCtx(), + grad_output, + args_value_backward=args_value_backward, + module_ref_backward=module_ref_backward, + saved_tensors_override=saved_tensors, + inp_override=inp_ref, + weight_override=weight_ref, + bias_override=bias_ref, + ) + return [g if g is not None else grad_output.new_empty((0,)) for g in grads] + + +_te_linear_backward = torch.library.custom_op( + "te::linear_backward", mutates_args=() +)(_te_linear_backward_impl) + + +@_te_linear_backward.register_fake +def _( + grad_output: torch.Tensor, + saved_tensors: List[torch.Tensor], + args_value_backward: _LinearBackwardStateArg, + module_ref_backward: _LinearQuantizerArg, + inp_ref: Optional[torch.Tensor], + weight_ref: Optional[torch.Tensor], + bias_ref: Optional[torch.Tensor], +) -> List[torch.Tensor]: + del module_ref_backward, inp_ref, bias_ref + inp_shape, use_bias, requires_dgrad, *_ = args_value_backward.args + weight_for_dgrad = weight_ref if weight_ref is not None else saved_tensors[1] + wgrad = grad_output.new_empty((0,)) + dgrad = grad_output.new_empty((0,)) + if requires_dgrad: + dgrad_shape = tuple(inp_shape) + dgrad = weight_for_dgrad.new_empty(dgrad_shape) + grad_bias = grad_output.new_empty((0,)) + if use_bias: + grad_bias = grad_output.new_empty((grad_output.shape[-1],)) + return [wgrad, dgrad, grad_bias] + +_te_linear_forward.register_autograd( + _Linear.backward, + setup_context=_Linear.setup_context, +) class Linear(TransformerEngineBaseModule): @@ -1099,6 +1770,7 @@ def __init__( name: Optional[str] = None, ) -> None: super().__init__(name) + self._module_arg = _LinearQuantizerArg(self) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1340,7 +2012,6 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - @no_torch_dynamo() def forward( self, inp: torch.Tensor, @@ -1376,13 +2047,6 @@ def forward( debug = self.is_debug_iter() - if FP8GlobalStateManager.fp8_graph_capturing(): - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() - else: - skip_fp8_weight_update = None - if skip_fp8_weight_update is not None: - is_first_microbatch = False - if self.ub_overlap_rs_fprop: if get_ub( self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() @@ -1394,47 +2058,18 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) + inp = self.prepare_forward( + inp, + allow_non_contiguous=isinstance(inp, QuantizedTensor), + defer_fp8_global_buffer_update=True, + ) try: weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - - if is_grad_enabled: - linear_fn = _Linear.apply - autograd_ctx = [] - else: - linear_fn = _Linear.forward - autograd_ctx = [None] - non_tensor_args = ( is_first_microbatch, self.fp8, self.fp8_calibration, - self.wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), self.tp_group, @@ -1452,20 +2087,26 @@ def forward( self.ub_bulk_wgrad, self.ub_name, fp8_output, + fp8_grad, self.fsdp_group, - self, - skip_fp8_weight_update, self.symmetric_ar_type, self.save_original_input, debug, ) - out = linear_fn( - *autograd_ctx, + bias_arg = ( + bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None + ) + args = _LinearNonTensorArgs(non_tensor_args) + result = _te_linear_forward( weight_tensor, inp, - bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, - non_tensor_args, + bias_arg, + self.weight_workspace, + self.wgrad_store, + args, + self._module_arg, ) + out = result[0] finally: self.end_forward() if self.gemm_bias_unfused_add: diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index d78677bc83..99fefb5fa9 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -11,6 +11,7 @@ import math import torch +from torch._opaque_base import OpaqueBaseMeta from torch.utils._pytree import tree_map from transformer_engine.common.recipe import Recipe @@ -173,7 +174,11 @@ def restore_from_saved( return tensor_objects -class Quantizer(abc.ABC): +class _TEQuantizerMeta(OpaqueBaseMeta, abc.ABCMeta): + """Metaclass that combines PyTorch opaque-type and abstract-class behavior.""" + + +class Quantizer(metaclass=_TEQuantizerMeta): """Builder class for quantized tensors. This class is typically used to convert a high-precision tensor @@ -292,6 +297,40 @@ def make_empty( "required for construction of unintialized quantized tensor" ) + def make_empty_like_layout( + self, + source: Union[torch.Tensor, "QuantizedTensor", QuantizedTensorStorage], + *, + internal: Optional[bool] = None, + requires_grad: bool = False, + ) -> Union["QuantizedTensor", QuantizedTensorStorage]: + """Create an empty quantized object with layout matching ``source``. + + Default implementation builds a tensor variant and does not support + storage-only construction. Subclasses with dedicated storage classes + should override this for ``internal=True``. + """ + shape = tuple(source.size()) + dtype = getattr(source, "dtype", torch.float32) + device = getattr(source, "device", None) + + quantizer = self.copy() if hasattr(self, "copy") else self + if internal is None: + internal = quantizer.internal + quantizer.internal = internal + + if internal: + raise NotImplementedError( + f"{self.__class__.__name__} does not implement storage construction " + "for internal=True in make_empty_like_layout" + ) + return quantizer.make_empty( + shape=shape, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + def calibrate(self, tensor: torch.Tensor) -> None: """Calibrate quantizer state @@ -445,7 +484,6 @@ def expand_as(self, other: torch.Tensor) -> torch.Tensor: @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - # Detach op if func == torch.ops.aten.detach.default: return args[0].detach() diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 55bca49af3..40b7959861 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -18,7 +18,7 @@ ) from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func -from ..quantized_tensor import QuantizedTensor, Quantizer +from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ._quantization_helpers import _IdentityFunc from ..constants import dist_group_type @@ -154,6 +154,39 @@ def make_empty( quantizer=self, ) + def make_empty_like_layout( + self, + source: Union[torch.Tensor, QuantizedTensor, QuantizedTensorStorage], + *, + internal: Optional[bool] = None, + requires_grad: bool = False, + ) -> Union[Float8Tensor, Float8TensorStorage]: + shape = tuple(source.size()) + dtype = getattr(source, "dtype", torch.float32) + device = getattr(source, "device", None) + + quantizer = self.copy() + if internal is None: + internal = quantizer.internal + quantizer.internal = internal + + out = quantizer.make_empty( + shape=shape, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + if not internal: + return out + + return Float8TensorStorage( + data=out._data, + fp8_scale_inv=out._scale_inv, + fp8_dtype=out._fp8_dtype, + data_transpose=out._transpose, + quantizer=quantizer, + ) + def calibrate(self, tensor: torch.Tensor) -> None: amin, amax = tensor.aminmax() self.amax.copy_(torch.max(-amin, amax)) @@ -374,6 +407,39 @@ def make_empty( quantizer=self, ) + def make_empty_like_layout( + self, + source: Union[torch.Tensor, QuantizedTensor, QuantizedTensorStorage], + *, + internal: Optional[bool] = None, + requires_grad: bool = False, + ) -> Union[Float8Tensor, Float8TensorStorage]: + shape = tuple(source.size()) + dtype = getattr(source, "dtype", torch.float32) + device = getattr(source, "device", None) + + quantizer = self.copy() + if internal is None: + internal = quantizer.internal + quantizer.internal = internal + + out = quantizer.make_empty( + shape=shape, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + if not internal: + return out + + return Float8TensorStorage( + data=out._data, + fp8_scale_inv=out._scale_inv, + fp8_dtype=out._fp8_dtype, + data_transpose=out._transpose, + quantizer=quantizer, + ) + def calibrate(self, tensor: torch.Tensor) -> None: # current scaling don't need to calibrate return diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 41d6c87f2b..39d05590fe 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -160,6 +160,41 @@ def make_empty( with_gemm_swizzled_scales=self.optimize_for_gemm, ) + def make_empty_like_layout( + self, + source: Union[torch.Tensor, QuantizedTensor, QuantizedTensorStorage], + *, + internal: Optional[bool] = None, + requires_grad: bool = False, + ) -> Union[MXFP8Tensor, MXFP8TensorStorage]: + shape = tuple(source.size()) + dtype = getattr(source, "dtype", torch.float32) + device = getattr(source, "device", None) + + quantizer = self.copy() + if internal is None: + internal = quantizer.internal + quantizer.internal = internal + + out = quantizer.make_empty( + shape=shape, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + if not internal: + return out + + return MXFP8TensorStorage( + rowwise_data=out._rowwise_data, + rowwise_scale_inv=out._rowwise_scale_inv, + columnwise_data=out._columnwise_data, + columnwise_scale_inv=out._columnwise_scale_inv, + fp8_dtype=out._fp8_dtype, + quantizer=quantizer, + with_gemm_swizzled_scales=out._with_gemm_swizzled_scales, + ) + def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass From f0471685c080a8198fcfa72098ee837ce1465088 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Feb 2026 15:35:01 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/base.py | 5 ++- .../pytorch/module/fx_serialize.py | 9 ++-- transformer_engine/pytorch/module/linear.py | 41 ++++++++----------- 3 files changed, 27 insertions(+), 28 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index c90a7153f7..1d94963034 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1211,7 +1211,10 @@ def prepare_forward( "necessary when using sequence parallelism with FP8." ) - if not defer_fp8_global_buffer_update and not FP8GlobalStateManager.fp8_graph_capturing(): + if ( + not defer_fp8_global_buffer_update + and not FP8GlobalStateManager.fp8_graph_capturing() + ): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) # Activation recomputation is used and this is the first forward phase. diff --git a/transformer_engine/pytorch/module/fx_serialize.py b/transformer_engine/pytorch/module/fx_serialize.py index 6fb6398828..d4cf305544 100644 --- a/transformer_engine/pytorch/module/fx_serialize.py +++ b/transformer_engine/pytorch/module/fx_serialize.py @@ -25,6 +25,7 @@ def add_global(self, name: str, value: object) -> None: raise RuntimeError(f"FX serializer global name collision for '{name}'") self.globals[name] = value + @singledispatch def _convert(value: object, ctx: _SerializeContext) -> str: try: @@ -135,7 +136,7 @@ def _(value: QuantizedTensorStorage, ctx: _SerializeContext) -> str: f"fp8_dtype={_convert(getattr(value, '_fp8_dtype'), ctx)}, " f"data_transpose={_convert(getattr(value, '_transpose'), ctx)}, " f"quantizer={_convert_or_none(getattr(value, '_quantizer'), ctx)}" - f")" + ")" ) if cls_name == "MXFP8TensorStorage": @@ -148,7 +149,7 @@ def _(value: QuantizedTensorStorage, ctx: _SerializeContext) -> str: f"fp8_dtype={_convert(getattr(value, '_fp8_dtype'), ctx)}, " f"quantizer={_convert_or_none(getattr(value, '_quantizer'), ctx)}, " f"with_gemm_swizzled_scales={_convert(getattr(value, '_with_gemm_swizzled_scales'), ctx)}" - f")" + ")" ) if cls_name == "Float8BlockwiseQTensorStorage": @@ -161,7 +162,7 @@ def _(value: QuantizedTensorStorage, ctx: _SerializeContext) -> str: f"fp8_dtype={_convert(getattr(value, '_fp8_dtype'), ctx)}, " f"quantizer={_convert_or_none(getattr(value, '_quantizer'), ctx)}, " f"is_2D_scaled={_convert(getattr(value, '_is_2D_scaled'), ctx)}" - f")" + ")" ) if cls_name == "NVFP4TensorStorage": @@ -176,7 +177,7 @@ def _(value: QuantizedTensorStorage, ctx: _SerializeContext) -> str: f"fp4_dtype={_convert(getattr(value, '_fp4_dtype'), ctx)}, " f"quantizer={_convert_or_none(getattr(value, '_quantizer'), ctx)}, " f"with_gemm_swizzled_scales={_convert(getattr(value, '_with_gemm_swizzled_scales'), ctx)}" - f")" + ")" ) # Fall back to generic object serializer for unknown storage subclasses. diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 720b4fe24e..5791828918 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -398,7 +398,8 @@ def forward( columnwise_usage = is_grad_enabled and inp.requires_grad if not columnwise_usage: columnwise_usage = ( - is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() ) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) elif isinstance(weight, QuantizedTensor): @@ -922,7 +923,9 @@ def _backward_impl( with get_nvtx_range_context("_Linear_backward"): inp_shape = tuple(inp_shape) - saved_tensors = real_ctx.saved_tensors if saved_tensors_override is None else saved_tensors_override + saved_tensors = ( + real_ctx.saved_tensors if saved_tensors_override is None else saved_tensors_override + ) restored_objects = restore_from_saved(tensor_objects, saved_tensors) if len(tensor_objects) == 2: inputmat_saved, weight_fp8_saved = restored_objects @@ -952,7 +955,10 @@ def _backward_impl( # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ( main_grad_func() - if weight is not None and fuse_wgrad_accumulation and requires_wgrad and main_grad_func + if weight is not None + and fuse_wgrad_accumulation + and requires_wgrad + and main_grad_func else None ) @@ -1027,11 +1033,7 @@ def _backward_impl( # results in `Assertion failed: output_tensor->has_data(). Quantizing in only the columnwise direction not supported yet!` # NOTE: For `ctx.bias is True`, selected quantize kernel errors with # `cast_kernels.cuh:1322 in function fp8_quantize_arch_l_100: Not implemented scaling mode or fusion: NVTE_DELAYED_TENSOR_SCALING or IS_DBIAS=true on GPU with compute capability < 10.0.` - if ( - not use_bias - and not requires_wgrad - and grad_output_quantizer is not None - ): + if not use_bias and not requires_wgrad and grad_output_quantizer is not None: grad_output_quantizer.set_usage(columnwise=False) # Prepare grad output tensor @@ -1135,9 +1137,7 @@ def _backward_impl( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if weight_quantizer is not None and isinstance( - weight_fp8, QuantizedTensorStorage - ): + if weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorStorage): weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator @@ -1293,9 +1293,7 @@ def _backward_impl( # Arguments to include in wgrad GEMM closure wgrad_gemm_kwargs = { - "out_dtype": ( - main_grad.dtype if fuse_wgrad_accumulation else activation_dtype - ), + "out_dtype": (main_grad.dtype if fuse_wgrad_accumulation else activation_dtype), "quantization_params": grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad @@ -1487,9 +1485,7 @@ def backward(ctx, *grads): # Register _Linear.forward as a custom PyTorch operator -_te_linear_forward = torch.library.custom_op( - "te::linear_forward", mutates_args=() -)(_Linear.forward) +_te_linear_forward = torch.library.custom_op("te::linear_forward", mutates_args=())(_Linear.forward) @_te_linear_forward.register_fake @@ -1620,9 +1616,9 @@ class _NoSavedCtx: return [g if g is not None else grad_output.new_empty((0,)) for g in grads] -_te_linear_backward = torch.library.custom_op( - "te::linear_backward", mutates_args=() -)(_te_linear_backward_impl) +_te_linear_backward = torch.library.custom_op("te::linear_backward", mutates_args=())( + _te_linear_backward_impl +) @_te_linear_backward.register_fake @@ -1648,6 +1644,7 @@ def _( grad_bias = grad_output.new_empty((grad_output.shape[-1],)) return [wgrad, dgrad, grad_bias] + _te_linear_forward.register_autograd( _Linear.backward, setup_context=_Linear.setup_context, @@ -2093,9 +2090,7 @@ def forward( self.save_original_input, debug, ) - bias_arg = ( - bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None - ) + bias_arg = bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None args = _LinearNonTensorArgs(non_tensor_args) result = _te_linear_forward( weight_tensor,