-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Use lower-right causal mask alignment consistently #2967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Use lower-right causal mask alignment consistently #2967
Conversation
Clarify that MLX uses lower-right alignment for causal masks when T_q != T_kv, which differs from PyTorch's default upper-left alignment. Relates to ml-explore#2835
zcbenz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think PyTorch has a causal_lower_right option for SDPA and the description is not really right.
|
Hey @zcbenz, it does have causal_lower_right since 2.3 and can be used with SDPA via the attn_mask parameter. I ran a script with: to verify. Here is the tutorial that documents this explicitly: https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html. I also verified masks are mathematically identical. For example with T_q=2, T_kv=4: The first two are identical; the third is different. This is also consistent with MLX's CUDA backend which uses cuDNN's set_causal_mask_bottom_right. Is there something specific about the description you think is incorrect? if your concern is that causal_lower_right isn't a direct SDPA parameter (like is_causal=True) but rather a separate utility class, I could clarify the wording to use the full module path torch.nn.attention.bias.causal_lower_right. |
|
Thanks for linking the docs, this is a new learn for me. On the behavior, it actually depends on whether T_q is larger or smaller than T_kv: mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp Lines 204 to 208 in 9052f67
|
The mask uses lower-right alignment when T_q <= T_kv and upper-left when T_q > T_kv.
|
Thanks! Fixed to describe the conditional alignment behavior 🙏 |
zcbenz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. /cc @awni for a second look.
|
The comment definitely makes sense. But I also find it a bit strange that we switch from lower right to upper left depending on if query is longer or shorter than the keys. It's quite rare for the query to be longer than the keys which is why we never really looked at it carefully. I'm wondering if we should change the behavior in that case rather than documenting something that is a bit unusual? Or maybe it's a good idea to keep it this way? |
|
I agree current behavior is unusual, and using lower right for all should be a better choice. |
|
@Anri-Lombard what do you think about changing the behavior to always be lower right even when QL > KL? Do you want to send a patch to this PR / send a new one instead of this? |
|
Hey @awni, always lower-right makes sense. The change is minimal (unless I'm missing somethign) - just two cuDNN locations (forward/backward) and the CPU fallback offset calculation. I'll update this PR to make the behavior change instead of just documenting it 👍 |
|
Yes the change should be pretty straight-forward. We may also need to update the mask index calculation in the Metal kernels. If you add a test for this case as well (qL > kL) that would be great. I can help with the metal kernels if needed. |
- cuDNN: Always use set_causal_mask_bottom_right() instead of conditionally selecting based on qL vs kL. This aligns with FlashAttention/PyTorch behavior. - Steel kernels: Add NaN protection for sum_score == 0 edge case when all keys are masked.
Enable scaled_dot_product_attention to handle cases where query sequence is longer than key sequence with causal mask. When qL > kL, early queries have no keys to attend to and output zeros. Changes: - Remove Metal routing guard that blocked qL > kL for causal mask - Fix CPU fallback to use proper lower-right alignment (not clamped) - Zero out output rows where queries have no keys to attend (row_pos < 0) - Update test references to handle all-masked rows correctly
|
@awni and @zcbenz updated and took a stab at the Metal kernels as well - feel free to push changes directly or point out where I deviated if you don't mind the extra time so I can learn the convention preferences more 🙏 For qL > kL, early queries have no keys to attend. Softmax of all-masked values gives uniform weights (exp(finite_min - finite_min) = 1), not zeros. Following PyTorch's pytorch/pytorch#108108 convention, we explicitly zero these rows... I think this is the only "big" change. |
|
@awni you mentioned the tests, the existing test shapes (127, 65, ...) with mask="causal" cover the qL > kL case. Would you prefer an explicit test that verifies early queries output zeros? 🙏 |
|
Nope if it's already tested that is fine! |
|
I don't think we should ensure 0s in the qL > kL case. It's a problem I've looked at in the past is what to do if every key position for a given query is masked. And right now it's not consistent. For now let's leave it as undefined behavior and then look into a more principled fix if necessary. I also would rather not reduce performance overall to handle an edge case we don't really care much about). |
Per review feedback, leave qL > kL with causal mask as undefined behavior rather than ensuring zeros. This avoids performance overhead for an edge case. Tests skip this undefined case.
|
@awni done - removed the zero-row handling. The qL > kL + causal case is now undefined behavior as suggested. Tests skip that case. |
Summary
mask="causal"uses lower-right alignmentis_causal=True(upper-left)When
T_q != T_kv, this distinction matters:References:
Relates to #2835