[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677
Conversation
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR simplifies how TE interacts with cuDNN's softmax statistics outputs for fused attention. Instead of conditionally requesting either
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["cuDNN SDPA Forward\n(generate_stats=true)"] --> B["Stats tensor\n(always returned)"]
A -->|"return_max_logit=true"| C["Max tensor\n(conditionally returned)"]
B --> D["Aux_CTX_Tensors[0] = Stats"]
C --> E["Aux_CTX_Tensors[1] = Max"]
D --> F["attention.cpp\nAllocates output_tensors"]
E --> F
F -->|"return_max_logit=true"| G["output_tensors:\n[O, Stats, Max, rng_state, ...]"]
F -->|"return_max_logit=false"| H["output_tensors:\n[O, Stats, rng_state, ...]"]
G --> I["fused_attn.py\nmax_logit = amax(output_tensors[2])\naux_ctx = [Stats, rng_state, ...]"]
H --> J["fused_attn.py\naux_ctx = [Stats, rng_state, ...]"]
I --> K["Backward Pass\naux_ctx[0] = Stats\naux_ctx[1] = rng_state"]
J --> K
Last reviewed commit: 1102738 |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
Additional Comments (1)
The public docstring still describes |
| stats = output_tensors[1] + torch.log(output_tensors[2]) | ||
| # thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1] | ||
| # bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] | ||
| # sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here) |
There was a problem hiding this comment.
Do we need the "there's no typo here" :)
There was a problem hiding this comment.
I deliberately added it because I didn't believe it and checked the shapes myself :P
| size_t i = 0; | ||
| if (Aux_CTX_Tensors->size == 0) { | ||
| const auto cudnn_runtime_version = cudnnGetVersion(); | ||
|
|
There was a problem hiding this comment.
You might need to make these changes in the "Aux_CTX_Tensors->size == 0" sections in _fwd/bwd_qkvpacked/kvpacked APIs as well. Please check. Thanks!
There was a problem hiding this comment.
Looks like I don't need to because nvte_fused...qvpacked are in fused_attn.cpp which calls fused_attn_f16_arbitrary... just like regular nvte_fused_fwd/bwd
transformer_engine/common/include/transformer_engine/fused_attn.h
Outdated
Show resolved
Hide resolved
| # Max -> max_logit [h] | ||
| max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype) | ||
| aux_ctx_tensors = [stats] | ||
| max_logit = torch.amax(output_tensors[2], dim=amax_dims).to(dtype=output_tensors[0].dtype) |
There was a problem hiding this comment.
Maybe I understood this incorrectly, but isn't TE now also suppose to receive max from cuDNN directly (like stats, but with stats it is always true and with max it cn be toggled) rather than calling amax() in TE ?
(Sudhakar: Why am I able to update your comment? )
There was a problem hiding this comment.
cuDNN returns Max ([b, h, sq, 1]) so it's an additional softmax statistic (apparently, the subset (Stats, Max) is enough for cuDNN bwd rather than the full set (Stats, SumExp, Max)).
Further, for muon, we need do amax on it to get a dimension [h] tensor. return_max_logit in TE controls whether to fetch Max from cuDNN.
Perf wise, it'd be nice for cuDNN to do additional reduction to return the [h] shaped tensor for muon as well but that's not the scope of this PR.
(Kshitij: looks like I can as well)
…eturn_stats_max_cudnn
…27/TransformerEngine into fix_return_stats_max_cudnn
21ca43a to
becc3ad
Compare
Additional Comments (1)
|
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
d4568db to
8f40cab
Compare
…eturn_stats_max_cudnn
Description
cuDNN recently made returning any subset of {Stats, SumExp, Max} possible. This PR adapts TE to always get
Statsfrom cuDNN andMaxtensor ifreturn_max_logit=True. (Note thatStats= log(SumExp)+Max)Type of change
Changes
Please list the changes introduced in this PR:
fused_attn_f16_arbitrary_seqlen.cuSumExptensor as it's not needed since cuDNN returnsStatsby default.generate_stats=Truewhich forces cuDNN to always returnStatstensor (needed in the backward pass)transformer_engine/pytorch/cpp_extensions/fused_attn.pyStats = log(SumExp) + Maxsince cuDNN returnsStatsdirectly and TE doesn't needSumExpfrom cuDNNChecklist: