Skip to content

Conversation

@copybara-service
Copy link

Add find_intermediate_sharding to optimize resharding.

This change introduces find_intermediate_sharding, which attempts to find an intermediate sharding to optimize resharding between two NamedShardings. This is particularly useful when resharding from a sharding with a larger total number of shards to one with a smaller total number of shards, where a direct resharding might involve an expensive all-gather. The intermediate sharding is constructed by splitting certain sharded dimensions into a "split" axis and a "replica" axis, allowing for a more efficient two-step resharding process. Helper functions are added to analyze sharding dimensions and build the intermediate mesh and PartitionSpec. Tests are included for the new functionality.

This change introduces `find_intermediate_sharding`, which attempts to find an intermediate sharding to optimize resharding between two NamedShardings. This is particularly useful when resharding from a sharding with a larger total number of shards to one with a smaller total number of shards, where a direct resharding might involve an expensive all-gather. The intermediate sharding is constructed by splitting certain sharded dimensions into a "split" axis and a "replica" axis, allowing for a more efficient two-step resharding process. Helper functions are added to analyze sharding dimensions and build the intermediate mesh and PartitionSpec. Tests are included for the new functionality.

PiperOrigin-RevId: 846807164
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant