Skip to content

Conversation

@zhangtao0408
Copy link
Contributor

@zhangtao0408 zhangtao0408 commented Jan 22, 2026

What does this PR do?

Fixes: #13015
Fixes: #13016

Test Codes:

import torch
import torch_npu
import torch.distributed as dist

import os, time
from PIL import Image
from diffusers import QwenImageEditPlusPipeline, ContextParallelConfig
from diffusers.utils import load_image

# Initialize Env
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))

if world_size > 1 and not dist.is_initialized():
	dist.init_process_group(backend="hccl")
	rank = dist.get_rank()
	device = torch.device("npu", rank % torch.npu.device_count())
	torch.npu.set_device(device)
else:
    device='npu'

image1 = load_image("https://github.com/vipshop/cache-dit/raw/main/examples/data/edit2509_1.jpg")
image2 = load_image("https://github.com/vipshop/cache-dit/raw/main/examples/data/edit2509_2.jpg")
prompt = "The magician bear is on the left, the alchemist bear is on the right, facing each other in the central park square"

pipe = QwenImageEditPlusPipeline.from_pretrained(
    "/PATH/TO/Qwen-Image-Edit-2509",
    torch_dtype=torch.bfloat16
).to(device)
pipe.transformer.set_attention_backend("_native_npu")

pipe.set_progress_bar_config(disable=rank != 0)
pipe.enable_model_cpu_offload(device=device)

if world_size > 1:
    pipe.transformer.enable_parallelism(
        config=ContextParallelConfig(ulysses_degree=world_size)
    )

with torch.inference_mode():
    # Inference
    torch.npu.synchronize()
    start_time = time.time()
    output = pipe(
        image=[image1, image2],
        prompt=prompt,
        generator=torch.Generator(device="cpu").manual_seed(0),
        true_cfg_scale=4.0,
        negative_prompt=" ",
        num_inference_steps=20,
        num_images_per_prompt=1,
        height=1024,
        width=1024,
    )
    torch.npu.synchronize()
    end_time = time.time()
    
    inference_time = end_time - start_time
    if rank == 0:
        output_image = output.images[0]
        output_image.save(f"qwen-image-ulysses{world_size}-time{inference_time:.2f}s.png")
        print(f"image saved at qwen-image-ulysses{world_size}-time{inference_time:.2f}s.png")
  • Run the code
# 1 card
python3 qwen_image_edit_test.py

# 4 card
python3 qwen_image_edit_test.py

Results in comment

Before submitting

Who can review?

cc @yiyixuxu @sayakpaul @asomoza @DN6

@zhangtao0408
Copy link
Contributor Author

Results Log

  • Before PR
Traceback (most recent call last):
  File "/home/qwen_image_edit_test.py", line 42, in <module>
    _ = pipe(
        ^^^^^
  File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py", line 803, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.13/lib/python3.11/site-packages/accelerate/hooks.py", line 175, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 923, in forward
    text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 167, in compute_text_seq_len_from_mask
    per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device. Expected NPU tensor, please check whether the input tensor device is correct.
[ERROR] 2026-01-22-02:30:34 (PID:633567, Device:0, RankID:-1) ERR01002 OPS invalid type
  • After PR

1 card

python3 qwen_image_edit_test.py
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 13.13it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 15.45it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.12it/s]
Attention backends are an experimental feature and the API may be subject to change.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:14<00:00,  7.43s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [01:34<00:00,  4.70s/it]
image saved at qwen-image-ulysses1-time109.56s.png

4 cards

torchrun --nproc_per_node=4 qwen_image_edit_test.py
W0121 16:48:33.258000 627474 site-packages/torch/distributed/run.py:774] 
W0121 16:48:33.258000 627474 site-packages/torch/distributed/run.py:774] *****************************************
W0121 16:48:33.258000 627474 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0121 16:48:33.258000 627474 site-packages/torch/distributed/run.py:774] *****************************************
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  8.11it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 14.83it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 14.30it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 15.67it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 13.18it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:03<00:00,  1.95it/s]
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:11<00:00,  5.91s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:35<00:00,  1.77s/it]
image saved at qwen-image-ulysses4-time49.67s.png

Results Image

1 card

qwen-image-ulysses1-time109 56s

4 cards

qwen-image-ulysses4-time49 67s

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left some comments.

Comment on lines 1132 to 1135
if (
attn_mask is not None
and torch.all(attn_mask != 0).item()
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (
attn_mask is not None
and torch.all(attn_mask != 0).item()
):
if attn_mask is not None and torch.all(attn_mask != 0):

Won't it work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't it work?

# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
if (
attn_mask is not None
and torch.all(attn_mask != 0).item()
):
attn_mask = None

Thanks for the reply!

Since NPU FA does not support the [B, Seq_len_kv] mask shape passed by QwenImageEditPlus, and the unsqueeze/expand operations slow down execution, I added logic to bypass these steps when the mask is all 1s. This optimization significantly improves speed in context parallel, as shown in the test results below:

Stage Card End to End Time(s)
Skip expand mask (set to None) 1 108.22
Skip expand mask (set to None) 4 49.83
Expand mask 1 108.62
Expand mask 4 57.74

Copy link
Member

@sayakpaul sayakpaul Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine. I am asking if this condition would work (i.e., no item()):
if attn_mask is not None and torch.all(attn_mask != 0):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that worked. I've removed item() and pushed the update.

# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
if (
attn_mask is not None
and torch.all(attn_mask != 0).item()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Comment on lines -167 to +171
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
per_sample_len = torch.where(
has_active,
active_positions.max(dim=1).values + 1,
torch.as_tensor(text_seq_len, device=encoder_hidden_states.device)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like an unrelated change? If so, could you undo it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is to fix #13015.

Comment on lines 2450 to 2458
if (
attn_mask is not None
and attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
attn_mask = ~attn_mask.to(torch.bool)
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have a small utlity named _maybe_modify_attn_mask_npu() so that it can be reused in the two places (here and above)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion! I've updated the _maybe_modify_attn_mask_npu() method.

def _maybe_modify_attn_mask_npu(
query: torch.Tensor,
key: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None
):
# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
if (attn_mask is not None and torch.all(attn_mask != 0)):
attn_mask = None
# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
if (
attn_mask is not None
and attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
attn_mask = ~attn_mask.to(torch.bool)
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
return attn_mask

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

4 participants