-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[Bug] Fix QwenImageEditPlus Series on NPU #13017
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…n and refining mask checks.
…alidation and conversion logic.
…e performance and validation.
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, left some comments.
| if ( | ||
| attn_mask is not None | ||
| and torch.all(attn_mask != 0).item() | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if ( | |
| attn_mask is not None | |
| and torch.all(attn_mask != 0).item() | |
| ): | |
| if attn_mask is not None and torch.all(attn_mask != 0): |
Won't it work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't it work?
diffusers/src/diffusers/models/attention_dispatch.py
Lines 1131 to 1136 in 5c92a77
| # Skip Attention Mask if all values are 1, `None` mask can speedup the computation | |
| if ( | |
| attn_mask is not None | |
| and torch.all(attn_mask != 0).item() | |
| ): | |
| attn_mask = None |
Thanks for the reply!
Since NPU FA does not support the [B, Seq_len_kv] mask shape passed by QwenImageEditPlus, and the unsqueeze/expand operations slow down execution, I added logic to bypass these steps when the mask is all 1s. This optimization significantly improves speed in context parallel, as shown in the test results below:
| Stage | Card | End to End Time(s) |
|---|---|---|
Skip expand mask (set to None) |
1 | 108.22 |
Skip expand mask (set to None) |
4 | 49.83 |
| Expand mask | 1 | 108.62 |
| Expand mask | 4 | 57.74 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fine. I am asking if this condition would work (i.e., no item()):
if attn_mask is not None and torch.all(attn_mask != 0):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that worked. I've removed item() and pushed the update.
| # Skip Attention Mask if all values are 1, `None` mask can speedup the computation | ||
| if ( | ||
| attn_mask is not None | ||
| and torch.all(attn_mask != 0).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
| per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) | ||
| per_sample_len = torch.where( | ||
| has_active, | ||
| active_positions.max(dim=1).values + 1, | ||
| torch.as_tensor(text_seq_len, device=encoder_hidden_states.device) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like an unrelated change? If so, could you undo it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is to fix #13015.
| if ( | ||
| attn_mask is not None | ||
| and attn_mask.ndim == 2 | ||
| and attn_mask.shape[0] == query.shape[0] | ||
| and attn_mask.shape[1] == key.shape[1] | ||
| ): | ||
| B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] | ||
| attn_mask = ~attn_mask.to(torch.bool) | ||
| attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to have a small utlity named _maybe_modify_attn_mask_npu() so that it can be reused in the two places (here and above)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your suggestion! I've updated the _maybe_modify_attn_mask_npu() method.
diffusers/src/diffusers/models/attention_dispatch.py
Lines 1114 to 1135 in 020a232
| def _maybe_modify_attn_mask_npu( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None | |
| ): | |
| # Skip Attention Mask if all values are 1, `None` mask can speedup the computation | |
| if (attn_mask is not None and torch.all(attn_mask != 0)): | |
| attn_mask = None | |
| # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] | |
| # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md | |
| if ( | |
| attn_mask is not None | |
| and attn_mask.ndim == 2 | |
| and attn_mask.shape[0] == query.shape[0] | |
| and attn_mask.shape[1] == key.shape[1] | |
| ): | |
| B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] | |
| attn_mask = ~attn_mask.to(torch.bool) | |
| attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() | |
| return attn_mask |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |


What does this PR do?
Fixes: #13015
Fixes: #13016
Test Codes:
Results in comment
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
cc @yiyixuxu @sayakpaul @asomoza @DN6