Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions src/art/megatron/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,13 @@ def _create_identity_lora(self, lora_path: str) -> None:
from transformers import AutoModelForCausalLM

lora_config = self._default_lora_adapter_config()
# Load on CPU to avoid claiming GPU memory — this runs before vLLM
# starts, and we only need the model to generate adapter config/weights
# on disk, not for any GPU computation.
model = AutoModelForCausalLM.from_pretrained(
self.base_model,
torch_dtype=torch.bfloat16,
device_map="auto",
device_map="cpu",
trust_remote_code=True,
)
peft_model = get_peft_model(model, lora_config)
Expand All @@ -111,9 +114,8 @@ def _create_identity_lora(self, lora_path: str) -> None:
os.makedirs(lora_path, exist_ok=True)
peft_model.save_pretrained(lora_path)
del peft_model, model
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
import gc
gc.collect()

def _ensure_identity_lora(self, lora_path: str) -> None:
if self._adapter_has_weights(lora_path):
Expand Down Expand Up @@ -186,8 +188,9 @@ async def _ensure_megatron_running(self) -> None:
num_gpus = torch.cuda.device_count()
os.environ["MODEL_IDENTIFIER"] = self.base_model

runner = "torchrun" if os.environ.get("RL_DOCKER") == "1" else "uv run torchrun"
command = (
f"{setup_cmd}uv run torchrun --nproc_per_node {num_gpus} {train_script}"
f"{setup_cmd}{runner} --nproc_per_node {num_gpus} {train_script}"
)
self._megatron_process = await asyncio.create_subprocess_shell(command)

Expand Down Expand Up @@ -236,7 +239,7 @@ async def train(
self._is_sleeping = True
gc_and_empty_cuda_cache()

# Start Megatron after vLLM has freed GPU memory.
# Start Megatron on non-vLLM GPUs (vLLM keeps GPU 0 reserved).
await self._ensure_megatron_running()

lora_path = get_last_checkpoint_dir(self.output_dir)
Expand Down Expand Up @@ -292,15 +295,27 @@ async def train(
)
self._ensure_lora_adapter_config(new_checkpoint_dir, source_path=lora_path)

wake_lock_path = "/tmp/megatron_vllm_waking"
try:
with open(wake_lock_path, "w") as lock_file:
lock_file.write("waking vllm\n")
if os.environ.get("MEGATRON_EXIT_AFTER_JOB") == "1":
# Wait for the Megatron process to fully exit so all GPU memory
# (including CUDA context / NCCL buffers) is released before vLLM wakes.
if self._megatron_process is not None:
await self._megatron_process.wait()
self._megatron_process = None
await run_on_workers(llm, do_wake_up)
self._is_sleeping = False
finally:
if os.path.exists(wake_lock_path):
os.remove(wake_lock_path)
else:
# Megatron stays alive (offloaded to CPU). Use a wake lock so it
# waits for vLLM to finish reclaiming GPU memory before starting
# the next job.
wake_lock_path = "/tmp/megatron_vllm_waking"
try:
with open(wake_lock_path, "w") as lock_file:
lock_file.write("waking vllm\n")
await run_on_workers(llm, do_wake_up)
self._is_sleeping = False
finally:
if os.path.exists(wake_lock_path):
os.remove(wake_lock_path)

await self._add_lora_aliases(llm, next_step, new_checkpoint_dir)
await llm.resume_generation()
Expand Down
7 changes: 7 additions & 0 deletions src/art/megatron/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,10 @@ def calculate_mask(
with open("/tmp/megatron_training_log.jsonl", "a+") as log_file:
log_file.write("all done\n")
shutil.rmtree(job.disk_packed_tensors["dir"])
if os.environ.get("MEGATRON_EXIT_AFTER_JOB") == "1":
# Exit after each job so that all GPU memory (including CUDA context)
# is fully freed before vLLM wakes up. The service restarts us for
# the next step.
torch.distributed.barrier()
torch.distributed.destroy_process_group()
break