Skip to content

Comments

Fix Flash Attention 3 API compatibility for window size parameters#2704

Open
jhvmhg wants to merge 7 commits intoNVIDIA:mainfrom
jhvmhg:fix/flash_attn3_support_CP
Open

Fix Flash Attention 3 API compatibility for window size parameters#2704
jhvmhg wants to merge 7 commits intoNVIDIA:mainfrom
jhvmhg:fix/flash_attn3_support_CP

Conversation

@jhvmhg
Copy link

@jhvmhg jhvmhg commented Feb 25, 2026

Replace single window_size parameter with window_size_left and window_size_right in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.

  • Update function signature in flash_attn_interface
  • Maintain backward compatibility where possible
  • Ensure consistency with Flash Attention v2 implementation

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

  1. Fix window size parameters in flash_attn_fwd - Replaces the single window_size parameter with separate window_size_left and window_size_right parameters to match the updated flash-attn v2.7.0+ API.
  2. Fix causal parameter naming in flash_attn_bwd - Renames causal to is_causal in the backward function signature for consistency with the latest flash-attn interface.

Motivation:

The flash-attn library v2.7.0+ introduced breaking API changes that cause compatibility issues with TransformerEngine's Flash Attention 3 integration. These updates ensure seamless operation with newer versions of the flash-attn library while maintaining correctness of both forward and backward attention computations.

Related API Changes:

flash-attn v2.7.0+ split window_size into window_size_left and window_size_right
flash-attn v3+ renamed causal parameter to is_causal in backward pass

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Replace single window_size parameter with window_size_left and window_size_right
    in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.
  • Rename causal parameter to is_causal in flash_attn_bwd function to align
    with flash-attn v3

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Replace single window_size parameter with window_size_left and window_size_right
in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.

- Update function signature in flash_attn_interface
- Maintain backward compatibility where possible
- Ensure consistency with Flash Attention v2 implementation

Signed-off-by: Chaoyang Mei <1192554423@qq.com>
Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 25, 2026

Greptile Summary

Updates Flash Attention API compatibility by replacing single window_size parameter with window_size_left and window_size_right for FA v2.7.0+ and FA3, and renaming causal to is_causal for FA3's backward pass.

Key changes:

  • Refactors conditional logic so FA3 and v2.7.0+ both use the new split window size parameters
  • Adds conditional is_causal vs causal parameter naming based on FA3 vs FA2
  • Maintains backward compatibility with FA v2.3+ to v2.6.x using single window_size tuple
  • Correctly uses "causal" in ctx.attn_mask_type to extract boolean from string mask type

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The changes are straightforward API compatibility updates that align with flash-attn library version updates. The conditional logic correctly separates v2.3-v2.6 (old API) from v2.7.0+ and FA3 (new API), and the parameter naming is handled appropriately for each version
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Updates Flash Attention 3 and v2.7.0+ API calls to use window_size_left/window_size_right parameters and is_causal parameter naming

Last reviewed commit: de8483f

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 25, 2026

Additional Comments (1)

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
removed causal parameter but other flash_attn_bwd calls in this file (lines 3222, 3832) still pass it - verify this inconsistency is intentional

Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@jhvmhg jhvmhg force-pushed the fix/flash_attn3_support_CP branch from a245229 to f9752ca Compare February 25, 2026 07:54
Copy link
Author

@jhvmhg jhvmhg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix Flash Attention 3 backward API parameter naming

Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

jhvmhg and others added 2 commits February 25, 2026 15:56
Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Replace keyword arguments with positional arguments in flash_attn_fwd and
flash_attn_bwd to abstract away parameter naming differences (causal vs
is_causal) between flash-attn versions. This provides a more robust
interface that is resilient to future API changes in the flash-attn library.

- Convert window_size_left, window_size_right, and causal parameters to
  positional args in both forward and backward functions
- Eliminate version-specific parameter naming dependencies
- Simplify compatibility handling across flash-attn v2.7.0+ variants

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

softmax_lse_per_step[i],
*fa_backward_args_thd,
causal="causal" in ctx.attn_mask_type,
ctx.attn_mask_type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.attn_mask_type is a string (e.g., "causal", "no_mask"), not a boolean. Should be "causal" in ctx.attn_mask_type to convert to boolean.

Suggested change
ctx.attn_mask_type,
"causal" in ctx.attn_mask_type,

@jhvmhg jhvmhg closed this Feb 25, 2026
Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v3 API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@jhvmhg jhvmhg reopened this Feb 25, 2026
@jhvmhg jhvmhg closed this Feb 25, 2026
@jhvmhg jhvmhg reopened this Feb 25, 2026
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant