Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions README_Training.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
uv run src/build_dataset.py --output ../data/
```

## Train Model
## Train Model with RLVR

```
bash scripts run_training.sh -m Qwen/Qwen3-0.6B -d <Absolute Path to Data>
Expand All @@ -24,4 +24,26 @@ bash scripts/run_async_training.sh \
-o "+generator.exp_config=configs/skyrl-experiments/read-only.yaml" \
-d $DATA_PATH \
2>&1 | tee training.log
```
```

## Train Model with On-Policy Distillation

```
DATA_PATH=<Absolute Path to Data>
bash scripts/run_distillation.sh \
-m Qwen/Qwen3-4B \ # Student model (model to be trained)
-r Qwen/Qwen3-8B \ # Teacher model (model to distill from)
-d $DATA_PATH \
2>&1 | tee distillation.log
```

```
DATA_PATH=<Absolute Path to Data>
bash scripts/run_distillation.sh \
-m Qwen/Qwen3-4B \ # Student model (model to be trained)
-r Qwen/Qwen3-8B \ # Teacher model (model to distill from)
-o "+generator.exp_config=configs/skyrl-experiments/read-only.yaml" \
-d $DATA_PATH \
2>&1 | tee distillation.log
```

10 changes: 10 additions & 0 deletions scripts/distill.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

# Loop over 10
for i in $(seq 1 10)
do
echo "Run number: $i"
# Kill any process using port 8080 after 4 hours
( sleep 14400 && fuser -k 8080/tcp ) & \
bash scripts/run_distillation.sh "$@"
done
130 changes: 130 additions & 0 deletions scripts/run_distillation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/bin/bash
#
# Usage: bash scripts/run_distillation.sh \
# -m Qwen/Qwen3-4B \ # Student model (model to be trained)
# -r Qwen/Qwen3-32B \ # Reference/Teacher model (model to distill from)
# -d data/swe_gym \ # Data path
# [-s ckpt_path] [-n n_rollouts] [-i num_inference_engines] [-t num_training_engines]
#

. .env 2>/dev/null || true

while getopts ":m:r:n:d:s:o:i:t:b:" opt; do
case ${opt} in
m ) STUDENT_MODEL=$OPTARG;; # -m: Student model (model to be trained)
r ) TEACHER_MODEL=$OPTARG;; # -r: Reference/Teacher model (model to distill from)
n ) N_ROLLOUTS=$OPTARG;;
d ) DATA_PATH=$OPTARG;;
s ) CKPT_PATH=$OPTARG;;
o ) OTHER_OPTION=$OPTARG;;
i ) NUM_INFERENCE_ENGINES=$OPTARG;;
t ) NUM_TRAINING_ENGINES=$OPTARG;;
b ) MICRO_BATCH_SIZE=$OPTARG;;
\? ) echo "Usage: $0 -m <student_model> -r <teacher_model> [-d data_path] [-s ckpt_path] [-n n_rollouts] [-i num_inference_engines] [-t num_training_engines] [-b micro_batch_size] [-o other_options]"; exit 1;;
esac
done

# Validate required parameters
if [ -z "$STUDENT_MODEL" ]; then
echo "Error: Student model (-m) is required"
echo "Usage: $0 -m <student_model> -r <teacher_model> -d <data_path>"
exit 1
fi
if [ -z "$TEACHER_MODEL" ]; then
echo "Error: Teacher model (-r) is required"
echo "Usage: $0 -m <student_model> -r <teacher_model> -d <data_path>"
exit 1
fi

STUDENT_MODEL_ALIAS=$(echo $STUDENT_MODEL | sed 's/\//-/g')
TEACHER_MODEL_ALIAS=$(echo $TEACHER_MODEL | sed 's/\//-/g')
# Get number of GPUs available
NUM_GPUS=$(nvidia-smi -L | wc -l)
N_ROLLOUTS="${N_ROLLOUTS:-8}"
BATCH_SIZE=16 # Must be <= num_parallel_generation_workers (set to 16 below)
MAX_LENGTH=8192
RUN_NAME="code_search_distillation_${STUDENT_MODEL_ALIAS}_${TEACHER_MODEL_ALIAS}"
set -x

