-
Notifications
You must be signed in to change notification settings - Fork 727
provide conceptual walkthrough #547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
AakashKumarNain
wants to merge
2
commits into
OpenPipe:main
Choose a base branch
from
AakashKumarNain:update_docs
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,314 @@ | ||
| --- | ||
| title: "ART 101" | ||
| description: "Conceptual walkthrough" | ||
| icon: "academic-cap" | ||
| --- | ||
|
|
||
| This section introduces the minimal conceptual vocabulary you need to train models using GRPO in ART and Serverless API. Serverless aims to provide a far | ||
| better developer experience for experienced practitioners and RL newbies alike. Though we want to simplify the ease of training models with RL, we also do | ||
| not want to make it sound like a black box. This walkthrough aims to give you an overview of the building blocks required to train RL models using Serverless API. | ||
|
|
||
|
|
||
| ## 1. Client-server training loop | ||
|
|
||
| ART's functionality is divided into a client and a server. | ||
|
|
||
| A **client** (your Python process): | ||
| - runs your agent workflow | ||
| - requests completions from an OpenAI-compatible API, | ||
| - stores the interaction into trajectories, | ||
| - computes rewards and submits them for training. | ||
|
|
||
| A **server** (GPU-backed): | ||
| - accepts the requests (a new batch of samples) submitted by the client to train the model | ||
| - trains the models on the samples using GRPO. The state of the server is either initialized from the latest checkpoint or an empty one on first iteration. | ||
| - leverages LoRA to train the model. Inference is blocked while training model on the latest batch | ||
| - saves and loads updated LoRA checkpoints, then unblocks inference. | ||
| - serves the latest checkpoint for inference. | ||
|
|
||
| **Note:** Serverless API uses [W&B Inference]((https://docs.wandb.ai/inference)) cluster. Hence, you would require to set your `WANDB_API_KEY` before using it. | ||
|
|
||
|
|
||
| ## 2. Model - Selection, Setup, and Registration | ||
|
|
||
| Serverless API as of today supports the following models: | ||
| - [OpenPipe Qwen 3 14B Instruct](https://huggingface.co/OpenPipe/Qwen3-14B-Instruct) | ||
| - [Qwen 3 30B A3B Instruct](https://huggingface.co/Qwen/Qwen3-30B-A3B) | ||
|
|
||
| Depending upon the use case and the needs, you can choose either of the model for your workflow. A trainable model is an instance of the `art.TrainableModel(..)` | ||
| class and it takes the following arguments as the inputs: | ||
| - *project*: The workflow will be initialized inside the W&B workspace with this name. | ||
| - *name*: The name of the final fine-tuned model. The scope of this name is similar to the scope of W&B runs within a project. | ||
| - *base_model*: The base model to start from. Choose any of the above two. | ||
|
|
||
| Here is an example demonstrating how to instantiate and register a model with the serverless backend. | ||
|
|
||
| ```python | ||
| import art | ||
| from art.serverless.backend import ServerlessBackend | ||
| import wandb | ||
|
|
||
|
|
||
| PROJECT_NAME = "math-agent" | ||
| MODEL_NAME = "qwen3-math-001" | ||
| BASE_MODEL = "OpenPipe/Qwen3-14B-Instruct" | ||
|
|
||
| # Init your project in wandb | ||
| wandb.init(project=PROJECT_NAME) | ||
|
|
||
| # Set up the trainable model and the training backend in ART | ||
| model = art.TrainableModel( | ||
| name=MODEL_NAME, | ||
| project=PROJECT_NAME, | ||
| base_model=BASE_MODEL, | ||
| ) | ||
|
|
||
| # We will use serverless backend to train our model. | ||
| backend = ServerlessBackend() | ||
|
|
||
| # Register the model with the backend | ||
| await model.register(backend) | ||
| ``` | ||
|
|
||
| ## 3. Trajectory and TrajectoryGroup | ||
|
|
||
| When training a model, we generate rollouts at every step. A rollout in the LLM world is the response or a group of responses generated by a LLM for a given | ||
| input (prompt). A `Trajectory` object contains a set of system, user, and assistant messages that are produced by the model or the agent in a single rollout. | ||
| The messages are stored in a `messages_and_choices` sequence. A trajectory object can also contain additional information like reward assigned to a rollout, | ||
| the training step number as metadata or the optional [additional histories](https://art.openpipe.ai/features/additional-histories). | ||
|
|
||
| A `TrajectoryGroup` object contains a set of `Trajectory` objects corresponding to a single input example. For example, if you are generating four output | ||
| samples per input sample at training step `t`, then each rollout is represented by a `Trajectory` object, and the `TrajectoryGroup` will store all such four | ||
| objects. | ||
|
|
||
| Here is an example showcasing how to store rollouts in a `Trajectory` object. | ||
|
|
||
| ```python | ||
| client = openai.AsyncOpenAI( | ||
| # We need to provide the endpoint where the model is hosted | ||
| base_url=model.inference_base_url, | ||
| api_key=model.inference_api_key, | ||
| ) | ||
|
|
||
| async def rollout(client, model, task_input, step, completion_args): | ||
| """Generates a rollout for a given input. | ||
|
|
||
| Args: | ||
| client: OpenAI compatible client required to generate responses | ||
| from a model endpoint | ||
| model: An instance of the (trainable) model | ||
| task_input: Input text for the model | ||
| step: Current training step | ||
| completion_args: (Optional) Any other completion args compatible | ||
| with the OpenAI client API | ||
|
|
||
| Returns: | ||
| A `Trajectory` instance containing the rollouts, and any other | ||
| metadata | ||
|
|
||
| """ | ||
|
|
||
| traj = art.Trajectory( | ||
| # We assign the actual reward value to a trajectory after evaluating | ||
| # it. This is why it is set to 0 during rollout phase. | ||
| reward=0.0, | ||
| messages_and_choices=[ | ||
| {"role": "system", "content": SYSTEM_PROMPT}, | ||
| {"role": "user", "content": task_input.question} | ||
| ], | ||
| metadata={"step": step,}, | ||
| ) | ||
|
|
||
| response = await client.chat.completions.create( | ||
| model=model.get_inference_name(), | ||
| messages=traj.messages(), # Extract the messages sequence | ||
| **completion_args | ||
| ) | ||
| traj.messages_and_choices.append(response.choices[0]) | ||
| return traj | ||
| ``` | ||
|
|
||
| ## 4. Rewards calculations and assignment | ||
|
|
||
| Once we have generated the trajectories for an input sample, we want to evaluate how good they are and assign them rewards accordingly. Rewards definition | ||
| depends on the use case at hand. The only constraint one needs to be aware of is that RL in LLMs works best with verifiable rewards. There are ongoing research | ||
| efforts to bridge the gap between verifiable and non-verifiable rewards. | ||
|
|
||
| Here is an example of how we can define rewards for our math agent. You can modify these as per your use case. | ||
|
|
||
| ```python | ||
| def normalize_number_str(num): | ||
| """Removes , from number represented as a string. | ||
| e.g. 2,00,000 -> 200000 | ||
| """ | ||
|
|
||
| if num is None: | ||
| return None | ||
| return num.strip().replace(",", "") | ||
|
|
||
| def extract_guess(text): | ||
| """Extract the answer within the solution tags. | ||
| e.g. <solution>34</solution> -> 34 | ||
| """ | ||
|
|
||
| match = re.search(rf"{solution_start}(.*?){solution_end}", text, re.DOTALL) | ||
| if match: | ||
| content = match.group(1).strip() | ||
| nums = re.findall(r"[-+]?\d[\d,]*(?:\.\d+)?", content) | ||
| if nums: | ||
| return normalize_number_str(nums[-1].strip()) | ||
|
|
||
| nums = re.findall(r"[-+]?\d[\d,]*(?:\.\d+)?", text) | ||
| if nums: | ||
| return normalize_number_str(nums[-1].strip()) | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def score_on_format(text): | ||
| """Check the correctness of format in the generated response. | ||
|
|
||
| Here we are interested in ensuring that the model generates | ||
| responses with solution_start and solution_end tags e.g. | ||
| <solution>...</solution>. A correct response gets a reward | ||
| of 1.0 and 0.0 otherwise. | ||
| """ | ||
|
|
||
| has_solution_start = text.count(solution_start) == 1 | ||
| has_solution_end = text.count(solution_end) == 1 | ||
|
|
||
| if has_solution_start and has_solution_end: | ||
| solution_start_pos = text.find(solution_start) | ||
| solution_end_pos = text.find(solution_end) | ||
| if solution_start_pos < solution_end_pos: | ||
| return 1.0 | ||
| return 0.0 | ||
|
|
||
|
|
||
| async def evaluate_rollouts(groups, questions, ground_truths): | ||
| """Evaluates the correctness of guessed answer wrt to ground truth | ||
|
|
||
| Here we evaluate each rollout in a batch against the corresponding | ||
| ground truth. If it matches exactly, we provide a high score say 10.0. | ||
| If it does not match exactly, but is an approximate match, we still | ||
| provide a positive score albeit a lower one. | ||
| """ | ||
|
|
||
| parsed = 0 | ||
| for group in groups: | ||
| for trajectory in group.trajectories: | ||
| input_question = questions[parsed] | ||
| true_answer = ground_truths[parsed] | ||
| response = trajectory.messages_and_choices[-1].message.content | ||
|
|
||
| format_score = score_on_format(response) | ||
| guess = extract_guess(response) | ||
|
|
||
| if guess is None: | ||
| answer_score = 0.0 | ||
| else: | ||
| if (guess == true_answer | ||
| or | ||
| guess.strip() == str(true_answer).strip() | ||
| ): | ||
| answer_score = 10.0 | ||
| else: | ||
| try: | ||
| guess_num = float(normalize_number_str(guess)) | ||
| true_num = float(normalize_number_str(str(true_answer))) | ||
| if true_num != 0: | ||
| ratio = guess_num / true_num | ||
| if 0.9 <= ratio <= 1.1: | ||
| answer_score = 5.0 | ||
| elif 0.8 <= ratio <= 1.2: | ||
| answer_score = 2.0 | ||
| else: | ||
| answer_score = 0.0 | ||
| else: | ||
| answer_score = 0.0 | ||
| except: | ||
| answer_score = 0.0 | ||
|
|
||
| # Once we evaluate a trajectory, we need to update the reward value | ||
| # associated with it. This will be used to update the policy accordingly. | ||
| trajectory.reward = format_score + answer_score | ||
| parsed += 1 | ||
| return groups | ||
| ``` | ||
|
|
||
| ## 5. Training and (Optional) validation loops | ||
|
|
||
| This is the last piece that we need to put together to train our model. We registered our model with the backend, defined our rollout function to generate | ||
| response corresponding to an input sample, defined a function to evaluate a batch of rollouts. A training loop involves the following steps: | ||
|
|
||
| 1. Sample the next batch from the dataset. A batch contains `n` samples. | ||
| 2. For each sample in batch, generate a `TrajectoryGroup` group containing `m` rollouts wrapped in a `Trajectory` object. | ||
| 3. Score the trajectories using the reward function as per the use case. | ||
| 4. Train the model on these scored trajectories | ||
| 5. (Optional) Delete the old checkpoints to save space | ||
|
|
||
| Similar to the training loop, we can kickoff our validation at some interval. For example, you could validate your model performance on your validation | ||
| data after every 50 steps in the training loop. | ||
|
|
||
| Here is an how the training loop looks like when training with ART and Serverless API: | ||
|
|
||
| ```python | ||
| from art.utils import iterate_dataset | ||
|
|
||
| # Create a dataset iterator for your dataset | ||
| iterator = iterate_dataset( | ||
| training_inputs, | ||
| groups_per_step=4, | ||
| num_epochs=1 | ||
| initial_step=0, | ||
| ) | ||
|
|
||
| for batch in data_iterator: | ||
| # 1. Sample the next batch | ||
| print(f"\nTraining step {batch.step}") | ||
| print(f"Batch contains {len(batch.items)} inputs") | ||
|
|
||
| input_questions = [] | ||
| ground_truths = [] | ||
| train_groups = [] | ||
| rollouts_per_group = 4 # Number of rollouts per sample within a batch | ||
|
|
||
| # 2. For each sample in the batch, generate the trajectories/rollouts | ||
| for task_input in batch.items: | ||
| for _ in range(rollouts_per_group): | ||
| input_questions.append(task_input.question) | ||
| ground_truths.append(task_input.answer) | ||
|
|
||
| train_groups.append( | ||
| art.TrajectoryGroup( | ||
| rollout(op_client, model, task_input, batch.step, completion_args=completion_args) | ||
| for _ in range(rollouts_per_group) | ||
| ) | ||
| ) | ||
| finished_groups = await art.gather_trajectory_groups( | ||
| train_groups, | ||
| pbar_desc="Generating rollouts", | ||
| max_exceptions=rollouts_per_group * len(batch.items), | ||
| ) | ||
|
|
||
| # 3. Score the trajectories using the reward functions defined in your environment | ||
| finished_groups = await evaluate_rollouts( | ||
| finished_groups, | ||
| ground_truths=ground_truths, | ||
| questions=input_questions | ||
| ) | ||
|
|
||
| # 4. Train the model on the current set of trajectories | ||
| await model.train(finished_groups, config=art.TrainConfig(learning_rate=3e-5)) | ||
|
|
||
| # 5. Delete old checkpoints | ||
| await model.delete_checkpoints(best_checkpoint_metric="train/reward") | ||
| if batch.step >= max_training_steps: | ||
| print("Training complete!") | ||
| break | ||
| ``` | ||
|
|
||
|
|
||
| ## Next steps | ||
|
|
||
| For more resources and code examples, you can refer to these [notebooks](./notebooks.mdx) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.