Skip to content

Conversation

@DefTruth
Copy link
Contributor

@DefTruth DefTruth commented Jan 19, 2026

fixed #12706, this pr implement ulysses anything attention for diffusers in order to support [ANY] sequence lengths and [ANY] head num for ulysses.

  • Support any sequence lengths
  • Support any head num (e.g, Z-Image, head num = 30)
  • NO extra padding while sequence length is not divisible by the number of devices
  • NO loss of precision

@sayakpaul @DN6 @yiyixuxu

About Ulysses Anything Attention

Please refer to our docs for more details. link: https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention

Test

Qwen-Image && Qwen-Image-2512

import torch
import argparse
from diffusers import QwenImagePipeline
import torch.distributed as dist
from diffusers import ContextParallelConfig
from diffusers.quantizers import PipelineQuantizationConfig


def parse_args():
    parser = argparse.ArgumentParser(description="Test Qwen-Image with Context Parallelism")
    parser.add_argument(
        "--use_2512",
        action="store_true",
        help="Use Qwen-Image-2512 model if set, otherwise use 2509 model.",
    )
    # torch.compile flags
    parser.add_argument(
        "--compile",
        action="store_true",
        help="Enable torch.compile for the pipeline if set.",
    )
    parser.add_argument(
        "--quantize",
        action="store_true",
        help="Enable quantization for the pipeline if set.",
    )
    parser.add_argument(
        "--ulysses-anything",
        action="store_true",
        help="Enable debug mode if set.",
    )
    return parser.parse_args()

args = parse_args()

if dist.is_available():
    dist.init_process_group(backend="cpu:gloo,cuda:nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    world_size = dist.get_world_size()
    torch.cuda.set_device(device)
else:
    rank = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    world_size = 1

if args.use_2512:
    model_id = "Qwen/Qwen-Image-2512"
else:
    model_id = "Qwen/Qwen-Image"

pipe = QwenImagePipeline.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16,
    quantization_config=(
        PipelineQuantizationConfig(
            quant_backend="bitsandbytes_4bit",
            quant_kwargs={
                "load_in_4bit": True,
                "bnb_4bit_quant_type": "nf4",
                "bnb_4bit_compute_dtype": torch.bfloat16,
            },
            components_to_quantize=["text_encoder", "transformer"],
        )
    ) if args.quantize else None,
)

if args.quantize:
    pipe.to(device)
else:
    pipe.enable_model_cpu_offload(device=device)

pipe.transformer.set_attention_backend("native")
if world_size > 1:
    from diffusers import QwenImageTransformer2DModel
    assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
    pipe.transformer.enable_parallelism(
        config=ContextParallelConfig(
            ulysses_degree=world_size,
            ulysses_anything=args.ulysses_anything,
        )
    )

pipe.set_progress_bar_config(disable=rank != 0)

positive_magic = {
        "en": ", Ultra HD, 4K, cinematic composition.",  # for english prompt
        "zh": ", 超清,4K,电影级构图.",  # for chinese prompt
}
prompt = (
        "A coffee shop entrance features a chalkboard sign reading "
        '"Qwen Coffee 😊 $2 per cup," with a neon light beside it '
        'displaying "通义千问". Next to it hangs a poster showing a '
        "beautiful Chinese woman, and beneath the poster is written "
        '"π≈3.1415926-53589793-23846264-33832795-02384197". '
        "Ultra HD, 4K, cinematic composition"
)


if args.compile:
    torch._dynamo.config.recompile_limit = 256
    torch._dynamo.config.accumulated_recompile_limit = 8096
    torch._inductor.config.reorder_for_compute_comm_overlap = True
    pipe.transformer.compile_repeated_blocks()


def run_pipe():
    with torch.inference_mode():
        inputs = {
            "prompt": prompt + positive_magic["en"],
            "generator": torch.Generator(device="cpu").manual_seed(0),
            "true_cfg_scale": 4.0,
            "negative_prompt": " ",
            "num_inference_steps": 50,
            "num_images_per_prompt": 1,
            "height": 1024,
            "width": 1024,
        }
        output = pipe(**inputs)
        output_image = output.images[0]
    return output_image


if args.compile:
    # Warm-up run for compilation
    for _ in range(2):
        run_pipe()


output_image = run_pipe()