DATA_PATH="${DATA_PATH:-data/swe_smith}"
CKPT_PATH="${CKPT_PATH:-$(pwd)/ckpts/${STUDENT_MODEL_ALIAS}}"
mkdir -p $CKPT_PATH

HALF_NUM_GPUS=$((NUM_GPUS / 2))
NUM_INFERENCE_ENGINES="${NUM_INFERENCE_ENGINES:-$NUM_GPUS}"
NUM_TRAINING_ENGINES="${NUM_TRAINING_ENGINES:-$NUM_GPUS}"

export VLLM_FLASH_ATTN_VERSION=2
export CUDA_LAUNCH_BLOCKING=1
export TORCH_USE_CUDA_DSA=1


uv run python -m src.train \
+run_async_trainer=false \
+use_distillation=true \
data.train_data="['$DATA_PATH/train.parquet']" \
data.val_data="['$DATA_PATH/validation.parquet']" \
trainer.algorithm.advantage_estimator="no_op" \
trainer.algorithm.policy_loss_type="importance_sampling" \
trainer.algorithm.use_kl_in_reward=true \
trainer.algorithm.use_kl_loss=false \
trainer.policy.model.path=${STUDENT_MODEL} \
trainer.ref.model.path=${TEACHER_MODEL} \
trainer.placement.colocate_all=true \
trainer.placement.colocate_policy_ref=true \
trainer.strategy=fsdp2 \
trainer.policy.fsdp_config.cpu_offload=true \
trainer.policy.fsdp_config.reshard_after_forward=true \
trainer.policy.fsdp_config.fsdp_size=-1 \
trainer.fully_async.num_parallel_generation_workers=16 \
trainer.placement.policy_num_gpus_per_node=${NUM_TRAINING_ENGINES} \
trainer.placement.ref_num_gpus_per_node=${NUM_TRAINING_ENGINES} \
trainer.placement.policy_num_nodes=1 \
trainer.placement.ref_num_nodes=1 \
trainer.policy.sequence_parallel_size=1 \
generator.num_inference_engines=${NUM_INFERENCE_ENGINES} \
generator.inference_engine_tensor_parallel_size=1 \
+generator.traj_dir=${CKPT_PATH}trajectories/ \
+generator.engine_init_kwargs.enable_auto_tool_choice=true \
+generator.engine_init_kwargs.tool_call_parser=hermes \
+generator.engine_init_kwargs.reasoning_parser=qwen3 \
trainer.epochs=20 \
trainer.eval_batch_size=100 \
trainer.eval_before_train=false \
trainer.eval_interval=100 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=${BATCH_SIZE} \
trainer.policy_mini_batch_size=${BATCH_SIZE} \
trainer.micro_forward_batch_size_per_gpu=1 \
trainer.micro_train_batch_size_per_gpu=${MICRO_BATCH_SIZE:-1} \
trainer.dump_data_batch=true \
trainer.export_path="${CKPT_PATH}exported_model/" \
trainer.hf_save_interval=5 \
trainer.ckpt_interval=5 \
trainer.max_prompt_length=4096 \
generator.sampling_params.max_generate_length=${MAX_LENGTH} \
generator.sampling_params.temperature=1.0 \
generator.max_input_length=32768 \
generator.max_num_batched_tokens=131072 \
generator.max_turns=4 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=False \
generator.backend=vllm \
generator.run_engines_locally=True \
generator.enable_http_endpoint=True \
generator.http_endpoint_host='0.0.0.0' \
generator.http_endpoint_port=8080 \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=false \
generator.n_samples_per_prompt=${N_ROLLOUTS} \
generator.gpu_memory_utilization=0.75 \
generator.enforce_eager=false \
trainer.step_wise_training=true \
trainer.logger="wandb" \
trainer.project_name="code_search" \
trainer.run_name=${RUN_NAME} \
trainer.resume_mode=latest \
trainer.ckpt_path="$CKPT_PATH" \
trainer.max_ckpts_to_keep=3 \
$OTHER_OPTION
64 changes: 64 additions & 0 deletions src/distiller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import ray
from omegaconf import DictConfig
from skyrl_train.entrypoints.main_base import BasePPOExp
import hydra
from skyrl_train.trainer import RayPPOTrainer
from skyrl_train.utils import initialize_ray
from skyrl_train.entrypoints.main_base import config_dir, validate_cfg
from skyrl_train.utils.ppo_utils import (
register_advantage_estimator,
register_policy_loss,
reduce_loss,
)
from skyrl_train.training_batch import TrainingInputBatch
from skyrl_train.fully_async_trainer import FullyAsyncRayPPOTrainer

