From eb382fa3b0870c49648a5327c07c2a322a47bfa9 Mon Sep 17 00:00:00 2001 From: David Mo Date: Thu, 15 Jan 2026 14:31:00 +0800 Subject: [PATCH 01/13] z-image support npu --- .../transformers/transformer_z_image.py | 62 ++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5983c34ab640..5afa6bb5b49f 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention_dispatch import dispatch_attention_fn from ..modeling_outputs import Transformer2DModelOutput +from ...utils import is_torch_npu_available ADALN_EMBED_DIM = 256 @@ -311,6 +312,62 @@ def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): return x +class RopeEmbedderNPU: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + self.freqs_real = None + self.freqs_imag = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_real_list = [] + freqs_imag_list = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_real = torch.cos(freqs) + freqs_imag = torch.sin(freqs) + freqs_real_list.append(freqs_real.to(torch.float32)) + freqs_imag_list.append(freqs_imag.to(torch.float32)) + + return freqs_real_list, freqs_imag_list + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_real is None or self.freqs_imag is None: + freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_real = [fr.to(device) for fr in freqs_real] + self.freqs_imag = [fi.to(device) for fi in freqs_imag] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_real[0].device != device: + self.freqs_real = [fr.to(device) for fr in freqs_real] + self.freqs_imag = [fi.to(device) for fi in freqs_imag] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + real_part = self.freqs_real[i][index] + imag_part = self.freqs_imag[i][index] + complex_part = torch.complex(real_part, imag_part) + result.append(complex_part) + return torch.cat(result, dim=-1) + + class RopeEmbedder: def __init__( self, @@ -478,7 +535,10 @@ def __init__( self.axes_dims = axes_dims self.axes_lens = axes_lens - self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + if is_torch_npu_available: + self.rope_embedder = RopeEmbedderNPU(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + else: + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) def unpatchify( self, From b103f42cedb70a85dd9321e17c74c26ad775e6fb Mon Sep 17 00:00:00 2001 From: zhangtao Date: Wed, 21 Jan 2026 15:44:11 +0000 Subject: [PATCH 02/13] [Bug Fix][Qwen-Image-Edit] Fix Qwen-Image-Edit series on NPU --- src/diffusers/models/attention_dispatch.py | 13 +++++++++++-- .../models/transformers/transformer_qwenimage.py | 6 +++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 61c478b03c4f..26d950935fa0 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1133,6 +1133,7 @@ def _npu_attention_forward_op( key, value, query.size(2), # num_heads + atten_mask=attn_mask, input_layout="BSND", pse=None, scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, @@ -2422,7 +2423,14 @@ def _native_npu_attention( _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for NPU attention") + q_seqlen, kv_seqlen = query.size(-2), key.size(-2) + if 0 not in attn_mask: + attn_mask = None + elif attn_mask.dim() not in [2, 4] or attn_mask.size(-2) != q_seqlen or attn_mask.size(-1) != kv_seqlen: + raise ValueError("The attn_mask must be a 2D tensor with shape [q_seqlen, kv_seqlen]," + " or a 4D tensor with shape [batch_size, num_heads, q_seqlen, kv_seqlen]") + else: + attn_mask = ~attn_mask.to(torch.bool) if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: @@ -2431,6 +2439,7 @@ def _native_npu_attention( key, value, query.size(2), # num_heads + atten_mask=attn_mask, input_layout="BSND", pse=None, scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, @@ -2445,7 +2454,7 @@ def _native_npu_attention( query, key, value, - None, + attn_mask, dropout_p, None, scale, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index cf11d8e01fb4..07701db60c80 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -164,7 +164,11 @@ def compute_text_seq_len_from_mask( position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) has_active = encoder_hidden_states_mask.any(dim=1) - per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) + per_sample_len = torch.where( + has_active, + active_positions.max(dim=1).values + 1, + torch.as_tensor(text_seq_len, device=encoder_hidden_states.device) + ) return text_seq_len, per_sample_len, encoder_hidden_states_mask From 3ed2a75f0a1a9fc710567f668b4c421f9fe5b5a2 Mon Sep 17 00:00:00 2001 From: zhangtao Date: Wed, 21 Jan 2026 16:12:34 +0000 Subject: [PATCH 03/13] Enhance NPU attention handling by converting attention mask to boolean and refining mask checks. --- src/diffusers/models/attention_dispatch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 26d950935fa0..8b6b373cdda3 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2423,14 +2423,17 @@ def _native_npu_attention( _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: if attn_mask is not None: + # https://www.hiascend.com/document/detail/zh/Pytorch/730/ptmoddevg/trainingmigrguide/performance_tuning_0034.html + attn_mask = attn_mask.bool() q_seqlen, kv_seqlen = query.size(-2), key.size(-2) - if 0 not in attn_mask: + + if attn_mask.all().item(): attn_mask = None elif attn_mask.dim() not in [2, 4] or attn_mask.size(-2) != q_seqlen or attn_mask.size(-1) != kv_seqlen: raise ValueError("The attn_mask must be a 2D tensor with shape [q_seqlen, kv_seqlen]," " or a 4D tensor with shape [batch_size, num_heads, q_seqlen, kv_seqlen]") else: - attn_mask = ~attn_mask.to(torch.bool) + attn_mask = ~attn_mask if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: From 50055640dd583a84acaed708cffcd2f0ff6c27f5 Mon Sep 17 00:00:00 2001 From: zhangtao Date: Wed, 21 Jan 2026 16:55:53 +0000 Subject: [PATCH 04/13] Refine attention mask handling in NPU attention function to improve validation and conversion logic. --- src/diffusers/models/attention_dispatch.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 8b6b373cdda3..4f6eb8af3aad 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2422,18 +2422,19 @@ def _native_npu_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + # attn_mask = None if attn_mask is not None: # https://www.hiascend.com/document/detail/zh/Pytorch/730/ptmoddevg/trainingmigrguide/performance_tuning_0034.html - attn_mask = attn_mask.bool() q_seqlen, kv_seqlen = query.size(-2), key.size(-2) - if attn_mask.all().item(): - attn_mask = None - elif attn_mask.dim() not in [2, 4] or attn_mask.size(-2) != q_seqlen or attn_mask.size(-1) != kv_seqlen: - raise ValueError("The attn_mask must be a 2D tensor with shape [q_seqlen, kv_seqlen]," - " or a 4D tensor with shape [batch_size, num_heads, q_seqlen, kv_seqlen]") + if attn_mask.dim() not in [2, 4] or attn_mask.size(-2) != q_seqlen or attn_mask.size(-1) != kv_seqlen: + if torch.all(attn_mask != 0).item(): + attn_mask = None + else: + raise ValueError("The attn_mask must be a 2D tensor with shape [q_seqlen, kv_seqlen]," + " or a 4D tensor with shape [batch_size, num_heads, q_seqlen, kv_seqlen]") else: - attn_mask = ~attn_mask + attn_mask = ~attn_mask.to(torch.bool) if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: From e042b0d177685d80d78cb9fb41825bf13349e46e Mon Sep 17 00:00:00 2001 From: zhangtao Date: Thu, 22 Jan 2026 01:25:41 +0000 Subject: [PATCH 05/13] Clean Code --- src/diffusers/models/attention_dispatch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 4f6eb8af3aad..a132978241e4 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2422,7 +2422,6 @@ def _native_npu_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - # attn_mask = None if attn_mask is not None: # https://www.hiascend.com/document/detail/zh/Pytorch/730/ptmoddevg/trainingmigrguide/performance_tuning_0034.html q_seqlen, kv_seqlen = query.size(-2), key.size(-2) From e4bbc6d343a55467b036cdc2e1bec3836c7dec6d Mon Sep 17 00:00:00 2001 From: Zhibin Mo <97496981+luren55@users.noreply.github.com> Date: Thu, 22 Jan 2026 15:44:22 +0800 Subject: [PATCH 06/13] Update attention_dispatch.py --- src/diffusers/models/attention_dispatch.py | 26 ++++++++++++---------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f4ec49703850..72573caa3117 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2195,18 +2195,20 @@ def _native_attention( attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) if _parallel_config is None: - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) + out = npu_fusion_attention( + query, + key, + value, + query.size(2), # num_heads + input_layout="BSND", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] else: out = _templated_context_parallel_attention( query, From 5c92a7762dba7e7a3c425a3b237af8b7773103e4 Mon Sep 17 00:00:00 2001 From: zhangtao Date: Thu, 22 Jan 2026 09:21:57 +0000 Subject: [PATCH 07/13] Refine attention mask processing in NPU attention functions to enhance performance and validation. --- src/diffusers/models/attention_dispatch.py | 50 ++++++++++++++++------ 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index a132978241e4..39dddce52dc1 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1127,6 +1127,25 @@ def _npu_attention_forward_op( ): if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") + + # Skip Attention Mask if all values are 1, `None` mask can speedup the computation + if ( + attn_mask is not None + and torch.all(attn_mask != 0).item() + ): + attn_mask = None + + # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] + # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md + if ( + attn_mask is not None + and attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + attn_mask = ~attn_mask.to(torch.bool) + attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() out = npu_fusion_attention( query, @@ -2422,21 +2441,28 @@ def _native_npu_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - if attn_mask is not None: - # https://www.hiascend.com/document/detail/zh/Pytorch/730/ptmoddevg/trainingmigrguide/performance_tuning_0034.html - q_seqlen, kv_seqlen = query.size(-2), key.size(-2) - - if attn_mask.dim() not in [2, 4] or attn_mask.size(-2) != q_seqlen or attn_mask.size(-1) != kv_seqlen: - if torch.all(attn_mask != 0).item(): - attn_mask = None - else: - raise ValueError("The attn_mask must be a 2D tensor with shape [q_seqlen, kv_seqlen]," - " or a 4D tensor with shape [batch_size, num_heads, q_seqlen, kv_seqlen]") - else: - attn_mask = ~attn_mask.to(torch.bool) if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: + # Skip Attention Mask if all values are 1, `None` mask can speedup the computation + if ( + attn_mask is not None + and torch.all(attn_mask != 0).item() + ): + attn_mask = None + + # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] + # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md + if ( + attn_mask is not None + and attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + attn_mask = ~attn_mask.to(torch.bool) + attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() + out = npu_fusion_attention( query, key, From 8abfddd571c2fe87581d490f46e722a1b77b4372 Mon Sep 17 00:00:00 2001 From: zhangtao Date: Thu, 22 Jan 2026 13:22:22 +0000 Subject: [PATCH 08/13] Remove item() ops on npu fa backend. --- src/diffusers/models/attention_dispatch.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 39dddce52dc1..1851c42fd0dc 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1129,10 +1129,7 @@ def _npu_attention_forward_op( raise ValueError("NPU attention backend does not support setting `return_lse=True`.") # Skip Attention Mask if all values are 1, `None` mask can speedup the computation - if ( - attn_mask is not None - and torch.all(attn_mask != 0).item() - ): + if (attn_mask is not None and torch.all(attn_mask != 0)): attn_mask = None # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] @@ -2445,10 +2442,7 @@ def _native_npu_attention( raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: # Skip Attention Mask if all values are 1, `None` mask can speedup the computation - if ( - attn_mask is not None - and torch.all(attn_mask != 0).item() - ): + if (attn_mask is not None and torch.all(attn_mask != 0)): attn_mask = None # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] From 677c0ffb12f868b04d1365508eda113b3e9e4737 Mon Sep 17 00:00:00 2001 From: Zhibin Mo <97496981+luren55@users.noreply.github.com> Date: Thu, 22 Jan 2026 23:29:47 +0800 Subject: [PATCH 09/13] attention_dispatch.py backup --- src/diffusers/models/attention_dispatch.py | 26 ++++++++++------------ 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 72573caa3117..f4ec49703850 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2195,20 +2195,18 @@ def _native_attention( attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) if _parallel_config is None: - out = npu_fusion_attention( - query, - key, - value, - query.size(2), # num_heads - input_layout="BSND", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0 - dropout_p, - sync=False, - inner_precise=0, - )[0] + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) else: out = _templated_context_parallel_attention( query, From 020a23275939ffd5e5242cace159a1087b3b4ff9 Mon Sep 17 00:00:00 2001 From: zhangtao Date: Thu, 22 Jan 2026 16:26:24 +0000 Subject: [PATCH 10/13] Reuse NPU attention mask by `_maybe_modify_attn_mask_npu` --- src/diffusers/models/attention_dispatch.py | 54 ++++++++++------------ 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 1851c42fd0dc..56604ed39a62 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1111,23 +1111,11 @@ def _sage_attention_backward_op( raise NotImplementedError("Backward pass is not implemented for Sage attention.") -def _npu_attention_forward_op( - ctx: torch.autograd.function.FunctionCtx, +def _maybe_modify_attn_mask_npu( query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, - return_lse: bool = False, - _save_ctx: bool = True, - _parallel_config: Optional["ParallelConfig"] = None, + attn_mask: Optional[torch.Tensor] = None ): - if return_lse: - raise ValueError("NPU attention backend does not support setting `return_lse=True`.") - # Skip Attention Mask if all values are 1, `None` mask can speedup the computation if (attn_mask is not None and torch.all(attn_mask != 0)): attn_mask = None @@ -1143,6 +1131,28 @@ def _npu_attention_forward_op( B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] attn_mask = ~attn_mask.to(torch.bool) attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() + + return attn_mask + + +def _npu_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if return_lse: + raise ValueError("NPU attention backend does not support setting `return_lse=True`.") + + attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) out = npu_fusion_attention( query, @@ -2441,21 +2451,7 @@ def _native_npu_attention( if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: - # Skip Attention Mask if all values are 1, `None` mask can speedup the computation - if (attn_mask is not None and torch.all(attn_mask != 0)): - attn_mask = None - - # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] - # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md - if ( - attn_mask is not None - and attn_mask.ndim == 2 - and attn_mask.shape[0] == query.shape[0] - and attn_mask.shape[1] == key.shape[1] - ): - B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] - attn_mask = ~attn_mask.to(torch.bool) - attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() + attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) out = npu_fusion_attention( query, From 1dc7cc5600141344d7e7042196b492b3291cfeb7 Mon Sep 17 00:00:00 2001 From: luren55 Date: Mon, 26 Jan 2026 13:51:57 +0800 Subject: [PATCH 11/13] merge RopeEmbedderNPU into RopeEmbedder --- .../transformers/transformer_z_image.py | 124 +++++++----------- 1 file changed, 50 insertions(+), 74 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5afa6bb5b49f..a62a1638ad9a 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -312,7 +312,7 @@ def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): return x -class RopeEmbedderNPU: +class RopeEmbedder: def __init__( self, theta: float = 256.0, @@ -330,87 +330,66 @@ def __init__( @staticmethod def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): with torch.device("cpu"): - freqs_real_list = [] - freqs_imag_list = [] - for i, (d, e) in enumerate(zip(dim, end)): - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) - freqs = torch.outer(timestep, freqs).float() - freqs_real = torch.cos(freqs) - freqs_imag = torch.sin(freqs) - freqs_real_list.append(freqs_real.to(torch.float32)) - freqs_imag_list.append(freqs_imag.to(torch.float32)) - - return freqs_real_list, freqs_imag_list + if is_torch_npu_available: + freqs_real_list = [] + freqs_imag_list = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_real = torch.cos(freqs) + freqs_imag = torch.sin(freqs) + freqs_real_list.append(freqs_real.to(torch.float32)) + freqs_imag_list.append(freqs_imag.to(torch.float32)) + + return freqs_real_list, freqs_imag_list + else: + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + return freqs_cis def __call__(self, ids: torch.Tensor): assert ids.ndim == 2 assert ids.shape[-1] == len(self.axes_dims) device = ids.device - if self.freqs_real is None or self.freqs_imag is None: - freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - self.freqs_real = [fr.to(device) for fr in freqs_real] - self.freqs_imag = [fi.to(device) for fi in freqs_imag] - else: - # Ensure freqs_cis are on the same device as ids - if self.freqs_real[0].device != device: + if is_torch_npu_available: + if self.freqs_real is None or self.freqs_imag is None: + freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) self.freqs_real = [fr.to(device) for fr in freqs_real] self.freqs_imag = [fi.to(device) for fi in freqs_imag] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_real[0].device != device: + self.freqs_real = [fr.to(device) for fr in freqs_real] + self.freqs_imag = [fi.to(device) for fi in freqs_imag] - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, i] - real_part = self.freqs_real[i][index] - imag_part = self.freqs_imag[i][index] - complex_part = torch.complex(real_part, imag_part) - result.append(complex_part) - return torch.cat(result, dim=-1) - - -class RopeEmbedder: - def __init__( - self, - theta: float = 256.0, - axes_dims: List[int] = (16, 56, 56), - axes_lens: List[int] = (64, 128, 128), - ): - self.theta = theta - self.axes_dims = axes_dims - self.axes_lens = axes_lens - assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" - self.freqs_cis = None - - @staticmethod - def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): - with torch.device("cpu"): - freqs_cis = [] - for i, (d, e) in enumerate(zip(dim, end)): - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) - freqs = torch.outer(timestep, freqs).float() - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 - freqs_cis.append(freqs_cis_i) - - return freqs_cis - - def __call__(self, ids: torch.Tensor): - assert ids.ndim == 2 - assert ids.shape[-1] == len(self.axes_dims) - device = ids.device - - if self.freqs_cis is None: - self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + real_part = self.freqs_real[i][index] + imag_part = self.freqs_imag[i][index] + complex_part = torch.complex(real_part, imag_part) + result.append(complex_part) else: - # Ensure freqs_cis are on the same device as ids - if self.freqs_cis[0].device != device: + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, i] - result.append(self.freqs_cis[i][index]) + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) @@ -535,10 +514,7 @@ def __init__( self.axes_dims = axes_dims self.axes_lens = axes_lens - if is_torch_npu_available: - self.rope_embedder = RopeEmbedderNPU(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - else: - self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) def unpatchify( self, From b7d13259bfd713d2de1a1ae8c5dcb598abc1c53d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 28 Jan 2026 04:00:47 +0000 Subject: [PATCH 12/13] Apply style fixes --- src/diffusers/models/transformers/transformer_z_image.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index a62a1638ad9a..a3e50132eb38 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -25,10 +25,10 @@ from ...models.attention_processor import Attention from ...models.modeling_utils import ModelMixin from ...models.normalization import RMSNorm +from ...utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..attention_dispatch import dispatch_attention_fn from ..modeling_outputs import Transformer2DModelOutput -from ...utils import is_torch_npu_available ADALN_EMBED_DIM = 256 @@ -342,8 +342,8 @@ def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): freqs_real_list.append(freqs_real.to(torch.float32)) freqs_imag_list.append(freqs_imag.to(torch.float32)) - return freqs_real_list, freqs_imag_list - else: + return freqs_real_list, freqs_imag_list + else: freqs_cis = [] for i, (d, e) in enumerate(zip(dim, end)): freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) @@ -389,7 +389,7 @@ def __call__(self, ids: torch.Tensor): for i in range(len(self.axes_dims)): index = ids[:, i] result.append(self.freqs_cis[i][index]) - + return torch.cat(result, dim=-1) From e67707217490bc66a6c00fd42e36cf09688ce628 Mon Sep 17 00:00:00 2001 From: zhangtao Date: Wed, 28 Jan 2026 15:07:07 +0800 Subject: [PATCH 13/13] Feat. Support Z-Image attention mask for NPU --- src/diffusers/models/attention_dispatch.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 56604ed39a62..de12ac032eef 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1132,6 +1132,18 @@ def _maybe_modify_attn_mask_npu( attn_mask = ~attn_mask.to(torch.bool) attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() + # Reshape Attention Mask: [batch_size, 1, 1, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] + if ( + attn_mask is not None + and attn_mask.ndim == 4 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[-1] == key.shape[1] + and attn_mask.shape[-2] == 1 + ): + B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + attn_mask = ~attn_mask.to(torch.bool) + attn_mask = attn_mask.expand(B, 1, Sq, Skv) + return attn_mask