model_version = "2512" if args.use_2512 else None
if world_size > 1:
    if model_version is not None:
        save_path = f"output_image_{model_version}_ulysses{world_size}.png"
    else:
        save_path = f"output_image_ulysses{world_size}.png"
else:
    if model_version is not None:
        save_path = f"output_image_{model_version}.png"
    else:
        save_path = f"output_image.png"
if rank == 0:
    output_image.save(save_path)
    print(f"image saved at {save_path}")

if dist.is_initialized():
    dist.destroy_process_group()

test cmds:

torchrun --nproc_per_node=1 test_qwen_image.py --use_2512 # baseline 2512
torchrun --nproc_per_node=2 test_qwen_image.py --use_2512 # cp2 2512
torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 # cp4 2512, standard ulysses failed
torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 --ulysses-anything # cp4 2512, working as expected
torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 --ulysses-anything --compile # cp4 2512 + compile, working as expected

before this pr:

torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 # standard ulysses failed

...
[rank3]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 190, in new_forward
[rank3]:     return function_reference.post_forward(module, output)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 201, in post_forward
[rank3]:     current_output = self._prepare_cp_input(current_output, cpm)
[rank3]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 213, in _prepare_cp_input
[rank3]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 266, in shard
[rank3]:     assert tensor.size()[dim] % mesh.size() == 0, (
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size

after this pr:

torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 --ulysses-anything

...
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
100%|█████████████████████████████████████████████████████████| 50/50 [00:49<00:00,  1.02it/s]
image saved at output_image_2512_ulysses4.png
Qwen-Image-2512 Qwen-Image-2512 Ulysses-2 Qwen-Image-2512 Ulysses-Anything-4
L20x1 w/ offload, 101s L20x2 w/ offload, 75s L20x4 w/ offload, 49s, standard Ulysses failed
output_image_2512 output_image_2512_ulysses2 output_image_2512_ulysses4
  • compile (w/o offload, L20 48GiB)
# NO Compile, ~35s
torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 --ulysses-anything --quantize 
# w/ Compile, ~32s
torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 --ulysses-anything --quantize --compile

@DefTruth DefTruth marked this pull request as ready for review January 19, 2026 08:32
@sayakpaul sayakpaul added the performance Anything related to performance improvements, profiling and benchmarking label Jan 19, 2026
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting this work!

Let's try to think if we can allow the user to more explicitly specify Ulysses Anything through the configs instead of an env var.

@DefTruth
Copy link
Contributor Author

Thanks for starting this work!

Let's try to think if we can allow the user to more explicitly specify Ulysses Anything through the configs instead of an env var.

Now, we can enable ulysses anything attention through the ContextParallelConfig:

pipe.transformer.enable_parallelism(
    config=ContextParallelConfig(
        ulysses_degree=world_size,
        ulysses_anything=True,
    )
)

@DefTruth
Copy link
Contributor Author

Let me prepare more test cases for ulysses anything.

@DefTruth
Copy link
Contributor Author

DefTruth commented Jan 20, 2026

FLUX.1-dev, 1008x1008

import os
import argparse
import time
import torch
import torch.distributed as dist
from diffusers import (
    FluxPipeline,
    FluxTransformer2DModel,
    ContextParallelConfig,
    PipelineQuantizationConfig,
)

def parse_args():
    parser = argparse.ArgumentParser(description="Context Parallelism")
    # torch.compile flags
    parser.add_argument(
        "--compile",
        action="store_true",
        help="Enable torch.compile for the pipeline if set.",
    )
    parser.add_argument(
        "--quantize",
        action="store_true",
        help="Enable quantization for the pipeline if set.",
    )
    parser.add_argument(
        "--ulysses-anything",
        action="store_true",
        help="Enable debug mode if set.",
    )
    # height and width
    parser.add_argument(
        "--height",
        type=int,
        default=None,
        help="Height of the generated image.",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=None,
        help="Width of the generated image.",
    )
    return parser.parse_args()

args = parse_args()

print(args)

if dist.is_available():
    dist.init_process_group(backend="cpu:gloo,cuda:nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    world_size = dist.get_world_size()
    torch.cuda.set_device(device)
else:
    rank = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    world_size = 1

pipe: FluxPipeline = FluxPipeline.from_pretrained(
    os.environ.get(
        "FLUX_DIR",
        "black-forest-labs/FLUX.1-dev",
    ),
    torch_dtype=torch.bfloat16,
    quantization_config=(
        PipelineQuantizationConfig(
            quant_backend="bitsandbytes_4bit",
            quant_kwargs={
                "load_in_4bit": True,
                "bnb_4bit_quant_type": "nf4",
                "bnb_4bit_compute_dtype": torch.bfloat16,
            },
            components_to_quantize=["text_encoder_2"],
        )
        if args.quantize
        else None
    ),
).to("cuda")


assert isinstance(pipe.transformer, FluxTransformer2DModel)
pipe.transformer.set_attention_backend("native")
if world_size > 1:
    pipe.transformer.enable_parallelism(
        config=ContextParallelConfig(
            ulysses_degree=world_size,
            ulysses_anything=args.ulysses_anything,
        )
    )


pipe.set_progress_bar_config(disable=rank != 0)

# Set default prompt
prompt = "A cat holding a sign that says hello world"


height = 1008 if args.height is None else args.height
width = 1008 if args.width is None else args.width


def run_pipe(pipe: FluxPipeline):
    image = pipe(
        prompt,
        height=height,
        width=width,
        num_inference_steps=28,
        generator=torch.Generator("cpu").manual_seed(0),
    ).images[0]
    return image


if args.compile:
    torch._dynamo.config.recompile_limit = 256
    torch._dynamo.config.accumulated_recompile_limit = 8096
    torch._inductor.config.reorder_for_compute_comm_overlap = True
    pipe.transformer = torch.compile(pipe.transformer)

# warmup
_ = run_pipe(pipe)

start = time.time()
image = run_pipe(pipe)
end = time.time()

if rank == 0:
    time_cost = end - start
    save_path = f"flux.{height}x{width}_ulysses{world_size}.png"
    print(f"Time cost: {time_cost:.2f}s")
    print(f"Saving image to {save_path}")
    image.save(save_path)

if dist.is_initialized():
    dist.destroy_process_group()

test cmds:

torchrun --nproc_per_node=1 test_flux.py # baseline
torchrun --nproc_per_node=2 test_flux.py # standard ulysses, failed
torchrun --nproc_per_node=2 test_flux.py --ulysses-anything # working as expected
torchrun --nproc_per_node=2 test_flux.py --ulysses-anything --compile # working as expected

before this pr:

torchrun --nproc_per_node=2 test_flux.py # standard ulysses, failed

...
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 188, in new_forward
[rank1]:     args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 158, in pre_forward
[rank1]:     input_val = self._prepare_cp_input(input_val, cpm)
[rank1]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 216, in _prepare_cp_input
[rank1]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 273, in shard
[rank1]:     assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size

after this pr:

torchrun --nproc_per_node=2 test_flux.py --ulysses-anything --compile # working as expected

...
|████████████████████████████████████████████████████████| 28/28 [00:11<00:00,  2.35it/s]
Time cost: 12.58s
Saving image to flux.1008x1008_ulysses2_compile1.png
FLUX.1-dev FLUX.1-dev Ulysses-Anything 2 FLUX.1-dev Ulysses-Anything 2 + compile
L20x1, 23.26s L20x2, 13.55s L20x2, 12.58s
flux 1008x1008_ulysses1 flux 1008x1008_ulysses2_compile0 flux 1008x1008_ulysses2_compile1

@sayakpaul I can also provide a case to demonstrate support for any head number via z‑image (e.g., head num = 30 with Ulysses Anything‑4) in a separate PR after this one is ready. For now, you can quickly try it using the examples in cache‑dit.

cd cache-dit/exmaples
torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py zimage --parallel ulysses --ulysses-anything

logs:

INFO 01-20 03:07:50 [base.py:622] ----------------------------------------------------------------------------------------------------
INFO 01-20 03:07:50 [base.py:395] 🤖 Example Init Config Summary:
INFO 01-20 03:07:50 [base.py:418] - Model: /workspace/dev/vipdev/hf_models/Z-Image-Turbo
INFO 01-20 03:07:50 [base.py:418] - Task Type: T2I - Text to Image
INFO 01-20 03:07:50 [base.py:418] - Torch Dtype: torch.bfloat16
INFO 01-20 03:07:50 [base.py:418] - LoRA Weights: None
INFO 01-20 03:07:50 [base.py:212] 🤖 Example Input Summary:
INFO 01-20 03:07:50 [base.py:212] - prompt: Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.
INFO 01-20 03:07:50 [base.py:212] - height: 1024
INFO 01-20 03:07:50 [base.py:212] - width: 1024
INFO 01-20 03:07:50 [base.py:212] - guidance_scale: 0.0
INFO 01-20 03:07:50 [base.py:212] - num_inference_steps: 9
INFO 01-20 03:07:50 [base.py:212] - generator: device cpu, seed 0
INFO 01-20 03:07:50 [base.py:307] 🤖 Example Output Summary:
INFO 01-20 03:07:50 [base.py:323] - Model: zimage
INFO 01-20 03:07:50 [base.py:323] - Optimization: C0_Q0_NONE_Ulysses4_ulysses_anything
INFO 01-20 03:07:50 [base.py:323] - Device: NVIDIA L20 x 4
INFO 01-20 03:07:50 [base.py:323] - Load Time: 12.48s
INFO 01-20 03:07:50 [base.py:323] - Warmup Time: 3.08s
INFO 01-20 03:07:50 [base.py:323] - Inference Time: 2.36s
INFO 01-20 03:07:50 [base.py:246] Image saved to zimage.1024x1024.C0_Q0_NONE_Ulysses4_ulysses_anything.png
INFO 01-20 03:07:50 [base.py:633] ----------------------------------------------------------------------------------------------------

@DefTruth
Copy link
Contributor Author

DefTruth commented Jan 21, 2026

@sayakpaul Hi~ can you take a look to the latest updates? thanks~

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looking much better for now.

Could you also provide a benchmark comparing all the four CP methods (Ring, Ulysses, Unified, Ulysses Anything)?

Comment on lines +1539 to +1541
query = query_wait() # type: torch.Tensor
key = key_wait() # type: torch.Tensor
value = value_wait() # type: torch.Tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are three waits here and a couple later. Do these not introduce communication overhead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NO. Here, we attempt to overlap the permute/reshape operation and communication of QKV, so theoretically, there will only be better performance.

image

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for working on this.

Let's also add this to the docs?

Additionally, let's provide a comparison between the different backends like so:

image

@sayakpaul sayakpaul requested a review from DN6 January 23, 2026 09:08
@DefTruth
Copy link
Contributor Author

Thanks, looking much better for now.

Could you also provide a benchmark comparing all the four CP methods (Ring, Ulysses, Unified, Ulysses Anything)?

FLUX.1-dev, L20, 28 steps

HxW Ring 2 Ulysses 2 Ulysses Anything 2 Ring 4 Ulysses 4 Ulysses Anything 4
1008x1008 13.53s 8.12s
1024x1024 13.37s 13.65s 13.64s 9.69s 8.22s 8.21s

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DefTruth
Copy link
Contributor Author

I will add more performance results next week~ I'm very sorry. I've been extremely busy these past few days.

@sayakpaul
Copy link
Member

Can we please maintain the same structure of the table as #12996 (review)? The benchmarking was conducted using this script.

@DefTruth
Copy link
Contributor Author

Can we please maintain the same structure of the table as #12996 (review)? The benchmarking was conducted using this script.

copy that~

@DefTruth
Copy link
Contributor Author

Can we please maintain the same structure of the table as #12996 (review)? The benchmarking was conducted using this script.

Benchmark Ulysses Anything

import os
import time
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cp-backend",
        type=str,
        choices=["ring", "ulysses", "unified", 'ulysses_anything'],
        default="ulysses",
        help="Context parallel backend to use.",
    )
    parser.add_argument(
        "--height",
        type=int,
        default=1024,
        help="Height of the generated image.",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=1024,
        help="Width of the generated image.",
    )
    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="cpu:gloo,cuda:nccl")
    rank = dist.get_rank()
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    return device


