From b64991fdee2b74770ab8aa0791c84efcae90c197 Mon Sep 17 00:00:00 2001 From: arcticfly Date: Thu, 12 Feb 2026 16:36:00 -0800 Subject: [PATCH 1/2] Add KL-penalized advantage adjustment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces a new mechanism that adjusts per-token advantages based on KL divergence from a reference model. Tokens where the policy has drifted more get reduced advantages, while tokens that drifted less get increased advantages. The adjustment is zero-mean (centered) across tokens. New parameters on LocalBackend.train(): - kl_penalty_coef: coefficient for the adjustment (0.0 = disabled) - kl_penalty_reference_step: use a specific checkpoint step as reference - kl_ref_adapter_path: use an arbitrary LoRA adapter path as reference Also fixes a pre-existing bug in preprocessing/inputs.py where warmup config used incorrect field names (lr → learning_rate, kl_coef → kl_penalty_coef). Co-Authored-By: Claude Opus 4.6 --- dev/run_yes_no_maybe_kl_advantage.py | 104 +++++++++++++++++++++++++ dev/yes-no-maybe-kl-advantage.py | 108 ++++++++++++++++++++++++++ src/art/dev/train.py | 2 + src/art/local/backend.py | 28 ++++++- src/art/loss.py | 10 +++ src/art/preprocessing/inputs.py | 4 +- src/art/test/test_kl_advantage.py | 112 +++++++++++++++++++++++++++ src/art/types.py | 1 + src/art/unsloth/train.py | 52 ++++++++++--- 9 files changed, 407 insertions(+), 14 deletions(-) create mode 100644 dev/run_yes_no_maybe_kl_advantage.py create mode 100644 dev/yes-no-maybe-kl-advantage.py create mode 100644 src/art/test/test_kl_advantage.py diff --git a/dev/run_yes_no_maybe_kl_advantage.py b/dev/run_yes_no_maybe_kl_advantage.py new file mode 100644 index 000000000..c19b0a54c --- /dev/null +++ b/dev/run_yes_no_maybe_kl_advantage.py @@ -0,0 +1,104 @@ +"""Launch yes-no-maybe-kl-advantage training on SkyPilot (Kubernetes). + +Usage: + uv run dev/run_yes_no_maybe_kl_advantage.py + uv run dev/run_yes_no_maybe_kl_advantage.py --fast + uv run dev/run_yes_no_maybe_kl_advantage.py --base-model Qwen/Qwen2.5-7B-Instruct +""" + +import argparse +import os +import textwrap + +import sky +from dotenv import load_dotenv +from sky import ClusterStatus + +load_dotenv() + +parser = argparse.ArgumentParser( + description="Launch yes-no-maybe KL advantage training on SkyPilot." +) +parser.add_argument( + "--fast", action="store_true", help="Skip setup (for re-runs on existing cluster)." +) +parser.add_argument( + "--base-model", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct" +) +parser.add_argument("--num-steps", type=int, default=20) +parser.add_argument("--kl-penalty-coef", type=float, default=0.1) +parser.add_argument("--accelerator", type=str, default="H200:1") +parser.add_argument("--cluster-name", type=str, default=None) +parser.add_argument("--kl-ref-step", type=int, default=None, help="Checkpoint step of training model to use as KL reference") +parser.add_argument("--kl-ref-adapter-path", type=str, default=None, help="Path to LoRA adapter checkpoint to use as KL reference") +args = parser.parse_args() + +cluster_name = args.cluster_name or f"ynm-kl-{args.kl_penalty_coef}" +cluster_prefix = os.environ.get("CLUSTER_PREFIX") +if cluster_prefix: + cluster_name = f"{cluster_prefix}-{cluster_name}" + +setup_script = textwrap.dedent("""\ + echo 'Setting up environment...' + apt install -y nvtop + curl -LsSf https://astral.sh/uv/install.sh | sh + source $HOME/.local/bin/env +""") + +kl_ref_env = "" +if args.kl_ref_step is not None: + kl_ref_env = f"KL_REF_STEP={args.kl_ref_step} " +elif args.kl_ref_adapter_path is not None: + kl_ref_env = f"KL_REF_ADAPTER_PATH={args.kl_ref_adapter_path} " + +run_script = textwrap.dedent(f"""\ + source $HOME/.local/bin/env + cd ~/sky_workdir + {kl_ref_env}BASE_MODEL={args.base_model} NUM_STEPS={args.num_steps} KL_PENALTY_COEF={args.kl_penalty_coef} uv run --python 3.11 --extra backend dev/yes-no-maybe-kl-advantage.py +""") + +task = sky.Task( + name="yes-no-maybe-kl-advantage", + setup=setup_script, + run=run_script, + workdir=".", +) +task.set_resources( + sky.Resources(accelerators=args.accelerator, cloud=sky.clouds.Kubernetes()) +) +task.set_file_mounts( + { + "~/sky_workdir/.env": ".env", + } +) + +print(f"Launching on cluster: {cluster_name}") +print(f" base_model: {args.base_model}") +print(f" accelerator: {args.accelerator}") +print(f" num_steps: {args.num_steps}") +print(f" kl_penalty_coef: {args.kl_penalty_coef}") +if args.kl_ref_step is not None: + print(f" kl_ref_step: {args.kl_ref_step}") +if args.kl_ref_adapter_path is not None: + print(f" kl_ref_adapter_path: {args.kl_ref_adapter_path}") + +# Cancel any existing jobs on this cluster +cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name])) +if len(cluster_status) > 0 and cluster_status[0]["status"] == ClusterStatus.UP: + print(f"Cluster {cluster_name} is UP. Canceling any active jobs...") + sky.stream_and_get(sky.cancel(cluster_name, all=True)) + +job_id, _ = sky.stream_and_get( + sky.launch( + task, + cluster_name=cluster_name, + retry_until_up=True, + idle_minutes_to_autostop=60, + down=True, + fast=args.fast, + ) +) + +print(f"Job submitted (ID: {job_id}). Streaming logs...") +exit_code = sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True) +print(f"Job {job_id} finished with exit code {exit_code}.") diff --git a/dev/yes-no-maybe-kl-advantage.py b/dev/yes-no-maybe-kl-advantage.py new file mode 100644 index 000000000..bf3e5efd8 --- /dev/null +++ b/dev/yes-no-maybe-kl-advantage.py @@ -0,0 +1,108 @@ +"""Yes-no-maybe training with KL-penalized advantage adjustment. + +Demonstrates the kl_penalty_coef feature: tokens where the policy has drifted +more from the reference model get reduced advantages, while tokens that have +drifted less get increased advantages. + +Uses meta-llama/Meta-Llama-3.1-8B-Instruct as the base model (trained locally). +""" + +import asyncio +from itertools import permutations +import os + +from dotenv import load_dotenv +import openai + +import art +from art.local import LocalBackend + + +async def rollout( + client: openai.AsyncOpenAI, model: art.TrainableModel, prompt: str +) -> art.Trajectory: + messages: art.Messages = [ + { + "role": "user", + "content": prompt, + } + ] + chat_completion = await client.chat.completions.create( + messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100 + ) + choice = chat_completion.choices[0] + content = choice.message.content + assert isinstance(content, str) + if content == "yes": + reward = 0.5 + elif content == "no": + reward = 0.75 + elif content == "maybe": + reward = 1.0 + else: + reward = 0.0 + return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward) + + +def with_quotes(w: str) -> str: + return f"'{w}'" + + +async def main(): + load_dotenv() + + backend = LocalBackend() + base_model = os.environ.get("BASE_MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct") + kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1")) + model = art.TrainableModel( + name=os.environ.get("MODEL_NAME", f"kl-{kl_penalty_coef}"), + project="yes-no-maybe", + base_model=base_model, + ) + await model.register(backend) + + kl_penalty_reference_step: int | None = ( + int(os.environ["KL_REF_STEP"]) if os.environ.get("KL_REF_STEP") is not None else None + ) + kl_ref_adapter_path: str | None = os.environ.get("KL_REF_ADAPTER_PATH") or None + + prompts = [ + f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}" + for prefix in ["respond", "just respond"] + for use_quotes in [True, False] + for words in ( + list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n) + ) + ] + + openai_client = model.openai_client() + max_steps = int(os.environ.get("NUM_STEPS", "20")) + start_step = await model.get_step() + for step in range(start_step, start_step + max_steps): + train_groups = await art.gather_trajectory_groups( + ( + art.TrajectoryGroup( + rollout(openai_client, model, prompt) for _ in range(32) + ) + for prompt in prompts + ) + ) + result = await backend.train( + model, + train_groups, + learning_rate=1e-4, + kl_penalty_coef=kl_penalty_coef, + kl_penalty_reference_step=kl_penalty_reference_step, + kl_ref_adapter_path=kl_ref_adapter_path, + ) + await model.log( + train_groups, + metrics=result.metrics, + step=result.step, + split="train", + ) + print(f"step {result.step}: {result.metrics}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/art/dev/train.py b/src/art/dev/train.py index bd4150740..55a749047 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -17,6 +17,8 @@ class TrainConfig(TypedDict, total=False): "token", "sequence", "average", "geometric_average" ] kimi_k2_tau: float | None + kl_penalty_coef: float + kl_ref_adapter_path: str | None logprob_calculation_chunk_size: int mask_prob_ratio: bool max_negative_advantage_importance_sampling_weight: float diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 11ac1111c..261bab087 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -387,6 +387,10 @@ async def train( # type: ignore[override] # Core training parameters learning_rate: float = 5e-6, beta: float = 0.0, + # KL-penalized advantage adjustment + kl_penalty_coef: float = 0.0, + kl_penalty_reference_step: int | None = None, + kl_ref_adapter_path: str | None = None, # RL algorithm settings ppo: bool = False, epsilon: float | None = None, @@ -425,7 +429,16 @@ async def train( # type: ignore[override] model: The trainable model to train. trajectory_groups: Batches of trajectories to train on. learning_rate: Learning rate for training. Defaults to 5e-6. - beta: KL penalty coefficient. Defaults to 0.0. + beta: KL penalty coefficient added to the loss. Defaults to 0.0. + kl_penalty_coef: Coefficient for KL-penalized advantage adjustment. + Tokens diverging more from the reference get reduced advantages. + Defaults to 0.0 (disabled). + kl_penalty_reference_step: Checkpoint step of the training model to + use as the KL reference. If None, uses the base model (LoRA + disabled) as reference. + kl_ref_adapter_path: Direct filesystem path to a LoRA adapter + checkpoint to use as the KL reference. Alternative to + kl_penalty_reference_step. ppo: Whether to use PPO clipping. Defaults to False. epsilon: Clip epsilon for importance sampling. Defaults based on ppo. epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. @@ -472,11 +485,14 @@ async def train( # type: ignore[override] groups_list = list(trajectory_groups) # Build config objects from explicit kwargs - config = TrainConfig(learning_rate=learning_rate, beta=beta) + config = TrainConfig( + learning_rate=learning_rate, beta=beta, kl_penalty_coef=kl_penalty_coef + ) dev_config: dev.TrainConfig = { "advantage_balance": advantage_balance, "allow_training_without_logprobs": allow_training_without_logprobs, "importance_sampling_level": importance_sampling_level, + "kl_penalty_coef": kl_penalty_coef, "mask_prob_ratio": mask_prob_ratio, "plot_tensors": plot_tensors, "ppo": ppo, @@ -499,6 +515,14 @@ async def train( # type: ignore[override] dev_config["kimi_k2_tau"] = kimi_k2_tau if truncated_importance_sampling is not None: dev_config["truncated_importance_sampling"] = truncated_importance_sampling + if kl_ref_adapter_path is not None: + dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path + elif kl_penalty_reference_step is not None: + ref_checkpoint_dir = get_step_checkpoint_dir( + get_model_dir(model=model, art_path=self._path), + kl_penalty_reference_step, + ) + dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir # Collect metrics from training training_metrics: list[dict[str, float]] = [] diff --git a/src/art/loss.py b/src/art/loss.py index 79154fde9..a22cca3ff 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -18,6 +18,7 @@ class Loss(BaseModel): mean_entropy: torch.Tensor | None policy_loss_sum: torch.Tensor probs_corr: torch.Tensor + kl_policy_ref: torch.Tensor | None = None def loss_fn( @@ -92,6 +93,14 @@ def loss_fn( ) if tau := experimental_config.get("kimi_k2_tau", None): advantages -= tau * logprob_diff.detach() + kl_policy_ref: torch.Tensor | None = None + kl_penalty_coef = experimental_config.get("kl_penalty_coef", 0.0) + if kl_penalty_coef > 0 and ref_logprobs is not None: + kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask + avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6) + kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask + advantages = advantages + kl_penalty + kl_policy_ref = avg_kl if ppo: policy_loss = -torch.min( prob_ratio * advantages, @@ -139,6 +148,7 @@ def loss_fn( mean_entropy=mean_entropy, policy_loss_sum=policy_loss.sum(), probs_corr=probs_corr, + kl_policy_ref=kl_policy_ref, ) diff --git a/src/art/preprocessing/inputs.py b/src/art/preprocessing/inputs.py index 996c15f20..9e5a7a54c 100644 --- a/src/art/preprocessing/inputs.py +++ b/src/art/preprocessing/inputs.py @@ -41,7 +41,9 @@ def create_train_inputs( [None] if warmup else packed_tensors["image_grid_thw"][offset : offset + 1] ), config=( - config.model_copy(update={"lr": 1e-9, "beta": 0.0, "kl_coef": 0.0}) + config.model_copy( + update={"learning_rate": 1e-9, "beta": 0.0, "kl_penalty_coef": 0.0} + ) if warmup else config ), diff --git a/src/art/test/test_kl_advantage.py b/src/art/test/test_kl_advantage.py new file mode 100644 index 000000000..8a44489e1 --- /dev/null +++ b/src/art/test/test_kl_advantage.py @@ -0,0 +1,112 @@ +"""Tests for KL-penalized advantage adjustment in loss_fn.""" + +import torch + +from art.loss import loss_fn, Loss + + +def _make_inputs( + batch_size: int = 1, + seq_len: int = 8, + advantages: list[float] | None = None, +): + """Create minimal TrainInputs-like dict for loss_fn.""" + if advantages is None: + advantages = [1.0] * seq_len + adv_tensor = torch.tensor([advantages], dtype=torch.float32) + tokens = torch.zeros(batch_size, seq_len, dtype=torch.long) + logprobs = torch.zeros(batch_size, seq_len) + assistant_mask = torch.ones(batch_size, seq_len, dtype=torch.bool) + # First token is not assistant (shifted) + assistant_mask[:, 0] = False + weights = torch.ones(batch_size, seq_len) + group_ids = torch.ones(batch_size, seq_len, dtype=torch.long) + parent_ids = torch.zeros(batch_size, seq_len, dtype=torch.long) + return { + "tokens": tokens, + "logprobs": logprobs, + "advantages": adv_tensor, + "assistant_mask": assistant_mask, + "weights": weights, + "group_ids": group_ids, + "parent_ids": parent_ids, + } + + +def test_kl_advantage_no_effect_when_disabled(): + """When kl_penalty_coef=0, advantages should not be modified.""" + inputs = _make_inputs() + new_logprobs = torch.zeros(1, 8) + ref_logprobs = torch.full((1, 8), -1.0) # different from new_logprobs + + loss_no_kl = loss_fn(inputs, new_logprobs, ref_logprobs, None, {"kl_penalty_coef": 0.0}) + loss_without_ref = loss_fn(inputs, new_logprobs, None, None, {}) + + assert loss_no_kl.kl_policy_ref is None + assert loss_without_ref.kl_policy_ref is None + + +def test_kl_advantage_enabled(): + """When kl_penalty_coef>0 and ref_logprobs provided, kl_policy_ref should be set.""" + inputs = _make_inputs() + new_logprobs = torch.zeros(1, 8) + ref_logprobs = torch.full((1, 8), -0.5) + + loss = loss_fn(inputs, new_logprobs, ref_logprobs, None, {"kl_penalty_coef": 0.1}) + + assert loss.kl_policy_ref is not None + assert loss.kl_policy_ref.item() > 0 # KL should be positive when logprobs differ + + +def test_kl_advantage_zero_mean_penalty(): + """The KL penalty should be zero-mean across assistant tokens.""" + inputs = _make_inputs(seq_len=16) + # Create varying logprobs to produce non-uniform KL + new_logprobs = torch.randn(1, 16) * 0.5 + ref_logprobs = torch.randn(1, 16) * 0.5 + + kl_penalty_coef = 0.1 + assistant_mask = torch.nn.functional.pad( + inputs["assistant_mask"][:, 1:].float(), (0, 1), value=0.0 + ) + + # Compute what the penalty should be + kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask + avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6) + kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask + + # Sum of penalty across tokens should be ~0 + penalty_sum = kl_penalty.sum().item() + assert abs(penalty_sum) < 1e-5, f"Penalty sum should be ~0, got {penalty_sum}" + + +def test_kl_advantage_direction(): + """Tokens with higher KL (more drift) should get reduced advantages.""" + # Create inputs where token 2 has high drift and token 5 has low drift + seq_len = 8 + inputs = _make_inputs(seq_len=seq_len, advantages=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]) + new_logprobs = torch.zeros(1, seq_len) + ref_logprobs = torch.zeros(1, seq_len) + + # Make token at position 2 (which after shifting = position 1 in shifted space) + # have high divergence + new_logprobs[0, 2] = 0.0 + ref_logprobs[0, 2] = -2.0 # large gap = high KL + + # Token at position 5 has low divergence + new_logprobs[0, 5] = -0.1 + ref_logprobs[0, 5] = -0.1 # no gap = low KL + + loss = loss_fn(inputs, new_logprobs, ref_logprobs, None, {"kl_penalty_coef": 1.0}) + + # The metric should exist + assert loss.kl_policy_ref is not None + + +def test_kl_advantage_does_not_affect_when_no_ref(): + """When ref_logprobs is None, kl_penalty_coef should have no effect.""" + inputs = _make_inputs() + new_logprobs = torch.zeros(1, 8) + + loss = loss_fn(inputs, new_logprobs, None, None, {"kl_penalty_coef": 0.5}) + assert loss.kl_policy_ref is None diff --git a/src/art/types.py b/src/art/types.py index df81d6842..c2ccd3bf4 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -17,6 +17,7 @@ class TrainConfig(pydantic.BaseModel): learning_rate: float = 5e-6 beta: float = 0.0 + kl_penalty_coef: float = 0.0 Verbosity = Literal[0, 1, 2] diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index e5d229537..34dbc5cdb 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -1,6 +1,6 @@ import asyncio from collections import defaultdict -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext import gc import os from typing import TYPE_CHECKING, Callable, cast @@ -138,7 +138,8 @@ def compute_loss( ) if return_new_logprobs: return torch.nn.functional.pad(new_logprobs[:, :-1], (1, 0), value=0.0) - if config.beta > 0.0: + if config.beta > 0.0 or config.kl_penalty_coef > 0.0: + ref_adapter = _config.get("kl_ref_adapter_path") ref_logprobs, _ = calculate_logprobs( dtype_for_autocasting, trainer, @@ -148,9 +149,13 @@ def compute_loss( next_input_ids, lm_head_t, chunk_size=chunk_size, - inference_mode=True, - no_grad=False, + # Can't use inference_mode with a custom adapter — inference + # tensors don't track version counters, which breaks unsloth's + # LoRA kernels. Use no_grad instead. + inference_mode=ref_adapter is None, + no_grad=ref_adapter is not None, reference_logprobs=True, + reference_adapter_name=ref_adapter, ) else: ref_logprobs = None @@ -170,6 +175,8 @@ def compute_loss( trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) if config.beta > 0.0: trainer._metrics["train"]["kl_div"].append(loss.mean_kl.item()) + if loss.kl_policy_ref is not None: + trainer._metrics["train"]["kl_policy_ref"].append(loss.kl_policy_ref.item()) return loss.mean_policy_loss + config.beta * loss.mean_kl return compute_loss @@ -248,6 +255,26 @@ def calculate_mask( return mask +@contextmanager +def _use_adapter(trainer: "GRPOTrainer", adapter_path: str): + """Context manager that switches to a named LoRA adapter, then restores the original.""" + # Sanitize the path to a valid module name (no dots allowed by PyTorch) + safe_name = adapter_path.replace(".", "_").replace("/", "_") + peft_model = trainer.accelerator.unwrap_model( + trainer.model, keep_fp32_wrapper=False + ) + if safe_name not in peft_model.peft_config: + peft_model.load_adapter(adapter_path, adapter_name=safe_name) + previous_adapter = peft_model.active_adapter + if isinstance(previous_adapter, list): + previous_adapter = previous_adapter[0] + peft_model.set_adapter(safe_name) + try: + yield + finally: + peft_model.set_adapter(previous_adapter) + + def calculate_logprobs( dtype_for_autocast: torch.dtype, trainer: "GRPOTrainer", @@ -260,19 +287,22 @@ def calculate_logprobs( inference_mode: bool, no_grad: bool, reference_logprobs: bool, + reference_adapter_name: str | None = None, ) -> tuple[ torch.Tensor, torch.Tensor ]: # Returns (log_probs, entropy) both shape [B, S] + if reference_logprobs and reference_adapter_name is not None: + adapter_ctx = _use_adapter(trainer, reference_adapter_name) + elif reference_logprobs: + adapter_ctx = trainer.accelerator.unwrap_model( + trainer.model, keep_fp32_wrapper=False + ).disable_adapter() + else: + adapter_ctx = nullcontext() with ( torch.inference_mode() if inference_mode else nullcontext(), torch.no_grad() if no_grad else nullcontext(), - ( - trainer.accelerator.unwrap_model( - trainer.model, keep_fp32_wrapper=False - ).disable_adapter() - if reference_logprobs - else nullcontext() - ), + adapter_ctx, torch.amp.autocast_mode.autocast(device_type="cuda", dtype=dtype_for_autocast), ): hidden_states = trainer.model( # type: ignore From 99ee7a4b840a5fb5819ff57cbd097109f8594146 Mon Sep 17 00:00:00 2001 From: arcticfly Date: Sun, 15 Feb 2026 01:29:06 -0700 Subject: [PATCH 2/2] Fix import sorting and formatting Co-Authored-By: Claude Opus 4.6 --- dev/run_yes_no_maybe_kl_advantage.py | 16 +++++++++++++--- dev/yes-no-maybe-kl-advantage.py | 4 +++- src/art/test/test_kl_advantage.py | 10 +++++++--- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/dev/run_yes_no_maybe_kl_advantage.py b/dev/run_yes_no_maybe_kl_advantage.py index c19b0a54c..c6601be03 100644 --- a/dev/run_yes_no_maybe_kl_advantage.py +++ b/dev/run_yes_no_maybe_kl_advantage.py @@ -10,8 +10,8 @@ import os import textwrap -import sky from dotenv import load_dotenv +import sky from sky import ClusterStatus load_dotenv() @@ -29,8 +29,18 @@ parser.add_argument("--kl-penalty-coef", type=float, default=0.1) parser.add_argument("--accelerator", type=str, default="H200:1") parser.add_argument("--cluster-name", type=str, default=None) -parser.add_argument("--kl-ref-step", type=int, default=None, help="Checkpoint step of training model to use as KL reference") -parser.add_argument("--kl-ref-adapter-path", type=str, default=None, help="Path to LoRA adapter checkpoint to use as KL reference") +parser.add_argument( + "--kl-ref-step", + type=int, + default=None, + help="Checkpoint step of training model to use as KL reference", +) +parser.add_argument( + "--kl-ref-adapter-path", + type=str, + default=None, + help="Path to LoRA adapter checkpoint to use as KL reference", +) args = parser.parse_args() cluster_name = args.cluster_name or f"ynm-kl-{args.kl_penalty_coef}" diff --git a/dev/yes-no-maybe-kl-advantage.py b/dev/yes-no-maybe-kl-advantage.py index bf3e5efd8..41ce0b119 100644 --- a/dev/yes-no-maybe-kl-advantage.py +++ b/dev/yes-no-maybe-kl-advantage.py @@ -62,7 +62,9 @@ async def main(): await model.register(backend) kl_penalty_reference_step: int | None = ( - int(os.environ["KL_REF_STEP"]) if os.environ.get("KL_REF_STEP") is not None else None + int(os.environ["KL_REF_STEP"]) + if os.environ.get("KL_REF_STEP") is not None + else None ) kl_ref_adapter_path: str | None = os.environ.get("KL_REF_ADAPTER_PATH") or None diff --git a/src/art/test/test_kl_advantage.py b/src/art/test/test_kl_advantage.py index 8a44489e1..d944efc62 100644 --- a/src/art/test/test_kl_advantage.py +++ b/src/art/test/test_kl_advantage.py @@ -2,7 +2,7 @@ import torch -from art.loss import loss_fn, Loss +from art.loss import Loss, loss_fn def _make_inputs( @@ -39,7 +39,9 @@ def test_kl_advantage_no_effect_when_disabled(): new_logprobs = torch.zeros(1, 8) ref_logprobs = torch.full((1, 8), -1.0) # different from new_logprobs - loss_no_kl = loss_fn(inputs, new_logprobs, ref_logprobs, None, {"kl_penalty_coef": 0.0}) + loss_no_kl = loss_fn( + inputs, new_logprobs, ref_logprobs, None, {"kl_penalty_coef": 0.0} + ) loss_without_ref = loss_fn(inputs, new_logprobs, None, None, {}) assert loss_no_kl.kl_policy_ref is None @@ -84,7 +86,9 @@ def test_kl_advantage_direction(): """Tokens with higher KL (more drift) should get reduced advantages.""" # Create inputs where token 2 has high drift and token 5 has low drift seq_len = 8 - inputs = _make_inputs(seq_len=seq_len, advantages=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]) + inputs = _make_inputs( + seq_len=seq_len, advantages=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0] + ) new_logprobs = torch.zeros(1, seq_len) ref_logprobs = torch.zeros(1, seq_len)