-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Provide an OpenAI Client training example with reinforcement learning #435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # OpenAI Client Example | ||
|
|
||
| This is a minimal example demonstrating how to use the OpenAI client to query an LLM endpoint and train a model with reinforcement learning using `verl`. | ||
|
|
||
| The dataset used is **GSM8K**, and the model is **Qwen2.5-1.5B-Instruct**. | ||
| The script can be run on a single **A100 80GB GPU**. | ||
|
|
||
|
|
||
| ## Quick Start | ||
|
|
||
| First, start a Ray cluster with the following command. | ||
| Replace `XXXXX` with your own Weights & Biases (wandb) API key. | ||
|
|
||
| ```bash | ||
| ray stop | ||
| env WANDB_API_KEY=XXXXX RAY_DEBUG=legacy HYDRA_FULL_ERROR=1 VLLM_USE_V1=1 ray start --head --dashboard-host=0.0.0.0 | ||
|
||
| ``` | ||
|
|
||
| Then start the training: | ||
|
|
||
| ```bash | ||
| python train.py | ||
| ``` | ||
|
|
||
| All LLM queries made by `gsm8k_agent` will be automatically recorded and used for training with the emitted rewards. | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,162 @@ | ||||||||||||||
| import random | ||||||||||||||
| import re | ||||||||||||||
| from typing import TypedDict, cast | ||||||||||||||
|
|
||||||||||||||
| from datasets import load_dataset | ||||||||||||||
| from openai import AsyncOpenAI | ||||||||||||||
|
|
||||||||||||||
| import agentlightning as agl | ||||||||||||||
|
|
||||||||||||||
| verl_config = { | ||||||||||||||
| "algorithm": { | ||||||||||||||
| "adv_estimator": "grpo", | ||||||||||||||
| "use_kl_in_reward": False, | ||||||||||||||
| }, | ||||||||||||||
| "data": { | ||||||||||||||
| "train_batch_size": 64, | ||||||||||||||
| "max_prompt_length": 512, | ||||||||||||||
| "max_response_length": 1024, | ||||||||||||||
| }, | ||||||||||||||
| "actor_rollout_ref": { | ||||||||||||||
| "rollout": { | ||||||||||||||
| "tensor_model_parallel_size": 1, | ||||||||||||||
| "n": 8, | ||||||||||||||
| "log_prob_micro_batch_size_per_gpu": 4, | ||||||||||||||
| "multi_turn": {"format": "hermes"}, | ||||||||||||||
| "name": "vllm", | ||||||||||||||
| "gpu_memory_utilization": 0.6, | ||||||||||||||
| "engine_kwargs": { | ||||||||||||||
| "vllm": { | ||||||||||||||
| "enable_auto_tool_choice": True, | ||||||||||||||
| "tool_call_parser": "hermes", | ||||||||||||||
| } | ||||||||||||||
| }, | ||||||||||||||
| }, | ||||||||||||||
| "actor": { | ||||||||||||||
| "ppo_mini_batch_size": 32, | ||||||||||||||
| "ppo_micro_batch_size_per_gpu": 8, | ||||||||||||||
| "optim": {"lr": 1e-6}, | ||||||||||||||
| "use_kl_loss": False, | ||||||||||||||
| "kl_loss_coef": 0.0, | ||||||||||||||
| "entropy_coeff": 0, | ||||||||||||||
| "clip_ratio_low": 0.2, | ||||||||||||||
| "clip_ratio_high": 0.28, | ||||||||||||||
| "fsdp_config": { | ||||||||||||||
| "param_offload": True, | ||||||||||||||
| "optimizer_offload": True, | ||||||||||||||
| }, | ||||||||||||||
| }, | ||||||||||||||
| "ref": { | ||||||||||||||
| "log_prob_micro_batch_size_per_gpu": 8, | ||||||||||||||
| "fsdp_config": {"param_offload": True}, | ||||||||||||||
| }, | ||||||||||||||
| "model": { | ||||||||||||||
| "path": "Qwen/Qwen2.5-1.5B-Instruct", | ||||||||||||||
| "use_remove_padding": True, | ||||||||||||||
| "enable_gradient_checkpointing": True, | ||||||||||||||
| }, | ||||||||||||||
| }, | ||||||||||||||
| "trainer": { | ||||||||||||||
| "n_gpus_per_node": 1, | ||||||||||||||
| "val_before_train": True, | ||||||||||||||
| "critic_warmup": 0, | ||||||||||||||
| "logger": ["console", "wandb"], | ||||||||||||||
| "project_name": "AgentLightning", | ||||||||||||||
| "experiment_name": "mini_rl_gsm8k", | ||||||||||||||
| "nnodes": 1, | ||||||||||||||
| "save_freq": 500, | ||||||||||||||
| "test_freq": 25, | ||||||||||||||
| "total_epochs": 2, | ||||||||||||||
| }, | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class Gsm8kProblem(TypedDict): | ||||||||||||||
| question: str | ||||||||||||||
| answer: str | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| prompt_template = """ | ||||||||||||||
| You are given the following question: | ||||||||||||||
|
|
||||||||||||||
| {} | ||||||||||||||
|
|
||||||||||||||
| Please think step by step and put your final answer after ####. | ||||||||||||||
|
|
||||||||||||||
| Output example: | ||||||||||||||
|
|
||||||||||||||
| <thinking process> | ||||||||||||||
| #### <your answer> | ||||||||||||||
| """.strip() | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @agl.rollout | ||||||||||||||
| async def gsm8k_agent(task: Gsm8kProblem, llm: agl.LLM) -> None: | ||||||||||||||
| # Collect llm endpoint information | ||||||||||||||
| # Temperature will be different for rollout and validation. | ||||||||||||||
| model = llm.model | ||||||||||||||
| openai_base_url = llm.endpoint | ||||||||||||||
| temperature = llm.sampling_parameters.get("temperature", 1.0) | ||||||||||||||
|
|
||||||||||||||
| client = AsyncOpenAI( | ||||||||||||||
|
||||||||||||||
| client = AsyncOpenAI( | |
| client = AsyncOpenAI( | |
| # Using a dummy API key is fine here because we are querying a local LLM proxy | |
| # endpoint that does not require a real OpenAI API key. |
Copilot
AI
Dec 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The AsyncOpenAI client is instantiated on every function call. Consider moving client creation outside the function or reusing a single client instance to avoid the overhead of creating new clients repeatedly.
Copilot
AI
Dec 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The regex pattern is compiled on every function call. For better performance, consider defining this pattern as a module-level constant since it's static and reused across all invocations.
Copilot
AI
Dec 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using broad exception catching with 'except Exception as e' can mask unexpected errors and make debugging difficult. Consider catching more specific exceptions (e.g., OpenAI-specific exceptions, network errors) to handle different failure modes appropriately.
Copilot
AI
Dec 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message only prints the exception string without logging which step failed or what input caused the failure. Consider improving the error message to include context such as the question being processed for better debugging.
| print("Failure:", str(e)) | |
| print( | |
| f"Failure while processing question: {task['question']!r}. Error: {e}" | |
| ) | |
| last_message = "None" |
Copilot
AI
Dec 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The regular expression search can return None if the pattern is not found in the ground truth answer, which would cause an AttributeError when calling .group(1). This should be handled similarly to how the answer extraction is handled on lines 118-122.
| gt_answer = re.search(regex_pattern, task["answer"]).group(1) | |
| gt_match = re.search(regex_pattern, task["answer"]) | |
| if gt_match: | |
| gt_answer = gt_match.group(1) | |
| else: | |
| gt_answer = task["answer"] |
Copilot
AI
Dec 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable 'last_message' may not be defined if an exception occurs before line 116. When the exception is caught on line 123, the code on line 140 will reference 'last_message' in the print statement, potentially causing an UnboundLocalError.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reference to 'verl' in the description is lowercase, but it appears to be a proper noun referring to the VERL reinforcement learning framework. Consider using consistent capitalization (VERL) throughout the documentation.