Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions dev/run_yes_no_maybe_kl_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Launch yes-no-maybe-kl-advantage training on SkyPilot (Kubernetes).

Usage:
uv run dev/run_yes_no_maybe_kl_advantage.py
uv run dev/run_yes_no_maybe_kl_advantage.py --fast
uv run dev/run_yes_no_maybe_kl_advantage.py --base-model Qwen/Qwen2.5-7B-Instruct
"""

import argparse
import os
import textwrap

from dotenv import load_dotenv
import sky
from sky import ClusterStatus

load_dotenv()

parser = argparse.ArgumentParser(
description="Launch yes-no-maybe KL advantage training on SkyPilot."
)
parser.add_argument(
"--fast", action="store_true", help="Skip setup (for re-runs on existing cluster)."
)
parser.add_argument(
"--base-model", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct"
)
parser.add_argument("--num-steps", type=int, default=20)
parser.add_argument("--kl-penalty-coef", type=float, default=0.1)
parser.add_argument("--accelerator", type=str, default="H200:1")
parser.add_argument("--cluster-name", type=str, default=None)
parser.add_argument(
"--kl-ref-step",
type=int,
default=None,
help="Checkpoint step of training model to use as KL reference",
)
parser.add_argument(
"--kl-ref-adapter-path",
type=str,
default=None,
help="Path to LoRA adapter checkpoint to use as KL reference",
)
args = parser.parse_args()

cluster_name = args.cluster_name or f"ynm-kl-{args.kl_penalty_coef}"
cluster_prefix = os.environ.get("CLUSTER_PREFIX")
if cluster_prefix:
cluster_name = f"{cluster_prefix}-{cluster_name}"

setup_script = textwrap.dedent("""\
echo 'Setting up environment...'
apt install -y nvtop
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
""")

kl_ref_env = ""
if args.kl_ref_step is not None:
kl_ref_env = f"KL_REF_STEP={args.kl_ref_step} "
elif args.kl_ref_adapter_path is not None:
kl_ref_env = f"KL_REF_ADAPTER_PATH={args.kl_ref_adapter_path} "

run_script = textwrap.dedent(f"""\
source $HOME/.local/bin/env
cd ~/sky_workdir
{kl_ref_env}BASE_MODEL={args.base_model} NUM_STEPS={args.num_steps} KL_PENALTY_COEF={args.kl_penalty_coef} uv run --python 3.11 --extra backend dev/yes-no-maybe-kl-advantage.py
""")

task = sky.Task(
name="yes-no-maybe-kl-advantage",
setup=setup_script,
run=run_script,
workdir=".",
)
task.set_resources(
sky.Resources(accelerators=args.accelerator, cloud=sky.clouds.Kubernetes())
)
task.set_file_mounts(
{
"~/sky_workdir/.env": ".env",
}
)

print(f"Launching on cluster: {cluster_name}")
print(f" base_model: {args.base_model}")
print(f" accelerator: {args.accelerator}")
print(f" num_steps: {args.num_steps}")
print(f" kl_penalty_coef: {args.kl_penalty_coef}")
if args.kl_ref_step is not None:
print(f" kl_ref_step: {args.kl_ref_step}")
if args.kl_ref_adapter_path is not None:
print(f" kl_ref_adapter_path: {args.kl_ref_adapter_path}")

# Cancel any existing jobs on this cluster
cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name]))
if len(cluster_status) > 0 and cluster_status[0]["status"] == ClusterStatus.UP:
print(f"Cluster {cluster_name} is UP. Canceling any active jobs...")
sky.stream_and_get(sky.cancel(cluster_name, all=True))

job_id, _ = sky.stream_and_get(
sky.launch(
task,
cluster_name=cluster_name,
retry_until_up=True,
idle_minutes_to_autostop=60,
down=True,
fast=args.fast,
)
)

print(f"Job submitted (ID: {job_id}). Streaming logs...")
exit_code = sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True)
print(f"Job {job_id} finished with exit code {exit_code}.")
110 changes: 110 additions & 0 deletions dev/yes-no-maybe-kl-advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Yes-no-maybe training with KL-penalized advantage adjustment.

Demonstrates the kl_penalty_coef feature: tokens where the policy has drifted
more from the reference model get reduced advantages, while tokens that have
drifted less get increased advantages.

Uses meta-llama/Meta-Llama-3.1-8B-Instruct as the base model (trained locally).
"""

import asyncio
from itertools import permutations
import os

from dotenv import load_dotenv
import openai

import art
from art.local import LocalBackend


async def rollout(
client: openai.AsyncOpenAI, model: art.TrainableModel, prompt: str
) -> art.Trajectory:
messages: art.Messages = [
{
"role": "user",
"content": prompt,
}
]
chat_completion = await client.chat.completions.create(
messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100
)
choice = chat_completion.choices[0]
content = choice.message.content
assert isinstance(content, str)
if content == "yes":
reward = 0.5
elif content == "no":
reward = 0.75
elif content == "maybe":
reward = 1.0
else:
reward = 0.0
return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)


def with_quotes(w: str) -> str:
return f"'{w}'"


async def main():
load_dotenv()

backend = LocalBackend()
base_model = os.environ.get("BASE_MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1"))
model = art.TrainableModel(
name=os.environ.get("MODEL_NAME", f"kl-{kl_penalty_coef}"),
project="yes-no-maybe",
base_model=base_model,
)
await model.register(backend)

