Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/source/en/training/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,34 @@ We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](

From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.


### Ulysses Anything Attention

The default Ulysses Attention mechanism requires that the sequence length of hidden states must be divisible by the number of devices. This imposes significant limitations on the practical application of Ulysses Attention. [Ulysses Anything Attention](https://github.com/huggingface/diffusers/pull/12996) is a variant of Ulysses Attention that supports arbitrary sequence lengths and arbitrary numbers of attention heads, thereby enhancing the versatility of Ulysses Attention in practical use.

[`ContextParallelConfig`] supports Ulysses Anything Attention by specifying both `ulysses_degree` and `ulysses_anything`. Please note that Ulysses Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with both `ulysses_degree` set to bigger than 1 and `ulysses_anything=True` to [`~ModelMixin.enable_parallelism`].

```py
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ulysses_anything=True))
```

> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency.

We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996#issuecomment-3797695999) on a node of 4 L20 GPUs. The results are summarized as follows:

| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)|
|--------------------|------------------|-------------|------------------|------------|
| ulysses | 281.07 | 3.56 | 37.11 | 1024x1024 |
| ring | 351.34 | 2.85 | 37.01 | 1024x1024 |
| unified_balanced | 324.37 | 3.08 | 37.16 | 1024x1024 |
| ulysses_anything | 280.94 | 3.56 | 37.11 | 1024x1024 |
| ulysses | failed | failed | failed | 1008x1008 |
| ring | failed | failed | failed | 1008x1008 |
| unified_balanced | failed | failed | failed | 1008x1008 |
Comment on lines +367 to +369
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this from a failed eval? Can it be removed?

| ulysses_anything | 278.40 | 3.59 | 36.99 | 1008x1008 |

From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.

### parallel_config

Pass `parallel_config` during model initialization to enable context parallelism.
Expand Down
14 changes: 13 additions & 1 deletion src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ContextParallelModelPlan,
ContextParallelOutput,
)
from ..models._ulysses_anything_utils import PartitionAnythingSharder
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from .hooks import HookRegistry, ModelHook
Expand Down Expand Up @@ -208,6 +209,10 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) ->
)
return x
else:
if self.parallel_config.ulysses_anything:
return PartitionAnythingSharder.shard_anything(
x, cp_input.split_dim, self.parallel_config._flattened_mesh
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)


Expand All @@ -233,7 +238,14 @@ def post_forward(self, module, output):
for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
if self.parallel_config.ulysses_anything:
output[i] = PartitionAnythingSharder.unshard_anything(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)
else:
output[i] = EquipartitionSharder.unshard(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)

return output[0] if is_tensor else tuple(output)

Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/models/_modeling_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False

_rank: int = None
_world_size: int = None
Expand Down Expand Up @@ -94,6 +97,11 @@ def __post_init__(self):
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
if self.ulysses_anything:
if self.ulysses_degree == 1:
raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.")
if self.ring_degree > 1:
raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.")

@property
def mesh_shape(self) -> Tuple[int, int]:
Expand Down
286 changes: 286 additions & 0 deletions src/diffusers/models/_ulysses_anything_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Adapted from: https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/attention/_templated_ulysses.py
import copy
import functools
from typing import Callable, List, Tuple

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as fc
import torch.nn.functional as F

from ..utils.torch_utils import maybe_allow_in_graph


# Helper functions for shape gathering
def _get_rank_world_size(group: dist.ProcessGroup) -> Tuple[int, int]:
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
return rank, world_size


def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]:
r"""Gather the local size from all ranks.
size: int, local size return: List[int], list of size from all ranks
"""
# NOTE(Serving/CP Safety):
# Do NOT cache this collective result.
#
# In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL)
# may legitimately differ across ranks. If we cache based on the *local* `size`,
# different ranks can have different cache hit/miss patterns across time.
#
# That can lead to a catastrophic distributed hang:
# - some ranks hit cache and *skip* dist.all_gather()
# - other ranks miss cache and *enter* dist.all_gather()
# This mismatched collective participation will stall the process group and
# eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL
# timeouts in Ulysses attention).
world_size = dist.get_world_size(group=group)
# HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
comm_backends = str(dist.get_backend(group=group))
# NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator()
gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)]
dist.all_gather(
gathered_sizes,
torch.tensor([size], device=gather_device, dtype=torch.int64),
group=group,
)

gathered_sizes = [s[0].item() for s in gathered_sizes]
# NOTE: DON'T use tolist here due to graph break - Explanation:
# Backend compiler `inductor` failed with aten._local_scalar_dense.default
return gathered_sizes


# Helper functions to pad/unpad head dimension for QKV and O projections
def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]:
r"""Maybe pad the head dimension to be divisible by world_size.
x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded
tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD
"""
_, world_size = _get_rank_world_size(group)
H_PAD = 0
if H % world_size != 0:
H_PAD = world_size - (H % world_size)
NEW_H_LOCAL = (H + H_PAD) // world_size
# e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2.
# NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14.
assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
return x, H_PAD


def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor:
r"""Maybe unpad the head dimension.
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor,
unpadded tensor (B, S_GLOBAL, H_LOCAL, D)
"""
rank, world_size = _get_rank_world_size(group)
# Only the last rank may have padding
if H_PAD > 0 and rank == world_size - 1:
x = x[:, :, :-H_PAD, :]
return x.contiguous()