def main():
    args = parse_args()

    device = setup_distributed()
    world_size = dist.get_world_size()

    pipeline = DiffusionPipeline.from_pretrained(
        os.environ.get(
            "FLUX_DIR",
            "black-forest-labs/FLUX.1-dev",
        ),
        torch_dtype=torch.bfloat16,
    ).to(device)
    # Always using it because `ring` doesn't support default. This helps ensure a fair comparison.
    pipeline.transformer.set_attention_backend("_native_cudnn")

    if args.cp_backend == "ring":
        cp_config = ContextParallelConfig(ring_degree=world_size)
    elif args.cp_backend == "unified":
        cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
    elif args.cp_backend == "ulysses":
        cp_config = ContextParallelConfig(ulysses_degree=world_size)
    elif args.cp_backend == "ulysses_anything":
        cp_config = ContextParallelConfig(ulysses_degree=world_size, ulysses_anything=True)
    else:
        raise ValueError(f"Unsupported cp_backend: {args.cp_backend}")

    pipeline.transformer.enable_parallelism(config=cp_config)

    prompt = """
    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
    """

    def run_pipe(pipeline, steps=50):
        image = pipeline(
            prompt,
            guidance_scale=3.5,
            num_inference_steps=steps,
            generator=torch.Generator().manual_seed(42),
            height=args.height,
            width=args.width,
        ).images[0]
        return image
    
    # warmup
    _ = run_pipe(pipeline, steps=10)

    start = time.time()
    image = run_pipe(pipeline, steps=50)
    end = time.time()

    if dist.get_rank() == 0:
        save_path = f"output_{args.height}x{args.width}_{args.cp_backend}.png"
        image.save(save_path)
        print(f"Saved image to {save_path}, time taken: {end - start:.2f} seconds")
        # Assume iters=steps=50, compute the metrics: Time / Iter (ms), Steps / Sec 
        total_time = end - start
        time_per_iter = (total_time / 50) * 1000  # in milliseconds
        steps_per_sec = 50 / total_time
        print(f"Time per Iter: {time_per_iter:.2f} ms, Steps per Sec: {steps_per_sec:.2f}")

    if dist.is_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

