Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/pytorch/test_torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pytest
import torch
from contextlib import nullcontext

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling


@pytest.mark.skipif(torch.__version__ < "2", reason="torch.compile not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for TE Linear")
@pytest.mark.parametrize(
"use_fp8,with_backward",
[
(False, False),
(False, True),
(True, False),
(True, True),
],
ids=["fp16_fwd", "fp16_fwd_bwd", "fp8_fwd", "fp8_fwd_bwd"],
)
def test_te_linear_fullgraph_compile(use_fp8, with_backward):
if use_fp8:
fp8_available, reason = te.is_fp8_available(return_reason=True)
if not fp8_available:
pytest.skip(reason)

model = te.Linear(128, 64, device="cuda").to(dtype=torch.bfloat16)
for param in model.parameters():
param.requires_grad_(False)
x = torch.randn(16, 128, device="cuda", dtype=torch.bfloat16, requires_grad=with_backward)

fp8_recipe = DelayedScaling() if use_fp8 else None
maybe_fp8 = te.autocast(enabled=True, recipe=fp8_recipe) if use_fp8 else nullcontext()

with maybe_fp8:
if use_fp8:
model.init_fp8_metadata()
compiled_model = torch.compile(model, fullgraph=True)
out = compiled_model(x)
assert out.shape == (16, 64)
if with_backward:
out.sum().backward()

if with_backward:
assert x.grad is not None
18 changes: 14 additions & 4 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,19 +1938,29 @@ def _fsdp_scatter_tensors(
fsdp_group: dist_group_type,
*tensors: torch.Tensor,
):
shapes = []
shapes = _collect_fsdp_tensor_shapes(*tensors)
if fsdp_group is not None:
for t in tensors:
if isinstance(t, torch.Tensor):
targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t]
for target in targets:
shapes.append(target.data.shape)
safely_set_viewless_tensor_data(
target,
split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True),
)
else:
shapes.append(None)
return shapes


def _collect_fsdp_tensor_shapes(*tensors: torch.Tensor) -> List[Optional[Tuple[int, ...]]]:
"""Collect tensor data shapes in the same order used by FSDP scatter/gather helpers."""
shapes: List[Optional[Tuple[int, ...]]] = []
for t in tensors:
if isinstance(t, torch.Tensor):
targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t]
for target in targets:
shapes.append(target.data.shape)
else:
shapes.append(None)
return shapes


Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
from torch._opaque_base import OpaqueBase

from .. import cpp_extensions as tex
from ..constants import TE_DType
Expand Down Expand Up @@ -195,7 +196,7 @@ def __post_init__(self):
self.init_fn = get_default_init_method()


class WeightGradStore:
class WeightGradStore(OpaqueBase):
"""
A class to manage weight gradient storage and computation in Transformer modules.
This class enables split backward propagation for better memory efficiency.
Expand Down
Loading