Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# an example

export RAY_TMPDIR="${RAY_TMPDIR:-/tmpworkspace/}"
export TMPDIR="${TMPDIR:-/tmpworkspace/tmp}"
export TMP="${TMP:-$TMPDIR}"
export TEMP="${TEMP:-$TMPDIR}"
export TESTBED_ROOT="${TESTBED_ROOT:-/tmpworkspace/testbed}"
mkdir -p "$RAY_TMPDIR" "$TMPDIR" "$TESTBED_ROOT"

export DATA_PATH=data/swe_gym
export LOG_DATE=$(date +%m%d)
mkdir -p logs

export PROJECT_NAME=code_search_baselines
export WANDB_API_KEY=
export RUN_NAME=swegym-${BASE_MODEL}-grpo
27 changes: 21 additions & 6 deletions README_Training.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,41 @@

## Build Dataset

```
```bash
uv run src/build_dataset.py --output ../data/
```

## Train Model

```
bash scripts run_training.sh -m Qwen/Qwen3-0.6B -d <Absolute Path to Data>
```
### Basic Training

```bash
bash scripts/run_training.sh -m Qwen/Qwen3-0.6B -d <Absolute Path to Data>
```

### Async Training

```bash
DATA_PATH=<Absolute Path to Data>
bash scripts/run_async_training.sh -m Qwen/Qwen3-4B -d $DATA_PATH 2>&1 | tee training.log
```

```
### With Custom Config