# Example usage:
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ring # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ulysses # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend unified # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ulysses_anything # success

# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ring # failed
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ulysses # failed
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend unified # failed
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ulysses_anything # success 

@sayakpaul @DN6 We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with this script on a node of 4 L20 GPUs. The results are summarized as follows:

CP Backend Time / Iter (ms) Steps / Sec Peak Memory (GB) Shape (HxW)
ulysses 281.07 3.56 37.11 1024x1024
ring 351.34 2.85 37.01 1024x1024
unified_balanced 324.37 3.08 37.16 1024x1024
ulysses_anything 280.94 3.56 37.11 1024x1024
ulysses failed failed failed 1008x1008
ring failed failed failed 1008x1008
unified_balanced failed failed failed 1008x1008
ulysses_anything 278.40 3.59 36.99 1008x1008

From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.

@DefTruth
Copy link
Contributor Author

PTAL~ Thanks~

@DefTruth
Copy link
Contributor Author

The output images:

1024x1024 1024x1024 1024x1024 1024x1024 1008x1008
Ulysses Ring Unified Ulysses Anything Ulysses Anything
output_1024x1024_ulysses output_1024x1024_ring output_1024x1024_unified output_1024x1024_ulysses_anything output_1008x1008_ulysses_anything

@DefTruth
Copy link
Contributor Author

