Skip to content

Comments

[JAX] Deprecate GSPMD: remove infer_sharding_from_operands and GSPMD tests#2702

Open
phu0ngng wants to merge 5 commits intoNVIDIA:mainfrom
phu0ngng:rm_gspmd
Open

[JAX] Deprecate GSPMD: remove infer_sharding_from_operands and GSPMD tests#2702
phu0ngng wants to merge 5 commits intoNVIDIA:mainfrom
phu0ngng:rm_gspmd

Conversation

@phu0ngng
Copy link
Collaborator

Description

GSPMD sharding propagation is being deprecated in favour of Shardy, which is now the default JAX partitioner. This commit removes all GSPMD-related code paths and tests:

  • Drop the infer_sharding_from_operands abstract method from BasePrimitive and remove it from def_partition() registration
  • Remove all infer_sharding_from_operands implementations across cpp_extensions: activation, amax, attention, gemm, normalization, quantization, and softmax primitives
  • Remove stale "Keep in sync with infer_sharding_from_operands" comments from FusedAttn shardy_sharding_rule methods
  • Drop all use_shardy=False (GSPMD) distributed test paths and the jax.config.update("jax_use_shardy_partitioner", ...) config calls
  • Consolidate paired GSPMD/Shardy test functions into single tests and strip _shardy suffixes from test names

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…tests

GSPMD sharding propagation is being deprecated in favour of Shardy,
which is now the default JAX partitioner. This commit removes all
GSPMD-related code paths and tests:

- Drop the infer_sharding_from_operands abstract method from
  BasePrimitive and remove it from def_partition() registration
- Remove all infer_sharding_from_operands implementations across
  cpp_extensions: activation, amax, attention, gemm, normalization,
  quantization, and softmax primitives
- Remove stale "Keep in sync with infer_sharding_from_operands"
  comments from FusedAttn shardy_sharding_rule methods
- Drop all use_shardy=False (GSPMD) distributed test paths and the
  jax.config.update("jax_use_shardy_partitioner", ...) config calls
- Consolidate paired GSPMD/Shardy test functions into single tests
  and strip _shardy suffixes from test names

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 24, 2026

Greptile Summary

This PR removes all GSPMD-related code paths from TransformerEngine JAX, as Shardy is now the default and only partitioner in JAX.

Core changes:

  • Removes the abstract method infer_sharding_from_operands from BasePrimitive and its registration in def_partition()
  • Removes all infer_sharding_from_operands implementations across 8 cpp_extensions primitives (activation, amax, attention, gemm, normalization, quantization, softmax)
  • Removes stale "Keep in sync with infer_sharding_from_operands" comments from FusedAttn shardy_sharding_rule methods
  • Removes all jax.config.update("jax_use_shardy_partitioner", ...) calls throughout tests and examples
  • Consolidates duplicate test methods by removing 23+ _shardy suffixed tests and removing use_shardy parameters
  • Removes workarounds that disabled Shardy for CollectiveGEMM (no longer needed)

Impact:

  • Breaking change: removes GSPMD support entirely
  • Total reduction: ~994 lines of deprecated code
  • Test coverage maintained through consolidated Shardy-only tests

Confidence Score: 5/5

  • This PR is safe to merge - it's a clean, well-scoped removal of deprecated GSPMD code
  • The changes are systematic and complete: all GSPMD references are cleanly removed, test coverage is preserved through consolidated Shardy tests, and the PR aligns with JAX's official deprecation of GSPMD in favor of Shardy
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/base.py Removes abstract method infer_sharding_from_operands from BasePrimitive and its registration in def_partition()
transformer_engine/jax/cpp_extensions/activation.py Removes infer_sharding_from_operands implementations from ActLuPrimitive and BaseDActLuDBiasQuantizePrimitive (151 lines)
transformer_engine/jax/cpp_extensions/attention.py Removes infer_sharding_from_operands from FusedAttnFwdPrimitive and FusedAttnBwdPrimitive, removes stale "Keep in sync" comments
transformer_engine/jax/cpp_extensions/gemm.py Removes 51 lines of infer_sharding_from_operands implementation from GemmPrimitive
tests/jax/test_distributed_fused_attn.py Removes use_shardy parameter and config updates, removes duplicate _shardy test methods, consolidates tests
tests/jax/test_distributed_layernorm_mlp.py Removes use_shardy parameter, consolidates duplicate _shardy test methods into single tests (89 lines removed)
examples/jax/collective_gemm/common.py Removes workaround that disabled Shardy for CollectiveGEMM, now unnecessary as Shardy is the only partitioner

Last reviewed commit: f6a06e5

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@phu0ngng
Copy link
Collaborator Author

/te-ci JAX L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

phu0ngng and others added 2 commits February 24, 2026 12:53
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

19 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant