From eb382fa3b0870c49648a5327c07c2a322a47bfa9 Mon Sep 17 00:00:00 2001 From: David Mo Date: Thu, 15 Jan 2026 14:31:00 +0800 Subject: [PATCH 1/5] z-image support npu --- .../transformers/transformer_z_image.py | 62 ++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5983c34ab640..5afa6bb5b49f 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention_dispatch import dispatch_attention_fn from ..modeling_outputs import Transformer2DModelOutput +from ...utils import is_torch_npu_available ADALN_EMBED_DIM = 256 @@ -311,6 +312,62 @@ def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): return x +class RopeEmbedderNPU: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + self.freqs_real = None + self.freqs_imag = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_real_list = [] + freqs_imag_list = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_real = torch.cos(freqs) + freqs_imag = torch.sin(freqs) + freqs_real_list.append(freqs_real.to(torch.float32)) + freqs_imag_list.append(freqs_imag.to(torch.float32)) + + return freqs_real_list, freqs_imag_list + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_real is None or self.freqs_imag is None: + freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_real = [fr.to(device) for fr in freqs_real] + self.freqs_imag = [fi.to(device) for fi in freqs_imag] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_real[0].device != device: + self.freqs_real = [fr.to(device) for fr in freqs_real] + self.freqs_imag = [fi.to(device) for fi in freqs_imag] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + real_part = self.freqs_real[i][index] + imag_part = self.freqs_imag[i][index] + complex_part = torch.complex(real_part, imag_part) + result.append(complex_part) + return torch.cat(result, dim=-1) + + class RopeEmbedder: def __init__( self, @@ -478,7 +535,10 @@ def __init__( self.axes_dims = axes_dims self.axes_lens = axes_lens - self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + if is_torch_npu_available: + self.rope_embedder = RopeEmbedderNPU(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + else: + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) def unpatchify( self, From e4bbc6d343a55467b036cdc2e1bec3836c7dec6d Mon Sep 17 00:00:00 2001 From: Zhibin Mo <97496981+luren55@users.noreply.github.com> Date: Thu, 22 Jan 2026 15:44:22 +0800 Subject: [PATCH 2/5] Update attention_dispatch.py --- src/diffusers/models/attention_dispatch.py | 26 ++++++++++++---------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f4ec49703850..72573caa3117 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2195,18 +2195,20 @@ def _native_attention( attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) if _parallel_config is None: - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) + out = npu_fusion_attention( + query, + key, + value, + query.size(2), # num_heads + input_layout="BSND", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] else: out = _templated_context_parallel_attention( query, From 677c0ffb12f868b04d1365508eda113b3e9e4737 Mon Sep 17 00:00:00 2001 From: Zhibin Mo <97496981+luren55@users.noreply.github.com> Date: Thu, 22 Jan 2026 23:29:47 +0800 Subject: [PATCH 3/5] attention_dispatch.py backup --- src/diffusers/models/attention_dispatch.py | 26 ++++++++++------------ 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 72573caa3117..f4ec49703850 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2195,20 +2195,18 @@ def _native_attention( attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) if _parallel_config is None: - out = npu_fusion_attention( - query, - key, - value, - query.size(2), # num_heads - input_layout="BSND", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0 - dropout_p, - sync=False, - inner_precise=0, - )[0] + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) else: out = _templated_context_parallel_attention( query, From 1dc7cc5600141344d7e7042196b492b3291cfeb7 Mon Sep 17 00:00:00 2001 From: luren55 Date: Mon, 26 Jan 2026 13:51:57 +0800 Subject: [PATCH 4/5] merge RopeEmbedderNPU into RopeEmbedder --- .../transformers/transformer_z_image.py | 124 +++++++----------- 1 file changed, 50 insertions(+), 74 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5afa6bb5b49f..a62a1638ad9a 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -312,7 +312,7 @@ def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): return x -class RopeEmbedderNPU: +class RopeEmbedder: def __init__( self, theta: float = 256.0, @@ -330,87 +330,66 @@ def __init__( @staticmethod def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): with torch.device("cpu"): - freqs_real_list = [] - freqs_imag_list = [] - for i, (d, e) in enumerate(zip(dim, end)): - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) - freqs = torch.outer(timestep, freqs).float() - freqs_real = torch.cos(freqs) - freqs_imag = torch.sin(freqs) - freqs_real_list.append(freqs_real.to(torch.float32)) - freqs_imag_list.append(freqs_imag.to(torch.float32)) - - return freqs_real_list, freqs_imag_list + if is_torch_npu_available: + freqs_real_list = [] + freqs_imag_list = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_real = torch.cos(freqs) + freqs_imag = torch.sin(freqs) + freqs_real_list.append(freqs_real.to(torch.float32)) + freqs_imag_list.append(freqs_imag.to(torch.float32)) + + return freqs_real_list, freqs_imag_list + else: + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + return freqs_cis def __call__(self, ids: torch.Tensor): assert ids.ndim == 2 assert ids.shape[-1] == len(self.axes_dims) device = ids.device - if self.freqs_real is None or self.freqs_imag is None: - freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - self.freqs_real = [fr.to(device) for fr in freqs_real] - self.freqs_imag = [fi.to(device) for fi in freqs_imag] - else: - # Ensure freqs_cis are on the same device as ids - if self.freqs_real[0].device != device: + if is_torch_npu_available: + if self.freqs_real is None or self.freqs_imag is None: + freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) self.freqs_real = [fr.to(device) for fr in freqs_real] self.freqs_imag = [fi.to(device) for fi in freqs_imag] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_real[0].device != device: + self.freqs_real = [fr.to(device) for fr in freqs_real] + self.freqs_imag = [fi.to(device) for fi in freqs_imag] - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, i] - real_part = self.freqs_real[i][index] - imag_part = self.freqs_imag[i][index] - complex_part = torch.complex(real_part, imag_part) - result.append(complex_part) - return torch.cat(result, dim=-1) - - -class RopeEmbedder: - def __init__( - self, - theta: float = 256.0, - axes_dims: List[int] = (16, 56, 56), - axes_lens: List[int] = (64, 128, 128), - ): - self.theta = theta - self.axes_dims = axes_dims - self.axes_lens = axes_lens - assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" - self.freqs_cis = None - - @staticmethod - def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): - with torch.device("cpu"): - freqs_cis = [] - for i, (d, e) in enumerate(zip(dim, end)): - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) - freqs = torch.outer(timestep, freqs).float() - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 - freqs_cis.append(freqs_cis_i) - - return freqs_cis - - def __call__(self, ids: torch.Tensor): - assert ids.ndim == 2 - assert ids.shape[-1] == len(self.axes_dims) - device = ids.device - - if self.freqs_cis is None: - self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + real_part = self.freqs_real[i][index] + imag_part = self.freqs_imag[i][index] + complex_part = torch.complex(real_part, imag_part) + result.append(complex_part) else: - # Ensure freqs_cis are on the same device as ids - if self.freqs_cis[0].device != device: + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, i] - result.append(self.freqs_cis[i][index]) + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) @@ -535,10 +514,7 @@ def __init__( self.axes_dims = axes_dims self.axes_lens = axes_lens - if is_torch_npu_available: - self.rope_embedder = RopeEmbedderNPU(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - else: - self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) def unpatchify( self, From b7d13259bfd713d2de1a1ae8c5dcb598abc1c53d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 28 Jan 2026 04:00:47 +0000 Subject: [PATCH 5/5] Apply style fixes --- src/diffusers/models/transformers/transformer_z_image.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index a62a1638ad9a..a3e50132eb38 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -25,10 +25,10 @@ from ...models.attention_processor import Attention from ...models.modeling_utils import ModelMixin from ...models.normalization import RMSNorm +from ...utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..attention_dispatch import dispatch_attention_fn from ..modeling_outputs import Transformer2DModelOutput -from ...utils import is_torch_npu_available ADALN_EMBED_DIM = 256 @@ -342,8 +342,8 @@ def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): freqs_real_list.append(freqs_real.to(torch.float32)) freqs_imag_list.append(freqs_imag.to(torch.float32)) - return freqs_real_list, freqs_imag_list - else: + return freqs_real_list, freqs_imag_list + else: freqs_cis = [] for i, (d, e) in enumerate(zip(dim, end)): freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) @@ -389,7 +389,7 @@ def __call__(self, ids: torch.Tensor): for i in range(len(self.axes_dims)): index = ids[:, i] result.append(self.freqs_cis[i][index]) - + return torch.cat(result, dim=-1)