The benchamrk script for qwen-image:

import os
import time
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig, PipelineQuantizationConfig


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cp-backend",
        type=str,
        choices=["ring", "ulysses", "unified", 'ulysses_anything'],
        default="ulysses",
        help="Context parallel backend to use.",
    )
    parser.add_argument(
        "--height",
        type=int,
        default=1024,
        help="Height of the generated image.",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=1024,
        help="Width of the generated image.",
    )
    # no quantization
    parser.add_argument(
        "--no-quantize",
        action="store_true",
        help="Disable quantization for the pipeline if set.",
    )
    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="cpu:gloo,cuda:nccl")
    rank = dist.get_rank()
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    return device


def main():
    args = parse_args()

    device = setup_distributed()
    world_size = dist.get_world_size()

    pipeline = DiffusionPipeline.from_pretrained(
        os.environ.get(
            "QWEN_IMAGE_DIR",
            "Qwen/Qwen-Image",
        ),
        torch_dtype=torch.bfloat16,
        # Keep this to avoid cpu offload during benchmarking.
        quantization_config=PipelineQuantizationConfig(
            quant_backend="bitsandbytes_4bit",
            quant_kwargs={
                "load_in_4bit": True,
                "bnb_4bit_quant_type": "nf4",
                "bnb_4bit_compute_dtype": torch.bfloat16,
            },
            components_to_quantize=["text_encoder", "transformer"],
        ) if not args.no_quantize else None,
    ).to(device)
    # Always using it because `ring` doesn't support default. This helps ensure a fair comparison.
    pipeline.transformer.set_attention_backend("_native_cudnn")

    if args.cp_backend == "ring":
        cp_config = ContextParallelConfig(ring_degree=world_size)
    elif args.cp_backend == "unified":
        cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
    elif args.cp_backend == "ulysses":
        cp_config = ContextParallelConfig(ulysses_degree=world_size)
    elif args.cp_backend == "ulysses_anything":
        cp_config = ContextParallelConfig(ulysses_degree=world_size, ulysses_anything=True)
    else:
        raise ValueError(f"Unsupported cp_backend: {args.cp_backend}")

    pipeline.transformer.enable_parallelism(config=cp_config)

    positive_magic = {
        "en": ", Ultra HD, 4K, cinematic composition.",  # for english prompt
        "zh": ", 超清,4K,电影级构图.",  # for chinese prompt
    }
    prompt = (
        "A coffee shop entrance features a chalkboard sign reading "
        '"Qwen Coffee 😊 $2 per cup," with a neon light beside it '
        'displaying "通义千问". Next to it hangs a poster showing a '
        "beautiful Chinese woman, and beneath the poster is written "
        '"π≈3.1415926-53589793-23846264-33832795-02384197". '
        "Ultra HD, 4K, cinematic composition"
    )

    def run_pipe(pipeline, steps=50):
        image = pipeline(
            prompt + positive_magic["en"],
            true_cfg_scale=4.0,
            negative_prompt=" ",
            num_inference_steps=steps,
            generator=torch.Generator("cpu").manual_seed(0),
            height=args.height,
            width=args.width,
        ).images[0]
        return image
    
    # warmup
    _ = run_pipe(pipeline, steps=10)

    start = time.time()
    image = run_pipe(pipeline, steps=50)
    end = time.time()

    if dist.get_rank() == 0:
        save_path = f"qwen_{args.height}x{args.width}_{args.cp_backend}.png"
        image.save(save_path)
        print(f"Saved image to {save_path}, time taken: {end - start:.2f} seconds")
        # Assume iters=steps=50, compute the metrics: Time / Iter (ms), Steps / Sec 
        total_time = end - start
        time_per_iter = (total_time / 50) * 1000  # in milliseconds
        steps_per_sec = 50 / total_time
        print(f"Time per Iter: {time_per_iter:.2f} ms, Steps per Sec: {steps_per_sec:.2f}")

    if dist.is_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

