-
Notifications
You must be signed in to change notification settings - Fork 6.7k
feat: support Ulysses Anything Attention #12996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for starting this work!
Let's try to think if we can allow the user to more explicitly specify Ulysses Anything through the configs instead of an env var.
Now, we can enable ulysses anything attention through the ContextParallelConfig: pipe.transformer.enable_parallelism(
config=ContextParallelConfig(
ulysses_degree=world_size,
ulysses_anything=True,
)
) |
|
Let me prepare more test cases for ulysses anything. |
FLUX.1-dev, 1008x1008import os
import argparse
import time
import torch
import torch.distributed as dist
from diffusers import (
FluxPipeline,
FluxTransformer2DModel,
ContextParallelConfig,
PipelineQuantizationConfig,
)
def parse_args():
parser = argparse.ArgumentParser(description="Context Parallelism")
# torch.compile flags
parser.add_argument(
"--compile",
action="store_true",
help="Enable torch.compile for the pipeline if set.",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Enable quantization for the pipeline if set.",
)
parser.add_argument(
"--ulysses-anything",
action="store_true",
help="Enable debug mode if set.",
)
# height and width
parser.add_argument(
"--height",
type=int,
default=None,
help="Height of the generated image.",
)
parser.add_argument(
"--width",
type=int,
default=None,
help="Width of the generated image.",
)
return parser.parse_args()
args = parse_args()
print(args)
if dist.is_available():
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
rank = dist.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
world_size = dist.get_world_size()
torch.cuda.set_device(device)
else:
rank = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
world_size = 1
pipe: FluxPipeline = FluxPipeline.from_pretrained(
os.environ.get(
"FLUX_DIR",
"black-forest-labs/FLUX.1-dev",
),
torch_dtype=torch.bfloat16,
quantization_config=(
PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["text_encoder_2"],
)
if args.quantize
else None
),
).to("cuda")
assert isinstance(pipe.transformer, FluxTransformer2DModel)
pipe.transformer.set_attention_backend("native")
if world_size > 1:
pipe.transformer.enable_parallelism(
config=ContextParallelConfig(
ulysses_degree=world_size,
ulysses_anything=args.ulysses_anything,
)
)
pipe.set_progress_bar_config(disable=rank != 0)
# Set default prompt
prompt = "A cat holding a sign that says hello world"
height = 1008 if args.height is None else args.height
width = 1008 if args.width is None else args.width
def run_pipe(pipe: FluxPipeline):
image = pipe(
prompt,
height=height,
width=width,
num_inference_steps=28,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
return image
if args.compile:
torch._dynamo.config.recompile_limit = 256
torch._dynamo.config.accumulated_recompile_limit = 8096
torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(pipe.transformer)
# warmup
_ = run_pipe(pipe)
start = time.time()
image = run_pipe(pipe)
end = time.time()
if rank == 0:
time_cost = end - start
save_path = f"flux.{height}x{width}_ulysses{world_size}.png"
print(f"Time cost: {time_cost:.2f}s")
print(f"Saving image to {save_path}")
image.save(save_path)
if dist.is_initialized():
dist.destroy_process_group()test cmds: torchrun --nproc_per_node=1 test_flux.py # baseline
torchrun --nproc_per_node=2 test_flux.py # standard ulysses, failed
torchrun --nproc_per_node=2 test_flux.py --ulysses-anything # working as expected
torchrun --nproc_per_node=2 test_flux.py --ulysses-anything --compile # working as expectedbefore this pr: torchrun --nproc_per_node=2 test_flux.py # standard ulysses, failed
...
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 188, in new_forward
[rank1]: args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 158, in pre_forward
[rank1]: input_val = self._prepare_cp_input(input_val, cpm)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 216, in _prepare_cp_input
[rank1]: return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 273, in shard
[rank1]: assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh sizeafter this pr: torchrun --nproc_per_node=2 test_flux.py --ulysses-anything --compile # working as expected
...
|████████████████████████████████████████████████████████| 28/28 [00:11<00:00, 2.35it/s]
Time cost: 12.58s
Saving image to flux.1008x1008_ulysses2_compile1.png
@sayakpaul I can also provide a case to demonstrate support for any head number via z‑image (e.g., head num = 30 with Ulysses Anything‑4) in a separate PR after this one is ready. For now, you can quickly try it using the examples in cache‑dit. cd cache-dit/exmaples
torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py zimage --parallel ulysses --ulysses-anythinglogs: INFO 01-20 03:07:50 [base.py:622] ----------------------------------------------------------------------------------------------------
INFO 01-20 03:07:50 [base.py:395] 🤖 Example Init Config Summary:
INFO 01-20 03:07:50 [base.py:418] - Model: /workspace/dev/vipdev/hf_models/Z-Image-Turbo
INFO 01-20 03:07:50 [base.py:418] - Task Type: T2I - Text to Image
INFO 01-20 03:07:50 [base.py:418] - Torch Dtype: torch.bfloat16
INFO 01-20 03:07:50 [base.py:418] - LoRA Weights: None
INFO 01-20 03:07:50 [base.py:212] 🤖 Example Input Summary:
INFO 01-20 03:07:50 [base.py:212] - prompt: Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.
INFO 01-20 03:07:50 [base.py:212] - height: 1024
INFO 01-20 03:07:50 [base.py:212] - width: 1024
INFO 01-20 03:07:50 [base.py:212] - guidance_scale: 0.0
INFO 01-20 03:07:50 [base.py:212] - num_inference_steps: 9
INFO 01-20 03:07:50 [base.py:212] - generator: device cpu, seed 0
INFO 01-20 03:07:50 [base.py:307] 🤖 Example Output Summary:
INFO 01-20 03:07:50 [base.py:323] - Model: zimage
INFO 01-20 03:07:50 [base.py:323] - Optimization: C0_Q0_NONE_Ulysses4_ulysses_anything
INFO 01-20 03:07:50 [base.py:323] - Device: NVIDIA L20 x 4
INFO 01-20 03:07:50 [base.py:323] - Load Time: 12.48s
INFO 01-20 03:07:50 [base.py:323] - Warmup Time: 3.08s
INFO 01-20 03:07:50 [base.py:323] - Inference Time: 2.36s
INFO 01-20 03:07:50 [base.py:246] Image saved to zimage.1024x1024.C0_Q0_NONE_Ulysses4_ulysses_anything.png
INFO 01-20 03:07:50 [base.py:633] ---------------------------------------------------------------------------------------------------- |
|
@sayakpaul Hi~ can you take a look to the latest updates? thanks~ |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looking much better for now.
Could you also provide a benchmark comparing all the four CP methods (Ring, Ulysses, Unified, Ulysses Anything)?
| query = query_wait() # type: torch.Tensor | ||
| key = key_wait() # type: torch.Tensor | ||
| value = value_wait() # type: torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are three waits here and a couple later. Do these not introduce communication overhead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for working on this.
Let's also add this to the docs?
Additionally, let's provide a comparison between the different backends like so:
FLUX.1-dev, L20, 28 steps
|
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
I will add more performance results next week~ I'm very sorry. I've been extremely busy these past few days. |
|
Can we please maintain the same structure of the table as #12996 (review)? The benchmarking was conducted using this script. |
copy that~ |
Benchmark Ulysses Anythingimport os
import time
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--cp-backend",
type=str,
choices=["ring", "ulysses", "unified", 'ulysses_anything'],
default="ulysses",
help="Context parallel backend to use.",
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="Height of the generated image.",
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="Width of the generated image.",
)
return parser.parse_args()
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
rank = dist.get_rank()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
return device
def main():
args = parse_args()
device = setup_distributed()
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
os.environ.get(
"FLUX_DIR",
"black-forest-labs/FLUX.1-dev",
),
torch_dtype=torch.bfloat16,
).to(device)
# Always using it because `ring` doesn't support default. This helps ensure a fair comparison.
pipeline.transformer.set_attention_backend("_native_cudnn")
if args.cp_backend == "ring":
cp_config = ContextParallelConfig(ring_degree=world_size)
elif args.cp_backend == "unified":
cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
elif args.cp_backend == "ulysses":
cp_config = ContextParallelConfig(ulysses_degree=world_size)
elif args.cp_backend == "ulysses_anything":
cp_config = ContextParallelConfig(ulysses_degree=world_size, ulysses_anything=True)
else:
raise ValueError(f"Unsupported cp_backend: {args.cp_backend}")
pipeline.transformer.enable_parallelism(config=cp_config)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
def run_pipe(pipeline, steps=50):
image = pipeline(
prompt,
guidance_scale=3.5,
num_inference_steps=steps,
generator=torch.Generator().manual_seed(42),
height=args.height,
width=args.width,
).images[0]
return image
# warmup
_ = run_pipe(pipeline, steps=10)
start = time.time()
image = run_pipe(pipeline, steps=50)
end = time.time()
if dist.get_rank() == 0:
save_path = f"output_{args.height}x{args.width}_{args.cp_backend}.png"
image.save(save_path)
print(f"Saved image to {save_path}, time taken: {end - start:.2f} seconds")
# Assume iters=steps=50, compute the metrics: Time / Iter (ms), Steps / Sec
total_time = end - start
time_per_iter = (total_time / 50) * 1000 # in milliseconds
steps_per_sec = 50 / total_time
print(f"Time per Iter: {time_per_iter:.2f} ms, Steps per Sec: {steps_per_sec:.2f}")
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()
# Example usage:
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ring # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ulysses # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend unified # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ulysses_anything # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ring # failed
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ulysses # failed
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend unified # failed
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ulysses_anything # success @sayakpaul @DN6 We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with this script on a node of 4 L20 GPUs. The results are summarized as follows:
From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention. |
|
PTAL~ Thanks~ |
|
The benchamrk script for qwen-image: import os
import time
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig, PipelineQuantizationConfig
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--cp-backend",
type=str,
choices=["ring", "ulysses", "unified", 'ulysses_anything'],
default="ulysses",
help="Context parallel backend to use.",
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="Height of the generated image.",
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="Width of the generated image.",
)
# no quantization
parser.add_argument(
"--no-quantize",
action="store_true",
help="Disable quantization for the pipeline if set.",
)
return parser.parse_args()
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
rank = dist.get_rank()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
return device
def main():
args = parse_args()
device = setup_distributed()
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
os.environ.get(
"QWEN_IMAGE_DIR",
"Qwen/Qwen-Image",
),
torch_dtype=torch.bfloat16,
# Keep this to avoid cpu offload during benchmarking.
quantization_config=PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["text_encoder", "transformer"],
) if not args.no_quantize else None,
).to(device)
# Always using it because `ring` doesn't support default. This helps ensure a fair comparison.
pipeline.transformer.set_attention_backend("_native_cudnn")
if args.cp_backend == "ring":
cp_config = ContextParallelConfig(ring_degree=world_size)
elif args.cp_backend == "unified":
cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
elif args.cp_backend == "ulysses":
cp_config = ContextParallelConfig(ulysses_degree=world_size)
elif args.cp_backend == "ulysses_anything":
cp_config = ContextParallelConfig(ulysses_degree=world_size, ulysses_anything=True)
else:
raise ValueError(f"Unsupported cp_backend: {args.cp_backend}")
pipeline.transformer.enable_parallelism(config=cp_config)
positive_magic = {
"en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
"zh": ", 超清,4K,电影级构图.", # for chinese prompt
}
prompt = (
"A coffee shop entrance features a chalkboard sign reading "
'"Qwen Coffee 😊 $2 per cup," with a neon light beside it '
'displaying "通义千问". Next to it hangs a poster showing a '
"beautiful Chinese woman, and beneath the poster is written "
'"π≈3.1415926-53589793-23846264-33832795-02384197". '
"Ultra HD, 4K, cinematic composition"
)
def run_pipe(pipeline, steps=50):
image = pipeline(
prompt + positive_magic["en"],
true_cfg_scale=4.0,
negative_prompt=" ",
num_inference_steps=steps,
generator=torch.Generator("cpu").manual_seed(0),
height=args.height,
width=args.width,
).images[0]
return image
# warmup
_ = run_pipe(pipeline, steps=10)
start = time.time()
image = run_pipe(pipeline, steps=50)
end = time.time()
if dist.get_rank() == 0:
save_path = f"qwen_{args.height}x{args.width}_{args.cp_backend}.png"
image.save(save_path)
print(f"Saved image to {save_path}, time taken: {end - start:.2f} seconds")
# Assume iters=steps=50, compute the metrics: Time / Iter (ms), Steps / Sec
total_time = end - start
time_per_iter = (total_time / 50) * 1000 # in milliseconds
steps_per_sec = 50 / total_time
print(f"Time per Iter: {time_per_iter:.2f} ms, Steps per Sec: {steps_per_sec:.2f}")
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()
# Example usage: (add --no-quantize to disable quantization if the VRAM is sufficient, e.g, >= 80GiB)
# torchrun --nproc_per_node=2 check_ulysses_anything_qwen.py --cp-backend ring # failed
# torchrun --nproc_per_node=2 check_ulysses_anything_qwen.py --cp-backend ulysses # success
# torchrun --nproc_per_node=2 check_ulysses_anything_qwen.py --cp-backend ulysses_anything # success
# torchrun --nproc_per_node=4 check_ulysses_anything_qwen.py --cp-backend ring # failed
# torchrun --nproc_per_node=4 check_ulysses_anything_qwen.py --cp-backend ulysses # failed
# torchrun --nproc_per_node=4 check_ulysses_anything_qwen.py --cp-backend unified # failed
# torchrun --nproc_per_node=4 check_ulysses_anything_qwen.py --cp-backend ulysses_anything # success |
The results are LGTM~ |
|
@DN6 Hi~ Could you also take a look at this PR to see if I need to make any updates? |
DN6
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look good to me, but I think we can organise the modules differently.
-
Move
PartitionAnythingSharderandAllGatherAnythingFunctiontocontext_parallel.py -
Move
TemplatedUlyssesAnythingAttentionand all associatd helper functions toattntion_dispatch.py -
_gather_size_by_commcan probably go into_modeling_parallel.py
| | ulysses | failed | failed | failed | 1008x1008 | | ||
| | ring | failed | failed | failed | 1008x1008 | | ||
| | unified_balanced | failed | failed | failed | 1008x1008 | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this from a failed eval? Can it be removed?













fixed #12706, this pr implement ulysses anything attention for diffusers in order to support [ANY] sequence lengths and [ANY] head num for ulysses.
@sayakpaul @DN6 @yiyixuxu
About Ulysses Anything Attention
Please refer to our docs for more details. link: https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention
Test
Qwen-Image && Qwen-Image-2512
test cmds:
before this pr:
after this pr: