Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 54 additions & 18 deletions src/diffusers/models/transformers/transformer_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ...models.attention_processor import Attention
from ...models.modeling_utils import ModelMixin
from ...models.normalization import RMSNorm
from ...utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_dispatch import dispatch_attention_fn
from ..modeling_outputs import Transformer2DModelOutput
Expand Down Expand Up @@ -323,37 +324,72 @@ def __init__(
self.axes_lens = axes_lens
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
self.freqs_cis = None
self.freqs_real = None
self.freqs_imag = None

@staticmethod
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
with torch.device("cpu"):
freqs_cis = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
freqs_cis.append(freqs_cis_i)

return freqs_cis
if is_torch_npu_available:
freqs_real_list = []
freqs_imag_list = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_real = torch.cos(freqs)
freqs_imag = torch.sin(freqs)
freqs_real_list.append(freqs_real.to(torch.float32))
freqs_imag_list.append(freqs_imag.to(torch.float32))

return freqs_real_list, freqs_imag_list
else:
freqs_cis = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
freqs_cis.append(freqs_cis_i)
return freqs_cis

def __call__(self, ids: torch.Tensor):
assert ids.ndim == 2
assert ids.shape[-1] == len(self.axes_dims)
device = ids.device

if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
if is_torch_npu_available:
if self.freqs_real is None or self.freqs_imag is None:
freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_real = [fr.to(device) for fr in freqs_real]
self.freqs_imag = [fi.to(device) for fi in freqs_imag]
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_real[0].device != device:
self.freqs_real = [fr.to(device) for fr in freqs_real]
self.freqs_imag = [fi.to(device) for fi in freqs_imag]

result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
real_part = self.freqs_real[i][index]
imag_part = self.freqs_imag[i][index]
complex_part = torch.complex(real_part, imag_part)
result.append(complex_part)
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_cis[0].device != device:
if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_cis[0].device != device:
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]

result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])

result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)


Expand Down