diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 0d812da057..2965896d07 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -131,10 +131,6 @@ def _initialize_distributed(args): ) _distributed_initialized = True - jax.clear_caches() - jax.config.update( - "jax_use_shardy_partitioner", False - ) # CollectiveGEMM does not work with Shardy yet assert jax.local_device_count() == 1, ( f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found" diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index d2994723bb..9ccf1f560e 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -88,8 +88,6 @@ def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_shard def run_gemm_tests(args, mesh=None): """Execute GEMM tests.""" print(args) - # Collective GEMM requires Shardy partitioner to be disabled - jax.config.update("jax_use_shardy_partitioner", False) # Initialize distributed with provided arguments _initialize_distributed(args) diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 61c960a7aa..84cb011da1 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -119,8 +119,6 @@ def _value_and_grad_layernorm_mlp( def run_layernorm_mlp_grad_tests(args, mesh=None): """Execute Dense Gradient tests.""" print(args) - # Collective GEMM requires Shardy partitioner to be disabled - jax.config.update("jax_use_shardy_partitioner", False) # Initialize distributed with provided arguments _initialize_distributed(args) diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index f2ef33da46..3c1f2ba1fb 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -11,10 +11,6 @@ TEST_CASES=( "test_te_current_scaling_fp8" "test_te_mxfp8" "test_te_nvfp4" -"test_te_bf16_shardy" -"test_te_delayed_scaling_fp8_shardy" -"test_te_current_scaling_fp8_shardy" -"test_te_nvfp4_shardy" ) : ${TE_PATH:=/opt/transformerengine} diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 73b93798a0..4400485f26 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -239,7 +239,6 @@ def check_fp8(state, var_collect, inputs, masks, labels): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) - jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) @@ -474,9 +473,6 @@ def encoder_parser(args): parser.add_argument( "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism." ) - parser.add_argument( - "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." - ) return parser.parse_args(args) @@ -559,70 +555,6 @@ def test_te_nvfp4_with_sp(self): actual = train_and_evaluate(self.args) assert actual[0] < 0.40 and actual[1] > 0.82 - @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - def test_te_bf16_shardy(self): - """Test Transformer Engine with BF16""" - self.args.enable_shardy = True - actual = train_and_evaluate(self.args) - assert actual[0] < 0.36 and actual[1] > 0.84 - - @unittest.skipIf(not is_fp8_supported, fp8_reason) - def test_te_delayed_scaling_fp8_shardy(self): - """Test Transformer Engine with DelayedScaling FP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "DelayedScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.362 and actual[1] > 0.84 - - @unittest.skipIf(not is_fp8_supported, fp8_reason) - def test_te_delayed_scaling_fp8_with_sp_shardy(self): - """Test Transformer Engine with DelayedScaling FP8 + SP""" - self.args.enable_shardy = True - self.args.enable_sp = True - self.args.use_fp8 = True - self.args.fp8_recipe = "DelayedScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.362 and actual[1] > 0.84 - - @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - def test_te_mxfp8_shardy(self): - """Test Transformer Engine with MXFP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "MXFP8BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.36 and actual[1] > 0.84 - - @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) - def test_te_nvfp4_shardy(self): - """Test Transformer Engine with NVFP4""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "NVFP4BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.40 and actual[1] > 0.82 - - @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - def test_te_mxfp8_with_sp_shardy(self): - """Test Transformer Engine with MXFP8 + SP""" - self.args.enable_shardy = True - self.args.enable_sp = True - self.args.use_fp8 = True - self.args.fp8_recipe = "MXFP8BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.36 and actual[1] > 0.84 - - @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) - def test_te_nvfp4_with_sp_shardy(self): - """Test Transformer Engine with NVFP4""" - self.args.enable_shardy = True - self.args.enable_sp = True - self.args.use_fp8 = True - self.args.fp8_recipe = "NVFP4BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.40 and actual[1] > 0.82 - if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 22a89cc0a9..e2edc589b9 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -249,7 +249,6 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) - jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) num_gpu = jax.local_device_count() @@ -438,9 +437,6 @@ def encoder_parser(args): default="DelayedScaling", help="Use FP8 recipe (default: DelayedScaling)", ) - parser.add_argument( - "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." - ) return parser.parse_args(args) @@ -494,49 +490,6 @@ def test_te_nvfp4(self): actual = train_and_evaluate(self.args) assert actual[0] < 0.52 and actual[1] > 0.74 - @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - def test_te_bf16_shardy(self): - """Test Transformer Engine with BF16""" - self.args.enable_shardy = True - actual = train_and_evaluate(self.args) - assert actual[0] < 0.51 and actual[1] > 0.75 - - @unittest.skipIf(not is_fp8_supported, fp8_reason) - def test_te_delayed_scaling_fp8_shardy(self): - """Test Transformer Engine with DelayedScaling FP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "DelayedScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.51 and actual[1] > 0.75 - - @unittest.skipIf(not is_fp8_supported, fp8_reason) - def test_te_current_scaling_fp8_shardy(self): - """Test Transformer Engine with CurrentScaling FP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "Float8CurrentScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.51 and actual[1] > 0.749 - - @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - def test_te_mxfp8_shardy(self): - """Test Transformer Engine with MXFP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "MXFP8BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.51 and actual[1] > 0.75 - - @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) - def test_te_nvfp4_shardy(self): - """Test Transformer Engine with NVFP4""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "NVFP4BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 - if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 0166b60acd..344e7d618b 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -359,7 +359,6 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) - jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) if args.process_id == 0: nltk.download("punkt_tab") @@ -605,9 +604,6 @@ def encoder_parser(args): default=0, help="the ID number of the current process (default: 0)", ) - parser.add_argument( - "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." - ) return parser.parse_args(args) @@ -616,7 +612,7 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): + def exec(self, use_fp8, fp8_recipe): """Run 5 epochs for testing""" args = encoder_parser(["--epochs", "5"]) @@ -632,7 +628,6 @@ def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): args.num_process = num_gpu args.process_id = self.process_id args.fp8_recipe = fp8_recipe - args.enable_shardy = enable_shardy return train_and_evaluate(args) @@ -674,44 +669,6 @@ def test_te_nvfp4(self): result = self.exec(True, "NVFP4BlockScaling") assert result[0] < 0.451 and result[1] > 0.787 - @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - def test_te_bf16_shardy(self): - """Test Transformer Engine with BF16""" - result = self.exec(False, None, enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.80 - - @unittest.skipIf( - not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" - ) - def test_te_delayed_scaling_fp8_shardy(self): - """Test Transformer Engine with DelayedScaling FP8""" - result = self.exec(True, "DelayedScaling", enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.80 - - @unittest.skipIf( - not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" - ) - def test_te_current_scaling_fp8_shardy(self): - """Test Transformer Engine with CurrentScaling FP8""" - result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) - assert result[0] < 0.432 and result[1] > 0.80 - - @unittest.skipIf( - not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" - ) - def test_te_mxfp8_shardy(self): - """Test Transformer Engine with MXFP8""" - result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.80 - - @unittest.skipIf( - not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4" - ) - def test_te_nvfp4_shardy(self): - """Test Transformer Engine with NVFP4""" - result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True) - assert result[0] < 0.451 and result[1] > 0.787 - if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index d5ebe9f261..50c5de1db7 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -68,9 +68,7 @@ def impl_test_self_attn( attn_mask_type, dtype, softmax_type, - use_shardy, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) dropout_prob = 0.0 is_training = True batch, seqlen, num_head, hidden = data_shape @@ -178,48 +176,6 @@ def test_self_attn( attn_mask_type, dtype, softmax_type, - use_shardy=False, - ) - - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize( - "attn_bias_type, bias_shape", - [ - pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), - pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), - ], - ) - @pytest.mark.parametrize( - "softmax_type", - [ - pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"), - pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"), - pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"), - ], - ) - def test_self_attn_shardy( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - attn_bias_type, - bias_shape, - softmax_type, - ): - data_shape = (32, 512, 12, 64) - self.impl_test_self_attn( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - attn_bias_type, - bias_shape, - AttnMaskType.PADDING_MASK, - jnp.bfloat16, - softmax_type, - use_shardy=True, ) @@ -348,7 +304,6 @@ def impl_test_context_parallel_attn( qkv_layout, load_balanced, cp_strategy, - use_shardy, use_scan_ring=False, window_size=None, stripe_size=None, @@ -366,8 +321,6 @@ def impl_test_context_parallel_attn( os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1" else: os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" - - jax.config.update("jax_use_shardy_partitioner", use_shardy) attn_bias_type = AttnBiasType.NO_BIAS bias_shape = None dropout_prob = 0.0 @@ -452,45 +405,6 @@ def check_has_backend_for_mask(mask_type): runner.test_backward() del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] - @pytest_parametrize_wrapper( - "device_count,mesh_shape,mesh_axes,mesh_resource", - generate_context_parallel_configs_for_attn(), - ) - @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) - @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) - @pytest.mark.parametrize( - "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, - ) - def test_context_parallel_allgather_attn_shardy( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - attn_mask_type, - dtype, - qkv_layout, - ): - if qkv_layout.is_thd(): - pytest.skip("Only BSHD layout is supported for CP + AG + Dual chunk attention") - kv_groups = 8 - self.impl_test_context_parallel_attn( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - kv_groups, - attn_mask_type, - dtype, - qkv_layout, - load_balanced=True, - cp_strategy=CPStrategy.ALL_GATHER, - use_shardy=True, - ) - @pytest_parametrize_wrapper( "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs_for_attn(), @@ -551,7 +465,6 @@ def test_context_parallel_allgather_striped_attn( qkv_layout, load_balanced, CPStrategy.ALL_GATHER, - use_shardy=False, window_size=window_size, stripe_size=stripe_size, num_segments_per_seq=num_segments_per_seq, @@ -599,7 +512,6 @@ def test_context_parallel_allgather_attn( qkv_layout, load_balanced, CPStrategy.ALL_GATHER, - use_shardy=False, ) @pytest_parametrize_wrapper( @@ -664,53 +576,11 @@ def test_context_parallel_ring_attn( qkv_layout, load_balanced, CPStrategy.RING, - use_shardy=False, use_scan_ring=use_scan, window_size=window_size, stripe_size=stripe_size, ) - @pytest_parametrize_wrapper( - "device_count,mesh_shape,mesh_axes,mesh_resource", - generate_context_parallel_configs_for_attn(), - ) - @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) - @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) - @pytest.mark.parametrize( - "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, - ) - def test_context_parallel_ring_attn_shardy( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - attn_mask_type, - dtype, - qkv_layout, - ): - kv_groups = 8 - # Set the stripe size to 1 (ring attention only support stripe_size=1) - stripe_size = 1 if qkv_layout.is_thd() else None - self.impl_test_context_parallel_attn( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - kv_groups, - attn_mask_type, - dtype, - qkv_layout, - load_balanced=True, - cp_strategy=CPStrategy.RING, - use_shardy=False, - use_scan_ring=True, - stripe_size=stripe_size, - ) - REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = { "L0": [[]], diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index e9a2fa49e2..bb1f38dcc8 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -87,7 +87,6 @@ def generate_collectives_count_ref( @pytest_parametrize_wrapper("zero_centered_gamma", [False, True]) @pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_layernorm( self, device_count, @@ -99,9 +98,7 @@ def test_layernorm( zero_centered_gamma, shard_weights, fp8_recipe, - use_shardy, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) epsilon = 1e-6 ln_type = "layernorm" q_dtype = jnp.float8_e4m3fn @@ -178,7 +175,6 @@ def ref_func(x, gamma, beta): @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_rmsnorm( self, device_count, @@ -189,9 +185,7 @@ def test_rmsnorm( dtype, shard_weights, fp8_recipe, - use_shardy, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) epsilon = 1e-6 ln_type = "rmsnorm" q_dtype = jnp.float8_e4m3fn diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index d214597cb3..abf579d48e 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -192,10 +192,8 @@ def _test_layernorm_mlp_grad( input_shape, dtype, quantization_recipe, - use_shardy, with_jax_gemm, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" @@ -313,36 +311,6 @@ def test_layernorm_mlp_grad( dtype, quantization_recipe, with_jax_gemm, - ): - if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): - pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") - self._test_layernorm_mlp_grad( - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - quantization_recipe, - use_shardy=False, - with_jax_gemm=with_jax_gemm, - ) - - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) - @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_layernorm_mlp_grad_shardy( - self, - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - quantization_recipe, - with_jax_gemm, ): if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") @@ -353,7 +321,6 @@ def test_layernorm_mlp_grad_shardy( input_shape, dtype, quantization_recipe=quantization_recipe, - use_shardy=True, with_jax_gemm=with_jax_gemm, ) @@ -366,10 +333,8 @@ def _test_layernorm_mlp( dtype, use_fp8, quantization_recipe, - use_shardy, with_jax_gemm, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape layernorm_type = "rmsnorm" @@ -481,7 +446,6 @@ def test_layernorm_mlp_layer( dtype, use_fp8=False, quantization_recipe=None, - use_shardy=False, with_jax_gemm=with_jax_gemm, ) @@ -512,58 +476,5 @@ def test_layernorm_mlp_layer_fp8( dtype, use_fp8=True, quantization_recipe=quantization_recipe, - use_shardy=False, - with_jax_gemm=with_jax_gemm, - ) - - @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) - @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) - @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_layernorm_mlp_layer_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm - ): - self._test_layernorm_mlp( - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - use_fp8=False, - quantization_recipe=None, - use_shardy=True, - with_jax_gemm=with_jax_gemm, - ) - - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) - @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_layernorm_mlp_layer_fp8_shardy( - self, - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - quantization_recipe, - with_jax_gemm, - ): - if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): - pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") - self._test_layernorm_mlp( - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - use_fp8=True, - quantization_recipe=quantization_recipe, - use_shardy=True, with_jax_gemm=with_jax_gemm, ) diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 0665baa4e3..ca1dcf1174 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -87,12 +87,9 @@ def impl_test_softmax( dtype, bad_sharding, broadcast_batch_mask, - use_shardy, ): if broadcast_batch_mask and softmax_fusion_type != SoftmaxFusionType.SCALED_MASKED: pytest.skip("Softmax type has no mask.") - - jax.config.update("jax_use_shardy_partitioner", use_shardy) target_func = partial( self.target_func, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type ) @@ -181,35 +178,4 @@ def test_softmax( dtype, bad_sharding, broadcast_batch_mask, - use_shardy=True, - ) - - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize( - "softmax_fusion_type", [SoftmaxFusionType.SCALED, SoftmaxFusionType.SCALED_MASKED] - ) - @pytest.mark.parametrize("bad_sharding", [False, True]) - @pytest.mark.parametrize("broadcast_batch_mask", [False, True]) - def test_softmax_gspmd( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - softmax_fusion_type, - bad_sharding, - broadcast_batch_mask, - ): - self.impl_test_softmax( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape=[32, 12, 128, 128], - softmax_fusion_type=softmax_fusion_type, - scale_factor=1.0, - dtype=DTYPES[0], - bad_sharding=bad_sharding, - broadcast_batch_mask=broadcast_batch_mask, - use_shardy=False, ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 8c0edae97e..969e534c50 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -375,77 +375,6 @@ def batcher( out_bdims, ) - @staticmethod - def infer_sharding_from_operands( - out_dtype, - act_enum, - act_len, - scaling_mode, - quantize_layout, - scale_dtype, - act_params, - amax_scope, - transpose_batch_sequence, - output_amax_when_no_scaling, - is_outer, - mesh, - arg_infos, - result_infos, - ): - del ( - out_dtype, - result_infos, - act_enum, - scale_dtype, - act_len, - act_params, - amax_scope, - transpose_batch_sequence, - output_amax_when_no_scaling, - is_outer, - ) # Unused. - x_spec = get_padded_spec(arg_infos[0]) - scale_spec = get_padded_spec(arg_infos[1]) - - out_spec = (*x_spec[:-2], x_spec[-1]) - out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") - - if quantize_layout.is_rowwise_colwise: - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) - else: - colwise_out_spec = out_spec - else: - colwise_out_spec = (None,) - colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out" - ) - - scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: - scale_inv_spec = out_spec - - if quantize_layout.is_rowwise_colwise: - colwise_scale_inv_spec = scale_inv_spec - - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv" - ) - amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax") - colwise_scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv" - ) - - return ( - out_sharding, - colwise_out_sharding, - scale_inv_sharding, - colwise_scale_inv_sharding, - amax_sharding, - ) - @staticmethod def partition( out_dtype, @@ -898,86 +827,6 @@ def batcher( out_bdims, ) - @staticmethod - def infer_sharding_from_operands( - out_dtype, - scaling_mode, - quantize_layout, - scale_dtype, - is_dbias, - act_enum, - act_len, - act_params, - amax_scope, - transpose_batch_sequence, - output_amax_when_no_scaling, - is_outer, - mesh, - arg_infos, - result_infos, - ): - del out_dtype, result_infos, act_enum, act_params, output_amax_when_no_scaling - del scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence - - x_spec = get_padded_spec(arg_infos[1]) - scale_spec = get_padded_spec(arg_infos[2]) - - assert ( - scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value - ), "Partitioned current tensor scaling is not yet supported." - - out_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out" - ) - if quantize_layout.is_rowwise_colwise: - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) - else: - colwise_x_spec = x_spec - else: - colwise_x_spec = (None,) - colwise_out_sharding = NamedSharding( - mesh, - PartitionSpec(*colwise_x_spec), - desc="BaseDActLuDBiasQuantizePrimitive.colwise_out", - ) - - dbias_spec = x_spec[-2:] if is_dbias else (None,) - dbias_sharding = NamedSharding( - mesh, - PartitionSpec(*dbias_spec), - desc="BaseDActLuDBiasQuantizePrimitive.dbias", - ) - - scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: - scale_inv_spec = x_spec - - if quantize_layout.is_rowwise_colwise: - colwise_scale_inv_spec = scale_inv_spec - - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv" - ) - amax_sharding = NamedSharding( - mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax" - ) - colwise_scale_inv_sharding = NamedSharding( - mesh, - PartitionSpec(*colwise_scale_inv_spec), - desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv", - ) - return ( - out_sharding, - colwise_out_sharding, - scale_inv_sharding, - colwise_scale_inv_sharding, - amax_sharding, - dbias_sharding, - ) - @staticmethod def partition( out_dtype, diff --git a/transformer_engine/jax/cpp_extensions/amax.py b/transformer_engine/jax/cpp_extensions/amax.py index 700ba9061c..5f425a284a 100644 --- a/transformer_engine/jax/cpp_extensions/amax.py +++ b/transformer_engine/jax/cpp_extensions/amax.py @@ -96,25 +96,6 @@ def impl( amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) return amax - @staticmethod - def infer_sharding_from_operands( - amax_scope, - transpose_batch_sequence, - mesh, - arg_infos, - result_infos, - ): - """ - amax calcuation infer_sharding_from_operands - """ - del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused. - amax_sharding = NamedSharding( - mesh, - PartitionSpec(None), - desc="AmaxCalculationPrimitive.out_sharding", - ) - return amax_sharding - @staticmethod def partition( amax_scope, @@ -267,36 +248,6 @@ def impl( ) return amax, post_rht_amax - @staticmethod - def infer_sharding_from_operands( - amax_scope, - transpose_batch_sequence, - rht_matrix_random_sign_mask_t, - produce_regular_amax, - flatten_axis, - mesh, - arg_infos, - result_infos, - ): - """ - amax calcuation infer_sharding_from_operands - """ - del ( - amax_scope, - transpose_batch_sequence, - rht_matrix_random_sign_mask_t, - produce_regular_amax, - flatten_axis, - arg_infos, - result_infos, - ) # Unused. - amax_sharding = NamedSharding( - mesh, - PartitionSpec(None), - desc="RHTAmaxCalculationPrimitive.out_sharding", - ) - return amax_sharding, amax_sharding - @staticmethod def partition( amax_scope, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index e5d75e1501..02f42f0b7b 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -636,56 +636,6 @@ def batcher(batched_args, batch_dims, *, config): out_bdims, ) - @staticmethod - def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): - del result_infos - q_spec = get_padded_spec(arg_infos[0]) - - # when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+ - # otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments) - is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd() - - if config.qkv_layout.is_qkvpacked(): - # q_spec = (...batch, q_seqlen, 3, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) - if not is_packed_softmax: - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None) - ) - else: - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-4], q_spec[-4], q_spec[-2], None) - ) - elif config.qkv_layout.is_kvpacked(): - # q_spec = (...batch, q_seqlen, head, hidden) - # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - if not is_packed_softmax: - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) - ) - else: - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None) - ) - elif config.qkv_layout.is_separate(): - # q_spec = (...batch, q_seqlen, head, hidden) - # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - if not is_packed_softmax: - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) - ) - else: - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None) - ) - else: - raise ValueError(f"Unsupported {config.qkv_layout=}") - - rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) - return (out_sharding, softmax_aux_sharding, rng_state_sharding) - @staticmethod def partition(config, mesh, arg_infos, result_infos): out_sharding = result_infos[0].sharding @@ -706,7 +656,6 @@ def partition(config, mesh, arg_infos, result_infos): def shardy_sharding_rule(config, mesh, value_types, result_types): del mesh, result_types - # Keep in sync with `infer_sharding_from_operands`. # We only need the first input. Fill up the rest with placeholders. input_spec = [(f"…{x}",) for x in range(len(value_types))] # The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint @@ -1091,21 +1040,6 @@ def batcher(batched_args, batch_dims, *, config): out_bdims, ) - @staticmethod - def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): - del config, result_infos - q_spec = get_padded_spec(arg_infos[0]) - k_spec = get_padded_spec(arg_infos[1]) - v_spec = get_padded_spec(arg_infos[2]) - bias_spec = get_padded_spec(arg_infos[3]) - softmax_offset_spec = get_padded_spec(arg_infos[4]) - dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) - dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) - dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec)) - return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding, dsoftmax_offset_sharding) - @staticmethod def partition(config, mesh, arg_infos, result_infos): del result_infos @@ -1187,7 +1121,6 @@ def sharded_impl( @staticmethod def shardy_sharding_rule(config, mesh, value_types, result_types): del config, mesh - # Keep in sync with `infer_sharding_from_operands`. input_spec = tuple((f"…{x}",) for x in range(len(value_types))) output_spec = tuple((f"…{x}",) for x in range(len(result_types))) return SdyShardingRule(input_spec, output_spec) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index b26e01c0c7..42fabeb996 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -143,14 +143,6 @@ def batcher(): """ return NotImplemented - @staticmethod - @abstractmethod - def infer_sharding_from_operands(): - """ - to describe infer_sharding_from_operands for custom_partitioning - """ - return NotImplemented - @staticmethod @abstractmethod def partition(): @@ -209,7 +201,6 @@ def name_of_wrapper_p(): batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) outer_p_lower.def_partition( - infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition, sharding_rule=cls.shardy_sharding_rule, ) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 71f133bfc4..887f2dcb95 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -989,51 +989,6 @@ def _parse_operand_output_specs( sequence_dim, ) - @staticmethod - def infer_sharding_from_operands( - out_dtype, - contracting_dims, - scaling_mode, - fuse_bias, - fuse_gelu, - grad, - use_split_accumulator, - transpose_batch_sequence, - sequence_dim, - is_outer, - collective_op, - mesh, - arg_infos, - result_infos, - ): - del ( - out_dtype, - scaling_mode, - use_split_accumulator, - result_infos, - is_outer, - sequence_dim, - ) - - (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( - GemmPrimitive._parse_operand_output_specs( - arg_infos, contracting_dims, transpose_batch_sequence, collective_op - ) - ) - out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) - - # Discard dbias gradient spec if there is no bias and grad fusion - if not (fuse_bias and grad): - dbias_specs = (None,) - dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs)) - - # Discard pre-GeLU output spec if there is no GeLU fusion - if not fuse_gelu: - pre_gelu_specs = (None,) - pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)) - - return [out_sharding, dbias_sharding, pre_gelu_sharding] - @staticmethod def partition( out_dtype, diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 70fdf4c474..da5f34359c 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -426,68 +426,6 @@ def batcher( out_bdims, ) - @staticmethod - def infer_sharding_from_operands( - norm_type, - zero_centered_gamma, - epsilon, - out_dtype, - scaling_mode, - quantize_layout, - scale_dtype, - amax_scope, - transpose_batch_sequence, - output_amax_when_no_scaling, - is_outer, - mesh, - arg_infos, - result_infos, - ): - del zero_centered_gamma, epsilon, out_dtype, result_infos - del scale_dtype, is_outer, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling - x_spec = get_padded_spec(arg_infos[0]) - scale_spec = get_padded_spec(arg_infos[1]) - amax_spec = get_padded_spec(arg_infos[2]) - out_spec = (*x_spec[:-1], None) - if x_spec[-1] is not None: - warnings.warn( - f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " - "Force to not shard the hidden dim, which might introduce extra collective ops, " - "and hurt performance." - ) - - out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") - colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,) - colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" - ) - rsigma_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" - ) - mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) - mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") - - scale_inv_spec = (None,) - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = scale_spec - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: - scale_inv_spec = out_spec - - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" - ) - amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax") - output = ( - out_sharding, - colwise_out_sharding, - scale_inv_sharding, # rowwise - scale_inv_sharding, # colwise - amax_sharding, - mu_sharding, - rsigma_sharding, - ) - return output - @staticmethod def partition( norm_type, @@ -801,32 +739,6 @@ def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma): out_bdims, ) - @staticmethod - def infer_sharding_from_operands(norm_type, zero_centered_gamma, mesh, arg_infos, result_infos): - del norm_type, zero_centered_gamma, result_infos - x_spec = get_padded_spec(arg_infos[1]) - if x_spec[-1] is not None: - warnings.warn( - f"Does not support to shard hidden dim in {NormBwdPrimitive.name}! " - "Force to not shard the hidden dim, which might introduce extra collective ops, " - "and hurt performance." - ) - g_b_spec = get_padded_spec(arg_infos[4]) - if g_b_spec[-1] is not None: - warnings.warn( - f"{NormBwdPrimitive.name} does not support sharding of gradients " - "of gamma and beta of " - "Enforcing no sharding of parameters hidden dim! " - ) - - dx_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-1], None), desc="NormBwdPrimitive.dx" - ) - dgamma_sharding = dbeta_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="NormBwdPrimitive.dgamma" - ) - return dx_sharding, dgamma_sharding, dbeta_sharding - @staticmethod def partition(norm_type, zero_centered_gamma, mesh, arg_infos, result_infos): del result_infos diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 1fcecb0e96..e2459abd7e 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -391,93 +391,6 @@ def batcher( out_bdims, ) - @staticmethod - def infer_sharding_from_operands( - out_dtype, - scaling_mode, - q_layout, - flatten_axis, - scale_dtype, - is_dbias, - is_outer, - stochastic_rounding, - use_rht, - mesh, - arg_infos, - result_infos, - ): - del ( - out_dtype, - result_infos, - scale_dtype, - is_outer, - stochastic_rounding, - use_rht, - ) # Unused. - - x_spec = get_padded_spec(arg_infos[0]) - amax_spec = get_padded_spec(arg_infos[2]) - out_sharding = NamedSharding( - mesh, - PartitionSpec(*x_spec), - desc="BaseDBiasQuantizePrimitive.out_sharding", - ) - if q_layout.has_colwise: - if ScalingMode(scaling_mode).is_colwise_transposed: - colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) - else: - colwise_out_spec = x_spec - else: - colwise_out_spec = (None,) - colwise_out_sharding = NamedSharding( - mesh, - PartitionSpec(*colwise_out_spec), - desc="BaseDBiasQuantizePrimitive.colwise_out_sharding", - ) - - dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,) - dbias_sharding = NamedSharding( - mesh, - PartitionSpec(*dbias_spec), - desc="BaseDBiasQuantizePrimitive.dbias_sharding", - ) - - scale_inv_spec = colwise_scale_inv_spec = (None,) - if ScalingMode(scaling_mode).is_block_scaling: - scale_inv_spec = x_spec - - if q_layout.has_colwise: - if ( - ScalingMode(scaling_mode).is_block_scaling - and ScalingMode(scaling_mode).is_colwise_transposed - ): - colwise_scale_inv_spec = multidim_transpose( - scale_inv_spec, transpose_axis=flatten_axis - ) - else: - colwise_scale_inv_spec = scale_inv_spec - - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" - ) - colwise_scale_inv_sharding = NamedSharding( - mesh, - PartitionSpec(*colwise_scale_inv_spec), - desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", - ) - amax_sharding = NamedSharding( - mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax" - ) - - return ( - out_sharding, - colwise_out_sharding, - scale_inv_sharding, - colwise_scale_inv_sharding, - amax_sharding, - dbias_sharding, - ) - @staticmethod def partition( out_dtype, diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index ff30c9bba3..1aa8531a75 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -136,22 +136,6 @@ def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor): out_bdims = logits_bdim return primitive.bind(logits, scale_factor=scale_factor), out_bdims - @classmethod - def forward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos): - """ - softmax_forward infer_sharding_from_operands - """ - del scale_factor, result_infos # Unused. - logits_spec = get_padded_spec(arg_infos[0]) - if logits_spec[-1] is not None: - warnings.warn( - f"Sharding the hidden dimension is not supported in {cls.name}! " - "Forcing XLA to not shard the hidden dim, which might introduce extra " - "collective ops and hurt performance." - ) - out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) - return out_sharding - @classmethod def forward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos): """ @@ -216,22 +200,6 @@ def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor): out_bdims = softmax_out_bdim return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims - @classmethod - def backward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos): - """ - softmax_backward infer_sharding_from_operands - """ - del scale_factor, result_infos # Unused. - dz_spec = get_padded_spec(arg_infos[0]) - if dz_spec[-1] is not None: - warnings.warn( - f"Sharding the hidden dimension is not supported in {cls.name}! " - "Forcing XLA to not shard the hidden dim, which might introduce extra " - "collective ops and hurt performance." - ) - dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None)) - return dx_sharding - @classmethod def backward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos): """ @@ -320,12 +288,6 @@ def batcher(batched_args, batch_dims, *, scale_factor): scale_factor=scale_factor, ) - @staticmethod - def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): - return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) - @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledSoftmaxFwdPrimitive.forward_partition( @@ -395,12 +357,6 @@ def batcher(batched_args, batch_dims, *, scale_factor): scale_factor=scale_factor, ) - @staticmethod - def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): - return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) - @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledSoftmaxBwdPrimitive.backward_partition( @@ -525,12 +481,6 @@ def batcher(batched_args, batch_dims, *, scale_factor): out_bdims, ) - @staticmethod - def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): - return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) - @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxFwdPrimitive.backward_partition( @@ -601,12 +551,6 @@ def batcher(batched_args, batch_dims, *, scale_factor): scale_factor=scale_factor, ) - @staticmethod - def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): - return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) - @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxBwdPrimitive.backward_partition( @@ -688,12 +632,6 @@ def batcher(batched_args, batch_dims, *, scale_factor): scale_factor=scale_factor, ) - @staticmethod - def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): - return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) - @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition( @@ -772,12 +710,6 @@ def batcher(batched_args, batch_dims, *, scale_factor): scale_factor=scale_factor, ) - @staticmethod - def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): - return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) - @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(