diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index c335d4e2..d6763cae 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -97,10 +97,13 @@ def _create_identity_lora(self, lora_path: str) -> None: from transformers import AutoModelForCausalLM lora_config = self._default_lora_adapter_config() + # Load on CPU to avoid claiming GPU memory — this runs before vLLM + # starts, and we only need the model to generate adapter config/weights + # on disk, not for any GPU computation. model = AutoModelForCausalLM.from_pretrained( self.base_model, torch_dtype=torch.bfloat16, - device_map="auto", + device_map="cpu", trust_remote_code=True, ) peft_model = get_peft_model(model, lora_config) @@ -111,9 +114,8 @@ def _create_identity_lora(self, lora_path: str) -> None: os.makedirs(lora_path, exist_ok=True) peft_model.save_pretrained(lora_path) del peft_model, model - if torch.cuda.is_available(): - torch.cuda.synchronize() - torch.cuda.empty_cache() + import gc + gc.collect() def _ensure_identity_lora(self, lora_path: str) -> None: if self._adapter_has_weights(lora_path): @@ -186,8 +188,9 @@ async def _ensure_megatron_running(self) -> None: num_gpus = torch.cuda.device_count() os.environ["MODEL_IDENTIFIER"] = self.base_model + runner = "torchrun" if os.environ.get("RL_DOCKER") == "1" else "uv run torchrun" command = ( - f"{setup_cmd}uv run torchrun --nproc_per_node {num_gpus} {train_script}" + f"{setup_cmd}{runner} --nproc_per_node {num_gpus} {train_script}" ) self._megatron_process = await asyncio.create_subprocess_shell(command) @@ -236,7 +239,7 @@ async def train( self._is_sleeping = True gc_and_empty_cuda_cache() - # Start Megatron after vLLM has freed GPU memory. + # Start Megatron on non-vLLM GPUs (vLLM keeps GPU 0 reserved). await self._ensure_megatron_running() lora_path = get_last_checkpoint_dir(self.output_dir) @@ -292,15 +295,27 @@ async def train( ) self._ensure_lora_adapter_config(new_checkpoint_dir, source_path=lora_path) - wake_lock_path = "/tmp/megatron_vllm_waking" - try: - with open(wake_lock_path, "w") as lock_file: - lock_file.write("waking vllm\n") + if os.environ.get("MEGATRON_EXIT_AFTER_JOB") == "1": + # Wait for the Megatron process to fully exit so all GPU memory + # (including CUDA context / NCCL buffers) is released before vLLM wakes. + if self._megatron_process is not None: + await self._megatron_process.wait() + self._megatron_process = None await run_on_workers(llm, do_wake_up) self._is_sleeping = False - finally: - if os.path.exists(wake_lock_path): - os.remove(wake_lock_path) + else: + # Megatron stays alive (offloaded to CPU). Use a wake lock so it + # waits for vLLM to finish reclaiming GPU memory before starting + # the next job. + wake_lock_path = "/tmp/megatron_vllm_waking" + try: + with open(wake_lock_path, "w") as lock_file: + lock_file.write("waking vllm\n") + await run_on_workers(llm, do_wake_up) + self._is_sleeping = False + finally: + if os.path.exists(wake_lock_path): + os.remove(wake_lock_path) await self._add_lora_aliases(llm, next_step, new_checkpoint_dir) await llm.resume_generation() diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index f1083f37..f567f7ce 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -340,3 +340,10 @@ def calculate_mask( with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: log_file.write("all done\n") shutil.rmtree(job.disk_packed_tensors["dir"]) + if os.environ.get("MEGATRON_EXIT_AFTER_JOB") == "1": + # Exit after each job so that all GPU memory (including CUDA context) + # is fully freed before vLLM wakes up. The service restarts us for + # the next step. + torch.distributed.barrier() + torch.distributed.destroy_process_group() + break