# Example usage: (add --no-quantize to disable quantization if the VRAM is sufficient, e.g, >= 80GiB)
# torchrun --nproc_per_node=2 check_ulysses_anything_qwen.py --cp-backend ring # failed
# torchrun --nproc_per_node=2 check_ulysses_anything_qwen.py --cp-backend ulysses # success
# torchrun --nproc_per_node=2 check_ulysses_anything_qwen.py --cp-backend ulysses_anything # success

# torchrun --nproc_per_node=4 check_ulysses_anything_qwen.py --cp-backend ring # failed
# torchrun --nproc_per_node=4 check_ulysses_anything_qwen.py --cp-backend ulysses # failed
# torchrun --nproc_per_node=4 check_ulysses_anything_qwen.py --cp-backend unified # failed
# torchrun --nproc_per_node=4 check_ulysses_anything_qwen.py --cp-backend ulysses_anything # success

@sayakpaul
Copy link
Member

sayakpaul commented Jan 26, 2026

So, I also benchmarked on a node of 4 H100s and got the following:

Backend Name Steps / sec Memory (GB)
ring 3.97 33.85
ulysses 7.35 33.85
unified 4.56 33.85
ulysses_anything 7.48 33.85
Script:
import os
import time
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig


def measure_memory():
    """Measure current GPU memory usage"""
    allocated = torch.cuda.memory_allocated() / 1024**3  # GB
    reserved = torch.cuda.memory_reserved() / 1024**3  # GB
    max_allocated = torch.cuda.max_memory_allocated() / 1024**3  # GB
    return {
        "allocated_gb": allocated,
        "reserved_gb": reserved,
        "max_allocated_gb": max_allocated,
    }

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cp-backend",
        type=str,
        choices=["ring", "ulysses", "unified", 'ulysses_anything'],
        default="ulysses",
        help="Context parallel backend to use.",
    )
    parser.add_argument(
        "--height",
        type=int,
        default=1024,
        help="Height of the generated image.",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=1024,
        help="Width of the generated image.",
    )
    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="cpu:gloo,cuda:nccl")
    rank = dist.get_rank()
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    return device


