From b103f42cedb70a85dd9321e17c74c26ad775e6fb Mon Sep 17 00:00:00 2001 From: zhangtao Date: Wed, 21 Jan 2026 15:44:11 +0000 Subject: [PATCH 1/7] [Bug Fix][Qwen-Image-Edit] Fix Qwen-Image-Edit series on NPU --- src/diffusers/models/attention_dispatch.py | 13 +++++++++++-- .../models/transformers/transformer_qwenimage.py | 6 +++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 61c478b03c4f..26d950935fa0 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1133,6 +1133,7 @@ def _npu_attention_forward_op( key, value, query.size(2), # num_heads + atten_mask=attn_mask, input_layout="BSND", pse=None, scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, @@ -2422,7 +2423,14 @@ def _native_npu_attention( _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for NPU attention") + q_seqlen, kv_seqlen = query.size(-2), key.size(-2) + if 0 not in attn_mask: + attn_mask = None + elif attn_mask.dim() not in [2, 4] or attn_mask.size(-2) != q_seqlen or attn_mask.size(-1) != kv_seqlen: + raise ValueError("The attn_mask must be a 2D tensor with shape [q_seqlen, kv_seqlen]," + " or a 4D tensor with shape [batch_size, num_heads, q_seqlen, kv_seqlen]") + else: + attn_mask = ~attn_mask.to(torch.bool) if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: @@ -2431,6 +2439,7 @@ def _native_npu_attention( key, value, query.size(2), # num_heads + atten_mask=attn_mask, input_layout="BSND", pse=None, scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, @@ -2445,7 +2454,7 @@ def _native_npu_attention( query, key, value, - None, + attn_mask, dropout_p, None, scale, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index cf11d8e01fb4..07701db60c80 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -164,7 +164,11 @@ def compute_text_seq_len_from_mask( position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) has_active = encoder_hidden_states_mask.any(dim=1) - per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) + per_sample_len = torch.where( + has_active, + active_positions.max(dim=1).values + 1, + torch.as_tensor(text_seq_len, device=encoder_hidden_states.device) + ) return text_seq_len, per_sample_len, encoder_hidden_states_mask From 3ed2a75f0a1a9fc710567f668b4c421f9fe5b5a2 Mon Sep 17 00:00:00 2001 From: zhangtao Date: Wed, 21 Jan 2026 16:12:34 +0000 Subject: [PATCH 2/7] Enhance NPU attention handling by converting attention mask to boolean and refining mask checks. --- src/diffusers/models/attention_dispatch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 26d950935fa0..8b6b373cdda3 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2423,14 +2423,17 @@ def _native_npu_attention( _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: if attn_mask is not None: + # https://www.hiascend.com/document/detail/zh/Pytorch/730/ptmoddevg/trainingmigrguide/performance_tuning_0034.html + attn_mask = attn_mask.bool() q_seqlen, kv_seqlen = query.size(-2), key.size(-2) - if 0 not in attn_mask: + + if attn_mask.all().item(): attn_mask = None elif attn_mask.dim() not in [2, 4] or attn_mask.size(-2) != q_seqlen or attn_mask.size(-1) != kv_seqlen: raise ValueError("The attn_mask must be a 2D tensor with shape [q_seqlen, kv_seqlen]," " or a 4D tensor with shape [batch_size, num_heads, q_seqlen, kv_seqlen]") else: - attn_mask = ~attn_mask.to(torch.bool) + attn_mask = ~attn_mask if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: From 50055640dd583a84acaed708cffcd2f0ff6c27f5 Mon Sep 17 00:00:00 2001 From: zhangtao Date: Wed, 21 Jan 2026 16:55:53 +0000 Subject: [PATCH 3/7] Refine attention mask handling in NPU attention function to improve validation and conversion logic. --- src/diffusers/models/attention_dispatch.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 8b6b373cdda3..4f6eb8af3aad 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2422,18 +2422,19 @@ def _native_npu_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + # attn_mask = None if attn_mask is not None: # https://www.hiascend.com/document/detail/zh/Pytorch/730/ptmoddevg/trainingmigrguide/performance_tuning_0034.html - attn_mask = attn_mask.bool() q_seqlen, kv_seqlen = query.size(-2), key.size(-2) - if attn_mask.all().item(): - attn_mask = None - elif attn_mask.dim() not in [2, 4] or attn_mask.size(-2) != q_seqlen or attn_mask.size(-1) != kv_seqlen: - raise ValueError("The attn_mask must be a 2D tensor with shape [q_seqlen, kv_seqlen]," - " or a 4D tensor with shape [batch_size, num_heads, q_seqlen, kv_seqlen]") + if attn_mask.dim() not in [2, 4] or attn_mask.size(-2) != q_seqlen or attn_mask.size(-1) != kv_seqlen: + if torch.all(attn_mask != 0).item(): + attn_mask = None + else: + raise ValueError("The attn_mask must be a 2D tensor with shape [q_seqlen, kv_seqlen]," + " or a 4D tensor with shape [batch_size, num_heads, q_seqlen, kv_seqlen]") else: - attn_mask = ~attn_mask + attn_mask = ~attn_mask.to(torch.bool) if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: From e042b0d177685d80d78cb9fb41825bf13349e46e Mon Sep 17 00:00:00 2001 From: zhangtao Date: Thu, 22 Jan 2026 01:25:41 +0000 Subject: [PATCH 4/7] Clean Code --- src/diffusers/models/attention_dispatch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 4f6eb8af3aad..a132978241e4 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2422,7 +2422,6 @@ def _native_npu_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - # attn_mask = None if attn_mask is not None: # https://www.hiascend.com/document/detail/zh/Pytorch/730/ptmoddevg/trainingmigrguide/performance_tuning_0034.html q_seqlen, kv_seqlen = query.size(-2), key.size(-2) From 5c92a7762dba7e7a3c425a3b237af8b7773103e4 Mon Sep 17 00:00:00 2001 From: zhangtao Date: Thu, 22 Jan 2026 09:21:57 +0000 Subject: [PATCH 5/7] Refine attention mask processing in NPU attention functions to enhance performance and validation. --- src/diffusers/models/attention_dispatch.py | 50 ++++++++++++++++------ 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index a132978241e4..39dddce52dc1 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1127,6 +1127,25 @@ def _npu_attention_forward_op( ): if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") + + # Skip Attention Mask if all values are 1, `None` mask can speedup the computation + if ( + attn_mask is not None + and torch.all(attn_mask != 0).item() + ): + attn_mask = None + + # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] + # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md + if ( + attn_mask is not None + and attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + attn_mask = ~attn_mask.to(torch.bool) + attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() out = npu_fusion_attention( query, @@ -2422,21 +2441,28 @@ def _native_npu_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - if attn_mask is not None: - # https://www.hiascend.com/document/detail/zh/Pytorch/730/ptmoddevg/trainingmigrguide/performance_tuning_0034.html - q_seqlen, kv_seqlen = query.size(-2), key.size(-2) - - if attn_mask.dim() not in [2, 4] or attn_mask.size(-2) != q_seqlen or attn_mask.size(-1) != kv_seqlen: - if torch.all(attn_mask != 0).item(): - attn_mask = None - else: - raise ValueError("The attn_mask must be a 2D tensor with shape [q_seqlen, kv_seqlen]," - " or a 4D tensor with shape [batch_size, num_heads, q_seqlen, kv_seqlen]") - else: - attn_mask = ~attn_mask.to(torch.bool) if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: + # Skip Attention Mask if all values are 1, `None` mask can speedup the computation + if ( + attn_mask is not None + and torch.all(attn_mask != 0).item() + ): + attn_mask = None + + # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] + # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md + if ( + attn_mask is not None + and attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + attn_mask = ~attn_mask.to(torch.bool) + attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() + out = npu_fusion_attention( query, key, From 8abfddd571c2fe87581d490f46e722a1b77b4372 Mon Sep 17 00:00:00 2001 From: zhangtao Date: Thu, 22 Jan 2026 13:22:22 +0000 Subject: [PATCH 6/7] Remove item() ops on npu fa backend. --- src/diffusers/models/attention_dispatch.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 39dddce52dc1..1851c42fd0dc 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1129,10 +1129,7 @@ def _npu_attention_forward_op( raise ValueError("NPU attention backend does not support setting `return_lse=True`.") # Skip Attention Mask if all values are 1, `None` mask can speedup the computation - if ( - attn_mask is not None - and torch.all(attn_mask != 0).item() - ): + if (attn_mask is not None and torch.all(attn_mask != 0)): attn_mask = None # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] @@ -2445,10 +2442,7 @@ def _native_npu_attention( raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: # Skip Attention Mask if all values are 1, `None` mask can speedup the computation - if ( - attn_mask is not None - and torch.all(attn_mask != 0).item() - ): + if (attn_mask is not None and torch.all(attn_mask != 0)): attn_mask = None # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] From 020a23275939ffd5e5242cace159a1087b3b4ff9 Mon Sep 17 00:00:00 2001 From: zhangtao Date: Thu, 22 Jan 2026 16:26:24 +0000 Subject: [PATCH 7/7] Reuse NPU attention mask by `_maybe_modify_attn_mask_npu` --- src/diffusers/models/attention_dispatch.py | 54 ++++++++++------------ 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 1851c42fd0dc..56604ed39a62 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1111,23 +1111,11 @@ def _sage_attention_backward_op( raise NotImplementedError("Backward pass is not implemented for Sage attention.") -def _npu_attention_forward_op( - ctx: torch.autograd.function.FunctionCtx, +def _maybe_modify_attn_mask_npu( query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, - return_lse: bool = False, - _save_ctx: bool = True, - _parallel_config: Optional["ParallelConfig"] = None, + attn_mask: Optional[torch.Tensor] = None ): - if return_lse: - raise ValueError("NPU attention backend does not support setting `return_lse=True`.") - # Skip Attention Mask if all values are 1, `None` mask can speedup the computation if (attn_mask is not None and torch.all(attn_mask != 0)): attn_mask = None @@ -1143,6 +1131,28 @@ def _npu_attention_forward_op( B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] attn_mask = ~attn_mask.to(torch.bool) attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() + + return attn_mask + + +def _npu_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if return_lse: + raise ValueError("NPU attention backend does not support setting `return_lse=True`.") + + attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) out = npu_fusion_attention( query, @@ -2441,21 +2451,7 @@ def _native_npu_attention( if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: - # Skip Attention Mask if all values are 1, `None` mask can speedup the computation - if (attn_mask is not None and torch.all(attn_mask != 0)): - attn_mask = None - - # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] - # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md - if ( - attn_mask is not None - and attn_mask.ndim == 2 - and attn_mask.shape[0] == query.shape[0] - and attn_mask.shape[1] == key.shape[1] - ): - B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] - attn_mask = ~attn_mask.to(torch.bool) - attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() + attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) out = npu_fusion_attention( query,