def apply_reward_kl_penalty(data: TrainingInputBatch) -> TrainingInputBatch:
"""Computes the KL penalty and sets the rewards to the KL penalty"""
loss_masks_all: torch.Tensor = data["loss_mask"]
teacher_action_log_probs: torch.Tensor = data["base_action_log_probs"]
action_log_probs: torch.Tensor = data["action_log_probs"]
rewards = -(action_log_probs - teacher_action_log_probs) * loss_masks_all
data["rewards"] = rewards
return data

class OnPolicyDistillationTrainer(RayPPOTrainer):
"""
Custom trainer for On Policy Distillation.

Overrides the apply_reward_kl_penalty method to set the rewards just to the kl penalty
"""

def apply_reward_kl_penalty(
self,
data: TrainingInputBatch,
) -> TrainingInputBatch:
"""Computes the KL penalty and sets the rewards to the KL penalty"""
return apply_reward_kl_penalty(data)

class FullyAsyncOnPolicyDistillationTrainer(FullyAsyncRayPPOTrainer):
def apply_reward_kl_penalty(
self,
data: TrainingInputBatch,
) -> TrainingInputBatch:
return apply_reward_kl_penalty(data)


# Using the decorator
@register_advantage_estimator("no_op")
def compute_no_op_advantage(token_level_rewards: torch.Tensor, **kwargs):
# just pass through the rewards
return token_level_rewards, token_level_rewards


@register_policy_loss("importance_sampling")
def compute_importance_sampling_policy_loss(
log_probs, old_log_probs, advantages, config, loss_mask=None, rollout_logprobs=None, **kwargs
):
# as defined here: https://tinker-docs.thinkingmachines.ai/losses#policy-gradient-importance_sampling
loss = -torch.exp(log_probs - old_log_probs) * advantages

