diff --git a/pathwaysutils/experimental/reshard.py b/pathwaysutils/experimental/reshard.py index 4f2b19a..e199c60 100644 --- a/pathwaysutils/experimental/reshard.py +++ b/pathwaysutils/experimental/reshard.py @@ -16,11 +16,21 @@ import base64 import collections import json +import logging +import math +import operator from typing import Any, Dict, Sequence import jax from pathwaysutils import lru_cache from pathwaysutils import plugin_executable +from pathwaysutils.experimental import split_by_mesh_axis + + +_logger = logging.getLogger(__name__) + +INTERMEDIATE_SPLIT_SUFFIX = "_intermediate_split" +INTERMEDIATE_REPLICA_SUFFIX = "_intermediate_replica" class ReshardingPlanWrapper: @@ -198,3 +208,225 @@ def reshard( result[idx] = arr return jax.tree.unflatten(tree_def, result) + + +class NoIntermediateShardingError(Exception): + """Raised when no intermediate sharding is found.""" + + +class NoIntermediateShardingNeededError(NoIntermediateShardingError): + """Raised when no intermediate sharding is needed for optimization.""" + + +def _get_sharding_spec_dims(sharding: jax.sharding.NamedSharding): + """Gets the sharding dimension sizes from a NamedSharding.""" + mesh = sharding.mesh + dims = [] + for spec in sharding.spec: + if spec is None: + dims.append(1) + elif isinstance(spec, str): + dims.append(mesh.shape[spec]) + elif isinstance(spec, (list, tuple)): + dims.append(math.prod([mesh.shape[ax] for ax in spec])) + else: + raise ValueError(f"Unsupported partition spec: {spec}") + return dims + + +def _check_sharding_divisibility( + in_sharding: jax.sharding.NamedSharding, + out_sharding: jax.sharding.NamedSharding, + src_dims: Sequence[int], + dst_dims: Sequence[int], +): + """Checks if source and destination shardings are compatible for optimization.""" + src_largest_dim = max(src_dims) if src_dims else 1 + dst_largest_dim = max(dst_dims) if dst_dims else 1 + src_total_dims = math.prod(src_dims) + dst_total_dims = math.prod(dst_dims) + + # Not able to handle resharding with undividable shardings. + if src_largest_dim % dst_largest_dim != 0: + raise NoIntermediateShardingError( + "Resharding with undividable shardings is not optimized with" + " intermediate sharding." + f" in_sharding={in_sharding}, out_sharding={out_sharding}" + ) + if src_total_dims <= dst_total_dims: + raise NoIntermediateShardingError( + "No intermediate sharding is found because the source sharding is not" + " larger than the target sharding." + f" in_sharding={in_sharding}, out_sharding={out_sharding}" + ) + if src_total_dims % dst_total_dims != 0: + raise NoIntermediateShardingError( + "No intermediate sharding is found because the source sharding is not" + " divisible by the target sharding." + f" in_sharding={in_sharding}, out_sharding={out_sharding}" + ) + + +def _get_split_candidates( + in_sharding: jax.sharding.NamedSharding, + src_dims: Sequence[int], + dst_dims: Sequence[int], + gcd_shards: Sequence[int], +) -> list[tuple[int, str]]: + """Finds dimensions that are candidates for splitting.""" + split_candidates = [] + for i, spec in enumerate(in_sharding.spec): + # TODO(b/1234) - Support splitting a dimension that is sharded over multiple + # mesh axes. + if ( + gcd_shards[i] == 1 + and src_dims[i] > dst_dims[i] + and isinstance(spec, str) + ): + split_candidates.append((i, spec)) + + if not split_candidates: + raise NoIntermediateShardingError( + "No intermediate sharding is found because all of the" + " gcd(src_dim_shards, dst_dim_shards) are 1s, or no suitable" + " dimension to split." + ) + return split_candidates + + +def _build_intermediate_mesh_and_spec( + src_mesh: jax.sharding.Mesh, + in_spec: jax.sharding.PartitionSpec, + src_dims: Sequence[int], + dst_dims: Sequence[int], + split_candidates: list[tuple[int, str]], +) -> tuple[jax.sharding.Mesh, jax.sharding.PartitionSpec, list[str]]: + """Builds the intermediate Mesh and PartitionSpec.""" + # Build a map of mesh axis to split information: (dim_idx, replicas) + mesh_axis_to_split_info = {} + for dim_idx, mesh_axis in split_candidates: + src_dim = src_dims[dim_idx] + dst_dim = dst_dims[dim_idx] + replicas = src_dim // dst_dim + mesh_axis_to_split_info[mesh_axis] = (dim_idx, replicas) + + # Build the intermediate mesh by expanding axes that need splitting. + new_replicated_axis_names = [] + new_replicated_mesh_shape = [] + new_axis_names = [] + new_mesh_shape = [] + for axis_name in src_mesh.axis_names: + axis_size = src_mesh.shape[axis_name] + if axis_name in mesh_axis_to_split_info: + dim_idx, replicas = mesh_axis_to_split_info[axis_name] + dst_dim = dst_dims[dim_idx] + split_axis_name = axis_name + INTERMEDIATE_SPLIT_SUFFIX + replica_axis_name = axis_name + INTERMEDIATE_REPLICA_SUFFIX + new_replicated_axis_names.append(replica_axis_name) + new_replicated_mesh_shape.append(replicas) + new_axis_names.append(split_axis_name) + new_mesh_shape.append(dst_dim) + else: + new_axis_names.append(axis_name) + new_mesh_shape.append(axis_size) + + new_axis_names = new_replicated_axis_names + new_axis_names + new_mesh_shape = new_replicated_mesh_shape + new_mesh_shape + intermediate_mesh = jax.sharding.Mesh( + src_mesh.devices.reshape(new_mesh_shape), + axis_names=tuple(new_axis_names), + ) + + # Build the intermediate PartitionSpec. + intermediate_spec_list = list(in_spec) + for dim_idx, mesh_axis in split_candidates: + split_axis_name = mesh_axis + INTERMEDIATE_SPLIT_SUFFIX + intermediate_spec_list[dim_idx] = split_axis_name + intermediate_spec = jax.sharding.PartitionSpec(*intermediate_spec_list) + + return intermediate_mesh, intermediate_spec, new_replicated_axis_names + + +def find_intermediate_sharding( + in_sharding: jax.sharding.Sharding, out_sharding: jax.sharding.Sharding +) -> tuple[jax.sharding.NamedSharding, list[str]]: + """Finds an intermediate sharding to reshard to before target sharding. + + This function tries to find an intermediate sharding that can be used to + reshard the in_sharding to the out_sharding. This is useful when resharding + from an in_sharding to an out_sharding that requires an all-gather, which can + be expensive. + + For example, consider resharding an array from in_sharding (e.g., [fsdp: 8, + tp: 1]) to out_sharding (e.g., [fsdp: 1, tp: 4]). In this case, the source + has a larger sharding factor, 8, than the target's largest sharding factor, 4. + To avoid an expensive all-gather, we introduce an intermediate sharding, e.g., + [fsdp_split: 4, tp: 1, fsdp_replica: 2]). This intermediate sharding + allows us to reshard the source array by still sharding along the fsdp + dimension and replicating it on the remaining devices. Then we can reshard any + replica of the source to the target as normal. + + Args: + in_sharding: The source sharding. + out_sharding: The target sharding. + + Returns: + A tuple containing: + - An intermediate sharding. + - A list of axis names that are replicated in the intermediate sharding. + + Raises: + NoIntermediateShardingError: If no intermediate sharding is found. + NoIntermediateShardingNeededError: If no intermediate sharding is needed for + optimization. + """ + if not isinstance(in_sharding, jax.sharding.NamedSharding) or not isinstance( + out_sharding, jax.sharding.NamedSharding + ): + raise NoIntermediateShardingError( + "Only NamedSharding is supported for now. Got" + f" in_sharding={in_sharding} and out_sharding={out_sharding}" + ) + src_mesh = in_sharding.mesh + + if len(in_sharding.spec) != len(out_sharding.spec): + raise NoIntermediateShardingError( + "Source and destination shardings must have the same rank (same" + f" PartitionSpec length). Got in_sharding.spec={in_sharding.spec} and" + f" out_sharding.spec={out_sharding.spec}" + ) + + src_dims = _get_sharding_spec_dims(in_sharding) + dst_dims = _get_sharding_spec_dims(out_sharding) + + _check_sharding_divisibility(in_sharding, out_sharding, src_dims, dst_dims) + + gcd_shards = jax.tree.map(math.gcd, src_dims, dst_dims) + + # If all of the gcd(src_dim_shards, dst_dim_shards) are 1s, an all-gather is + # needed as the single replica of the source cannot be presented by any + # sharded form on the target devices. + if jax.tree.reduce(operator.mul, gcd_shards, 1) != 1: + raise NoIntermediateShardingNeededError() + + try: + split_candidates = _get_split_candidates( + in_sharding, src_dims, dst_dims, gcd_shards + ) + except NoIntermediateShardingError as e: + raise NoIntermediateShardingError( + f"{e} in_sharding={in_sharding}, out_sharding={out_sharding}" + ) from e + + intermediate_mesh, intermediate_spec, replicated_axes = ( + _build_intermediate_mesh_and_spec( + src_mesh, in_sharding.spec, src_dims, dst_dims, split_candidates + ) + ) + + intermediate_sharding = jax.sharding.NamedSharding( + intermediate_mesh, + intermediate_spec, + memory_kind=in_sharding.memory_kind, + ) + return intermediate_sharding, replicated_axes