Skip to content
36 changes: 32 additions & 4 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,30 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")


def _maybe_modify_attn_mask_npu(
query: torch.Tensor,
key: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None
):
# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
if (attn_mask is not None and torch.all(attn_mask != 0)):
attn_mask = None

# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
if (
attn_mask is not None
and attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
attn_mask = ~attn_mask.to(torch.bool)
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()

return attn_mask


def _npu_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
Expand All @@ -1126,13 +1150,16 @@ def _npu_attention_forward_op(
_parallel_config: Optional["ParallelConfig"] = None,
):
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")

attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)

out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
atten_mask=attn_mask,
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
Expand Down Expand Up @@ -2421,16 +2448,17 @@ def _native_npu_attention(
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for NPU attention")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
if _parallel_config is None:
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)

out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
atten_mask=attn_mask,
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
Expand All @@ -2445,7 +2473,7 @@ def _native_npu_attention(
query,
key,
value,
None,
attn_mask,
dropout_p,
None,
scale,
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ def compute_text_seq_len_from_mask(
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
has_active = encoder_hidden_states_mask.any(dim=1)
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
per_sample_len = torch.where(
has_active,
active_positions.max(dim=1).values + 1,
torch.as_tensor(text_seq_len, device=encoder_hidden_states.device)
)
Comment on lines -167 to +171
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like an unrelated change? If so, could you undo it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is to fix #13015.

return text_seq_len, per_sample_len, encoder_hidden_states_mask


Expand Down