diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 61c478b03c4f..56604ed39a62 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1111,6 +1111,30 @@ def _sage_attention_backward_op( raise NotImplementedError("Backward pass is not implemented for Sage attention.") +def _maybe_modify_attn_mask_npu( + query: torch.Tensor, + key: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None +): + # Skip Attention Mask if all values are 1, `None` mask can speedup the computation + if (attn_mask is not None and torch.all(attn_mask != 0)): + attn_mask = None + + # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] + # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md + if ( + attn_mask is not None + and attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + attn_mask = ~attn_mask.to(torch.bool) + attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() + + return attn_mask + + def _npu_attention_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, @@ -1126,13 +1150,16 @@ def _npu_attention_forward_op( _parallel_config: Optional["ParallelConfig"] = None, ): if return_lse: - raise ValueError("NPU attention backend does not support setting `return_lse=True`.") + raise ValueError("NPU attention backend does not support setting `return_lse=True`.") + + attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) out = npu_fusion_attention( query, key, value, query.size(2), # num_heads + atten_mask=attn_mask, input_layout="BSND", pse=None, scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, @@ -2421,16 +2448,17 @@ def _native_npu_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for NPU attention") if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: + attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) + out = npu_fusion_attention( query, key, value, query.size(2), # num_heads + atten_mask=attn_mask, input_layout="BSND", pse=None, scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, @@ -2445,7 +2473,7 @@ def _native_npu_attention( query, key, value, - None, + attn_mask, dropout_p, None, scale, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index cf11d8e01fb4..07701db60c80 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -164,7 +164,11 @@ def compute_text_seq_len_from_mask( position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) has_active = encoder_hidden_states_mask.any(dim=1) - per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) + per_sample_len = torch.where( + has_active, + active_positions.max(dim=1).values + 1, + torch.as_tensor(text_seq_len, device=encoder_hidden_states.device) + ) return text_seq_len, per_sample_len, encoder_hidden_states_mask