diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 61c478b03c4f..de12ac032eef 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1111,6 +1111,42 @@ def _sage_attention_backward_op( raise NotImplementedError("Backward pass is not implemented for Sage attention.") +def _maybe_modify_attn_mask_npu( + query: torch.Tensor, + key: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None +): + # Skip Attention Mask if all values are 1, `None` mask can speedup the computation + if (attn_mask is not None and torch.all(attn_mask != 0)): + attn_mask = None + + # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] + # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md + if ( + attn_mask is not None + and attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + attn_mask = ~attn_mask.to(torch.bool) + attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() + + # Reshape Attention Mask: [batch_size, 1, 1, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] + if ( + attn_mask is not None + and attn_mask.ndim == 4 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[-1] == key.shape[1] + and attn_mask.shape[-2] == 1 + ): + B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + attn_mask = ~attn_mask.to(torch.bool) + attn_mask = attn_mask.expand(B, 1, Sq, Skv) + + return attn_mask + + def _npu_attention_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, @@ -1126,13 +1162,16 @@ def _npu_attention_forward_op( _parallel_config: Optional["ParallelConfig"] = None, ): if return_lse: - raise ValueError("NPU attention backend does not support setting `return_lse=True`.") + raise ValueError("NPU attention backend does not support setting `return_lse=True`.") + + attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) out = npu_fusion_attention( query, key, value, query.size(2), # num_heads + atten_mask=attn_mask, input_layout="BSND", pse=None, scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, @@ -2421,16 +2460,17 @@ def _native_npu_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for NPU attention") if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: + attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) + out = npu_fusion_attention( query, key, value, query.size(2), # num_heads + atten_mask=attn_mask, input_layout="BSND", pse=None, scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, @@ -2445,7 +2485,7 @@ def _native_npu_attention( query, key, value, - None, + attn_mask, dropout_p, None, scale, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index cf11d8e01fb4..07701db60c80 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -164,7 +164,11 @@ def compute_text_seq_len_from_mask( position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) has_active = encoder_hidden_states_mask.any(dim=1) - per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) + per_sample_len = torch.where( + has_active, + active_positions.max(dim=1).values + 1, + torch.as_tensor(text_seq_len, device=encoder_hidden_states.device) + ) return text_seq_len, per_sample_len, encoder_hidden_states_mask diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5983c34ab640..a3e50132eb38 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -25,6 +25,7 @@ from ...models.attention_processor import Attention from ...models.modeling_utils import ModelMixin from ...models.normalization import RMSNorm +from ...utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..attention_dispatch import dispatch_attention_fn from ..modeling_outputs import Transformer2DModelOutput @@ -323,37 +324,72 @@ def __init__( self.axes_lens = axes_lens assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" self.freqs_cis = None + self.freqs_real = None + self.freqs_imag = None @staticmethod def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): with torch.device("cpu"): - freqs_cis = [] - for i, (d, e) in enumerate(zip(dim, end)): - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) - freqs = torch.outer(timestep, freqs).float() - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 - freqs_cis.append(freqs_cis_i) - - return freqs_cis + if is_torch_npu_available: + freqs_real_list = [] + freqs_imag_list = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_real = torch.cos(freqs) + freqs_imag = torch.sin(freqs) + freqs_real_list.append(freqs_real.to(torch.float32)) + freqs_imag_list.append(freqs_imag.to(torch.float32)) + + return freqs_real_list, freqs_imag_list + else: + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + return freqs_cis def __call__(self, ids: torch.Tensor): assert ids.ndim == 2 assert ids.shape[-1] == len(self.axes_dims) device = ids.device - if self.freqs_cis is None: - self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + if is_torch_npu_available: + if self.freqs_real is None or self.freqs_imag is None: + freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_real = [fr.to(device) for fr in freqs_real] + self.freqs_imag = [fi.to(device) for fi in freqs_imag] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_real[0].device != device: + self.freqs_real = [fr.to(device) for fr in freqs_real] + self.freqs_imag = [fi.to(device) for fi in freqs_imag] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + real_part = self.freqs_real[i][index] + imag_part = self.freqs_imag[i][index] + complex_part = torch.complex(real_part, imag_part) + result.append(complex_part) else: - # Ensure freqs_cis are on the same device as ids - if self.freqs_cis[0].device != device: + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, i] - result.append(self.freqs_cis[i][index]) return torch.cat(result, dim=-1)