-
Notifications
You must be signed in to change notification settings - Fork 447
[DECOUPLED-MODE] Adding Decoupling Logic #2865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
(cherry picked from commit e8cc951)
(cherry picked from commit 0b58e96)
(cherry picked from commit 14f0508)
…ts library (cherry picked from commit 6f0b361)
(cherry picked from commit e43e370)
(cherry picked from commit 1c14d6c)
…ck, todo: remove this after updating jax. Configure ICI data parallelism for decoupled mode
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
| train_main( | ||
| [ | ||
| None, | ||
| os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from maxtext.tests.test_utils import get_test_config_path is missing from several test files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from maxtext.tests.test_utils import get_test_config_pathis missing from several test files.
Hi @SurbhiJainUSC thanks for the comment, I am trying to fix that issue and the linting issues. Once I have all those setup, I will let you know so you can look at a more developed version of the PR.
| def test_tiny_config(self): | ||
| test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable | ||
| decoupled = is_decoupled() | ||
| dataset_path = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this logic to test_utils.py and reuse it in every other tests?
tests/test_env_smoke.py
Outdated
| @@ -0,0 +1,69 @@ | |||
| """Pytest-based environment smoke test for MaxText (used esp for decoupling testing). | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add license header.
| return name, None, time.time() - t0, e | ||
|
|
||
|
|
||
| def test_environment_core_imports(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests can be simplified using parameterization. For example,
@pytest.mark.parametrize("name", CORE_IMPORTS)
def test_environment_core_imports(name):
importlib.import_module(name)
| skip_jax_distributed_system=True, | ||
| ) | ||
| self.mesh = Mesh(create_device_mesh(self.config), self.config.mesh_axes) | ||
| # Use a synthetic dataset for unit tests only when running in decoupled mode so |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SFT tests can be marked as external_training.
|
|
||
| from orbax import checkpoint as ocp | ||
|
|
||
| from tunix.sft import metrics_logger, peft_trainer, profiler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metrics_logger is used at line 83
|
|
||
| if not self.config.using_pipeline_parallelism: | ||
| sharding.assert_params_sufficiently_sharded(params, self.mesh, self.config.sharding_tolerance) | ||
| maxtext_utils.assert_params_sufficiently_sharded(params, self.mesh, self.config.sharding_tolerance) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is deprecated: https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/maxtext_utils.py#L56
| from cloud_tpu_diagnostics.configuration import debug_configuration | ||
| from cloud_tpu_diagnostics.configuration import diagnostic_configuration | ||
| from cloud_tpu_diagnostics.configuration import stack_trace_configuration | ||
| from MaxText.gcloud_stub import cloud_diagnostics as _cloud_diag, vertex_tensorboard_components, is_decoupled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the legacy RL trainer. Is there a need to modify this to work with decoupled mode?
Description
This PR is the second part of the decoupling support. It adds logic for decoupling support, along with some test modifications for decoupling to be enabled.
Details:
Tests
All unit tests pass in decoupled mode.
UT results:
== 306 passed, 170 skipped, 25 deselected, 6588 warnings in 975.16s (0:16:15) ==
Train test:
python -m MaxText.train MaxText/configs/base.yml run_name=test hardware=gpu steps=5 model_name=llama2-7b attention=cudnn_flash_te enable_checkpointing=False ici_expert_parallelism=1 ici_fsdp_parallelism=-1 ici_data_parallelism=1 remat_policy=minimal scan_layers=True dataset_type=synthetic logits_dot_in_fp32=False dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_target_length=2048 shardy=False
works.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.