def main():
    args = parse_args()

    device = setup_distributed()
    world_size = dist.get_world_size()

    pipeline = DiffusionPipeline.from_pretrained(
        os.environ.get(
            "FLUX_DIR",
            "black-forest-labs/FLUX.1-dev",
        ),
        torch_dtype=torch.bfloat16,
    ).to(device)

    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

    # Always using it because `ring` doesn't support default. This helps ensure a fair comparison.
    pipeline.transformer.set_attention_backend("_native_cudnn")

    if args.cp_backend == "ring":
        cp_config = ContextParallelConfig(ring_degree=world_size)
    elif args.cp_backend == "unified":
        cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
    elif args.cp_backend == "ulysses":
        cp_config = ContextParallelConfig(ulysses_degree=world_size)
    elif args.cp_backend == "ulysses_anything":
        cp_config = ContextParallelConfig(ulysses_degree=world_size, ulysses_anything=True)
    else:
        raise ValueError(f"Unsupported cp_backend: {args.cp_backend}")

    pipeline.transformer.enable_parallelism(config=cp_config)

    prompt = """
    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
    """

    def run_pipe(pipeline, steps=50):
        image = pipeline(
            prompt,
            guidance_scale=3.5,
            num_inference_steps=steps,
            generator=torch.Generator().manual_seed(42),
            height=args.height,
            width=args.width,
        ).images[0]
        return image
    
    # warmup
    _ = run_pipe(pipeline, steps=10)

    start = time.time()
    image = run_pipe(pipeline, steps=50)
    end = time.time()

    if dist.get_rank() == 0:
        print(f"{args.cp_backend=}")
        save_path = f"output_{args.height}x{args.width}_{args.cp_backend}.png"
        image.save(save_path)
        print(f"Saved image to {save_path}, time taken: {end - start:.2f} seconds")
        # Assume iters=steps=50, compute the metrics: Time / Iter (ms), Steps / Sec 
        total_time = end - start
        time_per_iter = (total_time / 50) * 1000  # in milliseconds
        steps_per_sec = 50 / total_time
        print(f"Time per Iter: {time_per_iter:.2f} ms, Steps per Sec: {steps_per_sec:.2f}")

        max_allocated_gb = measure_memory()["max_allocated_gb"]
        print(f"{max_allocated_gb=}")


    if dist.is_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

# Example usage:
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ring # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ulysses # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend unified # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ulysses_anything # success

# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ring # failed
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ulysses # failed
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend unified # failed
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ulysses_anything # success 
Unified Ulysses Ulysses Anything Ring
Image Image Image Image

@DefTruth
Copy link
Contributor Author

So, I also benchmarked on a node of 4 H100s and got the following:

The results are LGTM~

@DefTruth
Copy link
Contributor Author

@DN6 Hi~ Could you also take a look at this PR to see if I need to make any updates?

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me, but I think we can organise the modules differently.

  1. Move PartitionAnythingSharder and AllGatherAnythingFunction to context_parallel.py

  2. Move TemplatedUlyssesAnythingAttention and all associatd helper functions to attntion_dispatch.py

  3. _gather_size_by_comm can probably go into _modeling_parallel.py

Comment on lines +367 to +369
| ulysses | failed | failed | failed | 1008x1008 |
| ring | failed | failed | failed | 1008x1008 |
| unified_balanced | failed | failed | failed | 1008x1008 |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this from a failed eval? Can it be removed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Anything related to performance improvements, profiling and benchmarking

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Ulysses Attention for any sequence length w/o padding

4 participants