loss = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", config.max_seq_len)
# return loss and a dummy clip ratio value as we aren't clipping here
return loss, 0.0
54 changes: 38 additions & 16 deletions src/generator/code_search_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ def init_and_run(
messages = list(map(lambda event: event.model_dump(), conversation.state.events))
final_message = get_agent_final_response(conversation.state.events)

# remove the workspace dir
try:
if workspace.exists():
os.system(f"rm -rf {str(workspace)}")
logger.info(f"Removed workspace {str(workspace)}")
except Exception as e:
logger.error(f"Error removing workspace {str(workspace)}: {e}", exc_info=True)


conversation.close()
logger.info("Conversation Finished")

Expand Down Expand Up @@ -354,31 +363,44 @@ async def code_search_loop(
current_response_ids = current_response_ids[len(current_prompt_ids):]

max_response_len = max_train_len - len(current_prompt_ids)

mask = [1]*len(token_messages[0]["response_token_ids"])
for i in range(1, len(token_messages)):
mask += [0] * (len(token_messages[i]["prompt_token_ids"]) - len(token_messages[i-1]["prompt_token_ids"]) - len(token_messages[i-1]["response_token_ids"]))
mask += [1] * len(token_messages[i]["response_token_ids"])
# make mask of 0 for everything inside <|im_start|>
# and assistant and 1 elsewhere
start_token_id = self.tokenizer.convert_tokens_to_ids("<|im_start|>")
end_token_id = self.tokenizer.convert_tokens_to_ids("assistant")
end_of_turn_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
mask = []
found_role_switch = False
inside = False
for token_id in current_response_ids:
if token_id == start_token_id:
inside = True
mask.append(0)
elif token_id == end_token_id:
inside = False
mask.append(0)
idx = 0
while idx < len(current_response_ids):
token_id = current_response_ids[idx]
if not inside:
mask.append(1)
idx += 1
if token_id == end_of_turn_token_id:
inside = True
else:
if inside:
if token_id == start_token_id:
inside = True
mask.append(0)
idx += 1
elif token_id == end_token_id and found_role_switch:
inside = False
mask.append(0)
mask.append(0)
idx += 2
else:
mask.append(1)
mask.append(0)
idx += 1

# mask zero out everything beyond max_response_len
# Don't truncate the response, just mask out the loss
if len(current_response_ids) > max_response_len:
for i in range(max_response_len, len(current_response_ids)):
mask[i] = 0
if token_id == start_token_id:
found_role_switch = True
else:
found_role_switch = False

rollout_list.append(
(
Expand Down Expand Up @@ -509,7 +531,7 @@ async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
for step_id in range(len(step_outputs)):
out_trajectory_id = copy.deepcopy(trajectory_ids[i])
out_trajectory_id.step = step_id
out_trajectory_ids.append(out_trajectory_id.instance_id)
out_trajectory_ids.append(out_trajectory_id)
is_last_step.append(step_id == len(step_outputs) - 1)

if not len(responses):
Expand Down
2 changes: 1 addition & 1 deletion src/prompts/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_instruction(
"workspace_dir_name": workspace_dir_name,
"working_dir": workspace_path,
}
context["test_instructions"] = ""
# context["test_instructions"] = ""

# Render the instruction
instruction = template.render(context)
Expand Down
2 changes: 1 addition & 1 deletion src/prompts/templates/default.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
I have access to a python code repository in the directory {{ instance.repo_path }} .
I have access to a python code repository in the directory {{ working_dir }} .

Consider the following issue description:

Expand Down
2 changes: 1 addition & 1 deletion src/prompts/templates/file_localization.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
I have access to a python code repository in the directory {{ instance.repo_path }} .
I have access to a python code repository in the directory {{ working_dir }} .

Consider the following issue description:

Expand Down
2 changes: 1 addition & 1 deletion src/prompts/templates/file_module.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
I have access to a python code repository in the directory {{ instance.repo_path }} . Consider the following issue description:
I have access to a python code repository in the directory {{ working_dir }} . Consider the following issue description:

<issue_description>
{{ instance.problem_statement }}
Expand Down
2 changes: 1 addition & 1 deletion src/prompts/templates/file_module_parallel_tools.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
I have access to a python code repository in the directory {{ instance.repo_path }} . Consider the following issue description:
I have access to a python code repository in the directory {{ working_dir }} . Consider the following issue description:

<issue_description>
{{ instance.problem_statement }}
Expand Down
5 changes: 3 additions & 2 deletions src/rewards/file_localization/file_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

def compute_file_f1_score(predicted_files, true_files):
pred, true = set(predicted_files), set(true_files)
if not true:
return 0.0 # return 0 reward if ground truth is empty
tp = len(pred & true)
precision = tp / len(pred) if pred else 0.0
recall = tp / len(true) if true else 0.0
if not pred and not true:
return 1.0

return 0.0 if precision + recall == 0 else 2 * precision * recall / (precision + recall)

# def file_localization_f1_reward(final_message, instance):
Expand Down
Loading