Fix Flash Attention 3 API compatibility for window size parameters#2704
Fix Flash Attention 3 API compatibility for window size parameters#2704jhvmhg wants to merge 7 commits intoNVIDIA:mainfrom
Conversation
Replace single window_size parameter with window_size_left and window_size_right in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes. - Update function signature in flash_attn_interface - Maintain backward compatibility where possible - Ensure consistency with Flash Attention v2 implementation Signed-off-by: Chaoyang Mei <1192554423@qq.com> Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Greptile SummaryUpdates Flash Attention API compatibility by replacing single Key changes:
Confidence Score: 5/5
Important Files Changed
Last reviewed commit: de8483f |
Additional Comments (1)
|
Rename causal parameter to is_causal in flash_attn_bwd function to align with flash-attn v2.7.0+ API changes. This ensures consistency with the updated flash-attn library interface for backward pass operations. Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
a245229 to
f9752ca
Compare
jhvmhg
left a comment
There was a problem hiding this comment.
Fix Flash Attention 3 backward API parameter naming
Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.
Rename causal parameter to is_causal in flash_attn_bwd function to align with flash-attn v2.7.0+ API changes. This ensures consistency with the updated flash-attn library interface for backward pass operations. Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
for more information, see https://pre-commit.ci
… fix/flash_attn3_support_CP
Replace keyword arguments with positional arguments in flash_attn_fwd and flash_attn_bwd to abstract away parameter naming differences (causal vs is_causal) between flash-attn versions. This provides a more robust interface that is resilient to future API changes in the flash-attn library. - Convert window_size_left, window_size_right, and causal parameters to positional args in both forward and backward functions - Eliminate version-specific parameter naming dependencies - Simplify compatibility handling across flash-attn v2.7.0+ variants Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
| softmax_lse_per_step[i], | ||
| *fa_backward_args_thd, | ||
| causal="causal" in ctx.attn_mask_type, | ||
| ctx.attn_mask_type, |
There was a problem hiding this comment.
ctx.attn_mask_type is a string (e.g., "causal", "no_mask"), not a boolean. Should be "causal" in ctx.attn_mask_type to convert to boolean.
| ctx.attn_mask_type, | |
| "causal" in ctx.attn_mask_type, |
Rename causal parameter to is_causal in flash_attn_bwd function to align with flash-attn v3 API changes. This ensures consistency with the updated flash-attn library interface for backward pass operations. Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Replace single window_size parameter with window_size_left and window_size_right in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Motivation:
The flash-attn library v2.7.0+ introduced breaking API changes that cause compatibility issues with TransformerEngine's Flash Attention 3 integration. These updates ensure seamless operation with newer versions of the flash-attn library while maintaining correctness of both forward and backward attention computations.
Related API Changes:
flash-attn v2.7.0+ split window_size into window_size_left and window_size_right
flash-attn v3+ renamed causal parameter to is_causal in backward pass
Type of change
Changes
Please list the changes introduced in this PR:
in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.
with flash-attn v3
Checklist: