Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions examples/openai_client/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# OpenAI Client Example

This is a minimal example demonstrating how to use the OpenAI client to query an LLM endpoint and train a model with reinforcement learning using `verl`.
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference to 'verl' in the description is lowercase, but it appears to be a proper noun referring to the VERL reinforcement learning framework. Consider using consistent capitalization (VERL) throughout the documentation.

Suggested change
This is a minimal example demonstrating how to use the OpenAI client to query an LLM endpoint and train a model with reinforcement learning using `verl`.
This is a minimal example demonstrating how to use the OpenAI client to query an LLM endpoint and train a model with reinforcement learning using `VERL`.

Copilot uses AI. Check for mistakes.

The dataset used is **GSM8K**, and the model is **Qwen2.5-1.5B-Instruct**.
The script can be run on a single **A100 80GB GPU**.


## Quick Start

First, start a Ray cluster with the following command.
Replace `XXXXX` with your own Weights & Biases (wandb) API key.

```bash
ray stop
env WANDB_API_KEY=XXXXX RAY_DEBUG=legacy HYDRA_FULL_ERROR=1 VLLM_USE_V1=1 ray start --head --dashboard-host=0.0.0.0
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ray start command uses --dashboard-host=0.0.0.0, which exposes the Ray dashboard on all network interfaces and can allow anyone on the network to access and control the Ray cluster (including running arbitrary code) if additional protections are not in place. In environments where this command is copied as-is (e.g., shared clusters or cloud VMs), this creates a real risk of remote compromise. Consider binding the dashboard to 127.0.0.1 by default or explicitly documenting that 0.0.0.0 should only be used behind proper network access controls (e.g., firewall, SSH tunnel, or authenticated proxy).

Copilot uses AI. Check for mistakes.
```

Then start the training:

```bash
python train.py
```

All LLM queries made by `gsm8k_agent` will be automatically recorded and used for training with the emitted rewards.
162 changes: 162 additions & 0 deletions examples/openai_client/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import random
import re
from typing import TypedDict, cast

from datasets import load_dataset
from openai import AsyncOpenAI

import agentlightning as agl

verl_config = {
"algorithm": {
"adv_estimator": "grpo",
"use_kl_in_reward": False,
},
"data": {
"train_batch_size": 64,
"max_prompt_length": 512,
"max_response_length": 1024,
},
"actor_rollout_ref": {
"rollout": {
"tensor_model_parallel_size": 1,
"n": 8,
"log_prob_micro_batch_size_per_gpu": 4,
"multi_turn": {"format": "hermes"},
"name": "vllm",
"gpu_memory_utilization": 0.6,
"engine_kwargs": {
"vllm": {
"enable_auto_tool_choice": True,
"tool_call_parser": "hermes",
}
},
},
"actor": {
"ppo_mini_batch_size": 32,
"ppo_micro_batch_size_per_gpu": 8,
"optim": {"lr": 1e-6},
"use_kl_loss": False,
"kl_loss_coef": 0.0,
"entropy_coeff": 0,
"clip_ratio_low": 0.2,
"clip_ratio_high": 0.28,
"fsdp_config": {
"param_offload": True,
"optimizer_offload": True,
},
},
"ref": {
"log_prob_micro_batch_size_per_gpu": 8,
"fsdp_config": {"param_offload": True},
},
"model": {
"path": "Qwen/Qwen2.5-1.5B-Instruct",
"use_remove_padding": True,
"enable_gradient_checkpointing": True,
},
},
"trainer": {
"n_gpus_per_node": 1,
"val_before_train": True,
"critic_warmup": 0,
"logger": ["console", "wandb"],
"project_name": "AgentLightning",
"experiment_name": "mini_rl_gsm8k",
"nnodes": 1,
"save_freq": 500,
"test_freq": 25,
"total_epochs": 2,
},
}


class Gsm8kProblem(TypedDict):
question: str
answer: str


prompt_template = """
You are given the following question:

{}

Please think step by step and put your final answer after ####.

Output example:

<thinking process>
#### <your answer>
""".strip()


@agl.rollout
async def gsm8k_agent(task: Gsm8kProblem, llm: agl.LLM) -> None:
# Collect llm endpoint information
# Temperature will be different for rollout and validation.
model = llm.model
openai_base_url = llm.endpoint
temperature = llm.sampling_parameters.get("temperature", 1.0)

