Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions examples/jax/collective_gemm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,6 @@ def _initialize_distributed(args):
)

_distributed_initialized = True
jax.clear_caches()
jax.config.update(
"jax_use_shardy_partitioner", False
) # CollectiveGEMM does not work with Shardy yet

assert jax.local_device_count() == 1, (
f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found"
Expand Down
2 changes: 0 additions & 2 deletions examples/jax/collective_gemm/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_shard
def run_gemm_tests(args, mesh=None):
"""Execute GEMM tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)

# Initialize distributed with provided arguments
_initialize_distributed(args)
Expand Down
2 changes: 0 additions & 2 deletions examples/jax/collective_gemm/test_layernorm_mlp_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def _value_and_grad_layernorm_mlp(
def run_layernorm_mlp_grad_tests(args, mesh=None):
"""Execute Dense Gradient tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)

# Initialize distributed with provided arguments
_initialize_distributed(args)
Expand Down
4 changes: 0 additions & 4 deletions examples/jax/encoder/run_test_multiprocessing_encoder.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ TEST_CASES=(
"test_te_current_scaling_fp8"
"test_te_mxfp8"
"test_te_nvfp4"
"test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy"
"test_te_current_scaling_fp8_shardy"
"test_te_nvfp4_shardy"
)

: ${TE_PATH:=/opt/transformerengine}
Expand Down
68 changes: 0 additions & 68 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def check_fp8(state, var_collect, inputs, masks, labels):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)

train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)

Expand Down Expand Up @@ -474,9 +473,6 @@ def encoder_parser(args):
parser.add_argument(
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)

return parser.parse_args(args)

Expand Down Expand Up @@ -559,70 +555,6 @@ def test_te_nvfp4_with_sp(self):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.36 and actual[1] > 0.84

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.362 and actual[1] > 0.84

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.362 and actual[1] > 0.84

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.36 and actual[1] > 0.84

@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp_shardy(self):
"""Test Transformer Engine with MXFP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.36 and actual[1] > 0.84

@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_with_sp_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82


if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
47 changes: 0 additions & 47 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ def replace_params(x):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)

num_gpu = jax.local_device_count()
Expand Down Expand Up @@ -438,9 +437,6 @@ def encoder_parser(args):
default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)",
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)

return parser.parse_args(args)

Expand Down Expand Up @@ -494,49 +490,6 @@ def test_te_nvfp4(self):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.75

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.75

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.749

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.75

@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74


if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
45 changes: 1 addition & 44 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ def replace_params(x):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
if args.process_id == 0:
nltk.download("punkt_tab")

Expand Down Expand Up @@ -605,9 +604,6 @@ def encoder_parser(args):
default=0,
help="the ID number of the current process (default: 0)",
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)

return parser.parse_args(args)

Expand All @@ -616,7 +612,7 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
def exec(self, use_fp8, fp8_recipe):
"""Run 5 epochs for testing"""
args = encoder_parser(["--epochs", "5"])

Expand All @@ -632,7 +628,6 @@ def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
args.num_process = num_gpu
args.process_id = self.process_id
args.fp8_recipe = fp8_recipe
args.enable_shardy = enable_shardy

return train_and_evaluate(args)

Expand Down Expand Up @@ -674,44 +669,6 @@ def test_te_nvfp4(self):
result = self.exec(True, "NVFP4BlockScaling")
assert result[0] < 0.451 and result[1] > 0.787

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
result = self.exec(False, None, enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.432 and result[1] > 0.80

@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(
not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4"
)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True)
assert result[0] < 0.451 and result[1] > 0.787


if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
Loading
Loading