diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index e7ec1480aabd..bdaa2ae8ffff 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -343,6 +343,34 @@ We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script]( From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention. + +### Ulysses Anything Attention + +The default Ulysses Attention mechanism requires that the sequence length of hidden states must be divisible by the number of devices. This imposes significant limitations on the practical application of Ulysses Attention. [Ulysses Anything Attention](https://github.com/huggingface/diffusers/pull/12996) is a variant of Ulysses Attention that supports arbitrary sequence lengths and arbitrary numbers of attention heads, thereby enhancing the versatility of Ulysses Attention in practical use. + +[`ContextParallelConfig`] supports Ulysses Anything Attention by specifying both `ulysses_degree` and `ulysses_anything`. Please note that Ulysses Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with both `ulysses_degree` set to bigger than 1 and `ulysses_anything=True` to [`~ModelMixin.enable_parallelism`]. + +```py +pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ulysses_anything=True)) +``` + +> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency. + +We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996#issuecomment-3797695999) on a node of 4 L20 GPUs. The results are summarized as follows: + +| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)| +|--------------------|------------------|-------------|------------------|------------| +| ulysses | 281.07 | 3.56 | 37.11 | 1024x1024 | +| ring | 351.34 | 2.85 | 37.01 | 1024x1024 | +| unified_balanced | 324.37 | 3.08 | 37.16 | 1024x1024 | +| ulysses_anything | 280.94 | 3.56 | 37.11 | 1024x1024 | +| ulysses | failed | failed | failed | 1008x1008 | +| ring | failed | failed | failed | 1008x1008 | +| unified_balanced | failed | failed | failed | 1008x1008 | +| ulysses_anything | 278.40 | 3.59 | 36.99 | 1008x1008 | + +From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention. + ### parallel_config Pass `parallel_config` during model initialization to enable context parallelism. diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 6491d17b4f46..44c048032f67 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy +import functools import inspect from dataclasses import dataclass -from typing import Dict, List, Type, Union +from typing import Dict, List, Tuple, Type, Union import torch +import torch.distributed as dist if torch.distributed.is_available(): @@ -27,9 +29,10 @@ ContextParallelInput, ContextParallelModelPlan, ContextParallelOutput, + _gather_size_by_comm, ) from ..utils import get_logger -from ..utils.torch_utils import unwrap_module +from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module from .hooks import HookRegistry, ModelHook @@ -208,6 +211,10 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> ) return x else: + if self.parallel_config.ulysses_anything: + return PartitionAnythingSharder.shard_anything( + x, cp_input.split_dim, self.parallel_config._flattened_mesh + ) return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) @@ -233,7 +240,14 @@ def post_forward(self, module, output): for i, cpm in enumerate(self.metadata): if cpm is None: continue - output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh) + if self.parallel_config.ulysses_anything: + output[i] = PartitionAnythingSharder.unshard_anything( + output[i], cpm.gather_dim, self.parallel_config._flattened_mesh + ) + else: + output[i] = EquipartitionSharder.unshard( + output[i], cpm.gather_dim, self.parallel_config._flattened_mesh + ) return output[0] if is_tensor else tuple(output) @@ -274,6 +288,73 @@ def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_ return tensor +class AllGatherAnythingFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh): + ctx.dim = dim + ctx.group = group + ctx.world_size = dist.get_world_size(group) + ctx.rank = dist.get_rank(group) + gathered_tensor = _all_gather_anything(tensor, dim, group) + return gathered_tensor + + @staticmethod + def backward(ctx, grad_output): + # NOTE: We use `tensor_split` instead of chunk, because the `chunk` + # function may return fewer than the specified number of chunks! + grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim) + return grad_splits[ctx.rank], None, None + + +class PartitionAnythingSharder: + @classmethod + def shard_anything( + cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh + ) -> torch.Tensor: + assert tensor.size()[dim] >= mesh.size(), ( + f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}." + ) + # NOTE: We use `tensor_split` instead of chunk, because the `chunk` + # function may return fewer than the specified number of chunks! + return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())] + + @classmethod + def unshard_anything( + cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh + ) -> torch.Tensor: + tensor = tensor.contiguous() + tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group()) + return tensor + + +@functools.lru_cache(maxsize=64) +def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]: + gather_shapes = [] + for i in range(world_size): + rank_shape = list(copy.deepcopy(shape)) + rank_shape[dim] = gather_dims[i] + gather_shapes.append(rank_shape) + return gather_shapes + + +@maybe_allow_in_graph +def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor: + world_size = dist.get_world_size(group=group) + + tensor = tensor.contiguous() + shape = tensor.shape + rank_dim = shape[dim] + gather_dims = _gather_size_by_comm(rank_dim, group) + + gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size) + + gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes] + + dist.all_gather(gathered_tensors, tensor, group=group) + gathered_tensor = torch.cat(gathered_tensors, dim=dim) + return gathered_tensor + + def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: if name.count("*") > 1: raise ValueError("Wildcard '*' can only be used once in the name") diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 1c7703a13c52..a8fe2c7f5517 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union import torch +import torch.distributed as dist from ..utils import get_logger @@ -67,6 +68,9 @@ class ContextParallelConfig: convert_to_fp32: bool = True # TODO: support alltoall rotate_method: Literal["allgather", "alltoall"] = "allgather" + # Whether to enable ulysses anything attention to support + # any sequence lengths and any head numbers. + ulysses_anything: bool = False _rank: int = None _world_size: int = None @@ -94,6 +98,11 @@ def __post_init__(self): raise NotImplementedError( f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." ) + if self.ulysses_anything: + if self.ulysses_degree == 1: + raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.") + if self.ring_degree > 1: + raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.") @property def mesh_shape(self) -> Tuple[int, int]: @@ -257,3 +266,39 @@ def __repr__(self): # # ContextParallelOutput: # specifies how to gather the input tensor in the post-forward hook in the layer it is attached to + + +# Below are utility functions for distributed communication in context parallelism. +def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: + r"""Gather the local size from all ranks. + size: int, local size return: List[int], list of size from all ranks + """ + # NOTE(Serving/CP Safety): + # Do NOT cache this collective result. + # + # In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL) + # may legitimately differ across ranks. If we cache based on the *local* `size`, + # different ranks can have different cache hit/miss patterns across time. + # + # That can lead to a catastrophic distributed hang: + # - some ranks hit cache and *skip* dist.all_gather() + # - other ranks miss cache and *enter* dist.all_gather() + # This mismatched collective participation will stall the process group and + # eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL + # timeouts in Ulysses attention). + world_size = dist.get_world_size(group=group) + # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead + comm_backends = str(dist.get_backend(group=group)) + # NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl") + gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator() + gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)] + dist.all_gather( + gathered_sizes, + torch.tensor([size], device=gather_device, dtype=torch.int64), + group=group, + ) + + gathered_sizes = [s[0].item() for s in gathered_sizes] + # NOTE: DON'T use tolist here due to graph break - Explanation: + # Backend compiler `inductor` failed with aten._local_scalar_dense.default + return gathered_sizes diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 61c478b03c4f..db5e563932cd 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -21,6 +21,8 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist +import torch.nn.functional as F if torch.distributed.is_available(): @@ -44,6 +46,8 @@ is_xformers_version, ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS +from ..utils.torch_utils import maybe_allow_in_graph +from ._modeling_parallel import _gather_size_by_comm if TYPE_CHECKING: @@ -1279,6 +1283,154 @@ def backward(ctx, grad_outputs): return (None, grad_input, None, None) +# Below are helper functions to handle abritrary head num and abritrary sequence length for Ulysses Anything Attention. +def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]: + r"""Maybe pad the head dimension to be divisible by world_size. + x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded + tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD + """ + world_size = dist.get_world_size(group=group) + H_PAD = 0 + if H % world_size != 0: + H_PAD = world_size - (H % world_size) + NEW_H_LOCAL = (H + H_PAD) // world_size + # e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2. + # NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14. + assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" + x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() + return x, H_PAD + + +def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: + r"""Maybe unpad the head dimension. + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, + unpadded tensor (B, S_GLOBAL, H_LOCAL, D) + """ + rank = dist.get_rank(group=group) + world_size = dist.get_world_size(group=group) + # Only the last rank may have padding + if H_PAD > 0 and rank == world_size - 1: + x = x[:, :, :-H_PAD, :] + return x.contiguous() + + +def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]: + r"""Maybe pad the head dimension to be divisible by world_size. + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int], + padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD + """ + if H is None: + return x, 0 + + rank = dist.get_rank(group=group) + world_size = dist.get_world_size(group=group) + H_PAD = 0 + # Only the last rank may need padding + if H % world_size != 0: + # We need to broadcast H_PAD to all ranks to keep consistency + # in unpadding step later for all ranks. + H_PAD = world_size - (H % world_size) + NEW_H_LOCAL = (H + H_PAD) // world_size + assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" + if rank == world_size - 1: + x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() + return x, H_PAD + + +def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: + r"""Maybe unpad the head dimension. + x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, + unpadded tensor (B, S_LOCAL, H_GLOBAL, D) + """ + if H_PAD > 0: + x = x[:, :, :-H_PAD, :] + return x.contiguous() + + +def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict: + # query: (B, S_LOCAL, H_GLOBAL, D) + assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)" + extra_kwargs = {} + extra_kwargs["NUM_QO_HEAD"] = query.shape[2] + extra_kwargs["Q_S_LOCAL"] = query.shape[1] + # Add other kwargs if needed in future + return extra_kwargs + + +@maybe_allow_in_graph +def all_to_all_single_any_qkv_async( + x: torch.Tensor, group: dist.ProcessGroup, **kwargs +) -> Callable[..., torch.Tensor]: + r""" + x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D) + """ + world_size = dist.get_world_size(group=group) + B, S_LOCAL, H, D = x.shape + x, H_PAD = _maybe_pad_qkv_head(x, H, group) + H_LOCAL = (H + H_PAD) // world_size + # (world_size, S_LOCAL, B, H_LOCAL, D) + x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + + input_split_sizes = [S_LOCAL] * world_size + # S_LOCAL maybe not equal for all ranks in dynamic shape case, + # since we don't know the actual shape before this timing, thus, + # we have to use all gather to collect the S_LOCAL first. + output_split_sizes = _gather_size_by_comm(S_LOCAL, group) + x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D) + x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group) + + def wait() -> torch.Tensor: + nonlocal x, H_PAD + x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) + # (S_GLOBAL, B, H_LOCAL, D) + # -> (B, S_GLOBAL, H_LOCAL, D) + x = x.permute(1, 0, 2, 3).contiguous() + x = _maybe_unpad_qkv_head(x, H_PAD, group) + return x + + return wait + + +@maybe_allow_in_graph +def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]: + r""" + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D) + """ + # Assume H is provided in kwargs, since we can't infer H from x's shape. + # The padding logic needs H to determine if padding is necessary. + H = kwargs.get("NUM_QO_HEAD", None) + world_size = dist.get_world_size(group=group) + + x, H_PAD = _maybe_pad_o_head(x, H, group) + shape = x.shape # (B, S_GLOBAL, H_LOCAL, D) + (B, S_GLOBAL, H_LOCAL, D) = shape + + # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..] + # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..] + + # WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer + # from tensor split due to: if c = torch.cat((a, b)), world_size=4, then, + # c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] + + # b.tensor_split(4)[0].shape[1]) + + S_LOCAL = kwargs.get("Q_S_LOCAL") + input_split_sizes = _gather_size_by_comm(S_LOCAL, group) + x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D) + output_split_sizes = [S_LOCAL] * world_size + x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group) + + def wait() -> torch.Tensor: + nonlocal x, H_PAD + x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) + x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D) + x = x.permute(2, 1, 0, 3, 4).contiguous() + x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D) + x = _maybe_unpad_o_head(x, H_PAD, group) + return x + + return wait + + class TemplatedRingAttention(torch.autograd.Function): @staticmethod def forward( @@ -1501,6 +1653,82 @@ def backward( return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None +class TemplatedUlyssesAnythingAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + **kwargs, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + group = ulysses_mesh.get_group() + + ctx.forward_op = forward_op + ctx.backward_op = backward_op + ctx._parallel_config = _parallel_config + + metadata = ulysses_anything_metadata(query) + query_wait = all_to_all_single_any_qkv_async(query, group, **metadata) + key_wait = all_to_all_single_any_qkv_async(key, group, **metadata) + value_wait = all_to_all_single_any_qkv_async(value, group, **metadata) + + query = query_wait() # type: torch.Tensor + key = key_wait() # type: torch.Tensor + value = value_wait() # type: torch.Tensor + + out = forward_op( + ctx, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + _save_ctx=False, # ulysses anything only support forward pass now. + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse, *_ = out + + # out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D) + out_wait = all_to_all_single_any_o_async(out, group, **metadata) + + if return_lse: + # lse: (B, S_Q_GLOBAL, H_LOCAL) + lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1) + lse_wait = all_to_all_single_any_o_async(lse, group, **metadata) + out = out_wait() # type: torch.Tensor + lse = lse_wait() # type: torch.Tensor + lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL) + else: + out = out_wait() # type: torch.Tensor + lse = None + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + ): + raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.") + + def _templated_unified_attention( query: torch.Tensor, key: torch.Tensor, @@ -1618,20 +1846,37 @@ def _templated_context_parallel_attention( _parallel_config, ) elif _parallel_config.context_parallel_config.ulysses_degree > 1: - return TemplatedUlyssesAttention.apply( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - forward_op, - backward_op, - _parallel_config, - ) + if _parallel_config.context_parallel_config.ulysses_anything: + # For Any sequence lengths and Any head num support + return TemplatedUlyssesAnythingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + else: + return TemplatedUlyssesAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) else: raise ValueError("Reaching this branch of code is unexpected. Please report a bug.")