client = AsyncOpenAI(
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoding 'dummy' as the API key could be misleading for users who might expect they need to provide a valid API key. Consider adding a comment explaining why this is acceptable in this context (because it's querying a local endpoint).

Suggested change
client = AsyncOpenAI(
client = AsyncOpenAI(
# Using a dummy API key is fine here because we are querying a local LLM proxy
# endpoint that does not require a real OpenAI API key.

Copilot uses AI. Check for mistakes.
api_key="dummy",
base_url=openai_base_url,
)
Comment on lines +101 to +104
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AsyncOpenAI client is instantiated on every function call. Consider moving client creation outside the function or reusing a single client instance to avoid the overhead of creating new clients repeatedly.

Copilot uses AI. Check for mistakes.
regex_pattern = r"####\s*(.+)(\s*|$)"
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regex pattern is compiled on every function call. For better performance, consider defining this pattern as a module-level constant since it's static and reused across all invocations.

Copilot uses AI. Check for mistakes.

# Query LLM endpoint. All queries will be automatically tracked by LLM proxy
try:
prompt = prompt_template.format(task["question"])
messages = [{"role": "user", "content": prompt}]
response = await client.chat.completions.create(
model=model,
temperature=temperature,
messages=messages,
)
last_message = response.choices[0].message.content

answer = re.search(regex_pattern, last_message)
if answer:
answer = answer.group(1)
else:
answer = last_message
except Exception as e:
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using broad exception catching with 'except Exception as e' can mask unexpected errors and make debugging difficult. Consider catching more specific exceptions (e.g., OpenAI-specific exceptions, network errors) to handle different failure modes appropriately.

Copilot uses AI. Check for mistakes.
print("Failure:", str(e))
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message only prints the exception string without logging which step failed or what input caused the failure. Consider improving the error message to include context such as the question being processed for better debugging.

Suggested change
print("Failure:", str(e))
print(
f"Failure while processing question: {task['question']!r}. Error: {e}"
)
last_message = "None"

Copilot uses AI. Check for mistakes.
last_message = "None"
answer = "None"
gt_answer = re.search(regex_pattern, task["answer"]).group(1)
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regular expression search can return None if the pattern is not found in the ground truth answer, which would cause an AttributeError when calling .group(1). This should be handled similarly to how the answer extraction is handled on lines 118-122.

Suggested change
gt_answer = re.search(regex_pattern, task["answer"]).group(1)
gt_match = re.search(regex_pattern, task["answer"])
if gt_match:
gt_answer = gt_match.group(1)
else:
gt_answer = task["answer"]

Copilot uses AI. Check for mistakes.

# Exact matching for verifiable rewards
if gt_answer == answer:
reward = 1
else:
reward = 0

# This reward will be tracked automatically
agl.emit_reward(reward)

# Log some responses for better clarity
if random.random() < 0.01:
print(
f"--------\nQuestion: {task['question']}\nResponse: {last_message}\nGround Truth: {gt_answer}\nReward: {reward}\n"
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable 'last_message' may not be defined if an exception occurs before line 116. When the exception is caught on line 123, the code on line 140 will reference 'last_message' in the print statement, potentially causing an UnboundLocalError.

Copilot uses AI. Check for mistakes.
)


if __name__ == "__main__":
# Create dataset for training and validation
ds = load_dataset("openai/gsm8k", "main")
train_dataset = cast(agl.Dataset[Gsm8kProblem], ds["train"].to_list())
val_dataset = cast(agl.Dataset[Gsm8kProblem], ds["test"].to_list())

algorithm = agl.VERL(verl_config)
# Number of agents launched in parallel to query the LLM.
# This parameter strongly affects throughput and efficiency:
# higher parallelism improves utilization but increases GPU overhead.
n_runners = 32
# This tracer is a dummy one, as currently tracing is done in the llm proxy part
tracer = agl.OtelTracer()
adapter = agl.LlmProxyTraceToTriplet()
# Set store=None to use managed store
trainer = agl.Trainer(algorithm=algorithm, n_runners=n_runners, store=None, tracer=tracer, adapter=adapter)

trainer.fit(gsm8k_agent, train_dataset, val_dataset=val_dataset)
Loading