Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions pathwaysutils/experimental/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,21 @@
import base64
import collections
import json
import logging
import math
import operator
from typing import Any, Dict, Sequence

import jax
from pathwaysutils import lru_cache
from pathwaysutils import plugin_executable
from pathwaysutils.experimental import split_by_mesh_axis


_logger = logging.getLogger(__name__)

INTERMEDIATE_SPLIT_SUFFIX = "_intermediate_split"
INTERMEDIATE_REPLICA_SUFFIX = "_intermediate_replica"


class ReshardingPlanWrapper:
Expand Down Expand Up @@ -198,3 +208,225 @@ def reshard(
result[idx] = arr

return jax.tree.unflatten(tree_def, result)


class NoIntermediateShardingError(Exception):
"""Raised when no intermediate sharding is found."""


class NoIntermediateShardingNeededError(NoIntermediateShardingError):
"""Raised when no intermediate sharding is needed for optimization."""


def _get_sharding_spec_dims(sharding: jax.sharding.NamedSharding):
"""Gets the sharding dimension sizes from a NamedSharding."""
mesh = sharding.mesh
dims = []
for spec in sharding.spec:
if spec is None:
dims.append(1)
elif isinstance(spec, str):
dims.append(mesh.shape[spec])
elif isinstance(spec, (list, tuple)):
dims.append(math.prod([mesh.shape[ax] for ax in spec]))
else:
raise ValueError(f"Unsupported partition spec: {spec}")
return dims


def _check_sharding_divisibility(
in_sharding: jax.sharding.NamedSharding,
out_sharding: jax.sharding.NamedSharding,
src_dims: Sequence[int],
dst_dims: Sequence[int],
):
"""Checks if source and destination shardings are compatible for optimization."""
src_largest_dim = max(src_dims) if src_dims else 1
dst_largest_dim = max(dst_dims) if dst_dims else 1
src_total_dims = math.prod(src_dims)
dst_total_dims = math.prod(dst_dims)

# Not able to handle resharding with undividable shardings.
if src_largest_dim % dst_largest_dim != 0:
raise NoIntermediateShardingError(
"Resharding with undividable shardings is not optimized with"
" intermediate sharding."
f" in_sharding={in_sharding}, out_sharding={out_sharding}"
)
if src_total_dims <= dst_total_dims:
raise NoIntermediateShardingError(
"No intermediate sharding is found because the source sharding is not"
" larger than the target sharding."
f" in_sharding={in_sharding}, out_sharding={out_sharding}"
)
if src_total_dims % dst_total_dims != 0:
raise NoIntermediateShardingError(
"No intermediate sharding is found because the source sharding is not"
" divisible by the target sharding."
f" in_sharding={in_sharding}, out_sharding={out_sharding}"
)


def _get_split_candidates(
in_sharding: jax.sharding.NamedSharding,
src_dims: Sequence[int],
dst_dims: Sequence[int],
gcd_shards: Sequence[int],
) -> list[tuple[int, str]]:
"""Finds dimensions that are candidates for splitting."""
split_candidates = []
for i, spec in enumerate(in_sharding.spec):
# TODO(b/1234) - Support splitting a dimension that is sharded over multiple
# mesh axes.
if (
gcd_shards[i] == 1
and src_dims[i] > dst_dims[i]
and isinstance(spec, str)
):
split_candidates.append((i, spec))

if not split_candidates:
raise NoIntermediateShardingError(
"No intermediate sharding is found because all of the"
" gcd(src_dim_shards, dst_dim_shards) are 1s, or no suitable"
" dimension to split."
)
return split_candidates


def _build_intermediate_mesh_and_spec(
src_mesh: jax.sharding.Mesh,
in_spec: jax.sharding.PartitionSpec,
src_dims: Sequence[int],
dst_dims: Sequence[int],
split_candidates: list[tuple[int, str]],
) -> tuple[jax.sharding.Mesh, jax.sharding.PartitionSpec, list[str]]:
"""Builds the intermediate Mesh and PartitionSpec."""
# Build a map of mesh axis to split information: (dim_idx, replicas)
mesh_axis_to_split_info = {}
for dim_idx, mesh_axis in split_candidates:
src_dim = src_dims[dim_idx]
dst_dim = dst_dims[dim_idx]
replicas = src_dim // dst_dim
mesh_axis_to_split_info[mesh_axis] = (dim_idx, replicas)

# Build the intermediate mesh by expanding axes that need splitting.
new_replicated_axis_names = []
new_replicated_mesh_shape = []
new_axis_names = []
new_mesh_shape = []
for axis_name in src_mesh.axis_names:
axis_size = src_mesh.shape[axis_name]
if axis_name in mesh_axis_to_split_info:
dim_idx, replicas = mesh_axis_to_split_info[axis_name]
dst_dim = dst_dims[dim_idx]
split_axis_name = axis_name + INTERMEDIATE_SPLIT_SUFFIX
replica_axis_name = axis_name + INTERMEDIATE_REPLICA_SUFFIX
new_replicated_axis_names.append(replica_axis_name)
new_replicated_mesh_shape.append(replicas)
new_axis_names.append(split_axis_name)
new_mesh_shape.append(dst_dim)
else:
new_axis_names.append(axis_name)
new_mesh_shape.append(axis_size)

new_axis_names = new_replicated_axis_names + new_axis_names
new_mesh_shape = new_replicated_mesh_shape + new_mesh_shape
intermediate_mesh = jax.sharding.Mesh(
src_mesh.devices.reshape(new_mesh_shape),
axis_names=tuple(new_axis_names),
)

# Build the intermediate PartitionSpec.
intermediate_spec_list = list(in_spec)
for dim_idx, mesh_axis in split_candidates:
split_axis_name = mesh_axis + INTERMEDIATE_SPLIT_SUFFIX
intermediate_spec_list[dim_idx] = split_axis_name
intermediate_spec = jax.sharding.PartitionSpec(*intermediate_spec_list)

return intermediate_mesh, intermediate_spec, new_replicated_axis_names


def find_intermediate_sharding(
in_sharding: jax.sharding.Sharding, out_sharding: jax.sharding.Sharding
) -> tuple[jax.sharding.NamedSharding, list[str]]:
"""Finds an intermediate sharding to reshard to before target sharding.

This function tries to find an intermediate sharding that can be used to
reshard the in_sharding to the out_sharding. This is useful when resharding
from an in_sharding to an out_sharding that requires an all-gather, which can
be expensive.

For example, consider resharding an array from in_sharding (e.g., [fsdp: 8,
tp: 1]) to out_sharding (e.g., [fsdp: 1, tp: 4]). In this case, the source
has a larger sharding factor, 8, than the target's largest sharding factor, 4.
To avoid an expensive all-gather, we introduce an intermediate sharding, e.g.,
[fsdp_split: 4, tp: 1, fsdp_replica: 2]). This intermediate sharding
allows us to reshard the source array by still sharding along the fsdp
dimension and replicating it on the remaining devices. Then we can reshard any
replica of the source to the target as normal.

Args:
in_sharding: The source sharding.
out_sharding: The target sharding.

Returns:
A tuple containing:
- An intermediate sharding.
- A list of axis names that are replicated in the intermediate sharding.

Raises:
NoIntermediateShardingError: If no intermediate sharding is found.
NoIntermediateShardingNeededError: If no intermediate sharding is needed for
optimization.
"""
if not isinstance(in_sharding, jax.sharding.NamedSharding) or not isinstance(
out_sharding, jax.sharding.NamedSharding
):
raise NoIntermediateShardingError(
"Only NamedSharding is supported for now. Got"
f" in_sharding={in_sharding} and out_sharding={out_sharding}"
)
src_mesh = in_sharding.mesh

if len(in_sharding.spec) != len(out_sharding.spec):
raise NoIntermediateShardingError(
"Source and destination shardings must have the same rank (same"
f" PartitionSpec length). Got in_sharding.spec={in_sharding.spec} and"
f" out_sharding.spec={out_sharding.spec}"
)

src_dims = _get_sharding_spec_dims(in_sharding)
dst_dims = _get_sharding_spec_dims(out_sharding)

_check_sharding_divisibility(in_sharding, out_sharding, src_dims, dst_dims)

gcd_shards = jax.tree.map(math.gcd, src_dims, dst_dims)

# If all of the gcd(src_dim_shards, dst_dim_shards) are 1s, an all-gather is
# needed as the single replica of the source cannot be presented by any
# sharded form on the target devices.
if jax.tree.reduce(operator.mul, gcd_shards, 1) != 1:
raise NoIntermediateShardingNeededError()

try:
split_candidates = _get_split_candidates(
in_sharding, src_dims, dst_dims, gcd_shards
)
except NoIntermediateShardingError as e:
raise NoIntermediateShardingError(
f"{e} in_sharding={in_sharding}, out_sharding={out_sharding}"
) from e

intermediate_mesh, intermediate_spec, replicated_axes = (
_build_intermediate_mesh_and_spec(
src_mesh, in_sharding.spec, src_dims, dst_dims, split_candidates
)
)

intermediate_sharding = jax.sharding.NamedSharding(
intermediate_mesh,
intermediate_spec,
memory_kind=in_sharding.memory_kind,
)
return intermediate_sharding, replicated_axes