diff --git a/README_Training.md b/README_Training.md index 86df63f..192ad45 100644 --- a/README_Training.md +++ b/README_Training.md @@ -8,20 +8,49 @@ uv run src/build_dataset.py --output ../data/ ## Train Model +Training is driven by Hydra configs under `config/` (see `config/base.yaml` defaults). Run `src.train` and select an experiment via `experiment=...`. +Most configuration options follow the SkyRL config schema; see [SkyRL configuration docs](https://skyrl.readthedocs.io/en/latest/configuration/config.html) for details. +See [Hydra docs](https://hydra.cc/docs/intro/) for more details. + +### Minimal for local testing + +``` +uv run -m src.train model=Qwen/Qwen3-0.6B +``` + +### Async training + +``` +uv run -m src.train model=Qwen/Qwen3-4B training=async +``` + +### Async training with LoRA experiment config + ``` -bash scripts run_training.sh -m Qwen/Qwen3-0.6B -d +uv run -m src.train model=Qwen/Qwen3-4B training=async experiment=lora ``` +### Multirun +Note: the `none` experiment is a placeholder to allow multiruns including the baseline setting. + ``` -DATA_PATH= -bash scripts/run_async_training.sh -m Qwen/Qwen3-4B -d $DATA_PATH 2>&1 | tee training.log +uv run -m src.train -m model=Qwen/Qwen3-4B training=async experiment=none,lora ``` +### Slurm example (Submitit launcher) + ``` -DATA_PATH= -bash scripts/run_async_training.sh \ - -m Qwen/Qwen3-4B \ - -o "+generator.exp_config=configs/skyrl-experiments/read-only.yaml" \ - -d $DATA_PATH \ - 2>&1 | tee training.log -``` \ No newline at end of file +uv run -m src.train model=Qwen/Qwen3-4B training=async platform=babel hydra/launcher=slurm +``` + +### Add a new experiment + +Create `config/experiment/my_exp.yaml`: + +```yaml +# @package _global_ +trainer: + logger: wandb +``` + +Then run with `experiment=my_exp`. \ No newline at end of file diff --git a/config/base.yaml b/config/base.yaml new file mode 100644 index 0000000..d5d0fe6 --- /dev/null +++ b/config/base.yaml @@ -0,0 +1,93 @@ +defaults: + - ppo_base_config + - training: sync + - platform: local + - experiment: none + - override hydra/launcher: local + - _self_ + +hydra: + searchpath: + - pkg://skyrl_train.config + run: + dir: ${oc.env:PWD}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: ${oc.env:PWD}/outputs/${now:%Y-%m-%d} + subdir: ${now:%H-%M-%S}/${hydra.job.num} + +model: Qwen/Qwen3-4B +model_alias: ${replace:${model},/,-} +ckpt_path: ${oc.env:PWD}/ckpts/${model_alias} +data_path: data/swe_smith + +data: + train_data: + - ${data_path}/train.parquet + val_data: + - ${data_path}/validation.parquet + +trainer: + project_name: code_search + run_name: code_search_${model_alias} + logger: wandb + ckpt_path: ${ckpt_path} + + algorithm: + advantage_estimator: grpo + + placement: + policy_num_nodes: 1 + ref_num_nodes: 1 + policy_num_gpus_per_node: ${platform.gpus_per_node} + ref_num_gpus_per_node: ${platform.gpus_per_node} + + policy: + model: + path: ${model} + + eval_before_train: false + epochs: 20 + eval_batch_size: 4 + eval_interval: 100 + update_epochs_per_batch: 1 + dump_data_batch: true + export_path: ${ckpt_path}/exported_model/ + hf_save_interval: 5 + ckpt_interval: 5 + max_prompt_length: 4096 + +generator: + backend: vllm + run_engines_locally: true + + enable_http_endpoint: true + http_endpoint_host: "0.0.0.0" + http_endpoint_port: 8080 + weight_sync_backend: nccl + + async_engine: true + inference_engine_tensor_parallel_size: 1 + num_inference_engines: 1 + + traj_dir: ${ckpt_path}/trajectories/ + + engine_init_kwargs: + enable_auto_tool_choice: true + tool_call_parser: hermes + reasoning_parser: qwen3 + + sampling_params: + temperature: 1.0 + + max_input_length: 24000 + max_num_batched_tokens: 48000 + max_turns: 20 + + prompts: + system_prompt: "templates/system_prompt.j2" + user_prompt: "templates/file_localization.j2" + reward: + - fn: tool_use_reward + - fn: turn_efficiency + tools: + - terminal diff --git a/config/experiment/lora.yaml b/config/experiment/lora.yaml new file mode 100644 index 0000000..c0d4430 --- /dev/null +++ b/config/experiment/lora.yaml @@ -0,0 +1,11 @@ +# @package _global_ +trainer: + policy: + model: + lora: + rank: 32 + alpha: 32 + dropout: 0 + lora_sync_path: "/tmp/skyrl_lora_sync" + target_modules: "all-linear" + exclude_modules: null \ No newline at end of file diff --git a/config/experiment/none.yaml b/config/experiment/none.yaml new file mode 100644 index 0000000..e69de29 diff --git a/config/hydra/launcher/local.yaml b/config/hydra/launcher/local.yaml new file mode 100644 index 0000000..b320f64 --- /dev/null +++ b/config/hydra/launcher/local.yaml @@ -0,0 +1,12 @@ +# @package hydra.launcher +_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher + +submitit_folder: ${hydra.sweep.dir}/.submitit/%j + +timeout_min: 60 +cpus_per_task: 1 +gpus_per_node: 1 +tasks_per_node: 1 +mem_gb: 4 +nodes: 1 +name: ${hydra.job.name} diff --git a/config/hydra/launcher/slurm.yaml b/config/hydra/launcher/slurm.yaml new file mode 100644 index 0000000..bb9ca04 --- /dev/null +++ b/config/hydra/launcher/slurm.yaml @@ -0,0 +1,16 @@ +# @package hydra.launcher +_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher + +submitit_folder: ${hydra.sweep.dir}/.submitit/%j + +name: cso +partition: general +nodes: 1 +tasks_per_node: 1 +cpus_per_task: 32 +mem_gb: 512 +timeout_min: 2880 # 2-00:00:00 +exclude: babel-q5-28,babel-o5-20 +additional_parameters: + gres: gpu:A100:2 + diff --git a/config/platform/babel.yaml b/config/platform/babel.yaml new file mode 100644 index 0000000..314c6ad --- /dev/null +++ b/config/platform/babel.yaml @@ -0,0 +1,7 @@ +# @package _global_ +platform: + gpus_per_node: 2 + +generator: + sampling_params: + max_generate_length: 8192 \ No newline at end of file diff --git a/config/platform/local.yaml b/config/platform/local.yaml new file mode 100644 index 0000000..50682f0 --- /dev/null +++ b/config/platform/local.yaml @@ -0,0 +1,28 @@ +# @package _global_ +platform: + gpus_per_node: 1 + +trainer: + train_batch_size: 4 + +generator: + sampling_params: + max_generate_length: 2048 + max_input_length: 4096 + max_num_batched_tokens: 16384 + max_turns: 2 + max_num_seqs: 16 + n_samples_per_prompt: 2 + gpu_memory_utilization: 0.5 + +hydra: + job: + env_set: + CUDA_LAUNCH_BLOCKING: "1" + VLLM_FLASH_ATTN_VERSION: "2" + TORCH_USE_CUDA_DSA: "1" + RAY_OVERRIDE_RESOURCES: '{"GPU": 1, "TPU": 0}' + + + + diff --git a/config/training/async.yaml b/config/training/async.yaml new file mode 100644 index 0000000..1fe5a1d --- /dev/null +++ b/config/training/async.yaml @@ -0,0 +1,41 @@ +# @package _global_ +run_async_trainer: true + +trainer: + placement: + colocate_all: false + colocate_policy_ref: true + policy_num_gpus_per_node: ${eval:'${platform.gpus_per_node} // 2'} + ref_num_gpus_per_node: ${eval:'${platform.gpus_per_node} // 2'} + + strategy: fsdp2 + + policy: + fsdp_config: + cpu_offload: true + reshard_after_forward: true + fsdp_size: -1 + sequence_parallel_size: 1 + + fully_async: + num_parallel_generation_workers: 16 + + + + train_batch_size: 8 + policy_mini_batch_size: 8 + micro_forward_batch_size_per_gpu: 1 + micro_train_batch_size_per_gpu: 1 + + step_wise_training: true + resume_mode: latest + max_ckpts_to_keep: 3 + +generator: + batched: false + n_samples_per_prompt: 8 + gpu_memory_utilization: 0.75 + enforce_eager: false + + num_inference_engines: ${eval:'${platform.gpus_per_node} // 2'} + diff --git a/config/training/sync.yaml b/config/training/sync.yaml new file mode 100644 index 0000000..c3d6a63 --- /dev/null +++ b/config/training/sync.yaml @@ -0,0 +1,24 @@ +# @package _global_ +run_async_trainer: false + +trainer: + placement: + colocate_all: true + colocate_policy_ref: true + + strategy: fsdp2 + + policy_mini_batch_size: 4 + micro_forward_batch_size_per_gpu: 2 + micro_train_batch_size_per_gpu: 2 + resume_mode: null + + policy: + sequence_parallel_size: ${platform.gpus_per_node} + ref: + sequence_parallel_size: ${platform.gpus_per_node} + +generator: + batched: true + n_samples_per_prompt: 4 + gpu_memory_utilization: 0.6 diff --git a/pyproject.toml b/pyproject.toml index d2dbc08..92ef839 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "seaborn>=0.13.2", "gcsfs>=2025.3.0", "lmcache", + "hydra-submitit-launcher>=1.2.0", ] [build-system] diff --git a/src/train.py b/src/train.py index 9c405e6..21ea67a 100644 --- a/src/train.py +++ b/src/train.py @@ -1,17 +1,16 @@ import hydra -from omegaconf import DictConfig, OmegaConf, open_dict -from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg +from omegaconf import DictConfig, OmegaConf +from skyrl_train.entrypoints.main_base import BasePPOExp from skyrl_train.utils import initialize_ray import ray import asyncio -from src.tools import tool_exists from src.generator.code_search_generator import CodeSearchGenerator from src.async_trainer import CustomFullyAsyncRayPPOTrainer as FullyAsyncRayPPOTrainer +from src.utils.config import validate_cfg, register_resolvers # from skyrl_train.fully_async_trainer import FullyAsyncRayPPOTrainer - class CodeSearchPPOExp(BasePPOExp): def get_generator(self, cfg, tokenizer, inference_engine_client): generator = CodeSearchGenerator( @@ -61,47 +60,16 @@ def skyrl_entrypoint(cfg: DictConfig): else: print("Running sync trainer") exp = CodeSearchPPOExp(cfg) + + register_resolvers() + exp.run() -@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) +@hydra.main(config_path="../config", config_name="base", version_base=None) def main(cfg: DictConfig) -> None: # validate the arguments validate_cfg(cfg) - - # check cfg.generator.exp_config if it exists or not - if hasattr(cfg.generator, "exp_config"): - # Open yaml file and print its contents - with open(cfg.generator.exp_config, "r") as f: - exp_cfg = OmegaConf.load(f) - - with open_dict(cfg): - cfg.generator.reward = exp_cfg.reward - cfg.generator.tools = exp_cfg.tools - # Parse prompts if they exist in the exp config - if hasattr(exp_cfg, "prompts"): - cfg.generator.prompts = exp_cfg.prompts - else: - with open_dict(cfg): - cfg.generator.reward = [ - {"fn": "multilevel_localization_f1_reward"}, - ] - cfg.generator.tools = [ - "terminal", - ] - - # Check if the tool exists in the registry - for tool in cfg.generator.tools: - if not tool_exists(tool): - raise ValueError(f"Tool {tool} does not exist in the registry") - - # Set default prompts if not specified - if not hasattr(cfg.generator, "prompts"): - with open_dict(cfg): - cfg.generator.prompts = { - "system_prompt": "templates/system_prompt.j2", - "user_prompt": "templates/file_module_parallel_tools.j2" - } initialize_ray(cfg) ray.get(skyrl_entrypoint.remote(cfg)) diff --git a/src/utils/config.py b/src/utils/config.py new file mode 100644 index 0000000..9c3951f --- /dev/null +++ b/src/utils/config.py @@ -0,0 +1,39 @@ +from omegaconf import DictConfig, OmegaConf +from skyrl_train.utils import validate_cfg as validate_cfg_skyrl + +from src.tools import DEFAULT_OPENHANDS_TOOLS, TOOL_REGISTRY + +def register_resolvers(): + OmegaConf.register_new_resolver("replace", lambda s,old,new: s.replace(old, new)) + OmegaConf.register_new_resolver("eval", lambda x: eval(x)) + +def validate_cfg(config: DictConfig): + validate_cfg_skyrl(config) + + assert hasattr( + config, "generator" + ), "Missing `generator` config block" + + assert hasattr( + config.generator, "tools" + ), "Missing `generator.tools` (pick via Hydra `agent=...`)" + + assert ( + config.generator.tools is not None and len(config.generator.tools) > 0 + ), "`generator.tools` must be a non-empty list" + + for tool in config.generator.tools: + assert tool in DEFAULT_OPENHANDS_TOOLS or tool in TOOL_REGISTRY, f"Tool {tool} does not exist in the registry" + + assert hasattr( + config.generator, "prompts" + ), "Missing `generator.prompts` (pick via Hydra `prompt=...`)" + + assert hasattr( + config.generator.prompts, "system_prompt" + ), "Missing `generator.prompts.system_prompt`" + + assert hasattr( + config.generator.prompts, "user_prompt" + ), "Missing `generator.prompts.user_prompt`" +