From dd5642917e0306cfc5b6764b2eabba2c3c36ce77 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 19 Jan 2026 08:12:38 +0000 Subject: [PATCH 01/11] feat: support Ulysses Anything Attention --- src/diffusers/hooks/context_parallel.py | 8 + src/diffusers/models/_ulysses_anything.py | 403 +++++++++++++++++++++ src/diffusers/models/attention_dispatch.py | 48 ++- src/diffusers/utils/constants.py | 1 + 4 files changed, 445 insertions(+), 15 deletions(-) create mode 100644 src/diffusers/models/_ulysses_anything.py diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 6491d17b4f46..965879f3180b 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -28,7 +28,9 @@ ContextParallelModelPlan, ContextParallelOutput, ) +from ..models._ulysses_anything import PartitionAnythingSharder from ..utils import get_logger +from ..utils.constants import DIFFUSERS_ULYSSES_ANYTHING from ..utils.torch_utils import unwrap_module from .hooks import HookRegistry, ModelHook @@ -256,6 +258,9 @@ def backward(ctx, grad_output): class EquipartitionSharder: @classmethod def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + if DIFFUSERS_ULYSSES_ANYTHING: + return PartitionAnythingSharder.shard_anything(tensor, dim, mesh) + # NOTE: the following assertion does not have to be true in general. We simply enforce it for now # because the alternate case has not yet been tested/required for any model. assert tensor.size()[dim] % mesh.size() == 0, ( @@ -269,6 +274,9 @@ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_me @classmethod def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + if DIFFUSERS_ULYSSES_ANYTHING: + return PartitionAnythingSharder.unshard_anything(tensor, dim, mesh) + tensor = tensor.contiguous() tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group()) return tensor diff --git a/src/diffusers/models/_ulysses_anything.py b/src/diffusers/models/_ulysses_anything.py new file mode 100644 index 000000000000..2aaf573b8166 --- /dev/null +++ b/src/diffusers/models/_ulysses_anything.py @@ -0,0 +1,403 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from: https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/attention/_templated_ulysses.py +import copy +import functools +from typing import Callable, List, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.distributed._functional_collectives as fc +import torch.nn.functional as F + +from diffusers.models._modeling_parallel import ParallelConfig + + +def _wait_tensor(tensor) -> torch.Tensor: + if isinstance(tensor, fc.AsyncCollectiveTensor): + tensor = tensor.wait() + + return tensor + + +def _get_rank_world_size( + group: dist.ProcessGroup, +) -> Tuple[int, int]: + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + return rank, world_size + + +@functools.lru_cache(maxsize=128) +def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: + r"""Gather the local size from all ranks. + size: int, local size return: List[int], list of size from all ranks + """ + world_size = dist.get_world_size(group=group) + # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead + comm_backends = str(dist.get_backend(group=group)) + # NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl") + gather_device = "cpu" if "cpu" in comm_backends else "cuda" + gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)] + dist.all_gather( + gathered_sizes, + torch.tensor([size], device=gather_device, dtype=torch.int64), + group=group, + ) + + gathered_sizes = [s[0].item() for s in gathered_sizes] + # NOTE: DON'T use tolist here due to graph break - Explanation: + # Backend compiler `inductor` failed with aten._local_scalar_dense.default + return gathered_sizes + + +# Helper functions to pad/unpad head dimension for QKV and O projections +def _maybe_pad_qkv_head( + x: torch.Tensor, + H: int, + group: dist.ProcessGroup, +) -> Tuple[torch.Tensor, int]: + r"""Maybe pad the head dimension to be divisible by world_size. + x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded + tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD + """ + _, world_size = _get_rank_world_size(group) + H_PAD = 0 + if H % world_size != 0: + H_PAD = world_size - (H % world_size) + NEW_H_LOCAL = (H + H_PAD) // world_size + # e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2. + # NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14. + assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" + x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() + return x, H_PAD + + +def _maybe_unpad_qkv_head( + x: torch.Tensor, + H_PAD: int, + group: dist.ProcessGroup, +) -> torch.Tensor: + r"""Maybe unpad the head dimension. + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, + unpadded tensor (B, S_GLOBAL, H_LOCAL, D) + """ + rank, world_size = _get_rank_world_size(group) + # Only the last rank may have padding + if H_PAD > 0 and rank == world_size - 1: + x = x[:, :, :-H_PAD, :] + return x.contiguous() + + +def _maybe_pad_o_head( + x: torch.Tensor, + H: int, + group: dist.ProcessGroup, +) -> Tuple[torch.Tensor, int]: + r"""Maybe pad the head dimension to be divisible by world_size. + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int], + padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD + """ + if H is None: + return x, 0 + + rank, world_size = _get_rank_world_size(group) + H_PAD = 0 + # Only the last rank may need padding + if H % world_size != 0: + # We need to broadcast H_PAD to all ranks to keep consistency + # in unpadding step later for all ranks. + H_PAD = world_size - (H % world_size) + NEW_H_LOCAL = (H + H_PAD) // world_size + assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" + if rank == world_size - 1: + x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() + return x, H_PAD + + +def _maybe_unpad_o_head( + x: torch.Tensor, + H_PAD: int, + group: dist.ProcessGroup, +) -> torch.Tensor: + r"""Maybe unpad the head dimension. + x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, + unpadded tensor (B, S_LOCAL, H_GLOBAL, D) + """ + if H_PAD > 0: + x = x[:, :, :-H_PAD, :] + return x.contiguous() + + +# Helper functions to for all-to-all communication with Ulysses Anything Attention +def _comm_metadata( + query: torch.Tensor, + **kwargs, +) -> dict: + num_qo_head = query.shape[2] # (B, S_LOCAL, H_GLOBAL, D) + extra_kwargs = {} + extra_kwargs["num_qo_head"] = num_qo_head + # May ddd other kwargs if needed in future + return extra_kwargs + + +@torch.compiler.allow_in_graph +def _all_to_all_single_any_qkv_async( + x: torch.Tensor, + group: dist.ProcessGroup, + **kwargs, +) -> Callable[..., torch.Tensor]: + r""" + x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D) + """ + _, world_size = _get_rank_world_size(group) + B, S_LOCAL, H, D = x.shape + x, H_PAD = _maybe_pad_qkv_head(x, H, group) + H_LOCAL = (H + H_PAD) // world_size + # (world_size, S_LOCAL, B, H_LOCAL, D) + x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + + input_split_sizes = [S_LOCAL] * world_size + # S_LOCAL maybe not equal for all ranks in dynamic shape case, + # since we don't know the actual shape before this timing, thus, + # we have to use all gather to collect the S_LOCAL first. + output_split_sizes = _gather_size_by_comm(S_LOCAL, group) + x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D) + x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group) + + def wait() -> torch.Tensor: + nonlocal x, H_PAD + x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) + # (S_GLOBAL, B, H_LOCAL, D) + # -> (B, S_GLOBAL, H_LOCAL, D) + x = x.permute(1, 0, 2, 3).contiguous() + x = _maybe_unpad_qkv_head(x, H_PAD, group) + return x + + return wait + + +@torch.compiler.allow_in_graph +def _all_to_all_single_any_o_async( + x: torch.Tensor, + group: dist.ProcessGroup, + **kwargs, +) -> Callable[..., torch.Tensor]: + r""" + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D) + """ + # Assume H is provided in kwargs, since we can't infer H from x's shape. + # The padding logic needs H to determine if padding is necessary. + H = kwargs.get("num_qo_head", None) + rank, world_size = _get_rank_world_size(group) + x, H_PAD = _maybe_pad_o_head(x, H, group) + shape = x.shape # (B, S_GLOBAL, H_LOCAL, D) + (B, S_GLOBAL, H_LOCAL, D) = shape + # NOTE: We use tensor_split here to ensure the same split policy + # that we have used in the EquipartitionSharder sharding strategy. Please + # note that the 'tensor_split' splits a tensor into multiple sub-tensors, + # all of which are views of input, thus may not introduce extra IO access. + input_split_sizes = [o.size(1) for o in torch.tensor_split(x, world_size, dim=1)] + # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..] + # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..] + S_LOCAL = input_split_sizes[rank] + x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D) + output_split_sizes = [S_LOCAL] * world_size + x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group) + + def wait() -> torch.Tensor: + nonlocal x, H_PAD + x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) + x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D) + x = x.permute(2, 1, 0, 3, 4).contiguous() + x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D) + x = _maybe_unpad_o_head(x, H_PAD, group) + return x + + return wait + + +class TemplatedUlyssesAnythingAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + **kwargs, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + group = ulysses_mesh.get_group() + + ctx.forward_op = forward_op + ctx.backward_op = backward_op + ctx._parallel_config = _parallel_config + + metadata = _comm_metadata(query) + query_wait = _all_to_all_single_any_qkv_async(query, group, **metadata) + key_wait = _all_to_all_single_any_qkv_async(key, group, **metadata) + value_wait = _all_to_all_single_any_qkv_async(value, group, **metadata) + + query = query_wait() # type: torch.Tensor + key = key_wait() # type: torch.Tensor + value = value_wait() # type: torch.Tensor + + out = forward_op( + ctx, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + _save_ctx=False, + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse, *_ = out + + # out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D) + out_wait = _all_to_all_single_any_o_async(out, group, **metadata) + + if return_lse: + # lse: (B, S_Q_GLOBAL, H_LOCAL) + lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1) + lse_wait = _all_to_all_single_any_o_async(lse, group, **metadata) + out = out_wait() # type: torch.Tensor + lse = lse_wait() # type: torch.Tensor + lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL) + else: + out = out_wait() # type: torch.Tensor + lse = None + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + ): + raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.") + + +@functools.lru_cache(maxsize=64) +def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]: + gather_shapes = [] + for i in range(world_size): + rank_shape = list(copy.deepcopy(shape)) + rank_shape[dim] = gather_dims[i] + gather_shapes.append(rank_shape) + return gather_shapes + + +@torch.compiler.allow_in_graph +def _all_gather_anything( # noqa: F811 + tensor: torch.Tensor, + dim: int, + group: dist.device_mesh.DeviceMesh, +) -> torch.Tensor: + _, world_size = _get_rank_world_size(group) + tensor = tensor.contiguous() + shape = tensor.shape + rank_dim = shape[dim] + gather_dims = _gather_size_by_comm(rank_dim, group) + + gather_shapes = _fill_gather_shapes( + tuple(shape), + tuple(gather_dims), + dim, + world_size, + ) + + gathered_tensors = [ + torch.empty( + shape, + device=tensor.device, + dtype=tensor.dtype, + ) + for shape in gather_shapes + ] + + dist.all_gather(gathered_tensors, tensor, group=group) + gathered_tensor = torch.cat(gathered_tensors, dim=dim) + return gathered_tensor + + +class AllGatherAnythingFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + dim: int, + group: dist.device_mesh.DeviceMesh, + ): + ctx.dim = dim + ctx.group = group + ctx.world_size = dist.get_world_size(group) + ctx.rank = dist.get_rank(group) + gathered_tensor = _all_gather_anything(tensor, dim, group) + return gathered_tensor + + @staticmethod + def backward(ctx, grad_output): + # NOTE: We use `tensor_split` instead of chunk, because the `chunk` + # function may return fewer than the specified number of chunks! + grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim) + return grad_splits[ctx.rank], None, None + + +class PartitionAnythingSharder: + @classmethod + def shard_anything( + cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh + ) -> torch.Tensor: + assert tensor.size()[dim] >= mesh.size(), ( + f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}." + ) + # NOTE: We use `tensor_split` instead of chunk, because the `chunk` + # function may return fewer than the specified number of chunks! For example, + # x = torch.tensor([1,2,3,4,5]), torch.chunk(x, 4) will return only 3 chunks: + # (tensor([1, 2]), tensor([3, 4]), tensor([5])). This behavior can lead to + # inconsistencies when sharding tensors across multiple devices. In contrast, + # tensor_split will always return the specified number of chunks, the last chunk + # may be smaller if the tensor size is not divisible by the number of chunks. + # For example, torch.tensor_split(x, 4) will return 4 chunks: + # (tensor([1, 2]), tensor([3]), tensor([4]), tensor([5])). + return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())] + + @classmethod + def unshard_anything( + cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh + ) -> torch.Tensor: + tensor = tensor.contiguous() + # NOTE: We use AllGatherAnythingFunction to support gathering + # tensors with complex and uneven sizes across all ranks. It handles the + # case where the tensor size (the seq_len of hidden_states) along the + # specified dimension is not divisible by the number of ranks in the mesh. + tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group()) + return tensor diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 61c478b03c4f..b76283eb1cfc 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -43,7 +43,8 @@ is_xformers_available, is_xformers_version, ) -from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ULYSSES_ANYTHING +from ._ulysses_anything import TemplatedUlyssesAnythingAttention if TYPE_CHECKING: @@ -1618,20 +1619,37 @@ def _templated_context_parallel_attention( _parallel_config, ) elif _parallel_config.context_parallel_config.ulysses_degree > 1: - return TemplatedUlyssesAttention.apply( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - forward_op, - backward_op, - _parallel_config, - ) + if DIFFUSERS_ULYSSES_ANYTHING: + # For Any sequence lengths and Any head num support + return TemplatedUlyssesAnythingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + else: + return TemplatedUlyssesAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) else: raise ValueError("Reaching this branch of code is unexpected. Please report a bug.") diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index c46fa4363483..b9407d8945dd 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -46,6 +46,7 @@ DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES +DIFFUSERS_ULYSSES_ANYTHING = os.getenv("DIFFUSERS_ULYSSES_ANYTHING", "0").upper() in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are From 123f5264bbdd2163cf224194365e0d931f1a325c Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 19 Jan 2026 09:16:33 +0000 Subject: [PATCH 02/11] feat: support Ulysses Anything Attention --- src/diffusers/models/_ulysses_anything.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/_ulysses_anything.py b/src/diffusers/models/_ulysses_anything.py index 2aaf573b8166..3517c3ede4d4 100644 --- a/src/diffusers/models/_ulysses_anything.py +++ b/src/diffusers/models/_ulysses_anything.py @@ -22,8 +22,8 @@ import torch.distributed._functional_collectives as fc import torch.nn.functional as F -from diffusers.models._modeling_parallel import ParallelConfig - +from ..utils.torch_utils import maybe_allow_in_graph +from ._modeling_parallel import ParallelConfig def _wait_tensor(tensor) -> torch.Tensor: if isinstance(tensor, fc.AsyncCollectiveTensor): @@ -153,7 +153,7 @@ def _comm_metadata( return extra_kwargs -@torch.compiler.allow_in_graph +@maybe_allow_in_graph def _all_to_all_single_any_qkv_async( x: torch.Tensor, group: dist.ProcessGroup, @@ -189,7 +189,7 @@ def wait() -> torch.Tensor: return wait -@torch.compiler.allow_in_graph +@maybe_allow_in_graph def _all_to_all_single_any_o_async( x: torch.Tensor, group: dist.ProcessGroup, @@ -315,7 +315,7 @@ def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, wo return gather_shapes -@torch.compiler.allow_in_graph +@maybe_allow_in_graph def _all_gather_anything( # noqa: F811 tensor: torch.Tensor, dim: int, From af9af62d2d1e64eda8df5b7b95687f3840f46932 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 19 Jan 2026 09:17:11 +0000 Subject: [PATCH 03/11] feat: support Ulysses Anything Attention --- src/diffusers/models/_ulysses_anything.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/_ulysses_anything.py b/src/diffusers/models/_ulysses_anything.py index 3517c3ede4d4..a5594ec8cbd5 100644 --- a/src/diffusers/models/_ulysses_anything.py +++ b/src/diffusers/models/_ulysses_anything.py @@ -25,6 +25,7 @@ from ..utils.torch_utils import maybe_allow_in_graph from ._modeling_parallel import ParallelConfig + def _wait_tensor(tensor) -> torch.Tensor: if isinstance(tensor, fc.AsyncCollectiveTensor): tensor = tensor.wait() From b4d3f077f47325ccd949fbae68fa8cfdf0c5a446 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 20 Jan 2026 02:24:39 +0000 Subject: [PATCH 04/11] feat: support Ulysses Anything Attention --- src/diffusers/hooks/context_parallel.py | 22 ++- src/diffusers/models/_modeling_parallel.py | 3 + ...anything.py => _ulysses_anything_utils.py} | 187 +++--------------- src/diffusers/models/attention_dispatch.py | 86 +++++++- src/diffusers/utils/constants.py | 1 - 5 files changed, 124 insertions(+), 175 deletions(-) rename src/diffusers/models/{_ulysses_anything.py => _ulysses_anything_utils.py} (65%) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 965879f3180b..53e2b53d986e 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -28,9 +28,8 @@ ContextParallelModelPlan, ContextParallelOutput, ) -from ..models._ulysses_anything import PartitionAnythingSharder +from ..models._ulysses_anything_utils import PartitionAnythingSharder from ..utils import get_logger -from ..utils.constants import DIFFUSERS_ULYSSES_ANYTHING from ..utils.torch_utils import unwrap_module from .hooks import HookRegistry, ModelHook @@ -210,6 +209,10 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> ) return x else: + if self.parallel_config.ulysses_anything: + return PartitionAnythingSharder.shard_anything( + x, cp_input.split_dim, self.parallel_config._flattened_mesh + ) return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) @@ -235,7 +238,14 @@ def post_forward(self, module, output): for i, cpm in enumerate(self.metadata): if cpm is None: continue - output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh) + if self.parallel_config.ulysses_anything: + output[i] = PartitionAnythingSharder.unshard_anything( + output[i], cpm.gather_dim, self.parallel_config._flattened_mesh + ) + else: + output[i] = EquipartitionSharder.unshard( + output[i], cpm.gather_dim, self.parallel_config._flattened_mesh + ) return output[0] if is_tensor else tuple(output) @@ -258,9 +268,6 @@ def backward(ctx, grad_output): class EquipartitionSharder: @classmethod def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: - if DIFFUSERS_ULYSSES_ANYTHING: - return PartitionAnythingSharder.shard_anything(tensor, dim, mesh) - # NOTE: the following assertion does not have to be true in general. We simply enforce it for now # because the alternate case has not yet been tested/required for any model. assert tensor.size()[dim] % mesh.size() == 0, ( @@ -274,9 +281,6 @@ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_me @classmethod def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: - if DIFFUSERS_ULYSSES_ANYTHING: - return PartitionAnythingSharder.unshard_anything(tensor, dim, mesh) - tensor = tensor.contiguous() tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group()) return tensor diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 1c7703a13c52..f301ba771cc7 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -67,6 +67,9 @@ class ContextParallelConfig: convert_to_fp32: bool = True # TODO: support alltoall rotate_method: Literal["allgather", "alltoall"] = "allgather" + # Whether to enable ulysses anything attention to support + # any sequence lengths and any head numbers. + ulysses_anything: bool = False _rank: int = None _world_size: int = None diff --git a/src/diffusers/models/_ulysses_anything.py b/src/diffusers/models/_ulysses_anything_utils.py similarity index 65% rename from src/diffusers/models/_ulysses_anything.py rename to src/diffusers/models/_ulysses_anything_utils.py index a5594ec8cbd5..fd85f374d0f5 100644 --- a/src/diffusers/models/_ulysses_anything.py +++ b/src/diffusers/models/_ulysses_anything_utils.py @@ -15,7 +15,7 @@ # Adapted from: https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/attention/_templated_ulysses.py import copy import functools -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Tuple import torch import torch.distributed as dist @@ -23,19 +23,10 @@ import torch.nn.functional as F from ..utils.torch_utils import maybe_allow_in_graph -from ._modeling_parallel import ParallelConfig -def _wait_tensor(tensor) -> torch.Tensor: - if isinstance(tensor, fc.AsyncCollectiveTensor): - tensor = tensor.wait() - - return tensor - - -def _get_rank_world_size( - group: dist.ProcessGroup, -) -> Tuple[int, int]: +# Helper functions for shape gathering +def _get_rank_world_size(group: dist.ProcessGroup) -> Tuple[int, int]: world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) return rank, world_size @@ -50,7 +41,7 @@ def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead comm_backends = str(dist.get_backend(group=group)) # NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl") - gather_device = "cpu" if "cpu" in comm_backends else "cuda" + gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator() gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)] dist.all_gather( gathered_sizes, @@ -65,11 +56,7 @@ def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: # Helper functions to pad/unpad head dimension for QKV and O projections -def _maybe_pad_qkv_head( - x: torch.Tensor, - H: int, - group: dist.ProcessGroup, -) -> Tuple[torch.Tensor, int]: +def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]: r"""Maybe pad the head dimension to be divisible by world_size. x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD @@ -86,11 +73,7 @@ def _maybe_pad_qkv_head( return x, H_PAD -def _maybe_unpad_qkv_head( - x: torch.Tensor, - H_PAD: int, - group: dist.ProcessGroup, -) -> torch.Tensor: +def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: r"""Maybe unpad the head dimension. x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, unpadded tensor (B, S_GLOBAL, H_LOCAL, D) @@ -102,11 +85,7 @@ def _maybe_unpad_qkv_head( return x.contiguous() -def _maybe_pad_o_head( - x: torch.Tensor, - H: int, - group: dist.ProcessGroup, -) -> Tuple[torch.Tensor, int]: +def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]: r"""Maybe pad the head dimension to be divisible by world_size. x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD @@ -128,11 +107,7 @@ def _maybe_pad_o_head( return x, H_PAD -def _maybe_unpad_o_head( - x: torch.Tensor, - H_PAD: int, - group: dist.ProcessGroup, -) -> torch.Tensor: +def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: r"""Maybe unpad the head dimension. x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, unpadded tensor (B, S_LOCAL, H_GLOBAL, D) @@ -143,10 +118,14 @@ def _maybe_unpad_o_head( # Helper functions to for all-to-all communication with Ulysses Anything Attention -def _comm_metadata( - query: torch.Tensor, - **kwargs, -) -> dict: +def _wait_tensor(tensor) -> torch.Tensor: + if isinstance(tensor, fc.AsyncCollectiveTensor): + tensor = tensor.wait() + + return tensor + + +def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict: num_qo_head = query.shape[2] # (B, S_LOCAL, H_GLOBAL, D) extra_kwargs = {} extra_kwargs["num_qo_head"] = num_qo_head @@ -155,10 +134,8 @@ def _comm_metadata( @maybe_allow_in_graph -def _all_to_all_single_any_qkv_async( - x: torch.Tensor, - group: dist.ProcessGroup, - **kwargs, +def all_to_all_single_any_qkv_async( + x: torch.Tensor, group: dist.ProcessGroup, **kwargs ) -> Callable[..., torch.Tensor]: r""" x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D) @@ -191,11 +168,7 @@ def wait() -> torch.Tensor: @maybe_allow_in_graph -def _all_to_all_single_any_o_async( - x: torch.Tensor, - group: dist.ProcessGroup, - **kwargs, -) -> Callable[..., torch.Tensor]: +def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]: r""" x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D) """ @@ -207,9 +180,7 @@ def _all_to_all_single_any_o_async( shape = x.shape # (B, S_GLOBAL, H_LOCAL, D) (B, S_GLOBAL, H_LOCAL, D) = shape # NOTE: We use tensor_split here to ensure the same split policy - # that we have used in the EquipartitionSharder sharding strategy. Please - # note that the 'tensor_split' splits a tensor into multiple sub-tensors, - # all of which are views of input, thus may not introduce extra IO access. + # that we have used in the EquipartitionSharder sharding strategy. input_split_sizes = [o.size(1) for o in torch.tensor_split(x, world_size, dim=1)] # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..] # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..] @@ -230,82 +201,6 @@ def wait() -> torch.Tensor: return wait -class TemplatedUlyssesAnythingAttention(torch.autograd.Function): - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor], - dropout_p: float, - is_causal: bool, - scale: Optional[float], - enable_gqa: bool, - return_lse: bool, - forward_op, - backward_op, - _parallel_config: Optional["ParallelConfig"] = None, - **kwargs, - ): - ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh - group = ulysses_mesh.get_group() - - ctx.forward_op = forward_op - ctx.backward_op = backward_op - ctx._parallel_config = _parallel_config - - metadata = _comm_metadata(query) - query_wait = _all_to_all_single_any_qkv_async(query, group, **metadata) - key_wait = _all_to_all_single_any_qkv_async(key, group, **metadata) - value_wait = _all_to_all_single_any_qkv_async(value, group, **metadata) - - query = query_wait() # type: torch.Tensor - key = key_wait() # type: torch.Tensor - value = value_wait() # type: torch.Tensor - - out = forward_op( - ctx, - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - _save_ctx=False, - _parallel_config=_parallel_config, - ) - if return_lse: - out, lse, *_ = out - - # out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D) - out_wait = _all_to_all_single_any_o_async(out, group, **metadata) - - if return_lse: - # lse: (B, S_Q_GLOBAL, H_LOCAL) - lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1) - lse_wait = _all_to_all_single_any_o_async(lse, group, **metadata) - out = out_wait() # type: torch.Tensor - lse = lse_wait() # type: torch.Tensor - lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL) - else: - out = out_wait() # type: torch.Tensor - lse = None - - return (out, lse) if return_lse else out - - @staticmethod - def backward( - ctx: torch.autograd.function.FunctionCtx, - grad_out: torch.Tensor, - *args, - ): - raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.") - - @functools.lru_cache(maxsize=64) def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]: gather_shapes = [] @@ -317,32 +212,16 @@ def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, wo @maybe_allow_in_graph -def _all_gather_anything( # noqa: F811 - tensor: torch.Tensor, - dim: int, - group: dist.device_mesh.DeviceMesh, -) -> torch.Tensor: +def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor: _, world_size = _get_rank_world_size(group) tensor = tensor.contiguous() shape = tensor.shape rank_dim = shape[dim] gather_dims = _gather_size_by_comm(rank_dim, group) - gather_shapes = _fill_gather_shapes( - tuple(shape), - tuple(gather_dims), - dim, - world_size, - ) + gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size) - gathered_tensors = [ - torch.empty( - shape, - device=tensor.device, - dtype=tensor.dtype, - ) - for shape in gather_shapes - ] + gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes] dist.all_gather(gathered_tensors, tensor, group=group) gathered_tensor = torch.cat(gathered_tensors, dim=dim) @@ -351,12 +230,7 @@ def _all_gather_anything( # noqa: F811 class AllGatherAnythingFunction(torch.autograd.Function): @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - dim: int, - group: dist.device_mesh.DeviceMesh, - ): + def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh): ctx.dim = dim ctx.group = group ctx.world_size = dist.get_world_size(group) @@ -381,14 +255,7 @@ def shard_anything( f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}." ) # NOTE: We use `tensor_split` instead of chunk, because the `chunk` - # function may return fewer than the specified number of chunks! For example, - # x = torch.tensor([1,2,3,4,5]), torch.chunk(x, 4) will return only 3 chunks: - # (tensor([1, 2]), tensor([3, 4]), tensor([5])). This behavior can lead to - # inconsistencies when sharding tensors across multiple devices. In contrast, - # tensor_split will always return the specified number of chunks, the last chunk - # may be smaller if the tensor size is not divisible by the number of chunks. - # For example, torch.tensor_split(x, 4) will return 4 chunks: - # (tensor([1, 2]), tensor([3]), tensor([4]), tensor([5])). + # function may return fewer than the specified number of chunks! return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())] @classmethod @@ -396,9 +263,5 @@ def unshard_anything( cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh ) -> torch.Tensor: tensor = tensor.contiguous() - # NOTE: We use AllGatherAnythingFunction to support gathering - # tensors with complex and uneven sizes across all ranks. It handles the - # case where the tensor size (the seq_len of hidden_states) along the - # specified dimension is not divisible by the number of ranks in the mesh. tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group()) return tensor diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index b76283eb1cfc..4cd238dd5502 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -43,8 +43,12 @@ is_xformers_available, is_xformers_version, ) -from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ULYSSES_ANYTHING -from ._ulysses_anything import TemplatedUlyssesAnythingAttention +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS +from ._ulysses_anything_utils import ( + all_to_all_single_any_o_async, + all_to_all_single_any_qkv_async, + ulysses_anything_metadata, +) if TYPE_CHECKING: @@ -1502,6 +1506,82 @@ def backward( return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None +class TemplatedUlyssesAnythingAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + **kwargs, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + group = ulysses_mesh.get_group() + + ctx.forward_op = forward_op + ctx.backward_op = backward_op + ctx._parallel_config = _parallel_config + + metadata = ulysses_anything_metadata(query) + query_wait = all_to_all_single_any_qkv_async(query, group, **metadata) + key_wait = all_to_all_single_any_qkv_async(key, group, **metadata) + value_wait = all_to_all_single_any_qkv_async(value, group, **metadata) + + query = query_wait() # type: torch.Tensor + key = key_wait() # type: torch.Tensor + value = value_wait() # type: torch.Tensor + + out = forward_op( + ctx, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + _save_ctx=False, # ulysses anything only support forward pass now. + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse, *_ = out + + # out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D) + out_wait = all_to_all_single_any_o_async(out, group, **metadata) + + if return_lse: + # lse: (B, S_Q_GLOBAL, H_LOCAL) + lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1) + lse_wait = all_to_all_single_any_o_async(lse, group, **metadata) + out = out_wait() # type: torch.Tensor + lse = lse_wait() # type: torch.Tensor + lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL) + else: + out = out_wait() # type: torch.Tensor + lse = None + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + ): + raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.") + + def _templated_unified_attention( query: torch.Tensor, key: torch.Tensor, @@ -1619,7 +1699,7 @@ def _templated_context_parallel_attention( _parallel_config, ) elif _parallel_config.context_parallel_config.ulysses_degree > 1: - if DIFFUSERS_ULYSSES_ANYTHING: + if _parallel_config.context_parallel_config.ulysses_anything: # For Any sequence lengths and Any head num support return TemplatedUlyssesAnythingAttention.apply( query, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index b9407d8945dd..c46fa4363483 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -46,7 +46,6 @@ DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES -DIFFUSERS_ULYSSES_ANYTHING = os.getenv("DIFFUSERS_ULYSSES_ANYTHING", "0").upper() in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are From 403c204ab8d29b5ef6197ebeb46f97c6e923fac6 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Wed, 21 Jan 2026 13:19:36 +0000 Subject: [PATCH 05/11] fix UAA broken while using joint attn --- .../models/_ulysses_anything_utils.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/_ulysses_anything_utils.py b/src/diffusers/models/_ulysses_anything_utils.py index fd85f374d0f5..c17c0ab84812 100644 --- a/src/diffusers/models/_ulysses_anything_utils.py +++ b/src/diffusers/models/_ulysses_anything_utils.py @@ -32,7 +32,6 @@ def _get_rank_world_size(group: dist.ProcessGroup) -> Tuple[int, int]: return rank, world_size -@functools.lru_cache(maxsize=128) def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: r"""Gather the local size from all ranks. size: int, local size return: List[int], list of size from all ranks @@ -126,10 +125,12 @@ def _wait_tensor(tensor) -> torch.Tensor: def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict: - num_qo_head = query.shape[2] # (B, S_LOCAL, H_GLOBAL, D) + # query: (B, S_LOCAL, H_GLOBAL, D) + assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)" extra_kwargs = {} - extra_kwargs["num_qo_head"] = num_qo_head - # May ddd other kwargs if needed in future + extra_kwargs["NUM_QO_HEAD"] = query.shape[2] + extra_kwargs["Q_S_LOCAL"] = query.shape[1] + # Add other kwargs if needed in future return extra_kwargs @@ -174,17 +175,22 @@ def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **k """ # Assume H is provided in kwargs, since we can't infer H from x's shape. # The padding logic needs H to determine if padding is necessary. - H = kwargs.get("num_qo_head", None) + H = kwargs.get("NUM_QO_HEAD", None) rank, world_size = _get_rank_world_size(group) x, H_PAD = _maybe_pad_o_head(x, H, group) shape = x.shape # (B, S_GLOBAL, H_LOCAL, D) (B, S_GLOBAL, H_LOCAL, D) = shape - # NOTE: We use tensor_split here to ensure the same split policy - # that we have used in the EquipartitionSharder sharding strategy. - input_split_sizes = [o.size(1) for o in torch.tensor_split(x, world_size, dim=1)] + # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..] # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..] - S_LOCAL = input_split_sizes[rank] + + # WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer + # from tensor split due to: if c = torch.cat((a, b)), world_size=4, then, + # c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] + + # b.tensor_split(4)[0].shape[1]) + + S_LOCAL = kwargs.get("Q_S_LOCAL") + input_split_sizes = _gather_size_by_comm(S_LOCAL, group) x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D) output_split_sizes = [S_LOCAL] * world_size x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group) From 9280e2b6321304e52e76b90db32107d60b0bc1c3 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Wed, 21 Jan 2026 13:24:30 +0000 Subject: [PATCH 06/11] update --- src/diffusers/models/_ulysses_anything_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/_ulysses_anything_utils.py b/src/diffusers/models/_ulysses_anything_utils.py index c17c0ab84812..284ba6241b99 100644 --- a/src/diffusers/models/_ulysses_anything_utils.py +++ b/src/diffusers/models/_ulysses_anything_utils.py @@ -32,6 +32,7 @@ def _get_rank_world_size(group: dist.ProcessGroup) -> Tuple[int, int]: return rank, world_size +@functools.lru_cache(maxsize=128) def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: r"""Gather the local size from all ranks. size: int, local size return: List[int], list of size from all ranks From 4caa87eacd52d562ac8c8af6267a001af5bba956 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Fri, 23 Jan 2026 08:21:50 +0000 Subject: [PATCH 07/11] post check --- src/diffusers/models/_modeling_parallel.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index f301ba771cc7..c91403ad49dc 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -97,6 +97,11 @@ def __post_init__(self): raise NotImplementedError( f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." ) + if self.ulysses_anything: + if self.ulysses_degree == 1: + raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.") + if self.ring_degree > 1: + raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.") @property def mesh_shape(self) -> Tuple[int, int]: From 140ece8ecc880321f3f86f7efa49ec392e8a1eef Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 26 Jan 2026 02:54:21 +0000 Subject: [PATCH 08/11] add docs --- .../en/training/distributed_inference.md | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index e7ec1480aabd..13be261878c3 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -343,6 +343,34 @@ We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script]( From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention. + +### Ulysses Anything Attention + +The default Ulysses Attention mechanism requires that the sequence length of hidden states must be divisible by the number of devices. This imposes significant limitations on the practical application of Ulysses Attention. [Ulysses Anything Attention](https://github.com/huggingface/diffusers/pull/12996) is a variant of Ulysses Attention that supports arbitrary sequence lengths and arbitrary numbers of attention heads, thereby enhancing the versatility of Ulysses Attention in practical use. + +[`ContextParallelConfig`] supports Ulysses Anything Attention by specifying both `ulysses_degree` and `ulysses_anything`. Please note that Ulysses Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with both `ulysses_degree` set to bigger than 1 and `ulysses_anything=True` to [`~ModelMixin.enable_parallelism`]. + +```py +pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ulysses_anything=True)) +``` + +> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency. + +We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996) on a node of 4 L20 GPUs. The results are summarized as follows: + +| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)| +|--------------------|------------------|-------------|------------------|------------| +| ulysses | 281.07 | 3.56 | 37.11 | 1024x1024 | +| ring | 351.34 | 2.85 | 37.01 | 1024x1024 | +| unified_balanced | 324.37 | 3.08 | 37.16 | 1024x1024 | +| ulysses_anything | 280.94 | 3.56 | 37.11 | 1024x1024 | +| ulysses | failed | failed | failed | 1008x1008 | +| ring | failed | failed | failed | 1008x1008 | +| unified_balanced | failed | failed | failed | 1008x1008 | +| ulysses_anything | 278.40 | 3.59 | 36.99 | 1008x1008 | + +From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention. + ### parallel_config Pass `parallel_config` during model initialization to enable context parallelism. From 4cc56b9c844323cbb7fb1543a78bd19534b8a996 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 26 Jan 2026 02:56:41 +0000 Subject: [PATCH 09/11] add docs --- docs/source/en/training/distributed_inference.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index 13be261878c3..bdaa2ae8ffff 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -356,7 +356,7 @@ pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_deg > [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency. -We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996) on a node of 4 L20 GPUs. The results are summarized as follows: +We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996#issuecomment-3797695999) on a node of 4 L20 GPUs. The results are summarized as follows: | CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)| |--------------------|------------------|-------------|------------------|------------| From a74820b596e824c174ecb320663a58b27272fa21 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Wed, 28 Jan 2026 05:51:34 +0000 Subject: [PATCH 10/11] remove lru cache --- src/diffusers/models/_ulysses_anything_utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/_ulysses_anything_utils.py b/src/diffusers/models/_ulysses_anything_utils.py index 284ba6241b99..bbbce66d6ad7 100644 --- a/src/diffusers/models/_ulysses_anything_utils.py +++ b/src/diffusers/models/_ulysses_anything_utils.py @@ -32,11 +32,23 @@ def _get_rank_world_size(group: dist.ProcessGroup) -> Tuple[int, int]: return rank, world_size -@functools.lru_cache(maxsize=128) def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: r"""Gather the local size from all ranks. size: int, local size return: List[int], list of size from all ranks """ + # NOTE(Serving/CP Safety): + # Do NOT cache this collective result. + # + # In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL) + # may legitimately differ across ranks. If we cache based on the *local* `size`, + # different ranks can have different cache hit/miss patterns across time. + # + # That can lead to a catastrophic distributed hang: + # - some ranks hit cache and *skip* dist.all_gather() + # - other ranks miss cache and *enter* dist.all_gather() + # This mismatched collective participation will stall the process group and + # eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL + # timeouts in Ulysses attention). world_size = dist.get_world_size(group=group) # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead comm_backends = str(dist.get_backend(group=group)) From f8f220926ff83cf38846ac446d095c2f1c9b4a33 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Thu, 29 Jan 2026 06:26:11 +0000 Subject: [PATCH 11/11] move codes --- src/diffusers/hooks/context_parallel.py | 77 ++++- src/diffusers/models/_modeling_parallel.py | 37 +++ .../models/_ulysses_anything_utils.py | 286 ------------------ src/diffusers/models/attention_dispatch.py | 157 +++++++++- 4 files changed, 262 insertions(+), 295 deletions(-) delete mode 100644 src/diffusers/models/_ulysses_anything_utils.py diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 53e2b53d986e..44c048032f67 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy +import functools import inspect from dataclasses import dataclass -from typing import Dict, List, Type, Union +from typing import Dict, List, Tuple, Type, Union import torch +import torch.distributed as dist if torch.distributed.is_available(): @@ -27,10 +29,10 @@ ContextParallelInput, ContextParallelModelPlan, ContextParallelOutput, + _gather_size_by_comm, ) -from ..models._ulysses_anything_utils import PartitionAnythingSharder from ..utils import get_logger -from ..utils.torch_utils import unwrap_module +from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module from .hooks import HookRegistry, ModelHook @@ -286,6 +288,73 @@ def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_ return tensor +class AllGatherAnythingFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh): + ctx.dim = dim + ctx.group = group + ctx.world_size = dist.get_world_size(group) + ctx.rank = dist.get_rank(group) + gathered_tensor = _all_gather_anything(tensor, dim, group) + return gathered_tensor + + @staticmethod + def backward(ctx, grad_output): + # NOTE: We use `tensor_split` instead of chunk, because the `chunk` + # function may return fewer than the specified number of chunks! + grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim) + return grad_splits[ctx.rank], None, None + + +class PartitionAnythingSharder: + @classmethod + def shard_anything( + cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh + ) -> torch.Tensor: + assert tensor.size()[dim] >= mesh.size(), ( + f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}." + ) + # NOTE: We use `tensor_split` instead of chunk, because the `chunk` + # function may return fewer than the specified number of chunks! + return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())] + + @classmethod + def unshard_anything( + cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh + ) -> torch.Tensor: + tensor = tensor.contiguous() + tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group()) + return tensor + + +@functools.lru_cache(maxsize=64) +def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]: + gather_shapes = [] + for i in range(world_size): + rank_shape = list(copy.deepcopy(shape)) + rank_shape[dim] = gather_dims[i] + gather_shapes.append(rank_shape) + return gather_shapes + + +@maybe_allow_in_graph +def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor: + world_size = dist.get_world_size(group=group) + + tensor = tensor.contiguous() + shape = tensor.shape + rank_dim = shape[dim] + gather_dims = _gather_size_by_comm(rank_dim, group) + + gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size) + + gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes] + + dist.all_gather(gathered_tensors, tensor, group=group) + gathered_tensor = torch.cat(gathered_tensors, dim=dim) + return gathered_tensor + + def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: if name.count("*") > 1: raise ValueError("Wildcard '*' can only be used once in the name") diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index c91403ad49dc..a8fe2c7f5517 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union import torch +import torch.distributed as dist from ..utils import get_logger @@ -265,3 +266,39 @@ def __repr__(self): # # ContextParallelOutput: # specifies how to gather the input tensor in the post-forward hook in the layer it is attached to + + +# Below are utility functions for distributed communication in context parallelism. +def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: + r"""Gather the local size from all ranks. + size: int, local size return: List[int], list of size from all ranks + """ + # NOTE(Serving/CP Safety): + # Do NOT cache this collective result. + # + # In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL) + # may legitimately differ across ranks. If we cache based on the *local* `size`, + # different ranks can have different cache hit/miss patterns across time. + # + # That can lead to a catastrophic distributed hang: + # - some ranks hit cache and *skip* dist.all_gather() + # - other ranks miss cache and *enter* dist.all_gather() + # This mismatched collective participation will stall the process group and + # eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL + # timeouts in Ulysses attention). + world_size = dist.get_world_size(group=group) + # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead + comm_backends = str(dist.get_backend(group=group)) + # NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl") + gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator() + gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)] + dist.all_gather( + gathered_sizes, + torch.tensor([size], device=gather_device, dtype=torch.int64), + group=group, + ) + + gathered_sizes = [s[0].item() for s in gathered_sizes] + # NOTE: DON'T use tolist here due to graph break - Explanation: + # Backend compiler `inductor` failed with aten._local_scalar_dense.default + return gathered_sizes diff --git a/src/diffusers/models/_ulysses_anything_utils.py b/src/diffusers/models/_ulysses_anything_utils.py deleted file mode 100644 index bbbce66d6ad7..000000000000 --- a/src/diffusers/models/_ulysses_anything_utils.py +++ /dev/null @@ -1,286 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Adapted from: https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/attention/_templated_ulysses.py -import copy -import functools -from typing import Callable, List, Tuple - -import torch -import torch.distributed as dist -import torch.distributed._functional_collectives as fc -import torch.nn.functional as F - -from ..utils.torch_utils import maybe_allow_in_graph - - -# Helper functions for shape gathering -def _get_rank_world_size(group: dist.ProcessGroup) -> Tuple[int, int]: - world_size = dist.get_world_size(group=group) - rank = dist.get_rank(group=group) - return rank, world_size - - -def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: - r"""Gather the local size from all ranks. - size: int, local size return: List[int], list of size from all ranks - """ - # NOTE(Serving/CP Safety): - # Do NOT cache this collective result. - # - # In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL) - # may legitimately differ across ranks. If we cache based on the *local* `size`, - # different ranks can have different cache hit/miss patterns across time. - # - # That can lead to a catastrophic distributed hang: - # - some ranks hit cache and *skip* dist.all_gather() - # - other ranks miss cache and *enter* dist.all_gather() - # This mismatched collective participation will stall the process group and - # eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL - # timeouts in Ulysses attention). - world_size = dist.get_world_size(group=group) - # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead - comm_backends = str(dist.get_backend(group=group)) - # NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl") - gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator() - gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)] - dist.all_gather( - gathered_sizes, - torch.tensor([size], device=gather_device, dtype=torch.int64), - group=group, - ) - - gathered_sizes = [s[0].item() for s in gathered_sizes] - # NOTE: DON'T use tolist here due to graph break - Explanation: - # Backend compiler `inductor` failed with aten._local_scalar_dense.default - return gathered_sizes - - -# Helper functions to pad/unpad head dimension for QKV and O projections -def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]: - r"""Maybe pad the head dimension to be divisible by world_size. - x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded - tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD - """ - _, world_size = _get_rank_world_size(group) - H_PAD = 0 - if H % world_size != 0: - H_PAD = world_size - (H % world_size) - NEW_H_LOCAL = (H + H_PAD) // world_size - # e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2. - # NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14. - assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" - x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() - return x, H_PAD - - -def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: - r"""Maybe unpad the head dimension. - x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, - unpadded tensor (B, S_GLOBAL, H_LOCAL, D) - """ - rank, world_size = _get_rank_world_size(group) - # Only the last rank may have padding - if H_PAD > 0 and rank == world_size - 1: - x = x[:, :, :-H_PAD, :] - return x.contiguous() - - -def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]: - r"""Maybe pad the head dimension to be divisible by world_size. - x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int], - padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD - """ - if H is None: - return x, 0 - - rank, world_size = _get_rank_world_size(group) - H_PAD = 0 - # Only the last rank may need padding - if H % world_size != 0: - # We need to broadcast H_PAD to all ranks to keep consistency - # in unpadding step later for all ranks. - H_PAD = world_size - (H % world_size) - NEW_H_LOCAL = (H + H_PAD) // world_size - assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" - if rank == world_size - 1: - x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() - return x, H_PAD - - -def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: - r"""Maybe unpad the head dimension. - x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, - unpadded tensor (B, S_LOCAL, H_GLOBAL, D) - """ - if H_PAD > 0: - x = x[:, :, :-H_PAD, :] - return x.contiguous() - - -# Helper functions to for all-to-all communication with Ulysses Anything Attention -def _wait_tensor(tensor) -> torch.Tensor: - if isinstance(tensor, fc.AsyncCollectiveTensor): - tensor = tensor.wait() - - return tensor - - -def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict: - # query: (B, S_LOCAL, H_GLOBAL, D) - assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)" - extra_kwargs = {} - extra_kwargs["NUM_QO_HEAD"] = query.shape[2] - extra_kwargs["Q_S_LOCAL"] = query.shape[1] - # Add other kwargs if needed in future - return extra_kwargs - - -@maybe_allow_in_graph -def all_to_all_single_any_qkv_async( - x: torch.Tensor, group: dist.ProcessGroup, **kwargs -) -> Callable[..., torch.Tensor]: - r""" - x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D) - """ - _, world_size = _get_rank_world_size(group) - B, S_LOCAL, H, D = x.shape - x, H_PAD = _maybe_pad_qkv_head(x, H, group) - H_LOCAL = (H + H_PAD) // world_size - # (world_size, S_LOCAL, B, H_LOCAL, D) - x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() - - input_split_sizes = [S_LOCAL] * world_size - # S_LOCAL maybe not equal for all ranks in dynamic shape case, - # since we don't know the actual shape before this timing, thus, - # we have to use all gather to collect the S_LOCAL first. - output_split_sizes = _gather_size_by_comm(S_LOCAL, group) - x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D) - x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group) - - def wait() -> torch.Tensor: - nonlocal x, H_PAD - x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) - # (S_GLOBAL, B, H_LOCAL, D) - # -> (B, S_GLOBAL, H_LOCAL, D) - x = x.permute(1, 0, 2, 3).contiguous() - x = _maybe_unpad_qkv_head(x, H_PAD, group) - return x - - return wait - - -@maybe_allow_in_graph -def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]: - r""" - x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D) - """ - # Assume H is provided in kwargs, since we can't infer H from x's shape. - # The padding logic needs H to determine if padding is necessary. - H = kwargs.get("NUM_QO_HEAD", None) - rank, world_size = _get_rank_world_size(group) - x, H_PAD = _maybe_pad_o_head(x, H, group) - shape = x.shape # (B, S_GLOBAL, H_LOCAL, D) - (B, S_GLOBAL, H_LOCAL, D) = shape - - # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..] - # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..] - - # WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer - # from tensor split due to: if c = torch.cat((a, b)), world_size=4, then, - # c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] + - # b.tensor_split(4)[0].shape[1]) - - S_LOCAL = kwargs.get("Q_S_LOCAL") - input_split_sizes = _gather_size_by_comm(S_LOCAL, group) - x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D) - output_split_sizes = [S_LOCAL] * world_size - x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group) - - def wait() -> torch.Tensor: - nonlocal x, H_PAD - x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) - x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D) - x = x.permute(2, 1, 0, 3, 4).contiguous() - x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D) - x = _maybe_unpad_o_head(x, H_PAD, group) - return x - - return wait - - -@functools.lru_cache(maxsize=64) -def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]: - gather_shapes = [] - for i in range(world_size): - rank_shape = list(copy.deepcopy(shape)) - rank_shape[dim] = gather_dims[i] - gather_shapes.append(rank_shape) - return gather_shapes - - -@maybe_allow_in_graph -def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor: - _, world_size = _get_rank_world_size(group) - tensor = tensor.contiguous() - shape = tensor.shape - rank_dim = shape[dim] - gather_dims = _gather_size_by_comm(rank_dim, group) - - gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size) - - gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes] - - dist.all_gather(gathered_tensors, tensor, group=group) - gathered_tensor = torch.cat(gathered_tensors, dim=dim) - return gathered_tensor - - -class AllGatherAnythingFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh): - ctx.dim = dim - ctx.group = group - ctx.world_size = dist.get_world_size(group) - ctx.rank = dist.get_rank(group) - gathered_tensor = _all_gather_anything(tensor, dim, group) - return gathered_tensor - - @staticmethod - def backward(ctx, grad_output): - # NOTE: We use `tensor_split` instead of chunk, because the `chunk` - # function may return fewer than the specified number of chunks! - grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim) - return grad_splits[ctx.rank], None, None - - -class PartitionAnythingSharder: - @classmethod - def shard_anything( - cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh - ) -> torch.Tensor: - assert tensor.size()[dim] >= mesh.size(), ( - f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}." - ) - # NOTE: We use `tensor_split` instead of chunk, because the `chunk` - # function may return fewer than the specified number of chunks! - return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())] - - @classmethod - def unshard_anything( - cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh - ) -> torch.Tensor: - tensor = tensor.contiguous() - tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group()) - return tensor diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 4cd238dd5502..db5e563932cd 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -21,6 +21,8 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist +import torch.nn.functional as F if torch.distributed.is_available(): @@ -44,11 +46,8 @@ is_xformers_version, ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS -from ._ulysses_anything_utils import ( - all_to_all_single_any_o_async, - all_to_all_single_any_qkv_async, - ulysses_anything_metadata, -) +from ..utils.torch_utils import maybe_allow_in_graph +from ._modeling_parallel import _gather_size_by_comm if TYPE_CHECKING: @@ -1284,6 +1283,154 @@ def backward(ctx, grad_outputs): return (None, grad_input, None, None) +# Below are helper functions to handle abritrary head num and abritrary sequence length for Ulysses Anything Attention. +def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]: + r"""Maybe pad the head dimension to be divisible by world_size. + x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded + tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD + """ + world_size = dist.get_world_size(group=group) + H_PAD = 0 + if H % world_size != 0: + H_PAD = world_size - (H % world_size) + NEW_H_LOCAL = (H + H_PAD) // world_size + # e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2. + # NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14. + assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" + x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() + return x, H_PAD + + +def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: + r"""Maybe unpad the head dimension. + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, + unpadded tensor (B, S_GLOBAL, H_LOCAL, D) + """ + rank = dist.get_rank(group=group) + world_size = dist.get_world_size(group=group) + # Only the last rank may have padding + if H_PAD > 0 and rank == world_size - 1: + x = x[:, :, :-H_PAD, :] + return x.contiguous() + + +def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]: + r"""Maybe pad the head dimension to be divisible by world_size. + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int], + padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD + """ + if H is None: + return x, 0 + + rank = dist.get_rank(group=group) + world_size = dist.get_world_size(group=group) + H_PAD = 0 + # Only the last rank may need padding + if H % world_size != 0: + # We need to broadcast H_PAD to all ranks to keep consistency + # in unpadding step later for all ranks. + H_PAD = world_size - (H % world_size) + NEW_H_LOCAL = (H + H_PAD) // world_size + assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" + if rank == world_size - 1: + x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() + return x, H_PAD + + +def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: + r"""Maybe unpad the head dimension. + x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, + unpadded tensor (B, S_LOCAL, H_GLOBAL, D) + """ + if H_PAD > 0: + x = x[:, :, :-H_PAD, :] + return x.contiguous() + + +def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict: + # query: (B, S_LOCAL, H_GLOBAL, D) + assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)" + extra_kwargs = {} + extra_kwargs["NUM_QO_HEAD"] = query.shape[2] + extra_kwargs["Q_S_LOCAL"] = query.shape[1] + # Add other kwargs if needed in future + return extra_kwargs + + +@maybe_allow_in_graph +def all_to_all_single_any_qkv_async( + x: torch.Tensor, group: dist.ProcessGroup, **kwargs +) -> Callable[..., torch.Tensor]: + r""" + x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D) + """ + world_size = dist.get_world_size(group=group) + B, S_LOCAL, H, D = x.shape + x, H_PAD = _maybe_pad_qkv_head(x, H, group) + H_LOCAL = (H + H_PAD) // world_size + # (world_size, S_LOCAL, B, H_LOCAL, D) + x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + + input_split_sizes = [S_LOCAL] * world_size + # S_LOCAL maybe not equal for all ranks in dynamic shape case, + # since we don't know the actual shape before this timing, thus, + # we have to use all gather to collect the S_LOCAL first. + output_split_sizes = _gather_size_by_comm(S_LOCAL, group) + x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D) + x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group) + + def wait() -> torch.Tensor: + nonlocal x, H_PAD + x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) + # (S_GLOBAL, B, H_LOCAL, D) + # -> (B, S_GLOBAL, H_LOCAL, D) + x = x.permute(1, 0, 2, 3).contiguous() + x = _maybe_unpad_qkv_head(x, H_PAD, group) + return x + + return wait + + +@maybe_allow_in_graph +def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]: + r""" + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D) + """ + # Assume H is provided in kwargs, since we can't infer H from x's shape. + # The padding logic needs H to determine if padding is necessary. + H = kwargs.get("NUM_QO_HEAD", None) + world_size = dist.get_world_size(group=group) + + x, H_PAD = _maybe_pad_o_head(x, H, group) + shape = x.shape # (B, S_GLOBAL, H_LOCAL, D) + (B, S_GLOBAL, H_LOCAL, D) = shape + + # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..] + # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..] + + # WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer + # from tensor split due to: if c = torch.cat((a, b)), world_size=4, then, + # c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] + + # b.tensor_split(4)[0].shape[1]) + + S_LOCAL = kwargs.get("Q_S_LOCAL") + input_split_sizes = _gather_size_by_comm(S_LOCAL, group) + x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D) + output_split_sizes = [S_LOCAL] * world_size + x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group) + + def wait() -> torch.Tensor: + nonlocal x, H_PAD + x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) + x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D) + x = x.permute(2, 1, 0, 3, 4).contiguous() + x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D) + x = _maybe_unpad_o_head(x, H_PAD, group) + return x + + return wait + + class TemplatedRingAttention(torch.autograd.Function): @staticmethod def forward(