Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions examples/ultrarag_search_r1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# UltraRAG Search-R1 Example

## Overview
This example trains a Search-R1 style agent using the UltraRAG pipeline inside Agent Lightning. It reuses the `examples/search_r1` dataset and shows how to run end-to-end RL with Ray + vLLM.

## Included Files
| File/Directory | Description |
| --- | --- |
| `train.sh` | Launch RL training (Ray + vLLM) |
| `ultrarag_adapter.py` | UltraRAG-aware agent adapter |
| `search_r1_rl.yaml` | UltraRAG pipeline config for RL |
| `search_r1_rl_parameter.yaml` | UltraRAG parameter config |
| `requirements-ultrarag.txt` | Notes on installing deps via groups |

---

## Prepare Environment
From repo root:
```bash
uv pip install -e . --group torch-gpu-stable --group ultrarag
```
Data: expected under `examples/search_r1/data` (train/val parquet).
Base model: set `BASE_MODEL` (e.g., Llama-3.2-3B-Instruct).

---

## Run Training
1) Start Ray
```bash
bash scripts/restart_ray.sh
```
2) Run training
```bash
cd examples/ultrarag_search_r1
bash train.sh
```
Env overrides: `BASE_MODEL`, `DATA_DIR`, `RAY_ADDRESS`, `CUDA_VISIBLE_DEVICES`, `VLLM_PORT`, etc.

Optional sanity check (adapter import only):
```bash
cd examples/ultrarag_search_r1
python ultrarag_adapter.py
```

---

## Notes
- Validation runs before training and every `test_freq` steps (see `train.sh`).
- Checkpoints and validation results are written under `checkpoints/ultrarag_search_r1_checkpoints/`.
134 changes: 134 additions & 0 deletions examples/ultrarag_search_r1/qa_em.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) Microsoft. All rights reserved.

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import re
import string
from typing import Mapping, Optional, Sequence, Union


def normalize_answer(s: str) -> str:
"""Lowercase, remove punctuation/articles, and normalize whitespace."""

def remove_articles(text: str) -> str:
return re.sub(r"\b(a|an|the)\b", " ", text)

def white_space_fix(text: str) -> str:
return " ".join(text.split())

def remove_punc(text: str) -> str:
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)

def lower(text: str) -> str:
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def em_check(prediction: str, golden_answers: Union[str, Sequence[str]]) -> int:
if isinstance(golden_answers, str):
golden_answers = [golden_answers]
normalized_prediction = normalize_answer(prediction)
score = 0
for golden_answer in golden_answers:
golden_answer = normalize_answer(golden_answer)
if golden_answer == normalized_prediction:
score = 1
break
return score


def subem_check(prediction: str, golden_answers: Union[str, Sequence[str]]) -> int:
if isinstance(golden_answers, str):
golden_answers = [golden_answers]
normalized_prediction = normalize_answer(prediction)
score = 0
for golden_answer in golden_answers:
golden_answer = normalize_answer(golden_answer)
if golden_answer in normalized_prediction:
score = 1
break
return score


def extract_solution(solution_str: str) -> Optional[str]:
"""Extract the last <answer>...</answer> span from a solution string.

Returns None if fewer than two such spans are present, to match original behavior.
"""
answer_pattern = r"<answer>(.*?)</answer>"
match_iter = re.finditer(answer_pattern, solution_str, re.DOTALL)
matches = list(match_iter)

# If there are 0 or exactly 1 matches, return None
if len(matches) <= 1:
return None

# If there are 2 or more matches, return the last one
return matches[-1].group(1).strip()


def compute_score_em(
solution_str: str,
ground_truth: Union[str, Sequence[str]],
method: str = "strict",
format_score: float = 0.0,
score: float = 1.0,
) -> float:
"""Scoring function for exact match (EM)."""
answer = extract_solution(solution_str=solution_str)
do_print = random.randint(1, 64) == 1

if do_print:
print(f"--------------------------------")
print(f"Golden answers: {ground_truth}")
print(f"Extracted answer: {answer}")
print(f"Solution string: {solution_str}")

if answer is None:
return 0.0
else:
if em_check(answer, ground_truth):
return score
else:
return format_score


def compute_score_subem(
solution_str: str,
ground_truth: Mapping[str, Union[str, Sequence[str]]],
method: str = "strict",
format_score: float = 0.0,
score: float = 1.0,
) -> float:
"""Scoring function for substring exact match (EM)."""
answer = extract_solution(solution_str=solution_str)
do_print = random.randint(1, 64) == 1

if do_print:
print(f"--------------------------------")
print(f"Golden answers: {ground_truth['target']}")
print(f"Extracted answer: {answer}")
print(f"Solution string: {solution_str}")