kl_penalty_reference_step: int | None = (
int(os.environ["KL_REF_STEP"])
if os.environ.get("KL_REF_STEP") is not None
else None
)
kl_ref_adapter_path: str | None = os.environ.get("KL_REF_ADAPTER_PATH") or None

prompts = [
f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}"
for prefix in ["respond", "just respond"]
for use_quotes in [True, False]
for words in (
list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n)
)
]

openai_client = model.openai_client()
max_steps = int(os.environ.get("NUM_STEPS", "20"))
start_step = await model.get_step()
for step in range(start_step, start_step + max_steps):
train_groups = await art.gather_trajectory_groups(
(
art.TrajectoryGroup(
rollout(openai_client, model, prompt) for _ in range(32)
)
for prompt in prompts
)
)
result = await backend.train(
model,
train_groups,
learning_rate=1e-4,
kl_penalty_coef=kl_penalty_coef,
kl_penalty_reference_step=kl_penalty_reference_step,
kl_ref_adapter_path=kl_ref_adapter_path,
)
await model.log(
train_groups,
metrics=result.metrics,
step=result.step,
split="train",
)
print(f"step {result.step}: {result.metrics}")


if __name__ == "__main__":
asyncio.run(main())
2 changes: 2 additions & 0 deletions src/art/dev/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class TrainConfig(TypedDict, total=False):
"token", "sequence", "average", "geometric_average"
]
kimi_k2_tau: float | None
kl_penalty_coef: float
kl_ref_adapter_path: str | None
logprob_calculation_chunk_size: int
mask_prob_ratio: bool
max_negative_advantage_importance_sampling_weight: float
Expand Down
28 changes: 26 additions & 2 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ async def train( # type: ignore[override]
# Core training parameters
learning_rate: float = 5e-6,
beta: float = 0.0,
# KL-penalized advantage adjustment
kl_penalty_coef: float = 0.0,
kl_penalty_reference_step: int | None = None,
kl_ref_adapter_path: str | None = None,
# RL algorithm settings
ppo: bool = False,
epsilon: float | None = None,
Expand Down Expand Up @@ -425,7 +429,16 @@ async def train( # type: ignore[override]
model: The trainable model to train.
trajectory_groups: Batches of trajectories to train on.
learning_rate: Learning rate for training. Defaults to 5e-6.
beta: KL penalty coefficient. Defaults to 0.0.
beta: KL penalty coefficient added to the loss. Defaults to 0.0.
kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
Tokens diverging more from the reference get reduced advantages.
Defaults to 0.0 (disabled).
kl_penalty_reference_step: Checkpoint step of the training model to
use as the KL reference. If None, uses the base model (LoRA
disabled) as reference.
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
checkpoint to use as the KL reference. Alternative to
kl_penalty_reference_step.
ppo: Whether to use PPO clipping. Defaults to False.
epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
Expand Down Expand Up @@ -472,11 +485,14 @@ async def train( # type: ignore[override]
groups_list = list(trajectory_groups)

# Build config objects from explicit kwargs
config = TrainConfig(learning_rate=learning_rate, beta=beta)
config = TrainConfig(
learning_rate=learning_rate, beta=beta, kl_penalty_coef=kl_penalty_coef
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"allow_training_without_logprobs": allow_training_without_logprobs,
"importance_sampling_level": importance_sampling_level,
"kl_penalty_coef": kl_penalty_coef,
"mask_prob_ratio": mask_prob_ratio,
"plot_tensors": plot_tensors,
"ppo": ppo,
Expand All @@ -499,6 +515,14 @@ async def train( # type: ignore[override]
dev_config["kimi_k2_tau"] = kimi_k2_tau
if truncated_importance_sampling is not None:
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
if kl_ref_adapter_path is not None:
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
elif kl_penalty_reference_step is not None:
ref_checkpoint_dir = get_step_checkpoint_dir(
get_model_dir(model=model, art_path=self._path),
kl_penalty_reference_step,
)
dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir

# Collect metrics from training
training_metrics: list[dict[str, float]] = []
Expand Down
10 changes: 10 additions & 0 deletions src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Loss(BaseModel):
mean_entropy: torch.Tensor | None
policy_loss_sum: torch.Tensor
probs_corr: torch.Tensor
kl_policy_ref: torch.Tensor | None = None


def loss_fn(
Expand Down Expand Up @@ -92,6 +93,14 @@ def loss_fn(
)
if tau := experimental_config.get("kimi_k2_tau", None):
advantages -= tau * logprob_diff.detach()
kl_policy_ref: torch.Tensor | None = None
kl_penalty_coef = experimental_config.get("kl_penalty_coef", 0.0)
if kl_penalty_coef > 0 and ref_logprobs is not None:
kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask
avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6)
kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask
advantages = advantages + kl_penalty
kl_policy_ref = avg_kl
if ppo:
policy_loss = -torch.min(
prob_ratio * advantages,
Expand Down Expand Up @@ -139,6 +148,7 @@ def loss_fn(
mean_entropy=mean_entropy,
policy_loss_sum=policy_loss.sum(),
probs_corr=probs_corr,
kl_policy_ref=kl_policy_ref,
)


Expand Down
4 changes: 3 additions & 1 deletion src/art/preprocessing/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def create_train_inputs(
[None] if warmup else packed_tensors["image_grid_thw"][offset : offset + 1]
),
config=(
config.model_copy(update={"lr": 1e-9, "beta": 0.0, "kl_coef": 0.0})
config.model_copy(
update={"learning_rate": 1e-9, "beta": 0.0, "kl_penalty_coef": 0.0}
)
if warmup
else config
),
Expand Down
Loading