Skip to content

Add fused_adam, quantized_model_init, and fsdp2 example#2698

Open
pstjohn wants to merge 2 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fsdp2-fused-adam
Open

Add fused_adam, quantized_model_init, and fsdp2 example#2698
pstjohn wants to merge 2 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fsdp2-fused-adam

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Feb 22, 2026

Summary

  • Fix FusedAdam to work with PyTorch-native FSDP2 (fully_shard) when parameters are DTensor-wrapped Float8Tensor/QuantizedTensor
  • Fix fuse_wgrad_accumulation guard to avoid crashing with vanilla FSDP2 (previously assumed Megatron-Core FSDP exclusively)
  • Add examples for quantized_model_init on single-GPU (main.py) and multi-GPU FSDP2 (fully_shard.py)

Note: fuse_wgrad_accumulation remains incompatible with vanilla FSDP2

fuse_wgrad_accumulation still cannot be used with vanilla FSDP2. The feature writes weight gradients directly into main_grad and returns None to autograd, bypassing FSDP2's reduce-scatter. Each rank ends up with an unreduced gradient. Megatron-Core FSDP solves this by wiring get_main_grad() into its own reduce-scatter infrastructure. Vanilla FSDP2 does not yet expose an equivalent hook.

Fixes #2682

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch 2 times, most recently from 22604c4 to 4d89e04 Compare February 23, 2026 15:28
@pstjohn pstjohn marked this pull request as ready for review February 23, 2026 17:27
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 23, 2026

Greptile Summary

This PR enables FusedAdam optimizer to work correctly with PyTorch-native FSDP2 (fully_shard) when parameters are DTensor-wrapped Float8Tensor/QuantizedTensor objects. The core fix extracts local tensors from DTensor wrappers before passing them to multi-tensor kernels and state initialization routines.

Key changes:

  • Modified FusedAdam to extract _local_tensor from DTensor params/gradients before operations
  • Added comprehensive test suite for FSDP2 + FusedAdam integration across multiple FP8 recipes
  • Includes tests for checkpoint save/load, safetensors export, and store_param_remainders mode
  • Added well-documented examples demonstrating single-GPU and multi-GPU FSDP2 usage with quantized_model_init
  • Correctly marks fuse_wgrad_accumulation + FSDP2 test as xfail (incompatible without Megatron-Core hooks)

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The changes are well-isolated to DTensor handling in FusedAdam, thoroughly tested with 8 new test cases covering edge cases, and include comprehensive documentation. The xfail test properly documents the known limitation with fuse_wgrad_accumulation.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/optimizers/fused_adam.py Adds DTensor support for FSDP2 by extracting local tensors before operations, ensuring optimizer states are plain CUDA tensors
tests/pytorch/distributed/run_fsdp2_fused_adam.py Comprehensive test suite covering FusedAdam with FSDP2, FP8 params, checkpoint save/load, and expected failure for fuse_wgrad_accumulation
examples/pytorch/quantized_model_init/fully_shard.py Multi-GPU FSDP2 example showing meta-device init, sharding, FP8 quantization, and checkpoint/safetensors export

Sequence Diagram

sequenceDiagram
    participant User
    participant Model as TransformerLayer<br/>(FP8 params via quantized_model_init)
    participant FSDP2 as fully_shard<br/>(DTensor wrapping)
    participant Optimizer as FusedAdam<br/>(FP32 master weights)
    participant Kernel as Multi-tensor kernels

    User->>Model: Build model in quantized_model_init context
    Model-->>User: Float8Tensor/QuantizedTensor params
    User->>FSDP2: Apply fully_shard(model)
    FSDP2-->>User: DTensor-wrapped params
    Note over FSDP2: DTensor._local_tensor = Float8Tensor
    User->>Optimizer: FusedAdam(params, master_weights=True)
    Optimizer->>Optimizer: Extract local_tensor from DTensor
    Optimizer->>Optimizer: Dequantize QuantizedTensor
    Optimizer-->>User: FP32 master_param states initialized
    User->>Model: Forward + backward pass
    Model-->>Optimizer: Gradients (DTensor-wrapped)
    Optimizer->>Optimizer: Extract p_grad._local_tensor
    Optimizer->>Optimizer: Extract p._local_tensor for FP8 params
    Optimizer->>Kernel: Call multi_tensor_adam with plain CUDA tensors
    Kernel-->>Optimizer: Updated FP32 master weights
    Optimizer->>Optimizer: Quantize and update FP8 params
Loading

Last reviewed commit: 96c123e

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Member

@cspades cspades left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, clean edits.

# to get a plain float32 copy for the master weight.
local_param = param._local_tensor if isinstance(param, DTensor) else param
if isinstance(local_param, QuantizedTensor):
master = local_param.dequantize().clone().detach().float()
Copy link
Member

@cspades cspades Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use dequantize(dtype=torch.float32), to fuse the cast into the de-quantization's output buffer? (Likely not a big deal since I don't think this will change anything numerically, and you only call this function during init and whenever you save and load DCP checkpoints.)

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from 0103b53 to 3c3dbd2 Compare February 24, 2026 20:06
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@XueSongTap
Copy link

@pstjohn Hi, thanks for the great work! Does this PR plan to also handle the BF16 path? I noticed the BF16 branch still operates on the original p/p_grad without unwrapping when they're DTensors. In my experiments with FSDP2 + BF16, I'm seeing non-trivial overhead during the optimizer step from repeated DTensor dispatch. Curious if that's intentional or a planned follow-up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Example of quantized_model_init for low-precision compute weights and fp32 main weights with fsdp2

3 participants