if answer is None:
return 0.0
else:
if subem_check(answer, ground_truth["target"]):
return score
else:
return format_score
55 changes: 55 additions & 0 deletions examples/ultrarag_search_r1/search_r1_rl_parameter.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
custom: {}
generation:
backend: vllm
backend_configs:
vllm:
model_name_or_path: /path/to/model
trust_remote_code: true
dtype: auto
gpu_ids: 0
gpu_memory_utilization: 0.4
sampling_params:
max_tokens: 4096
temperature: 0.7
top_p: 0.8
system_prompt: ''
prompt:
template: prompt/qa_boxed.jinja
r1_searcher_gen_template: prompt/r1_searcher_append.jinja
retriever:
backend: sentence_transformers
backend_configs:
sentence_transformers:
trust_remote_code: true
sentence_transformers_encode:
encode_chunk_size: 256
normalize_embeddings: false
psg_prompt_name: document
psg_task: null
q_prompt_name: query
q_task: null
openai:
api_key: ''
base_url: https://api.openai.com/v1
model_name: text-embedding-3-small
infinity:
bettertransformer: false
model_warmup: false
pooling_method: auto
trust_remote_code: true
bm25:
lang: en
save_path: index/bm25
batch_size: 8
corpus_path: data/wiki18_ultra.jsonl
gpu_ids: 0
index_backend: faiss
index_backend_configs:
faiss:
index_chunk_size: 256
index_path: data/e5_Flat.index
index_use_gpu: true
is_multimodal: false
model_name_or_path: intfloat/e5-base-v2
query_instruction: ''
top_k: 5
31 changes: 31 additions & 0 deletions examples/ultrarag_search_r1/search_r1_rl_server.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Search-R1 RL Training Pipeline (no benchmark/eval)

# MCP Server
servers:
generation: servers/generation
retriever: servers/retriever
prompt: servers/prompt
router: servers/router
custom: servers/custom

# MCP Client Pipeline (RL mode)
pipeline:
- retriever.retriever_init
- generation.generation_init
- prompt.qa_boxed
- generation.generate
- loop:
times: 8
steps:
- branch:
router:
- router.search_r1_check
branches:
incomplete:
- custom.search_r1_query_extract
- retriever.retriever_search:
input:
query_list: extract_query_list
- prompt.search_r1_gen
- generation.generate
complete: []
74 changes: 74 additions & 0 deletions examples/ultrarag_search_r1/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#!/bin/bash
set -e

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
AGL_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"

if [ -f "${AGL_ROOT}/.venv/bin/python" ]; then
PYTHON="${AGL_ROOT}/.venv/bin/python"
echo "Using uv virtual environment: ${PYTHON}"
else
PYTHON="python"
echo "Warning: uv virtual environment not found at ${AGL_ROOT}/.venv/bin/python"
echo "Using system python. Make sure all dependencies are installed."
fi

export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-3,4,5}
export N_GPUS=${N_GPUS:-2}
export VLLM_PORT=${VLLM_PORT:-8001}
export VLLM_HOST=${VLLM_HOST:-127.0.0.1}
export PYTORCH_CUDA_ALLOC_CONF=${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}

export BASE_MODEL=${BASE_MODEL:-/path/to/llama-3.2-3b-instruct}
export DATA_DIR=${DATA_DIR:-${SCRIPT_DIR}/../search_r1/data}
export EXPERIMENT_NAME=${EXPERIMENT_NAME:-ultrarag_search_r1}
export PROJECT_NAME=${PROJECT_NAME:-AgentLightning-ultrarag}

echo "Using GPUs: $CUDA_VISIBLE_DEVICES"
echo "Number of GPUs: $N_GPUS"
echo "Data dir: $DATA_DIR"
echo "Base model: $BASE_MODEL"

cd "${SCRIPT_DIR}"
PYTHONPATH="${SCRIPT_DIR}" ${PYTHON} -m agentlightning.verl \
algorithm.adv_estimator=grpo \
data.train_files=${DATA_DIR}/train.parquet \
data.val_files=${DATA_DIR}/test_100.parquet \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
trainer.n_gpus_per_node=${N_GPUS} \
data.train_batch_size=32 \
actor_rollout_ref.rollout.n=2 \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.model.path=${BASE_MODEL} \
data.max_prompt_length=4096 \
data.max_response_length=4096 \
data.truncation='error' \
trainer.val_before_train=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.95 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_kl_loss=false \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.clip_ratio_low=0.2 \
actor_rollout_ref.actor.clip_ratio_high=0.3 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.default_local_dir=checkpoints/ultrarag_search_r1_checkpoints/$EXPERIMENT_NAME \
trainer.project_name=${PROJECT_NAME} \
trainer.experiment_name=${EXPERIMENT_NAME} \
trainer.nnodes=1 \
trainer.save_freq=10 \
trainer.test_freq=10 \
trainer.total_epochs=15 \
trainer.total_training_steps=300
Loading