def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]:
r"""Maybe pad the head dimension to be divisible by world_size.
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int],
padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD
"""
if H is None:
return x, 0

rank, world_size = _get_rank_world_size(group)
H_PAD = 0
# Only the last rank may need padding
if H % world_size != 0:
# We need to broadcast H_PAD to all ranks to keep consistency
# in unpadding step later for all ranks.
H_PAD = world_size - (H % world_size)
NEW_H_LOCAL = (H + H_PAD) // world_size
assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
if rank == world_size - 1:
x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
return x, H_PAD


def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor:
r"""Maybe unpad the head dimension.
x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor,
unpadded tensor (B, S_LOCAL, H_GLOBAL, D)
"""
if H_PAD > 0:
x = x[:, :, :-H_PAD, :]
return x.contiguous()


# Helper functions to for all-to-all communication with Ulysses Anything Attention
def _wait_tensor(tensor) -> torch.Tensor:
if isinstance(tensor, fc.AsyncCollectiveTensor):
tensor = tensor.wait()

return tensor


def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict:
# query: (B, S_LOCAL, H_GLOBAL, D)
assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)"
extra_kwargs = {}
extra_kwargs["NUM_QO_HEAD"] = query.shape[2]
extra_kwargs["Q_S_LOCAL"] = query.shape[1]
# Add other kwargs if needed in future
return extra_kwargs


@maybe_allow_in_graph
def all_to_all_single_any_qkv_async(
x: torch.Tensor, group: dist.ProcessGroup, **kwargs
) -> Callable[..., torch.Tensor]:
r"""
x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D)
"""
_, world_size = _get_rank_world_size(group)
B, S_LOCAL, H, D = x.shape
x, H_PAD = _maybe_pad_qkv_head(x, H, group)
H_LOCAL = (H + H_PAD) // world_size
# (world_size, S_LOCAL, B, H_LOCAL, D)
x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()

input_split_sizes = [S_LOCAL] * world_size
# S_LOCAL maybe not equal for all ranks in dynamic shape case,
# since we don't know the actual shape before this timing, thus,
# we have to use all gather to collect the S_LOCAL first.
output_split_sizes = _gather_size_by_comm(S_LOCAL, group)
x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D)
x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)

def wait() -> torch.Tensor:
nonlocal x, H_PAD
x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
# (S_GLOBAL, B, H_LOCAL, D)
# -> (B, S_GLOBAL, H_LOCAL, D)
x = x.permute(1, 0, 2, 3).contiguous()
x = _maybe_unpad_qkv_head(x, H_PAD, group)
return x

return wait


@maybe_allow_in_graph
def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]:
r"""
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D)
"""
# Assume H is provided in kwargs, since we can't infer H from x's shape.
# The padding logic needs H to determine if padding is necessary.
H = kwargs.get("NUM_QO_HEAD", None)
rank, world_size = _get_rank_world_size(group)
x, H_PAD = _maybe_pad_o_head(x, H, group)
shape = x.shape # (B, S_GLOBAL, H_LOCAL, D)
(B, S_GLOBAL, H_LOCAL, D) = shape

# input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..]
# output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..]

# WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer
# from tensor split due to: if c = torch.cat((a, b)), world_size=4, then,
# c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] +
# b.tensor_split(4)[0].shape[1])

S_LOCAL = kwargs.get("Q_S_LOCAL")
input_split_sizes = _gather_size_by_comm(S_LOCAL, group)
x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D)
output_split_sizes = [S_LOCAL] * world_size
x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)

def wait() -> torch.Tensor:
nonlocal x, H_PAD
x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D)
x = x.permute(2, 1, 0, 3, 4).contiguous()
x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D)
x = _maybe_unpad_o_head(x, H_PAD, group)
return x

return wait


@functools.lru_cache(maxsize=64)
def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]:
gather_shapes = []
for i in range(world_size):
rank_shape = list(copy.deepcopy(shape))
rank_shape[dim] = gather_dims[i]
gather_shapes.append(rank_shape)
return gather_shapes


@maybe_allow_in_graph
def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor:
_, world_size = _get_rank_world_size(group)
tensor = tensor.contiguous()
shape = tensor.shape
rank_dim = shape[dim]
gather_dims = _gather_size_by_comm(rank_dim, group)

gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size)

gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes]

dist.all_gather(gathered_tensors, tensor, group=group)
gathered_tensor = torch.cat(gathered_tensors, dim=dim)
return gathered_tensor


class AllGatherAnythingFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh):
ctx.dim = dim
ctx.group = group
ctx.world_size = dist.get_world_size(group)
ctx.rank = dist.get_rank(group)
gathered_tensor = _all_gather_anything(tensor, dim, group)
return gathered_tensor

@staticmethod
def backward(ctx, grad_output):
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
# function may return fewer than the specified number of chunks!
grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim)
return grad_splits[ctx.rank], None, None


class PartitionAnythingSharder:
@classmethod
def shard_anything(
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
) -> torch.Tensor:
assert tensor.size()[dim] >= mesh.size(), (
f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}."
)
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
# function may return fewer than the specified number of chunks!
return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())]

@classmethod
def unshard_anything(
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group())
return tensor
Loading