diff --git a/examples/pytorch/quantized_model_init/fully_shard.py b/examples/pytorch/quantized_model_init/fully_shard.py new file mode 100644 index 0000000000..6131712001 --- /dev/null +++ b/examples/pytorch/quantized_model_init/fully_shard.py @@ -0,0 +1,266 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""FSDP2 distributed training with quantized model initialization. + +Extends the single-GPU ``main.py`` example to multi-GPU training using +PyTorch-native FSDP2 (``fully_shard``). The script demonstrates: + +1. **Meta-device initialization** -- Model parameters are created on the + ``meta`` device (zero memory), then FSDP2 sharding is applied, and + finally ``reset_parameters()`` materializes and quantizes only the + local shards on each rank's GPU. +2. ``quantized_model_init`` -- Flags the model for FP8 weight initialization + (actual quantization happens in ``reset_parameters`` after sharding). +3. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer. +4. ``FusedAdam`` with FP32 master weights for full-precision training updates. + +.. note:: + ``fuse_wgrad_accumulation`` is **not** used here. That feature writes + weight gradients directly into ``main_grad`` buffers, bypassing the + autograd gradient flow. FSDP2 requires gradients to go through its + reduce-scatter, so ``fuse_wgrad_accumulation`` needs Megatron-Core's + FSDP integration (which provides ``get_main_grad()``). + +Usage:: + + torchrun --nproc-per-node 2 fully_shard.py +""" + +import os + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor + +import transformer_engine.pytorch as te +from transformer_engine.pytorch import QuantizedTensor +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule + +# ── Configuration (matches main.py) ────────────────────────────────── +HIDDEN_SIZE = 256 +FFN_HIDDEN_SIZE = 1024 +NUM_ATTENTION_HEADS = 8 +NUM_LAYERS = 3 +SEQ_LEN = 32 +BATCH_PER_RANK = 2 +NUM_STEPS = 5 +DTYPE = torch.bfloat16 + + +def dist_print(msg): + """Print only on rank 0.""" + if int(os.environ.get("RANK", "0")) == 0: + print(msg) + + +def main(): + # ── 1. Distributed setup ───────────────────────────────────────── + assert "TORCHELASTIC_RUN_ID" in os.environ, ( + "This script must be launched with torchrun, e.g.:\n" + " torchrun --nproc-per-node 2 fully_shard.py" + ) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl") + device = torch.device(f"cuda:{local_rank}") + + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # ── 2. Create model on meta device (zero memory) ──────────────── + # quantized_model_init sets the flag for FP8 weight initialization, + # but with device="meta" no actual memory is allocated yet. + with te.quantized_model_init(enabled=True): + model = torch.nn.Sequential( + *[ + te.TransformerLayer( + HIDDEN_SIZE, + FFN_HIDDEN_SIZE, + NUM_ATTENTION_HEADS, + fuse_qkv_params=True, + params_dtype=DTYPE, + hidden_dropout=0.0, + attention_dropout=0.0, + device="meta", + ) + for _ in range(NUM_LAYERS) + ] + ) + + # Verify all parameters are on meta device (no GPU memory used). + for name, param in model.named_parameters(): + assert param.device == torch.device("meta"), f"{name} is not on meta device" + dist_print("Model created on meta device (zero GPU memory).") + + # ── 3. FSDP2 sharding ──────────────────────────────────────────── + # Apply sharding to the meta-device model. FSDP2 wraps parameters + # as DTensors but no GPU memory is allocated yet. + mesh = DeviceMesh("cuda", list(range(world_size))) + for child in model.children(): + fully_shard(child, mesh=mesh) + fully_shard(model, mesh=mesh) + dist_print("FSDP2 sharding applied to meta-device model.") + + # ── 4. Materialize parameters on GPU ────────────────────────────── + # reset_parameters() on each TE module materializes the local shard + # on CUDA, applies weight initialization, and quantizes to FP8. + for module in model.modules(): + if isinstance(module, TransformerEngineBaseModule): + module.reset_parameters() + + # Post-materialization verification. + for name, param in model.named_parameters(): + assert isinstance(param, DTensor), f"{name} is not a DTensor after sharding" + qt_count = sum( + 1 + for _, p in model.named_parameters() + if isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) + ) + assert qt_count > 0, "No QuantizedTensor local tensors after materialization" + dist_print( + f"Parameters materialized: {qt_count} FP8 (QuantizedTensor) weight params " + "wrapped in DTensors." + ) + + # ── 5. Optimizer ───────────────────────────────────────────────── + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + dist_print("Using FusedAdam with master_weights=True.") + + # ── 6. Training loop ───────────────────────────────────────────── + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device) + target = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device) + + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + + with te.autocast(enabled=True): + output = model(x) + + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + dist_print(f" Step {step}: loss = {loss.item():.6f}") + + # ── 7. Post-training assertions ────────────────────────────────── + dist_print("\nVerifying invariants ...") + + qt_after = 0 + for name, param in model.named_parameters(): + assert isinstance(param, DTensor), f"{name} lost DTensor wrapping" + if isinstance(param._local_tensor, QuantizedTensor): + qt_after += 1 + assert qt_after > 0, "No QuantizedTensor local tensors after training" + dist_print(f" {qt_after} params still have QuantizedTensor local tensors.") + + # Optimizer states: master weights and moments should be float32. + for param in model.parameters(): + state = optimizer.state[param] + if "master_param" in state: + assert ( + state["master_param"].dtype == torch.float32 + ), f"Master weight dtype {state['master_param'].dtype}, expected float32" + assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32" + assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32" + + dist_print("All assertions passed!") + dist_print(" - Linear weight parameters: QuantizedTensor (FP8) wrapped in DTensor") + dist_print(" - Optimizer master weights: float32") + dist_print(" - Optimizer states (exp_avg, exp_avg_sq): float32") + + # ── 8. Distributed checkpoint: save and load ───────────────────── + # torch.distributed.checkpoint (DCP) saves sharded state — each rank + # writes only its local shard. This preserves FP8 compute weights + # and the full optimizer state (master weights, moments, step count). + import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_optimizer_state_dict, + ) + + # Use a fixed path so all ranks agree on the checkpoint location. + checkpoint_dir = "/tmp/te_fsdp2_example_checkpoint" + dist_print(f"\nSaving distributed checkpoint to {checkpoint_dir} ...") + + # Save sharded checkpoint. DCP handles DTensor shards natively — + # each rank writes only its local shard to the filesystem. + dcp.save( + {"model": model.state_dict(), "optimizer": optimizer.state_dict()}, + checkpoint_id=checkpoint_dir, + ) + dist_print(" Checkpoint saved (FP8 weights + optimizer state).") + + # Load checkpoint back. Provide empty state dict containers with the + # same structure; DCP fills them from the saved files. + state_to_load = {"model": model.state_dict(), "optimizer": optimizer.state_dict()} + dcp.load(state_to_load, checkpoint_id=checkpoint_dir) + model.load_state_dict(state_to_load["model"]) + optimizer.load_state_dict(state_to_load["optimizer"]) + dist_print(" Checkpoint loaded — FP8 weights and optimizer state restored.") + + # Verify training continues after checkpoint load. + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True): + output = model(x) + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + dist_print(f" Post-checkpoint training step: loss = {loss.item():.6f}") + + # ── 9. Save full-precision (FP32) model to safetensors ─────────── + # For inference or fine-tuning you typically want FP32 weights, not + # FP8 compute weights. The optimizer's master weight copies are the + # authoritative FP32 values (more precise than dequantizing FP8). + # All ranks must participate in gathering; only rank 0 saves. + from safetensors.torch import save_file + + full_opts = StateDictOptions(full_state_dict=True, cpu_offload=True) + + full_model_state = get_model_state_dict(model, options=full_opts) + full_opt_state = get_optimizer_state_dict(model, optimizer, options=full_opts) + + rank = int(os.environ.get("RANK", "0")) + if rank == 0: + fp32_state = {} + opt_param_states = full_opt_state.get("state", {}) + + for key, value in full_model_state.items(): + if key in opt_param_states and "master_param" in opt_param_states[key]: + # Prefer optimizer's FP32 master weight (maintained throughout training). + fp32_state[key] = opt_param_states[key]["master_param"].float() + elif isinstance(value, QuantizedTensor): + # Fallback: dequantize FP8 → FP32 (e.g. if master_weights was off). + fp32_state[key] = value.dequantize().float() + else: + # Non-FP8 params (e.g. LayerNorm weights): cast to FP32. + fp32_state[key] = value.float() + + save_path = "/tmp/te_fsdp2_example_model_fp32.safetensors" + save_file(fp32_state, save_path) + dist_print(f"\nSaved FP32 model ({len(fp32_state)} params) to {save_path}") + + # Quick verification: all saved tensors are float32. + from safetensors.torch import load_file + + loaded = load_file(save_path) + for k, v in loaded.items(): + assert v.dtype == torch.float32, f"{k}: expected float32, got {v.dtype}" + dist_print(f" Verified: all {len(loaded)} tensors are float32.") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/quantized_model_init/main.py b/examples/pytorch/quantized_model_init/main.py new file mode 100644 index 0000000000..a9d3480cad --- /dev/null +++ b/examples/pytorch/quantized_model_init/main.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Quantized model initialization with FusedAdam and gradient accumulation fusion. + +Demonstrates three Transformer Engine features working together: + +1. ``quantized_model_init`` -- Initialize a model with low-precision (FP8) + parameters, avoiding the memory cost of storing both high-precision and + quantized copies of every weight. + +2. ``FusedAdam`` with master weights -- Maintain FP32 master copies of the + weights inside the optimizer so that the training update retains full + precision despite the model parameters being FP8. + +3. Gradient accumulation fusion -- Use ``fuse_wgrad_accumulation=True`` + together with per-parameter ``main_grad`` buffers so that weight + gradients are accumulated directly in FP32 via Tensor Cores, avoiding a + separate FP8-to-FP32 cast kernel. + +Usage:: + + python main.py +""" + +import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor + +# ── Configuration ────────────────────────────────────────────────────── +HIDDEN_SIZE = 256 +FFN_HIDDEN_SIZE = 1024 +NUM_ATTENTION_HEADS = 8 +SEQ_LEN = 32 +BATCH_SIZE = 2 +NUM_STEPS = 5 +DTYPE = torch.bfloat16 + + +def main(): + # ── 1. Create model with quantized parameters ───────────────────── + # + # Inside quantized_model_init, TransformerEngine modules store only the + # FP8 quantized copy of each parameter (a Float8Tensor), eliminating the + # memory overhead of a high-precision shadow copy. + with te.quantized_model_init(enabled=True): + model = te.TransformerLayer( + HIDDEN_SIZE, + FFN_HIDDEN_SIZE, + NUM_ATTENTION_HEADS, + fuse_wgrad_accumulation=True, + fuse_qkv_params=True, # required for fuse_wgrad_accumulation + params_dtype=DTYPE, + hidden_dropout=0.0, # disable dropout for this synthetic example + attention_dropout=0.0, + ) + + # Verify that linear-layer weight parameters are quantized. + # Biases and LayerNorm parameters are *not* quantized. + quantized_count = 0 + for name, param in model.named_parameters(): + if isinstance(param, QuantizedTensor): + quantized_count += 1 + assert quantized_count > 0, "No QuantizedTensor parameters found" + print(f"Found {quantized_count} QuantizedTensor (FP8) weight parameters.") + + # ── 2. Allocate main_grad buffers (FP32) ────────────────────────── + # + # fuse_wgrad_accumulation causes weight-gradient GEMMs to write directly + # into ``param.main_grad`` in FP32 (via Tensor Core accumulation). + # Non-weight parameters (e.g. LayerNorm) still receive gradients through + # the normal ``param.grad`` path. + for param in model.parameters(): + param.main_grad = torch.zeros(param.shape, dtype=torch.float32, device=param.device) + + # ── 3. Optimizer with FP32 master weights ───────────────────────── + # + # use_decoupled_grad=True tells FusedAdam to read gradients from + # ``param.decoupled_grad`` instead of ``param.grad``. This avoids + # the dtype-mismatch error that would occur when assigning FP32 + # gradients to bfloat16 parameters via ``.grad``. + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + use_decoupled_grad=True, + ) + + # ── 4. Training loop ────────────────────────────────────────────── + # + # Use a fixed synthetic dataset so that loss decreases over steps. + x = torch.randn(SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE, device="cuda") + target = torch.randn(SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE, device="cuda") + + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + for param in model.parameters(): + param.main_grad.zero_() + + # Forward pass inside autocast to enable FP8 compute. + with te.autocast(enabled=True): + output = model(x) + + loss = torch.nn.functional.mse_loss(output, target) + loss.backward() + + # Consolidate gradients into main_grad. + # * Weight params with fuse_wgrad_accumulation: backward already + # accumulated the gradient directly into main_grad (FP32). + # * Other params (e.g. LayerNorm): autograd set param.grad. + for param in model.parameters(): + if param.grad is not None: + param.main_grad.copy_(param.grad) + param.grad = None + + # Expose main_grad as decoupled_grad so FusedAdam can read it. + for param in model.parameters(): + param.decoupled_grad = param.main_grad + + optimizer.step() + print(f" Step {step}: loss = {loss.item():.6f}") + + # ── 5. Post-training assertions ─────────────────────────────────── + print("\nVerifying invariants ...") + + # Optimizer states. + for param in model.parameters(): + state = optimizer.state[param] + if "master_param" in state: + master = state["master_param"] + assert ( + master.dtype == torch.float32 + ), f"Master weight dtype {master.dtype}, expected float32" + assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32" + assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32" + + # main_grad buffers. + for param in model.parameters(): + assert param.main_grad.dtype == torch.float32, "main_grad should be float32" + + print("All assertions passed!") + print(" - Linear weight parameters: QuantizedTensor (FP8)") + print(" - Optimizer master weights: float32") + print(" - Optimizer states (exp_avg, exp_avg_sq): float32") + print(" - Gradient accumulation buffers (main_grad): float32") + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/distributed/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/run_fsdp2_fused_adam.py new file mode 100644 index 0000000000..951d545e50 --- /dev/null +++ b/tests/pytorch/distributed/run_fsdp2_fused_adam.py @@ -0,0 +1,650 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""FSDP2 + FusedAdam compatibility tests. + +Launched via torchrun from test_fused_optimizer.py. +""" + +import argparse +import functools +import os +from collections import Counter + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor + +import transformer_engine.pytorch as te +from transformer_engine.pytorch import QuantizedTensor +import transformer_engine.common.recipe + + +def get_recipe_from_string(recipe): + return getattr(transformer_engine.common.recipe, recipe)() + + +HIDDEN_SIZE = 256 +FFN_HIDDEN_SIZE = 1024 +NUM_ATTENTION_HEADS = 8 +NUM_LAYERS = 2 +SEQ_LEN = 32 +BATCH_PER_RANK = 2 +NUM_STEPS = 3 + + +def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + if isinstance(param, QuantizedTensor): + ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")] + else: + ignore_keys = [] + attrs = vars(param) + custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys} + return custom_attrs + + +def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + +def _setup(): + """Common distributed setup. Returns (world_size, local_rank, device).""" + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + # CPU backend required for async save + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + device = torch.device(f"cuda:{local_rank}") + torch.manual_seed(42) + torch.cuda.manual_seed(42) + return world_size, local_rank, device + + +def _build_model(fp8_init, fuse_wgrad_accumulation=False, recipe=None): + """Build a Sequential of TransformerLayers, optionally with FP8 init.""" + if fp8_init: + ctx = te.quantized_model_init(enabled=True, recipe=recipe) + else: + from contextlib import nullcontext + + ctx = nullcontext() + with ctx: + model = torch.nn.Sequential( + *[ + te.TransformerLayer( + HIDDEN_SIZE, + FFN_HIDDEN_SIZE, + NUM_ATTENTION_HEADS, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + fuse_qkv_params=True, + params_dtype=torch.bfloat16, + hidden_dropout=0.0, + attention_dropout=0.0, + ) + for _ in range(NUM_LAYERS) + ] + ) + return model + + +def _shard_model(model, world_size): + """Apply FSDP2 sharding with save/restore custom attrs.""" + custom_attrs = save_custom_attrs(model) + mesh = DeviceMesh("cuda", list(range(world_size))) + for child in model.children(): + fully_shard(child, mesh=mesh) + fully_shard(model, mesh=mesh) + restore_custom_attrs(model, custom_attrs) + return model + + +def test_fused_adam_fp8_master_weights(recipe=None): + """FusedAdam with master_weights + FSDP2 + quantized_model_init (FP8 params). + + Verifies: + - Optimizer states are created with correct dtype (float32) + - Training loop completes without error + - DTensor wrapping and QuantizedTensor local tensors are preserved + """ + world_size, _, device = _setup() + + model = _build_model(fp8_init=True, recipe=recipe) + + # Verify FP8 params created + qt_count = sum(1 for _, p in model.named_parameters() if isinstance(p, QuantizedTensor)) + assert qt_count > 0, "No QuantizedTensor local tensors before training" + + model = _shard_model(model, world_size) + + # Verify params are DTensors + for name, param in model.named_parameters(): + assert isinstance(param, DTensor), f"{name} is not DTensor" + + # Verify FP8 params after sharding + qt_count = sum( + 1 + for _, p in model.named_parameters() + if isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) + ) + assert qt_count > 0, "No QuantizedTensor local tensors after sharding" + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + + # Verify optimizer states + for param in model.parameters(): + state = optimizer.state[param] + assert ( + state["exp_avg"].dtype == torch.float32 + ), f"exp_avg dtype {state['exp_avg'].dtype}, expected float32" + assert ( + state["exp_avg_sq"].dtype == torch.float32 + ), f"exp_avg_sq dtype {state['exp_avg_sq'].dtype}, expected float32" + if "master_param" in state: + assert ( + state["master_param"].dtype == torch.float32 + ), f"master_param dtype {state['master_param'].dtype}, expected float32" + + # Verify FP8 params preserved + qt_count = sum( + 1 + for _, p in model.named_parameters() + if isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) + ) + assert qt_count > 0, "No QuantizedTensor local tensors after training" + + dist.destroy_process_group() + + +def test_fused_adam_bf16(recipe=None): + """FusedAdam with master_weights + FSDP2 + bf16 params (no FP8). + + Verifies the non-FP8 DTensor param path in step() works correctly. + """ + world_size, _, device = _setup() + + model = _build_model(fp8_init=False) + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + losses = [] + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + losses.append(loss.item()) + loss.backward() + optimizer.step() + + # Verify optimizer states are float32 + for param in model.parameters(): + state = optimizer.state[param] + assert state["exp_avg"].dtype == torch.float32 + assert state["exp_avg_sq"].dtype == torch.float32 + + # Verify loss decreased (basic sanity) + assert losses[-1] < losses[0], f"Loss did not decrease: {losses}" + + dist.destroy_process_group() + + +def test_fused_adam_fp8_no_master(recipe=None): + """FusedAdam without master_weights + FSDP2 + FP8 params. + + Verifies FusedAdam works with FSDP2 even without master weights enabled. + """ + world_size, _, device = _setup() + + model = _build_model(fp8_init=True, recipe=recipe) + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=False, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + + # Verify DTensors preserved + for name, param in model.named_parameters(): + assert isinstance(param, DTensor), f"{name} lost DTensor wrapping" + + dist.destroy_process_group() + + +def test_fused_adam_bf16_store_param_remainders(recipe=None): + """FusedAdam with master_weights + store_param_remainders + FSDP2 + bf16 params. + + store_param_remainders stores only the trailing 16 remainder bits (int16) + instead of full FP32 master params. The FP32 master can be reconstructed + from BF16 params + int16 remainders. Only works with bf16 params + fp32 + master weights. + + Verifies: + - Training loop completes without error + - Optimizer master_param states are int16 (remainder bits) + - exp_avg and exp_avg_sq are float32 + - Loss decreases (basic sanity) + """ + world_size, _, device = _setup() + + model = _build_model(fp8_init=False) + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + store_param_remainders=True, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + losses = [] + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + losses.append(loss.item()) + loss.backward() + optimizer.step() + + # Verify model params are bf16 (required for store_param_remainders) + for name, param in model.named_parameters(): + assert ( + param.dtype == torch.bfloat16 + ), f"{name}: param dtype {param.dtype}, expected bfloat16" + + # Verify optimizer states + for name, param in model.named_parameters(): + state = optimizer.state[param] + assert ( + state["exp_avg"].dtype == torch.float32 + ), f"{name}: exp_avg dtype {state['exp_avg'].dtype}, expected float32" + assert ( + state["exp_avg_sq"].dtype == torch.float32 + ), f"{name}: exp_avg_sq dtype {state['exp_avg_sq'].dtype}, expected float32" + # store_param_remainders stores master_param as int16 remainder bits + if "master_param" in state: + assert ( + state["master_param"].dtype == torch.int16 + ), f"{name}: master_param dtype {state['master_param'].dtype}, expected int16" + + # Verify loss decreased (basic sanity) + assert losses[-1] < losses[0], f"Loss did not decrease: {losses}" + + dist.destroy_process_group() + + +def test_fuse_wgrad_accumulation(recipe=None): + """fuse_wgrad_accumulation=True + FSDP2 -- expected to fail. + + With vanilla FSDP2, PyTorch's autograd Function.apply unwraps DTensor + inputs to local tensors. The local Float8Tensor inside the autograd + function does not have the `main_grad` attribute (which is set on the + DTensor parameter). This causes an AttributeError during backward. + + Additionally, even if main_grad were accessible, fuse_wgrad_accumulation + writes the gradient directly into main_grad and returns None to autograd, + bypassing FSDP2's reduce-scatter. + """ + world_size, _, device = _setup() + + model = _build_model(fp8_init=True, fuse_wgrad_accumulation=True, recipe=recipe) + + # Allocate main_grad buffers on the DTensor params + for param in model.parameters(): + param.main_grad = torch.zeros(param.shape, dtype=torch.float32, device=param.device) + + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + use_decoupled_grad=True, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + # This is currently failing during backward because the local Float8Tensor + # inside the autograd function doesn't have main_grad. + optimizer.zero_grad(set_to_none=True) + for param in model.parameters(): + param.main_grad.zero_() + + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + + loss = F.mse_loss(output, target) + loss.backward() # Expected to raise AttributeError + + dist.destroy_process_group() + + +def test_safetensors_fp32_export(recipe=None): + """Export full-precision (FP32) model to safetensors from optimizer master weights. + + Verifies: + - get_model_state_dict with full_state_dict gathers all params + - get_optimizer_state_dict with full_state_dict gathers optimizer state + - FP32 state dict is built from optimizer master weights + - All saved tensors are float32 + - Saved tensor shapes match expected (unsharded) shapes + """ + from safetensors.torch import load_file, save_file + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_optimizer_state_dict, + ) + + world_size, _, device = _setup() + + model = _build_model(fp8_init=True, recipe=recipe) + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + # Train a few steps. + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + + # Gather full state dicts (all ranks participate). + full_opts = StateDictOptions(full_state_dict=True, cpu_offload=True) + full_model_state = get_model_state_dict(model, options=full_opts) + full_opt_state = get_optimizer_state_dict(model, optimizer, options=full_opts) + + rank = int(os.environ.get("RANK", "0")) + save_path = "/tmp/te_test_fsdp2_model_fp32.safetensors" + + if rank == 0: + # Build FP32 state dict from optimizer master weights. + fp32_state = {} + opt_param_states = full_opt_state.get("state", {}) + + for key, value in full_model_state.items(): + if key in opt_param_states and "master_param" in opt_param_states[key]: + fp32_state[key] = opt_param_states[key]["master_param"].float() + else: + fp32_state[key] = value.float() + + assert len(fp32_state) > 0, "FP32 state dict is empty" + + # Save and verify. + save_file(fp32_state, save_path) + loaded = load_file(save_path) + + assert len(loaded) == len( + fp32_state + ), f"Loaded {len(loaded)} tensors, expected {len(fp32_state)}" + for k, v in loaded.items(): + assert v.dtype == torch.float32, f"{k}: expected float32, got {v.dtype}" + + # Clean up. + os.remove(save_path) + + dist.destroy_process_group() + + +def test_dcp_output_parity(recipe=None, async_save=False): + """DCP save/load round-trip produces bitwise-identical model outputs. + + 1. Builds and trains a model for NUM_STEPS + 2. Runs a forward pass and records the output + 3. Saves model + optimizer state via DCP + 4. Builds a *fresh* model + optimizer (same architecture) + 5. Loads the DCP checkpoint into the fresh model + 6. Runs the same forward pass and asserts outputs are identical + 7. Runs one more training step on both models and asserts outputs still match + """ + import torch.distributed.checkpoint as dcp + + world_size, local_rank, device = _setup() + + # ── Build and train the original model ─────────────────────────── + model = _build_model(fp8_init=True, recipe=recipe) + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + for _ in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + + # Record reference output from the trained model. + with torch.no_grad(): + with te.autocast(enabled=True, recipe=recipe): + ref_output = model(x).clone() + + # ── Save checkpoint ────────────────────────────────────────────── + checkpoint_dir = "/tmp/te_test_fsdp2_dcp_parity" + + if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling): + model_state = { + k: v for k, v in model.state_dict().items() if not k.endswith("_extra_state") + } + else: + model_state = model.state_dict() + + if not async_save: + dcp.save( + {"model": model_state, "optimizer": optimizer.state_dict()}, + checkpoint_id=checkpoint_dir, + ) + future = None + else: + future = dcp.async_save( + {"model": model_state, "optimizer": optimizer.state_dict()}, + checkpoint_id=checkpoint_dir, + ) + + # ── Build a fresh model and load the checkpoint ────────────────── + model2 = _build_model(fp8_init=True, recipe=recipe) + model2 = _shard_model(model2, world_size) + + optimizer2 = te.optimizers.FusedAdam( + model2.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + # Populate optimizer state so load_state_dict has matching structure. + optimizer2.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + out_tmp = model2(x) + F.mse_loss(out_tmp, target).backward() + optimizer2.step() + + if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling): + model2_state = { + k: v for k, v in model2.state_dict().items() if not k.endswith("_extra_state") + } + else: + model2_state = model2.state_dict() + + state_to_load = {"model": model2_state, "optimizer": optimizer2.state_dict()} + + if async_save: + future.result() # Block on async save completion + + dcp.load(state_to_load, checkpoint_id=checkpoint_dir) + model2.load_state_dict( + state_to_load["model"], + strict=( + False if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling) else True + ), + ) + optimizer2.load_state_dict(state_to_load["optimizer"]) + + # ── Verify identical forward-pass output ───────────────────────── + with torch.no_grad(): + with te.autocast(enabled=True, recipe=recipe): + loaded_output = model2(x) + + if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling): + # DelayedScaling stores amax history and scaling factors in _extra_state, + # which cannot be saved via DCP due to non-deterministic pickle sizes + # across ranks. The fresh model therefore uses default scaling factors, + # producing small numerical differences from FP8 re-quantization. + torch.testing.assert_close( + loaded_output, + ref_output, + rtol=0.05, + atol=0.1, + msg="Fresh model loaded from DCP checkpoint produces different output", + ) + else: + torch.testing.assert_close( + loaded_output, + ref_output, + rtol=0, + atol=0, + msg="Fresh model loaded from DCP checkpoint produces different output", + ) + + # ── Verify one more training step produces identical results ───── + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + out1 = model(x) + loss1 = F.mse_loss(out1, target) + loss1.backward() + optimizer.step() + + optimizer2.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + out2 = model2(x) + loss2 = F.mse_loss(out2, target) + loss2.backward() + optimizer2.step() + + if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling): + torch.testing.assert_close( + out2, + out1, + rtol=0.05, + atol=0.1, + msg="Training step after DCP load produces different output", + ) + else: + torch.testing.assert_close( + out2, out1, msg="Training step after DCP load produces different output" + ) + + # ── Cleanup ────────────────────────────────────────────────────── + import shutil + + if int(os.environ.get("RANK", "0")) == 0: + shutil.rmtree(checkpoint_dir, ignore_errors=True) + + dist.destroy_process_group() + + +TESTS = { + "fused_adam_fp8_master_weights": test_fused_adam_fp8_master_weights, + "fused_adam_bf16": test_fused_adam_bf16, + "fused_adam_fp8_no_master": test_fused_adam_fp8_no_master, + "fused_adam_bf16_store_param_remainders": test_fused_adam_bf16_store_param_remainders, + "fuse_wgrad_accumulation": test_fuse_wgrad_accumulation, + "dcp_output_parity": functools.partial(test_dcp_output_parity, async_save=False), + "dcp_output_parity_async": functools.partial(test_dcp_output_parity, async_save=True), + "safetensors_fp32_export": test_safetensors_fp32_export, +} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--test", required=True, choices=list(TESTS.keys())) + parser.add_argument( + "--recipe", + type=str, + default="MXFP8BlockScaling", + help="Quantizer type.", + choices=[ + "DelayedScaling", + "Float8CurrentScaling", + "Float8BlockScaling", + "MXFP8BlockScaling", + "NVFP4BlockScaling", + ], + ) + args = parser.parse_args() + recipe = get_recipe_from_string(args.recipe) + TESTS[args.test](recipe) diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index 5df3468861..60d7cd2023 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -9,12 +9,7 @@ import argparse import transformer_engine.pytorch as te -from transformer_engine.common.recipe import ( - Format, - DelayedScaling, - Float8CurrentScaling, - MXFP8BlockScaling, -) +import transformer_engine.common.recipe import torch import torch.distributed as dist @@ -43,14 +38,23 @@ def _parse_args(argv=None, namespace=None): parser.add_argument("--seq-length", type=int, default=128, help="Sequence length of input") parser.add_argument("--params-dtype", type=str, default="float32", help="Parameter dtype.") parser.add_argument( - "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." + "--fp8-init", + action="store_true", + default=False, + help="Initialize primary weights in FP8.", ) parser.add_argument( "--recipe", type=str, - default="mx_fp8_block_scaling", + default="MXFP8BlockScaling", help="Quantizer type.", - choices=["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"], + choices=[ + "DelayedScaling", + "Float8CurrentScaling", + "Float8BlockScaling", + "MXFP8BlockScaling", + "NVFP4BlockScaling", + ], ) parser.add_argument( "--layer-type", @@ -110,15 +114,8 @@ def get_te_layer_from_string(layer_name): return te_layer_map[layer_name.lower()] -def get_recipe_from_string(recipe, fp8_format=Format.HYBRID): - if recipe == "delayed_scaling": - return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") - elif recipe == "current_scaling": - return Float8CurrentScaling(fp8_format=fp8_format) - elif recipe == "mx_fp8_block_scaling": - return MXFP8BlockScaling(fp8_format=fp8_format) - else: - raise ValueError(f"Unknown quantizer type: {recipe}") +def get_recipe_from_string(recipe): + return getattr(transformer_engine.common.recipe, recipe)() def init_te_model(config): @@ -244,7 +241,7 @@ def test_fp8_fsdp2_allgather(model): module.unshard() # Make sure allgathered parameters match exactly for name, param in model.named_parameters(): - assert torch.allclose(param.dequantize(), fp32_allgathered_params[name]) + torch.testing.assert_close(param.dequantize(), fp32_allgathered_params[name]) # Revert model to original sharded state for module in model.modules(): # Not all modules are wrapped/sharded with FSDP2. @@ -278,8 +275,7 @@ def _train(args): device = torch.device(f"cuda:{LOCAL_RANK}") # FP8 Configuration - fp8_format = Format.HYBRID - fp8_recipe = get_recipe_from_string(args.recipe, fp8_format) + fp8_recipe = get_recipe_from_string(args.recipe) build_model_context_args = {} if not args.fp8_init: @@ -292,13 +288,13 @@ def _train(args): build_model_context_args["enabled"] = True build_model_context_args["recipe"] = fp8_recipe - dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device)/1e6} MB") + dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device) / 1e6} MB") # Create the model on the meta/cuda device as per args with build_model_context(**build_model_context_args): model, inp_shape, out_shape = init_te_model(args) dist_print( f"Memory after model init on device {args.device}:" - f" {torch.cuda.memory_allocated(device)/1e6} MB" + f" {torch.cuda.memory_allocated(device) / 1e6} MB" ) # Creating a DeviceMesh for fully_shard @@ -319,7 +315,7 @@ def _train(args): dist_print(f" Sharded parameters materialized and initialized on cuda device.") dist_print( - f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB" + f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device) / 1e6} MB" ) optimizer = optim.Adam(model.parameters(), lr=1e-3) @@ -327,11 +323,20 @@ def _train(args): for iteration in range(args.iter): # Zero the parameter gradients optimizer.zero_grad() - input_data = torch.randn(inp_shape).to(device) - with te.autocast(enabled=True, recipe=fp8_recipe): - output = model(input_data) - target = torch.randn(out_shape).to(device) - loss = F.mse_loss(output, target) + + input_data = torch.randn(inp_shape, device=device) + target = torch.randn(out_shape, device=device) + + # NVFP4BlockScaling requires bfloat16 inputs in both the forward and backward passes. + with ( + torch.autocast(device_type="cuda", dtype=torch.bfloat16) + if args.recipe == "NVFP4BlockScaling" + else nullcontext() + ): + with te.autocast(enabled=True, recipe=fp8_recipe): + output = model(input_data) + loss = F.mse_loss(output, target) + loss.backward() optimizer.step() dist_print(f"Iteration {iteration} completed with loss {loss.item()}") diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index e328e57758..a60ab4fafa 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -3,18 +3,47 @@ # See LICENSE for license information. import os -import pytest import subprocess from pathlib import Path -import transformer_engine.pytorch as te +import pytest import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch import fp8 -fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) -mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) NUM_PROCS: int = torch.cuda.device_count() +# Each entry: (recipe_class_name, hydra_overrides, check_fn) +_FP8_RECIPE_CONFIGS = [ + ("DelayedScaling", fp8.check_fp8_support), + ("Float8CurrentScaling", fp8.check_fp8_support), + ("Float8BlockScaling", fp8.check_fp8_block_scaling_support), + ("MXFP8BlockScaling", fp8.check_mxfp8_support), + ("NVFP4BlockScaling", fp8.check_nvfp4_support), +] + + +def _parametrize_fp8_recipes(): + """Generate pytest.param objects with xfail marks for unsupported FP8 recipes.""" + params = [] + for name, check_fn in _FP8_RECIPE_CONFIGS: + supported, reason = check_fn() + params.append( + pytest.param( + name, + id=name, + marks=pytest.mark.xfail(condition=not supported, reason=reason), + ) + ) + return params + + +@pytest.fixture(params=_parametrize_fp8_recipes()) +def fp_recipe(request): + """Parametrized fixture providing FP8 recipe Hydra overrides for each supported TE recipe.""" + return request.param + def _run_test(fp_init, sharding_dims, recipe, layer_type): test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" @@ -32,7 +61,7 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type): test_cmd += ["--recipe", recipe] test_cmd += ["--layer-type", layer_type] - result = subprocess.run(test_cmd, env=os.environ, check=True) + subprocess.run(test_cmd, env=os.environ, check=True) @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") @@ -40,20 +69,122 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type): @pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("fp8_init", (False, True)) -@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling")) @pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer")) -def test_distributed(fp8_init, sharding_dims, recipe, layer_type): - - # Skip invalid configurations - if torch.cuda.device_count() < 4: - pytest.skip("FSDP2 test requires at least 4 GPUs") - - if recipe == "mx_fp8_block_scaling" and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - elif not fp8_available: - pytest.skip(reason_for_no_fp8) - - _run_test(fp8_init, sharding_dims, recipe, layer_type) +def test_distributed(fp8_init, sharding_dims, fp_recipe, layer_type): + + if fp_recipe in ("Float8BlockScaling", "NVFP4BlockScaling") and fp8_init: + pytest.xfail(f"{fp_recipe} + fp8_init: test_fp8_fsdp2_allgather is currently failing.") + + _run_test(fp8_init, sharding_dims, fp_recipe, layer_type) + + +## ── FusedAdam + FSDP2 tests ───────────────────────────────────────── + + +def _run_fused_adam_test(test_name, recipe="delayed_scaling"): + """Launch an FSDP2 + FusedAdam test via torchrun.""" + test_path = Path(__file__).parent.resolve() / "run_fsdp2_fused_adam.py" + nproc = min(NUM_PROCS, 2) # These tests only need 2 GPUs + test_cmd = [ + "torchrun", + f"--nproc_per_node={nproc}", + str(test_path), + "--test", + test_name, + "--recipe", + recipe, + ] + + subprocess.run(test_cmd, env=os.environ, check=True) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_fused_adam_fp8_master_weights(fp_recipe): + """FusedAdam(master_weights=True) + FSDP2 + quantized_model_init.""" + if fp_recipe in ("Float8BlockScaling", "MXFP8BlockScaling", "NVFP4BlockScaling"): + pytest.xfail( + f"{fp_recipe}: quantized_model_init and FSDP2 is not currently supported, since the " + "block tensor is dequantized before we flatten it for FSDP2." + ) + _run_fused_adam_test("fused_adam_fp8_master_weights", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_fused_adam_bf16(fp_recipe): + """FusedAdam(master_weights=True) + FSDP2 + bf16 params (no FP8).""" + _run_fused_adam_test("fused_adam_bf16", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_fused_adam_fp8_no_master(fp_recipe): + """FusedAdam(master_weights=False) + FSDP2 + FP8 params.""" + if fp_recipe == "MXFP8BlockScaling": + pytest.xfail( + "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " + "MXFP8 quantized tensors, causing illegal memory access" + ) + _run_fused_adam_test("fused_adam_fp8_no_master", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_fused_adam_bf16_store_param_remainders(fp_recipe): + """FusedAdam(master_weights=True, store_param_remainders=True) + FSDP2 + bf16.""" + _run_fused_adam_test("fused_adam_bf16_store_param_remainders", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_dcp_output_parity(fp_recipe): + """DCP save/load round-trip into a fresh model produces identical outputs.""" + if fp_recipe == "MXFP8BlockScaling": + pytest.xfail( + "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " + "MXFP8 quantized tensors, causing illegal memory access" + ) + _run_fused_adam_test("dcp_output_parity", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_dcp_output_parity_async(fp_recipe): + """DCP save/load round-trip into a fresh model produces identical outputs.""" + if fp_recipe in ("DelayedScaling", "Float8CurrentScaling"): + pytest.xfail( + f"async DCP save/load with {fp_recipe} produces different outputs: " + "the async staging may capture stale tensor state for FP8 scaling " + "factors, causing numerical divergence after reload" + ) + if fp_recipe == "MXFP8BlockScaling": + pytest.xfail( + "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " + "MXFP8 quantized tensors, causing illegal memory access" + ) + _run_fused_adam_test("dcp_output_parity_async", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_safetensors_fp32_export(fp_recipe): + """Export FP32 model from optimizer master weights to safetensors.""" + if fp_recipe == "MXFP8BlockScaling": + pytest.xfail( + "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " + "MXFP8 quantized tensors, causing illegal memory access" + ) + _run_fused_adam_test("safetensors_fp32_export", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +@pytest.mark.xfail( + reason=( + "fuse_wgrad_accumulation is incompatible with vanilla FSDP2: " + "autograd Function.apply unwraps DTensors to local tensors, so " + "main_grad (set on the DTensor) is inaccessible during backward. " + "Additionally, the fused wgrad GEMM bypasses FSDP2's reduce-scatter." + ), + raises=subprocess.CalledProcessError, + strict=True, +) +def test_fsdp2_fuse_wgrad_accumulation(fp_recipe): + """fuse_wgrad_accumulation=True + FSDP2 -- expected to fail.""" + _run_fused_adam_test("fuse_wgrad_accumulation", fp_recipe) def test_dummy() -> None: @@ -63,3 +194,10 @@ def test_dummy() -> None: """ pass + + +""" +TODO: + - async DCP tests + +""" diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py index 5bdc537e4b..8580cf4a33 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py @@ -218,6 +218,12 @@ def __init__( self.with_amax_reduction = False self.amax_reduction_group = None + def __getstate__(self): + """Exclude unpicklable process group from serialized state.""" + state = self.__dict__.copy() + state["amax_reduction_group"] = None + return state + @property def custom(self) -> bool: """Flag to indicate this quantizer is custom.""" diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index a87d968334..56fe63cdd6 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -394,8 +394,14 @@ def _initialize_state( store_param_remainders (bool): Store only trailing remainder bits. """ dtype = self.name_to_dtype_map[state_name] + # Extract local tensor from DTensor (e.g. from FSDP2) to avoid + # QuantizedTensor.__torch_dispatch__ ignoring the dtype kwarg in + # torch.empty_like, and to ensure optimizer states are plain tensors. + local_param = param._local_tensor if isinstance(param, DTensor) else param # Handle QuantizedTensor by dequantizing first - param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param + param_for_empty = ( + local_param.dequantize() if isinstance(local_param, QuantizedTensor) else local_param + ) if store_param_remainders: data = torch.zeros_like(param_for_empty, dtype=torch.int16) else: @@ -440,7 +446,14 @@ def initialize_state(self, param, store_param_remainders): store_param_remainders=store_param_remainders, ) if not store_param_remainders: - self.set_scaled_state(param, "master_param", param.clone().detach().float()) + # Extract local tensor from DTensor and dequantize QuantizedTensor + # to get a plain float32 copy for the master weight. + local_param = param._local_tensor if isinstance(param, DTensor) else param + if isinstance(local_param, QuantizedTensor): + master = local_param.dequantize(dtype=torch.float32).clone().detach() + else: + master = local_param.clone().detach().float() + self.set_scaled_state(param, "master_param", master) def state_dict(self): """Override the state_dict() of pytorch. Before returning the state_dict, cast all @@ -575,6 +588,10 @@ def step(self, closure=None, grad_scaler=None): if p_grad is None: continue + # Extract local tensors from DTensors (e.g. from FSDP2) + # so that multi_tensor kernels receive plain CUDA tensors. + if isinstance(p_grad, DTensor): + p_grad = p_grad._local_tensor if p_grad.data.is_sparse: raise RuntimeError("FusedAdam does not support sparse gradients.") @@ -594,10 +611,10 @@ def step(self, closure=None, grad_scaler=None): unscaled_lists[name].append(unscaled) scaled_lists[name].append(state[name]) state_scales[name].append(self._scales[p][name]) - if isinstance(p, Float8Tensor) or ( - isinstance(p, DTensor) and isinstance(p._local_tensor, Float8Tensor) - ): - p = p._local_tensor if isinstance(p, DTensor) else p + # Extract local tensor from DTensor param for multi_tensor kernels + if isinstance(p, DTensor): + p = p._local_tensor + if isinstance(p, Float8Tensor): out_dtype = p._fp8_dtype p_fp8_model.append(p._data.data) scale, amax, scale_inv = get_fp8_meta(p) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ecafb6ddfc..b473be4621 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -6,6 +6,7 @@ from __future__ import annotations from collections.abc import Iterable import math +import warnings from typing import Any, Optional, Tuple, Union import torch @@ -604,19 +605,27 @@ def forward( if tensor._is_2D_scaled: # For the case of 2D scaled tensor, the last 2 dimensions should not change if shape[-1] != ctx.shape[-1] or shape[-2] != ctx.shape[-2]: - raise RuntimeError( + warnings.warn( "2D scaled Float8BlockwiseQTensor does not support view " "the last 2 dimensions " - f"(attempted to view dims={tuple(tensor.shape)} to {tuple(shape)})" + f"(attempted to view dims={tuple(tensor.shape)} to {tuple(shape)}). " + "If you are using this for FSDP2 without compiled_autograd_enabled, " + "then ignore this warning since this view is not going to be used anywhere.", + stacklevel=2, ) + return tensor.dequantize().view(*shape) else: # For the case of 1D scaled tensor, the last dimension should not change if shape[-1] != ctx.shape[-1]: - raise RuntimeError( + warnings.warn( "1D scaled Float8BlockwiseQTensor does not support view " "the last dimension " - f"(attempted to view dims={tuple(tensor.shape)} to {tuple(shape)})" + f"(attempted to view dims={tuple(tensor.shape)} to {tuple(shape)}). " + "If you are using this for FSDP2 without compiled_autograd_enabled, " + "then ignore this warning since this view is not going to be used anywhere.", + stacklevel=2, ) + return tensor.dequantize().view(*shape) if list(shape) == list(tensor.shape): return tensor @@ -711,19 +720,27 @@ def forward( if tensor._is_2D_scaled: # For the case of 2D scaled tensor, the last 2 dimensions should not change if shape[-1] != ctx.shape[-1] or shape[-2] != ctx.shape[-2]: - raise RuntimeError( + warnings.warn( "2D scaled Float8BlockwiseQTensor does not support reshaping " "the last 2 dimensions " - f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)}). " + "If you are using this for FSDP2 without compiled_autograd_enabled, " + "then ignore this warning since this view is not going to be used anywhere.", + stacklevel=2, ) + return tensor.dequantize().reshape(*shape) else: # For the case of 1D scaled tensor, the last dimension should not change if shape[-1] != ctx.shape[-1]: - raise RuntimeError( + warnings.warn( "1D scaled Float8BlockwiseQTensor does not support reshaping " "the last dimension " - f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)}). " + "If you are using this for FSDP2 without compiled_autograd_enabled, " + "then ignore this warning since this view is not going to be used anywhere.", + stacklevel=2, ) + return tensor.dequantize().reshape(*shape) if list(shape) == list(tensor.shape): return tensor diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 55bca49af3..500fadf2b0 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -284,6 +284,12 @@ def __init__( self.force_pow_2_scales = force_pow_2_scales self.amax_epsilon = amax_epsilon + def __getstate__(self): + """Exclude unpicklable process group from serialized state.""" + state = self.__dict__.copy() + state["amax_reduction_group"] = None + return state + def copy(self) -> Float8CurrentScalingQuantizer: """Create shallow copy""" diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 41d6c87f2b..ea972dcbe7 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -365,6 +365,21 @@ def reshape(self, *shape: Tuple[int]) -> MXFP8Tensor: # pylint: disable=missing-function-docstring return _ReshapeFunc.apply(self, shape) + def untyped_storage(self) -> torch.UntypedStorage: + """Return the underlying UntypedStorage of the FP8 data. + + Note that MXFP8 tensor may involve multiple buffers: row-wise + FP8 data, row-wise scales, column-wise FP8 data, column-wise + scales. The UntypedStorage of the row-wise FP8 data is returned + if it exists, and otherwise the UntypedStorage of the + column-wise FP8 data. + + """ + data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data + if data is not None: + return data.untyped_storage() + return torch.UntypedStorage(0, device=self.device) + def contiguous( self, memory_format: torch.memory_format = torch.contiguous_format, diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 66f986a900..fa0715e7fe 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -6,6 +6,7 @@ from __future__ import annotations from collections.abc import Iterable import math +import warnings from typing import Dict, Optional, Tuple, Union import functools @@ -157,6 +158,12 @@ def __init__( ) self.rht_matrix = get_rht_matrix(with_random_sign_mask, torch.cuda.current_device()) + def __getstate__(self): + """Exclude unpicklable process group from serialized state.""" + state = self.__dict__.copy() + state["amax_reduction_group"] = None + return state + def update_quantized( self, src: torch.Tensor, @@ -511,6 +518,19 @@ def clone(self) -> NVFP4Tensor: }, ) + def untyped_storage(self) -> torch.UntypedStorage: + """Return the underlying UntypedStorage of the FP4 data. + + The UntypedStorage of the row-wise data is returned if it + exists, and otherwise the UntypedStorage of the column-wise + data. + + """ + data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data + if data is not None: + return data.untyped_storage() + return torch.UntypedStorage(0, device=self.device) + def view(self, *shape: Tuple[int]) -> NVFP4Tensor: # pylint: disable=missing-function-docstring return _ViewFunc.apply(self, shape) @@ -755,10 +775,14 @@ def forward( shape[i] = d_inferred break if shape[-1] != cur_shape[-1]: - raise RuntimeError( + warnings.warn( "NVFP4Tensor does not support reshaping inner dimension " - f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)}). " + "If you are using this for FSDP2 without compiled_autograd_enabled, " + "then ignore this warning since this view is not going to be used anywhere.", + stacklevel=2, ) + return tensor.dequantize().view(*shape) # Reshape data new_rowwise_data = None @@ -877,10 +901,14 @@ def forward( shape[i] = d_inferred break if shape[-1] != cur_shape[-1]: - raise RuntimeError( + warnings.warn( "NVFP4Tensor does not support reshaping inner dimension " - f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)}). " + "If you are using this for FSDP2 without compiled_autograd_enabled, " + "then ignore this warning since this view is not going to be used anywhere.", + stacklevel=2, ) + return tensor.dequantize().reshape(*shape) # Reshape data new_rowwise_data = None