Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions README_Training.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,49 @@ uv run src/build_dataset.py --output ../data/

## Train Model

Training is driven by Hydra configs under `config/` (see `config/base.yaml` defaults). Run `src.train` and select an experiment via `experiment=...`.
Most configuration options follow the SkyRL config schema; see [SkyRL configuration docs](https://skyrl.readthedocs.io/en/latest/configuration/config.html) for details.
See [Hydra docs](https://hydra.cc/docs/intro/) for more details.

### Minimal for local testing

```
uv run -m src.train model=Qwen/Qwen3-0.6B
```

### Async training

```
uv run -m src.train model=Qwen/Qwen3-4B training=async
```

### Async training with LoRA experiment config

```
bash scripts run_training.sh -m Qwen/Qwen3-0.6B -d <Absolute Path to Data>
uv run -m src.train model=Qwen/Qwen3-4B training=async experiment=lora
```

### Multirun
Note: the `none` experiment is a placeholder to allow multiruns including the baseline setting.

```
DATA_PATH=<Absolute Path to Data>
bash scripts/run_async_training.sh -m Qwen/Qwen3-4B -d $DATA_PATH 2>&1 | tee training.log
uv run -m src.train -m model=Qwen/Qwen3-4B training=async experiment=none,lora
```

### Slurm example (Submitit launcher)

```
DATA_PATH=<Absolute Path to Data>
bash scripts/run_async_training.sh \
-m Qwen/Qwen3-4B \
-o "+generator.exp_config=configs/skyrl-experiments/read-only.yaml" \
-d $DATA_PATH \
2>&1 | tee training.log
```
uv run -m src.train model=Qwen/Qwen3-4B training=async platform=babel hydra/launcher=slurm
```

### Add a new experiment

Create `config/experiment/my_exp.yaml`:

```yaml
# @package _global_
trainer:
logger: wandb
```

Then run with `experiment=my_exp`.
93 changes: 93 additions & 0 deletions config/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
defaults:
- ppo_base_config
- training: sync
- platform: local
- experiment: none
- override hydra/launcher: local
- _self_

hydra:
searchpath:
- pkg://skyrl_train.config
run:
dir: ${oc.env:PWD}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
sweep:
dir: ${oc.env:PWD}/outputs/${now:%Y-%m-%d}
subdir: ${now:%H-%M-%S}/${hydra.job.num}

model: Qwen/Qwen3-4B
model_alias: ${replace:${model},/,-}
ckpt_path: ${oc.env:PWD}/ckpts/${model_alias}
data_path: data/swe_smith

data:
train_data:
- ${data_path}/train.parquet
val_data:
- ${data_path}/validation.parquet

trainer:
project_name: code_search
run_name: code_search_${model_alias}
logger: wandb
ckpt_path: ${ckpt_path}

algorithm:
advantage_estimator: grpo

placement:
policy_num_nodes: 1
ref_num_nodes: 1
policy_num_gpus_per_node: ${platform.gpus_per_node}
ref_num_gpus_per_node: ${platform.gpus_per_node}

policy:
model:
path: ${model}

eval_before_train: false
epochs: 20
eval_batch_size: 4
eval_interval: 100
update_epochs_per_batch: 1
dump_data_batch: true
export_path: ${ckpt_path}/exported_model/
hf_save_interval: 5
ckpt_interval: 5
max_prompt_length: 4096

generator:
backend: vllm
run_engines_locally: true

enable_http_endpoint: true
http_endpoint_host: "0.0.0.0"
http_endpoint_port: 8080
weight_sync_backend: nccl

async_engine: true
inference_engine_tensor_parallel_size: 1
num_inference_engines: 1

traj_dir: ${ckpt_path}/trajectories/

engine_init_kwargs:
enable_auto_tool_choice: true
tool_call_parser: hermes
reasoning_parser: qwen3

sampling_params:
temperature: 1.0

max_input_length: 24000
max_num_batched_tokens: 48000
max_turns: 20

prompts:
system_prompt: "templates/system_prompt.j2"
user_prompt: "templates/file_localization.j2"
reward:
- fn: tool_use_reward
- fn: turn_efficiency
tools:
- terminal
11 changes: 11 additions & 0 deletions config/experiment/lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# @package _global_
trainer:
policy:
model:
lora:
rank: 32
alpha: 32
dropout: 0
lora_sync_path: "/tmp/skyrl_lora_sync"
target_modules: "all-linear"
exclude_modules: null
Empty file added config/experiment/none.yaml
Empty file.
12 changes: 12 additions & 0 deletions config/hydra/launcher/local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# @package hydra.launcher
_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher

submitit_folder: ${hydra.sweep.dir}/.submitit/%j

timeout_min: 60
cpus_per_task: 1
gpus_per_node: 1
tasks_per_node: 1
mem_gb: 4
nodes: 1
name: ${hydra.job.name}
16 changes: 16 additions & 0 deletions config/hydra/launcher/slurm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# @package hydra.launcher
_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher

submitit_folder: ${hydra.sweep.dir}/.submitit/%j

name: cso
partition: general
nodes: 1
tasks_per_node: 1
cpus_per_task: 32
mem_gb: 512
timeout_min: 2880 # 2-00:00:00
exclude: babel-q5-28,babel-o5-20
additional_parameters:
gres: gpu:A100:2

7 changes: 7 additions & 0 deletions config/platform/babel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# @package _global_
platform:
gpus_per_node: 2

generator:
sampling_params:
max_generate_length: 8192
28 changes: 28 additions & 0 deletions config/platform/local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# @package _global_
platform:
gpus_per_node: 1

trainer:
train_batch_size: 4

generator:
sampling_params:
max_generate_length: 2048
max_input_length: 4096
max_num_batched_tokens: 16384
max_turns: 2
max_num_seqs: 16
n_samples_per_prompt: 2
gpu_memory_utilization: 0.5

hydra:
job:
env_set:
CUDA_LAUNCH_BLOCKING: "1"
VLLM_FLASH_ATTN_VERSION: "2"
TORCH_USE_CUDA_DSA: "1"
RAY_OVERRIDE_RESOURCES: '{"GPU": 1, "TPU": 0}'




41 changes: 41 additions & 0 deletions config/training/async.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# @package _global_
run_async_trainer: true

trainer:
placement:
colocate_all: false
colocate_policy_ref: true
policy_num_gpus_per_node: ${eval:'${platform.gpus_per_node} // 2'}
ref_num_gpus_per_node: ${eval:'${platform.gpus_per_node} // 2'}

strategy: fsdp2

policy:
fsdp_config:
cpu_offload: true
reshard_after_forward: true
fsdp_size: -1
sequence_parallel_size: 1

fully_async:
num_parallel_generation_workers: 16



train_batch_size: 8
policy_mini_batch_size: 8
micro_forward_batch_size_per_gpu: 1
micro_train_batch_size_per_gpu: 1

step_wise_training: true
resume_mode: latest
max_ckpts_to_keep: 3

generator:
batched: false
n_samples_per_prompt: 8
gpu_memory_utilization: 0.75
enforce_eager: false

num_inference_engines: ${eval:'${platform.gpus_per_node} // 2'}

24 changes: 24 additions & 0 deletions config/training/sync.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# @package _global_
run_async_trainer: false

trainer:
placement:
colocate_all: true
colocate_policy_ref: true

strategy: fsdp2

policy_mini_batch_size: 4
micro_forward_batch_size_per_gpu: 2
micro_train_batch_size_per_gpu: 2
resume_mode: null

policy:
sequence_parallel_size: ${platform.gpus_per_node}
ref:
sequence_parallel_size: ${platform.gpus_per_node}

generator:
batched: true
n_samples_per_prompt: 4
gpu_memory_utilization: 0.6
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"seaborn>=0.13.2",
"gcsfs>=2025.3.0",
"lmcache",
"hydra-submitit-launcher>=1.2.0",
]

[build-system]
Expand Down
46 changes: 7 additions & 39 deletions src/train.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import hydra
from omegaconf import DictConfig, OmegaConf, open_dict
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg
from omegaconf import DictConfig, OmegaConf
from skyrl_train.entrypoints.main_base import BasePPOExp
from skyrl_train.utils import initialize_ray
import ray

import asyncio

from src.tools import tool_exists
from src.generator.code_search_generator import CodeSearchGenerator
from src.async_trainer import CustomFullyAsyncRayPPOTrainer as FullyAsyncRayPPOTrainer
from src.utils.config import validate_cfg, register_resolvers
# from skyrl_train.fully_async_trainer import FullyAsyncRayPPOTrainer


class CodeSearchPPOExp(BasePPOExp):
def get_generator(self, cfg, tokenizer, inference_engine_client):
generator = CodeSearchGenerator(
Expand Down Expand Up @@ -61,47 +60,16 @@ def skyrl_entrypoint(cfg: DictConfig):
else:
print("Running sync trainer")
exp = CodeSearchPPOExp(cfg)

register_resolvers()

exp.run()


@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None)
@hydra.main(config_path="../config", config_name="base", version_base=None)
def main(cfg: DictConfig) -> None:
# validate the arguments
validate_cfg(cfg)

# check cfg.generator.exp_config if it exists or not
if hasattr(cfg.generator, "exp_config"):
# Open yaml file and print its contents
with open(cfg.generator.exp_config, "r") as f:
exp_cfg = OmegaConf.load(f)

with open_dict(cfg):
cfg.generator.reward = exp_cfg.reward
cfg.generator.tools = exp_cfg.tools
# Parse prompts if they exist in the exp config
if hasattr(exp_cfg, "prompts"):
cfg.generator.prompts = exp_cfg.prompts
else:
with open_dict(cfg):
cfg.generator.reward = [
{"fn": "multilevel_localization_f1_reward"},
]
cfg.generator.tools = [
"terminal",
]

# Check if the tool exists in the registry
for tool in cfg.generator.tools:
if not tool_exists(tool):
raise ValueError(f"Tool {tool} does not exist in the registry")

# Set default prompts if not specified
if not hasattr(cfg.generator, "prompts"):
with open_dict(cfg):
cfg.generator.prompts = {
"system_prompt": "templates/system_prompt.j2",
"user_prompt": "templates/file_module_parallel_tools.j2"
}

initialize_ray(cfg)
ray.get(skyrl_entrypoint.remote(cfg))
Expand Down
Loading