```bash
DATA_PATH=<Absolute Path to Data>
bash scripts/run_async_training.sh \
-m Qwen/Qwen3-4B \
-o "+generator.exp_config=configs/skyrl-experiments/read-only.yaml" \
-d $DATA_PATH \
2>&1 | tee training.log
```
```

## Model Path Formats

The `-m` parameter supports multiple formats:
- **HuggingFace model ID**: `Qwen/Qwen2.5-7B-Instruct`
- **Absolute path**: `/mnt/models/qwen3-4b`
- **Relative path**: `./models/qwen3-4b` or `~/models/qwen3-4b`

> **Note**: For relative paths without `./` or `~/` prefix (e.g., `models/qwen`), the system will check if the path exists locally; otherwise, it will attempt to download from HuggingFace Hub.
2 changes: 2 additions & 0 deletions README_verifiers.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ sudo apt-get install ripgrep -y
```bash
uv run vf-eval swe-grep-oss-env --api-base-url http://localhost:8000/v1 --model "Qwen/Qwen3-8B" --num-examples 1 --rollouts-per-example 1
```

> Note: `verifiers` is pinned to `0.1.6.post0` for SDK/API compatibility. If you hit version mismatch issues, re-run `uv sync` and avoid installing `verifiers` via `pip` outside the `uv` environment.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"openhands-agent-server",
"openhands-workspace",
"vllm==0.11.0",
"verifiers>=0.1.6.post0",
"verifiers==0.1.6.post0",
"datasets>=4.0.0",
"ipykernel>=7.1.0",
"ipywidgets>=8.1.8",
Expand Down Expand Up @@ -64,7 +64,7 @@ explicit = true
flash-attn = ["torch"]

[tool.uv.sources]
skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", rev = "69ca4d9", subdirectory = "skyrl-train" }
skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", rev = "7504d18", subdirectory = "skyrl-train" }
flash-attn = {url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp313-cp313-linux_x86_64.whl"}
openhands-sdk = { workspace = true }
openhands-tools = { workspace = true }
Expand Down
19 changes: 12 additions & 7 deletions scripts/run_async_training.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ NUM_TRAINING_ENGINES="${NUM_TRAINING_ENGINES:-$HALF_NUM_GPUS}"
export VLLM_FLASH_ATTN_VERSION=2
export CUDA_LAUNCH_BLOCKING=1
export TORCH_USE_CUDA_DSA=1
# fully-async constraints:
# - train_batch_size must equal policy_mini_batch_size
# - mini_batch_size <= num_parallel_generation_workers <= mini_batch_size * (max_staleness_steps + 1)
NUM_PARALLEL_GENERATION_WORKERS="${NUM_PARALLEL_GENERATION_WORKERS:-$TRAIN_BATCH_SIZE}"


uv run --isolated -m src.train \
+run_async_trainer=true \
Expand All @@ -49,7 +54,7 @@ uv run --isolated -m src.train \
trainer.policy.fsdp_config.cpu_offload=true \
trainer.policy.fsdp_config.reshard_after_forward=true \
trainer.policy.fsdp_config.fsdp_size=-1 \
trainer.fully_async.num_parallel_generation_workers=16 \
trainer.fully_async.num_parallel_generation_workers="${NUM_PARALLEL_GENERATION_WORKERS}" \
trainer.placement.policy_num_gpus_per_node=${NUM_TRAINING_ENGINES} \
trainer.placement.ref_num_gpus_per_node=${NUM_TRAINING_ENGINES} \
trainer.placement.policy_num_nodes=1 \
Expand All @@ -61,7 +66,7 @@ uv run --isolated -m src.train \
+generator.engine_init_kwargs.enable_auto_tool_choice=true \
+generator.engine_init_kwargs.tool_call_parser=hermes \
+generator.engine_init_kwargs.reasoning_parser=qwen3 \
trainer.epochs=20 \
trainer.epochs=1 \
trainer.eval_batch_size=100 \
trainer.eval_before_train=false \
trainer.eval_interval=100 \
Expand All @@ -72,8 +77,8 @@ uv run --isolated -m src.train \
trainer.micro_train_batch_size_per_gpu=${MICRO_BATCH_SIZE:-1} \
trainer.dump_data_batch=true \
trainer.export_path="${CKPT_PATH}exported_model/" \
trainer.hf_save_interval=5 \
trainer.ckpt_interval=5 \
trainer.hf_save_interval=10 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=4096 \
generator.sampling_params.max_generate_length=${MAX_LENGTH} \
generator.sampling_params.temperature=1.0 \
Expand All @@ -93,11 +98,11 @@ uv run --isolated -m src.train \
generator.n_samples_per_prompt=${N_ROLLOUTS} \
generator.gpu_memory_utilization=0.75 \
generator.enforce_eager=false \
trainer.step_wise_training=true \
+trainer.step_wise_training=true \
trainer.logger="wandb" \
trainer.project_name="code_search" \
trainer.project_name="${PROJECT_NAME}" \
trainer.run_name=${RUN_NAME} \
trainer.resume_mode=latest \
trainer.ckpt_path="$CKPT_PATH" \
trainer.max_ckpts_to_keep=3 \
$OTHER_OPTION
$OTHER_OPTION 2>&1 | tee logs/${LOG_DATE}_${RUN_NAME}.log
31 changes: 24 additions & 7 deletions src/generator/code_search_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def init_and_run(

# Avoid collisions in /tmp testbed directories
uuid_str = str(uuid.uuid4())[:8]
workspace = Path(f"/tmp/testbed/{uuid_str}/")
# Allow overriding testbed root to avoid filling local /tmp on constrained machines.
# Defaults to /tmp/testbed to preserve existing behavior.
workspace = Path(os.environ.get("TESTBED_ROOT", "/tmp/testbed")) / uuid_str
workspace.mkdir(parents=True, exist_ok=True)
status, working_dir = clone_instance(repo_name, commit_id, instance_id, workspace)

if training_phase == "eval":
Expand Down Expand Up @@ -123,13 +126,20 @@ def init_and_run(
system_prompt_path = os.path.join(prompts_base_dir, generator_cfg.prompts.system_prompt)
user_prompt_path = os.path.join(prompts_base_dir, generator_cfg.prompts.user_prompt)

# Get max_input_length from config to prevent context overflow
max_input_length = generator_cfg.get("max_input_length", 38400)
# Reserve some tokens for system prompt, tools, and response generation
# Set max_input_tokens to ensure OpenHands handles context length properly
effective_max_input = max_input_length - 2000 # Reserve 2000 tokens for overhead

agent = CustomAgent(
llm=LLM(
usage_id="agent",
model=litellm_model_name,
base_url=litellm_base_url,
api_key="sk-xxx",
temperature=temperature,
max_input_tokens=effective_max_input, # Let OpenHands handle context truncation
litellm_extra_body={
"return_token_ids": True,
"include_stop_str_in_output": True,
Expand Down Expand Up @@ -200,19 +210,26 @@ def __init__(
generator_cfg, skyrl_gym_cfg, inference_engine_client, tokenizer, model_name
)

self.http_endpoint_host = generator_cfg.get(
"http_endpoint_host", "127.0.0.1"
)
# NOTE:
# `http_endpoint_host` is often set to "0.0.0.0" for *binding* the local server,
# but clients should not connect to 0.0.0.0. When we build the OpenAI-compatible
# base_url for LiteLLM/OpenHands, prefer a loopback address in that case.
self.http_endpoint_host = generator_cfg.get("http_endpoint_host", "127.0.0.1")
self.http_endpoint_port = generator_cfg.get(
"http_endpoint_port", 8000
)
self.base_url = f"http://{self.http_endpoint_host}:{self.http_endpoint_port}/v1/"
request_host = self.http_endpoint_host
if request_host in {"0.0.0.0", "::"}:
request_host = "127.0.0.1"
self.base_url = f"http://{request_host}:{self.http_endpoint_port}/v1/"
logger.info(f"Using CodeSearchGenerator with model {model_name} at {self.base_url}")
self.generator_cfg = generator_cfg
self.tokenizer = tokenizer
self.model_name = model_name
# self.litellm_model_name = "openai/" + self.model_name
self.litellm_model_name = "litellm_proxy/" + self.model_name

# LiteLLM model routing: supports both HF model IDs and local paths
# Examples: "Qwen/Qwen2.5-7B" or "/path/to/model" or "./models/qwen"
self.litellm_model_name = f"openai/{self.model_name}"

if self.generator_cfg.chat_template.name_or_path is not None:
raise NotImplementedError(
Expand Down
30 changes: 8 additions & 22 deletions src/metrics/efficiency_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ def compute_tool_call_metrics(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
- avg_tool_calls_per_step: Average tool calls per step
- tool_call_breakdown: Dict mapping tool names to counts
"""
# Find all assistant messages with tool calls
tool_call_count = 0
tool_breakdown = {}
if not messages:
return {
Expand All @@ -84,26 +82,14 @@ def compute_tool_call_metrics(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
"tool_call_breakdown": {},
}

for msg in messages:
# Assistant messages contain tool_calls field
if msg.get("role") == "assistant" and "tool_calls" in msg:
tool_calls = msg.get("tool_calls", [])
tool_call_count += len(tool_calls)

# Track tool types
for tool_call in tool_calls:
# Handle both dict and object-like structures
if isinstance(tool_call, dict):
if "function" in tool_call:
tool_name = tool_call["function"].get("name", "unknown")
else:
tool_name = "unknown"
elif hasattr(tool_call, "function"):
tool_name = getattr(tool_call.function, "name", "unknown")
else:
tool_name = "unknown"

tool_breakdown[tool_name] = tool_breakdown.get(tool_name, 0) + 1
# Count tool calls from ActionEvents (the actual tool invocations)
action_messages = [msg for msg in messages if msg.get("kind") == "ActionEvent"]
tool_call_count = len(action_messages)

# Track tool types from ActionEvents
for action in action_messages:
tool_name = action.get("tool_name", "unknown")
tool_breakdown[tool_name] = tool_breakdown.get(tool_name, 0) + 1

# Calculate average per step
num_steps = compute_step_count(messages)
Expand Down
Loading