Skip to content

[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677

Open
sudhakarsingh27 wants to merge 16 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn
Open

[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677
sudhakarsingh27 wants to merge 16 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

cuDNN recently made returning any subset of {Stats, SumExp, Max} possible. This PR adapts TE to always get Stats from cuDNN and Max tensor if return_max_logit=True. (Note that Stats = log(SumExp)+Max)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • fused_attn_f16_arbitrary_seqlen.cu
    • Removed references to SumExp tensor as it's not needed since cuDNN returns Stats by default.
    • set generate_stats=True which forces cuDNN to always return Stats tensor (needed in the backward pass)
  • transformer_engine/pytorch/cpp_extensions/fused_attn.py
    • Remove code that manually did Stats = log(SumExp) + Max since cuDNN returns Stats directly and TE doesn't need SumExp from cuDNN
  • Corresponding documentation

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

sudhakarsingh27 and others added 5 commits February 12, 2026 13:12
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 12, 2026

Greptile Summary

This PR simplifies how TE interacts with cuDNN's softmax statistics outputs for fused attention. Instead of conditionally requesting either Stats or {Max, SumExp} based on return_max_logit, the code now always requests Stats from cuDNN (by setting generate_stats=true) and additionally requests Max when return_max_logit=true. This leverages cuDNN's recent ability to return any subset of {Stats, SumExp, Max}.

  • Removes all SumExp tensor handling from the forward pass graph construction and tensor packing
  • Eliminates the Python-side manual computation Stats = log(SumExp) + Max, since cuDNN now provides Stats directly
  • Changes tensor output order to [Stats, (optional Max), rng_state, ...] from the previous [Max, SumExp, rng_state, ...] when return_max_logit=true
  • Renames generate_max_sum_exp to return_max_logit in FADescriptor_v1 cache key struct
  • Updates documentation in fused_attn.h across all 4 API functions

Confidence Score: 4/5

  • This PR is a clean simplification that removes unnecessary tensor handling and leverages cuDNN's improved API. The tensor ordering is consistent across all layers.
  • Score of 4 reflects that the changes are logically sound with consistent tensor ordering across CUDA kernel, C++ extension, and Python layers. The simplification removes complexity (SumExp handling, manual Stats computation) rather than adding it. One point deducted because there are no new tests added to cover the changed tensor ordering, and the interaction with cuDNN's subset output feature is not easily verifiable without runtime testing.
  • Pay close attention to fused_attn_f16_arbitrary_seqlen.cu — it contains the core tensor reordering logic across both allocation (size==0) and extraction (size>=2) paths that must stay in sync.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Core change: always sets generate_stats=true so cuDNN returns Stats directly, removes SumExp tensor handling, and reorders tensor packing to [Stats, (optional Max), rng_state, ...]. Tensor ordering is consistent across allocation and extraction paths.
transformer_engine/common/fused_attn/utils.h Renames generate_max_sum_exp field to return_max_logit in FADescriptor_v1 struct and its comparison operator. Straightforward rename with no behavioral change.
transformer_engine/common/include/transformer_engine/fused_attn.h Updates documentation for return_max_logit parameter in 4 locations to reflect new behavior: "Whether to produce Max along with Stats" instead of the old "Whether to produce Max and Sum_Exp, or Stats".
transformer_engine/pytorch/cpp_extensions/fused_attn.py Removes manual Stats = log(SumExp) + Max computation since cuDNN now returns Stats directly. Updates tensor indexing to match new output order [O, Stats, Max, rng_state, ...] when return_max_logit=True.
transformer_engine/pytorch/csrc/extensions/attention.cpp Updates comments and tensor layout documentation to reflect new output order. The allocation logic correctly handles both paths: S always first, then Max (conditional), then rng_state.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["cuDNN SDPA Forward\n(generate_stats=true)"] --> B["Stats tensor\n(always returned)"]
    A -->|"return_max_logit=true"| C["Max tensor\n(conditionally returned)"]
    
    B --> D["Aux_CTX_Tensors[0] = Stats"]
    C --> E["Aux_CTX_Tensors[1] = Max"]
    
    D --> F["attention.cpp\nAllocates output_tensors"]
    E --> F
    
    F -->|"return_max_logit=true"| G["output_tensors:\n[O, Stats, Max, rng_state, ...]"]
    F -->|"return_max_logit=false"| H["output_tensors:\n[O, Stats, rng_state, ...]"]
    
    G --> I["fused_attn.py\nmax_logit = amax(output_tensors[2])\naux_ctx = [Stats, rng_state, ...]"]
    H --> J["fused_attn.py\naux_ctx = [Stats, rng_state, ...]"]
    
    I --> K["Backward Pass\naux_ctx[0] = Stats\naux_ctx[1] = rng_state"]
    J --> K
Loading

Last reviewed commit: 1102738

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 17, 2026

Additional Comments (1)

transformer_engine/pytorch/cpp_extensions/fused_attn.py
Stale docstring: wrong formula for softmaxStats

The public docstring still describes softmaxStats as log(sum(e^(x - max(x)))), which is log(SumExp). However, with this PR, the returned tensor is cuDNN's Stats = log(SumExp) + Max, not just log(SumExp). This formula was already incorrect before this PR (the old code computed Max + log(SumExp) and stored it as stats), but the PR is an opportunity to correct it.

                       softmaxStats: torch.Tensor
                           log(sum(e^(x - max(x)))) + max(x), where x=Q*K.T (i.e. Stats = log(SumExp) + Max)
                           shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

stats = output_tensors[1] + torch.log(output_tensors[2])
# thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the "there's no typo here" :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deliberately added it because I didn't believe it and checked the shapes myself :P

size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();

Copy link
Collaborator

@cyanguwa cyanguwa Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to make these changes in the "Aux_CTX_Tensors->size == 0" sections in _fwd/bwd_qkvpacked/kvpacked APIs as well. Please check. Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I don't need to because nvte_fused...qvpacked are in fused_attn.cpp which calls fused_attn_f16_arbitrary... just like regular nvte_fused_fwd/bwd

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

# Max -> max_logit [h]
max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype)
aux_ctx_tensors = [stats]
max_logit = torch.amax(output_tensors[2], dim=amax_dims).to(dtype=output_tensors[0].dtype)
Copy link
Collaborator

@KshitijLakhani KshitijLakhani Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I understood this incorrectly, but isn't TE now also suppose to receive max from cuDNN directly (like stats, but with stats it is always true and with max it cn be toggled) rather than calling amax() in TE ?

(Sudhakar: Why am I able to update your comment? )

Copy link
Collaborator Author

@sudhakarsingh27 sudhakarsingh27 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuDNN returns Max ([b, h, sq, 1]) so it's an additional softmax statistic (apparently, the subset (Stats, Max) is enough for cuDNN bwd rather than the full set (Stats, SumExp, Max)).

Further, for muon, we need do amax on it to get a dimension [h] tensor. return_max_logit in TE controls whether to fetch Max from cuDNN.

Perf wise, it'd be nice for cuDNN to do additional reduction to return the [h] shaped tensor for muon as well but that's not the scope of this PR.

(Kshitij: looks like I can as well)

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27 sudhakarsingh27 force-pushed the fix_return_stats_max_cudnn branch from 21ca43a to becc3ad Compare February 20, 2026 19:41
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 20, 2026

Additional Comments (1)

transformer_engine/common/include/transformer_engine/fused_attn.h
Entire file has been reformatted with unintentional 3-space indentation changes. This creates a large diff unrelated to the actual feature changes. Revert the formatting to match the original file structure.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the fix_return_stats_max_cudnn branch from d4568db to 8f40cab Compare February 20, 2026 20:00
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants