Skip to content

Conversation

@gulsumgudukbay
Copy link
Collaborator

Description

This PR is the second part of the decoupling support. It adds logic for decoupling support, along with some test modifications for decoupling to be enabled.

Details:

  1. Update decoupled_base_test.yml
  2. Add decoupling locig to src/MaxText/decode.py, src/MaxText/elastic_train.py, src/MaxText/experimental/rl/grpo_trainer.py, src/MaxText/gcp_workload_monitor.py, src/MaxText/max_utils.py, src/MaxText/maxengine.py, src/MaxText/maxengine_config.py, src/MaxText/maxengine_server.py, src/MaxText/metric_logger.py, src/MaxText/prefill_packing.py, src/MaxText/profiler.py, src/MaxText/sft/hooks.py, src/MaxText/sft/sft_trainer.py, src/MaxText/train.py, src/MaxText/utils/gcs_utils.py, src/MaxText/utils/goodput_utils.py, src/MaxText/vertex_tensorboard.py
  3. Update src/MaxText/gcloud_stub.py to add IS_STUB variables, and add google_cloud_mldiagnostics stub
  4. Update tests to support decoupled mode (add markers, update file paths, make them use decoupled_base_test.yml config file).

Tests

All unit tests pass in decoupled mode.
UT results:
== 306 passed, 170 skipped, 25 deselected, 6588 warnings in 975.16s (0:16:15) ==

Train test:
python -m MaxText.train MaxText/configs/base.yml run_name=test hardware=gpu steps=5 model_name=llama2-7b attention=cudnn_flash_te enable_checkpointing=False ici_expert_parallelism=1 ici_fsdp_parallelism=-1 ici_data_parallelism=1 remat_policy=minimal scan_layers=True dataset_type=synthetic logits_dot_in_fp32=False dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_target_length=2048 shardy=False

works.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

gulsumgudukbay and others added 25 commits December 21, 2025 06:16
(cherry picked from commit e8cc951)
(cherry picked from commit 0b58e96)
(cherry picked from commit 14f0508)
(cherry picked from commit e43e370)
(cherry picked from commit 1c14d6c)
…ck, todo: remove this after updating jax. Configure ICI data parallelism for decoupled mode
__init__.py Outdated
@@ -0,0 +1,14 @@
"""Top-level shim for importing test_utils
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

License header please.

__init__.py Outdated
@@ -0,0 +1,14 @@
"""Top-level shim for importing test_utils
This shim lets test modules import `maxtext.tests`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this __init__.py be moved inside /tests, since it contains the logic relevant for testing only?


def test_tiny_config(self):
test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable
dataset_path = get_test_dataset_path()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be moved to a setUp() method, and it will be called automatically before each individual test methods.

decoupled = is_decoupled()
dataset_path = get_test_dataset_path()
base_output_directory = (
os.environ.get("LOCAL_BASE_OUTPUT", get_test_base_output_directory())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we use this if logic in train_smoke_test.py?

from MaxText.globals import MAXTEXT_PKG_DIR
from maxtext.tests.test_utils import get_test_config_path

pytestmark = [pytest.mark.tpu_only]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are suppose to run on CPUs. Why are we adding tpu_only marker?

from MaxText.input_pipeline import input_pipeline_interface
from maxtext.tests.test_utils import get_test_config_path, get_test_dataset_path, get_test_base_output_directory

MAXTEXT_ASSETS_ROOT = os.path.join("src", MAXTEXT_PKG_DIR, "assets")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_defects: list[str] = []


def _import(name: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is not called anywhere, I guess?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants