Skip to content

Commit 261611c

Browse files
committed
ROPE computing suitable for NPU
1 parent 17c0e79 commit 261611c

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

src/diffusers/models/embeddings.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,9 +1169,18 @@ def get_1d_rotary_pos_embed(
11691169
if is_npu:
11701170
freqs = freqs.float()
11711171
if use_real and repeat_interleave_real:
1172-
# flux, hunyuan-dit, cogvideox
1173-
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
1174-
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
1172+
if is_npu:
1173+
# flux, hunyuan-dit, cogvideox that suitable for NPU
1174+
freqs_cos = (
1175+
freqs.cos().T.repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).T.float().contiguous()
1176+
) # [S, D]
1177+
freqs_sin = (
1178+
freqs.sin().T.repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).T.float().contiguous()
1179+
) # [S, D]
1180+
else:
1181+
# flux, hunyuan-dit, cogvideox
1182+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
1183+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
11751184
return freqs_cos, freqs_sin
11761185
elif use_real:
11771186
# stable audio, allegro

0 commit comments

Comments
 (0)