diff --git a/examples/ultrarag_search_r1/README.md b/examples/ultrarag_search_r1/README.md
new file mode 100644
index 000000000..43b883a22
--- /dev/null
+++ b/examples/ultrarag_search_r1/README.md
@@ -0,0 +1,49 @@
+# UltraRAG Search-R1 Example
+
+## Overview
+This example trains a Search-R1 style agent using the UltraRAG pipeline inside Agent Lightning. It reuses the `examples/search_r1` dataset and shows how to run end-to-end RL with Ray + vLLM.
+
+## Included Files
+| File/Directory | Description |
+| --- | --- |
+| `train.sh` | Launch RL training (Ray + vLLM) |
+| `ultrarag_adapter.py` | UltraRAG-aware agent adapter |
+| `search_r1_rl.yaml` | UltraRAG pipeline config for RL |
+| `search_r1_rl_parameter.yaml` | UltraRAG parameter config |
+| `requirements-ultrarag.txt` | Notes on installing deps via groups |
+
+---
+
+## Prepare Environment
+From repo root:
+```bash
+uv pip install -e . --group torch-gpu-stable --group ultrarag
+```
+Data: expected under `examples/search_r1/data` (train/val parquet).
+Base model: set `BASE_MODEL` (e.g., Llama-3.2-3B-Instruct).
+
+---
+
+## Run Training
+1) Start Ray
+```bash
+bash scripts/restart_ray.sh
+```
+2) Run training
+```bash
+cd examples/ultrarag_search_r1
+bash train.sh
+```
+Env overrides: `BASE_MODEL`, `DATA_DIR`, `RAY_ADDRESS`, `CUDA_VISIBLE_DEVICES`, `VLLM_PORT`, etc.
+
+Optional sanity check (adapter import only):
+```bash
+cd examples/ultrarag_search_r1
+python ultrarag_adapter.py
+```
+
+---
+
+## Notes
+- Validation runs before training and every `test_freq` steps (see `train.sh`).
+- Checkpoints and validation results are written under `checkpoints/ultrarag_search_r1_checkpoints/`.
diff --git a/examples/ultrarag_search_r1/qa_em.py b/examples/ultrarag_search_r1/qa_em.py
new file mode 100644
index 000000000..48617605f
--- /dev/null
+++ b/examples/ultrarag_search_r1/qa_em.py
@@ -0,0 +1,134 @@
+# Copyright (c) Microsoft. All rights reserved.
+
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import re
+import string
+from typing import Mapping, Optional, Sequence, Union
+
+
+def normalize_answer(s: str) -> str:
+ """Lowercase, remove punctuation/articles, and normalize whitespace."""
+
+ def remove_articles(text: str) -> str:
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text: str) -> str:
+ return " ".join(text.split())
+
+ def remove_punc(text: str) -> str:
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text: str) -> str:
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def em_check(prediction: str, golden_answers: Union[str, Sequence[str]]) -> int:
+ if isinstance(golden_answers, str):
+ golden_answers = [golden_answers]
+ normalized_prediction = normalize_answer(prediction)
+ score = 0
+ for golden_answer in golden_answers:
+ golden_answer = normalize_answer(golden_answer)
+ if golden_answer == normalized_prediction:
+ score = 1
+ break
+ return score
+
+
+def subem_check(prediction: str, golden_answers: Union[str, Sequence[str]]) -> int:
+ if isinstance(golden_answers, str):
+ golden_answers = [golden_answers]
+ normalized_prediction = normalize_answer(prediction)
+ score = 0
+ for golden_answer in golden_answers:
+ golden_answer = normalize_answer(golden_answer)
+ if golden_answer in normalized_prediction:
+ score = 1
+ break
+ return score
+
+
+def extract_solution(solution_str: str) -> Optional[str]:
+ """Extract the last ... span from a solution string.
+
+ Returns None if fewer than two such spans are present, to match original behavior.
+ """
+ answer_pattern = r"(.*?)"
+ match_iter = re.finditer(answer_pattern, solution_str, re.DOTALL)
+ matches = list(match_iter)
+
+ # If there are 0 or exactly 1 matches, return None
+ if len(matches) <= 1:
+ return None
+
+ # If there are 2 or more matches, return the last one
+ return matches[-1].group(1).strip()
+
+
+def compute_score_em(
+ solution_str: str,
+ ground_truth: Union[str, Sequence[str]],
+ method: str = "strict",
+ format_score: float = 0.0,
+ score: float = 1.0,
+) -> float:
+ """Scoring function for exact match (EM)."""
+ answer = extract_solution(solution_str=solution_str)
+ do_print = random.randint(1, 64) == 1
+
+ if do_print:
+ print(f"--------------------------------")
+ print(f"Golden answers: {ground_truth}")
+ print(f"Extracted answer: {answer}")
+ print(f"Solution string: {solution_str}")
+
+ if answer is None:
+ return 0.0
+ else:
+ if em_check(answer, ground_truth):
+ return score
+ else:
+ return format_score
+
+
+def compute_score_subem(
+ solution_str: str,
+ ground_truth: Mapping[str, Union[str, Sequence[str]]],
+ method: str = "strict",
+ format_score: float = 0.0,
+ score: float = 1.0,
+) -> float:
+ """Scoring function for substring exact match (EM)."""
+ answer = extract_solution(solution_str=solution_str)
+ do_print = random.randint(1, 64) == 1
+
+ if do_print:
+ print(f"--------------------------------")
+ print(f"Golden answers: {ground_truth['target']}")
+ print(f"Extracted answer: {answer}")
+ print(f"Solution string: {solution_str}")
+
+ if answer is None:
+ return 0.0
+ else:
+ if subem_check(answer, ground_truth["target"]):
+ return score
+ else:
+ return format_score
diff --git a/examples/ultrarag_search_r1/search_r1_rl_parameter.yaml b/examples/ultrarag_search_r1/search_r1_rl_parameter.yaml
new file mode 100644
index 000000000..70d48e451
--- /dev/null
+++ b/examples/ultrarag_search_r1/search_r1_rl_parameter.yaml
@@ -0,0 +1,55 @@
+custom: {}
+generation:
+ backend: vllm
+ backend_configs:
+ vllm:
+ model_name_or_path: /path/to/model
+ trust_remote_code: true
+ dtype: auto
+ gpu_ids: 0
+ gpu_memory_utilization: 0.4
+ sampling_params:
+ max_tokens: 4096
+ temperature: 0.7
+ top_p: 0.8
+ system_prompt: ''
+prompt:
+ template: prompt/qa_boxed.jinja
+ r1_searcher_gen_template: prompt/r1_searcher_append.jinja
+retriever:
+ backend: sentence_transformers
+ backend_configs:
+ sentence_transformers:
+ trust_remote_code: true
+ sentence_transformers_encode:
+ encode_chunk_size: 256
+ normalize_embeddings: false
+ psg_prompt_name: document
+ psg_task: null
+ q_prompt_name: query
+ q_task: null
+ openai:
+ api_key: ''
+ base_url: https://api.openai.com/v1
+ model_name: text-embedding-3-small
+ infinity:
+ bettertransformer: false
+ model_warmup: false
+ pooling_method: auto
+ trust_remote_code: true
+ bm25:
+ lang: en
+ save_path: index/bm25
+ batch_size: 8
+ corpus_path: data/wiki18_ultra.jsonl
+ gpu_ids: 0
+ index_backend: faiss
+ index_backend_configs:
+ faiss:
+ index_chunk_size: 256
+ index_path: data/e5_Flat.index
+ index_use_gpu: true
+ is_multimodal: false
+ model_name_or_path: intfloat/e5-base-v2
+ query_instruction: ''
+ top_k: 5
diff --git a/examples/ultrarag_search_r1/search_r1_rl_server.yaml b/examples/ultrarag_search_r1/search_r1_rl_server.yaml
new file mode 100644
index 000000000..ee4ae8d1e
--- /dev/null
+++ b/examples/ultrarag_search_r1/search_r1_rl_server.yaml
@@ -0,0 +1,31 @@
+# Search-R1 RL Training Pipeline (no benchmark/eval)
+
+# MCP Server
+servers:
+ generation: servers/generation
+ retriever: servers/retriever
+ prompt: servers/prompt
+ router: servers/router
+ custom: servers/custom
+
+# MCP Client Pipeline (RL mode)
+pipeline:
+- retriever.retriever_init
+- generation.generation_init
+- prompt.qa_boxed
+- generation.generate
+- loop:
+ times: 8
+ steps:
+ - branch:
+ router:
+ - router.search_r1_check
+ branches:
+ incomplete:
+ - custom.search_r1_query_extract
+ - retriever.retriever_search:
+ input:
+ query_list: extract_query_list
+ - prompt.search_r1_gen
+ - generation.generate
+ complete: []
diff --git a/examples/ultrarag_search_r1/train.sh b/examples/ultrarag_search_r1/train.sh
new file mode 100755
index 000000000..33c7be2a1
--- /dev/null
+++ b/examples/ultrarag_search_r1/train.sh
@@ -0,0 +1,74 @@
+#!/bin/bash
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+AGL_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"
+
+if [ -f "${AGL_ROOT}/.venv/bin/python" ]; then
+ PYTHON="${AGL_ROOT}/.venv/bin/python"
+ echo "Using uv virtual environment: ${PYTHON}"
+else
+ PYTHON="python"
+ echo "Warning: uv virtual environment not found at ${AGL_ROOT}/.venv/bin/python"
+ echo "Using system python. Make sure all dependencies are installed."
+fi
+
+export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-3,4,5}
+export N_GPUS=${N_GPUS:-2}
+export VLLM_PORT=${VLLM_PORT:-8001}
+export VLLM_HOST=${VLLM_HOST:-127.0.0.1}
+export PYTORCH_CUDA_ALLOC_CONF=${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}
+
+export BASE_MODEL=${BASE_MODEL:-/path/to/llama-3.2-3b-instruct}
+export DATA_DIR=${DATA_DIR:-${SCRIPT_DIR}/../search_r1/data}
+export EXPERIMENT_NAME=${EXPERIMENT_NAME:-ultrarag_search_r1}
+export PROJECT_NAME=${PROJECT_NAME:-AgentLightning-ultrarag}
+
+echo "Using GPUs: $CUDA_VISIBLE_DEVICES"
+echo "Number of GPUs: $N_GPUS"
+echo "Data dir: $DATA_DIR"
+echo "Base model: $BASE_MODEL"
+
+cd "${SCRIPT_DIR}"
+PYTHONPATH="${SCRIPT_DIR}" ${PYTHON} -m agentlightning.verl \
+ algorithm.adv_estimator=grpo \
+ data.train_files=${DATA_DIR}/train.parquet \
+ data.val_files=${DATA_DIR}/test_100.parquet \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ trainer.n_gpus_per_node=${N_GPUS} \
+ data.train_batch_size=32 \
+ actor_rollout_ref.rollout.n=2 \
+ actor_rollout_ref.actor.ppo_mini_batch_size=128 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.model.path=${BASE_MODEL} \
+ data.max_prompt_length=4096 \
+ data.max_response_length=4096 \
+ data.truncation='error' \
+ trainer.val_before_train=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.95 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.use_kl_loss=false \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.clip_ratio_low=0.2 \
+ actor_rollout_ref.actor.clip_ratio_high=0.3 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.mode=async \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ algorithm.use_kl_in_reward=False \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.default_local_dir=checkpoints/ultrarag_search_r1_checkpoints/$EXPERIMENT_NAME \
+ trainer.project_name=${PROJECT_NAME} \
+ trainer.experiment_name=${EXPERIMENT_NAME} \
+ trainer.nnodes=1 \
+ trainer.save_freq=10 \
+ trainer.test_freq=10 \
+ trainer.total_epochs=15 \
+ trainer.total_training_steps=300
diff --git a/examples/ultrarag_search_r1/ultrarag_adapter.py b/examples/ultrarag_search_r1/ultrarag_adapter.py
new file mode 100755
index 000000000..26f27828c
--- /dev/null
+++ b/examples/ultrarag_search_r1/ultrarag_adapter.py
@@ -0,0 +1,556 @@
+
+"""
+UltraRAG adapter for Agent Lightning.
+Uses UltraRAG components (retrieval/generation) with the AGL training interface.
+"""
+
+import asyncio
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, cast
+import logging
+
+from agentlightning import LLM, LitAgent, NamedResources, Trainer, setup_logging
+from agentlightning.reward import reward
+
+try:
+ from ultrarag.client import UltraData, Configuration
+ from fastmcp import Client
+ ULTRARAG_AVAILABLE = True
+except ImportError:
+ print("Warning: UltraRAG components are unavailable")
+ ULTRARAG_AVAILABLE = False
+
+from qa_em import compute_score_em, em_check
+
+setup_logging()
+
+
+@reward
+async def eval(prediction: str, ground_truth: List[str]) -> float:
+ has_answer_tag = "" in prediction
+ if not has_answer_tag:
+ reward_score = float(em_check(prediction, ground_truth))
+ else:
+ reward_score = float(compute_score_em(prediction, ground_truth))
+ print(f"pred: {prediction} | gold_answer: {ground_truth} | has_tag: {has_answer_tag} | res: {reward_score}")
+ return reward_score
+
+
+def extract_answer_from_response(response_text: str) -> str:
+ """Extract the final answer from the response."""
+ pattern = r"(.*?)"
+ matches = re.findall(pattern, response_text, re.DOTALL)
+ if matches:
+ return matches[-1].strip()
+ return response_text.strip()
+
+
+class UltraRAGPipelineExecutor:
+ """
+ UltraRAG Pipeline executor.
+
+ Wraps the UltraRAG pipeline execution into a single-query interface.
+ """
+
+ def __init__(
+ self,
+ config_path: str,
+ param_path: Optional[str] = None,
+ generation_endpoint: Optional[str] = None,
+ generation_model: Optional[str] = None,
+ ):
+ """
+ Args:
+ config_path: UltraRAG pipeline config path.
+ param_path: UltraRAG parameter config path.
+ generation_endpoint: Generation model API endpoint (override config).
+ generation_model: Generation model name (override config).
+ """
+ self.config_path = Path(config_path)
+ self.param_path = Path(param_path) if param_path else None
+ self.generation_endpoint = generation_endpoint
+ self.generation_model = generation_model
+
+ self.cfg = Configuration()
+ self.pipeline_config = None
+ self.param_config = None
+ self._load_configs()
+
+ def _load_configs(self):
+ """Load configuration files."""
+ if not self.config_path.exists():
+ raise FileNotFoundError(f"Config file not found: {self.config_path}")
+
+ self.pipeline_config = self.cfg.load_config(str(self.config_path))
+
+ if self.param_path and self.param_path.exists():
+ self.param_config = self.cfg.load_parameter_config(str(self.param_path))
+ else:
+ param_file = self.config_path.parent / "parameter" / f"{self.config_path.stem}_parameter.yaml"
+ if param_file.exists():
+ self.param_config = self.cfg.load_parameter_config(str(param_file))
+
+ async def execute_single_query(
+ self,
+ question: str,
+ llm_endpoint: Optional[str] = None,
+ llm_model: Optional[str] = None,
+ temperature: float = 1.0,
+ max_iterations: int = 8,
+ ) -> Dict[str, Any]:
+ """
+ Run a single query via the UltraRAG pipeline.
+
+ Args:
+ question: Question text.
+ llm_endpoint: LLM API endpoint (override).
+ llm_model: LLM model name (override).
+ temperature: Sampling temperature.
+ max_iterations: Max iterations (from pipeline config or this argument).
+
+ Returns:
+ Dict with answer, reasoning steps, and retrieved nodes.
+ """
+ if not ULTRARAG_AVAILABLE:
+ raise ImportError("UltraRAG components unavailable; full pipeline mode disabled.")
+
+ if not self.pipeline_config:
+ raise ValueError("Pipeline config not loaded")
+
+
+ import ultrarag.client as ultrarag_client
+ if getattr(ultrarag_client, "logger", None) is None:
+ ultrarag_client.logger = logging.getLogger("ultrarag")
+ logger = ultrarag_client.logger
+ logger.setLevel(logging.INFO)
+ if not logger.handlers:
+ logger.addHandler(logging.StreamHandler())
+
+ from ultrarag.client import run as ultrarag_run
+ import tempfile
+ import json
+ import yaml
+
+ temp_param_file = None
+ temp_benchmark_file = None
+
+ try:
+ param_config = self.param_config.copy() if self.param_config else {}
+
+ if llm_endpoint:
+ if "generation" not in param_config:
+ param_config["generation"] = {}
+ if "backend_configs" not in param_config["generation"]:
+ param_config["generation"]["backend_configs"] = {}
+ if "openai" not in param_config["generation"]["backend_configs"]:
+ param_config["generation"]["backend_configs"]["openai"] = {}
+
+ param_config["generation"]["backend_configs"]["openai"]["base_url"] = llm_endpoint
+ if llm_model:
+ param_config["generation"]["backend_configs"]["openai"]["model_name"] = llm_model
+ param_config["generation"]["backend_configs"]["openai"]["use_completions"] = True
+
+ if "generation" not in param_config:
+ param_config["generation"] = {}
+ if "sampling_params" not in param_config["generation"]:
+ param_config["generation"]["sampling_params"] = {}
+ param_config["generation"]["sampling_params"]["temperature"] = temperature
+
+ pipeline_steps = self.pipeline_config.get("pipeline", [])
+ has_benchmark = any(
+ (isinstance(step, str) and step == "benchmark.get_data") or
+ (isinstance(step, dict) and "benchmark" in str(step))
+ for step in pipeline_steps
+ )
+
+ temp_benchmark_file = None
+ if has_benchmark:
+ temp_benchmark_file = tempfile.NamedTemporaryFile(
+ mode='w', suffix='.jsonl', delete=False, encoding='utf-8'
+ )
+ benchmark_data = {
+ "question": question,
+ "golden_answers": []
+ }
+ temp_benchmark_file.write(json.dumps(benchmark_data, ensure_ascii=False) + "\n")
+ temp_benchmark_file.close()
+
+ if "benchmark" not in param_config:
+ param_config["benchmark"] = {}
+ param_config["benchmark"]["benchmark"] = {
+ "path": temp_benchmark_file.name,
+ "limit": 1,
+ "key_map": {
+ "q_ls": "question",
+ "gt_ls": "golden_answers"
+ }
+ }
+ else:
+ pass
+
+ temp_param_file = tempfile.NamedTemporaryFile(
+ mode='w', suffix='.yaml', delete=False, encoding='utf-8'
+ )
+ yaml.dump(param_config, temp_param_file, allow_unicode=True)
+ temp_param_file.close()
+
+ result = await ultrarag_run(
+ str(self.config_path),
+ param_path=temp_param_file.name,
+ return_all=True # return all intermediate results for reasoning steps
+ )
+
+ all_results = result.get("all_results", [])
+ final_result = result.get("final_result", None)
+
+ answer = ""
+ if final_result:
+ if isinstance(final_result, dict):
+ ans_ls = final_result.get("ans_ls", [])
+ answer = ans_ls[0] if ans_ls else ""
+ elif isinstance(final_result, list) and final_result:
+ answer = final_result[0] if isinstance(final_result[0], str) else str(final_result[0])
+ else:
+ answer = str(final_result) if final_result else ""
+
+ reasoning_steps = []
+ retrieved_nodes = []
+ rollout_content_parts = []
+
+ for snapshot in all_results:
+ if "ans_ls" in snapshot:
+ ans_list = snapshot.get("ans_ls", [])
+ if ans_list and ans_list[-1]:
+ reasoning_steps.append(str(ans_list[-1]))
+ rollout_content_parts.append(str(ans_list[-1]))
+
+ if "retrieved_docs" in snapshot:
+ retrieved = snapshot.get("retrieved_docs", [])
+ if retrieved:
+ retrieved_nodes.extend(retrieved)
+
+ rollout_content = "\n".join(rollout_content_parts) if rollout_content_parts else answer
+
+ return {
+ "answer": answer,
+ "response": rollout_content, # full response (for RL training)
+ "steps": reasoning_steps,
+ "retrieved_nodes": retrieved_nodes,
+ }
+
+ except Exception as e:
+ print(f"UltraRAG pipeline execution error: {e}")
+ import traceback
+ traceback.print_exc()
+ return {
+ "answer": "",
+ "response": "",
+ "steps": [],
+ "retrieved_nodes": [],
+ }
+ finally:
+ if temp_param_file:
+ try:
+ os.unlink(temp_param_file.name)
+ except:
+ pass
+ if temp_benchmark_file:
+ try:
+ os.unlink(temp_benchmark_file.name)
+ except:
+ pass
+
+
+class UltraRAGAgent(LitAgent[Any]):
+ """UltraRAG Agent for UltraRAG + AGL training."""
+ """
+ Agent that integrates UltraRAG with Agent Lightning.
+
+ Uses UltraRAG core components to process queries.
+ """
+
+ def __init__(
+ self,
+ ultrarag_config_path: Optional[str] = None,
+ ultrarag_param_path: Optional[str] = None,
+ use_simplified_interface: bool = False,
+ ):
+ """
+ Args:
+ ultrarag_config_path: UltraRAG pipeline config path.
+ ultrarag_param_path: UltraRAG parameter config path.
+ use_simplified_interface: Whether to use the simplified interface (direct retrieve/generate without full pipeline).
+ """
+ super().__init__()
+ self.use_simplified_interface = use_simplified_interface
+
+ if ultrarag_config_path:
+ self.ultrarag_config_path = Path(ultrarag_config_path)
+ else:
+ default_path = Path(__file__).parent / "search_r1_rl.yaml"
+ if not default_path.exists():
+ default_path = Path(__file__).parent / "r1_searcher.yaml"
+ self.ultrarag_config_path = default_path if default_path.exists() else None
+
+ if ultrarag_param_path:
+ self.ultrarag_param_path = Path(ultrarag_param_path)
+ else:
+ default_param = Path(__file__).parent / "search_r1_rl_parameter.yaml"
+ if not default_param.exists():
+ default_param = Path(__file__).parent / "r1_searcher_parameter.yaml"
+ self.ultrarag_param_path = default_param if default_param.exists() else None
+
+ if not use_simplified_interface and self.ultrarag_config_path:
+ self.pipeline_executor = UltraRAGPipelineExecutor(
+ str(self.ultrarag_config_path),
+ str(self.ultrarag_param_path) if self.ultrarag_param_path else None,
+ generation_endpoint=None, # set at runtime
+ generation_model=None, # set at runtime
+ )
+ else:
+ self.pipeline_executor = None
+
+ async def _act_with_ultrarag(
+ self,
+ question: str,
+ llm_endpoint: str,
+ llm_model: str,
+ temperature: float = 1.0,
+ ) -> Dict[str, Any]:
+ """
+ Execute query with UltraRAG core components.
+
+ Return format:
+ {
+ "response": str, # full response (for RL training)
+ "steps": List[str], # reasoning steps (for RL reward)
+ "retrieved_nodes": List[Dict], # retrieved nodes
+ }
+ """
+ if self.pipeline_executor and not self.use_simplified_interface:
+ result = await self.pipeline_executor.execute_single_query(
+ question,
+ llm_endpoint=llm_endpoint,
+ llm_model=llm_model,
+ temperature=temperature,
+ )
+ return {
+ "response": result.get("response", ""), # Use full response, not just the answer.
+ "steps": result.get("steps", []),
+ "retrieved_nodes": result.get("retrieved_nodes", []),
+ }
+ else:
+ return await self._act_with_simplified_interface(
+ question, llm_endpoint, llm_model, temperature
+ )
+
+ async def _act_with_simplified_interface(
+ self,
+ question: str,
+ llm_endpoint: str,
+ llm_model: str,
+ temperature: float = 1.0,
+ ) -> Dict[str, Any]:
+ """
+ Execute query with the simplified interface.
+
+ Similar to search_r1_agent but keeps hooks for full UltraRAG pipeline.
+ """
+ from openai import AsyncOpenAI
+ import requests
+
+ INSTRUCTION_FORMAT = """Answer the given question. You must conduct reasoning inside and first every time you get new information. After reasoning, if you find you lack some knowledge, you can call a search engine by query and it will return the top searched results between and . You can search as many times as your want. If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. For example, Beijing . Question: """
+
+ client = AsyncOpenAI(
+ base_url=llm_endpoint,
+ api_key=os.environ.get("OPENAI_API_KEY", "token-abc123"),
+ )
+
+ async def call_llm(content: str, max_tokens: int = 500) -> str:
+ """Call LLM via chat.completions (Instruct models support chat templates).
+
+ AsyncOpenAI so AgentOpsTracer can capture calls and generate triplets.
+ """
+ response = await client.chat.completions.create(
+ model=llm_model,
+ messages=[
+ {"role": "user", "content": content}
+ ],
+ temperature=temperature,
+ max_tokens=max_tokens,
+ extra_body={"return_token_ids": True}, # vLLM needs this to return token_ids.
+ )
+ return response.choices[0].message.content or ""
+
+ def extract_action(response: str) -> Tuple[Optional[str], str]:
+ pattern = r"<(search|answer)>(.*?)\1>"
+ match = re.search(pattern, response, re.DOTALL)
+ if match:
+ return match.group(1), match.group(2).strip()
+ return None, ""
+
+ def postprocess_response(response: str) -> str:
+ if "" in response:
+ return response.split("")[0] + ""
+ elif "" in response:
+ return response.split("")[0] + ""
+ return response
+
+ retrieval_endpoint = os.environ.get("RETRIEVAL_ENDPOINT", "http://127.0.0.1:8002/retrieve")
+
+ prompt = INSTRUCTION_FORMAT + question
+ rollout_content = ""
+ reasoning_steps = []
+ retrieved_nodes = []
+ turn_id = 0
+ finished = False
+
+ while turn_id < 4 and not finished:
+ turn_id += 1
+ turn_response = await call_llm(prompt + rollout_content)
+ valid_response = postprocess_response(turn_response)
+ reasoning_steps.append(valid_response)
+
+ action, content = extract_action(valid_response)
+ if action == "answer":
+ finished = True
+ rollout_content += valid_response
+ elif action == "search":
+ payload = {"queries": [content], "topk": 3, "return_scores": True}
+ try:
+ resp = requests.post(retrieval_endpoint, json=payload, timeout=10)
+ resp.raise_for_status()
+ json_resp = resp.json()
+ retrieval_result = json_resp["result"][0]
+ retrieved_nodes.extend(retrieval_result)
+
+ format_ref = ""
+ for idx, doc_item in enumerate(retrieval_result):
+ doc = doc_item.get("document", doc_item)
+ content_str = doc.get("contents", str(doc)) if isinstance(doc, dict) else str(doc)
+ lines = content_str.split("\n")
+ title = lines[0] if lines else ""
+ text = "\n".join(lines[1:]) if len(lines) > 1 else content_str
+ format_ref += f"Doc {idx+1}(Title: {title}) {text}\n"
+
+ env_feedback = f"\n\n{format_ref}\n\n"
+ except Exception as e:
+ print(f"Retrieval error: {e}")
+ env_feedback = "\n\nretrieval failed\n\n"
+
+ rollout_content += valid_response + env_feedback
+ else:
+ error_msg = "\nMy previous action is invalid. If I want to search, I should put the query between and . If I want to give the final answer, I should put the answer between and . Let me try again.\n"
+ rollout_content += valid_response + error_msg
+
+ if not finished:
+ final_response = await call_llm(prompt + rollout_content)
+ rollout_content += final_response
+ reasoning_steps.append(final_response)
+
+ return {
+ "response": rollout_content,
+ "steps": reasoning_steps,
+ "retrieved_nodes": retrieved_nodes,
+ }
+
+ async def training_rollout_async(
+ self,
+ task: Any,
+ resources: NamedResources,
+ rollout: Any,
+ temperature: float = 1.0,
+ ) -> Any:
+ question = task["question"]
+ answer_list: List[str] = cast(List[str], task["golden_answers"])
+ llm: LLM = cast(LLM, resources.get("main_llm"))
+
+ result = await self._act_with_ultrarag(
+ question, llm.endpoint, llm.model, temperature
+ )
+
+ pred_answer = extract_answer_from_response(result["response"])
+
+ reward_score = await eval(pred_answer, answer_list)
+ print(
+ f"question: {question} "
+ f"pred_answer: {pred_answer} "
+ f"ground_truth: {answer_list} "
+ f"reward: {reward_score}"
+ )
+
+ return reward_score
+
+ async def validation_rollout_async(
+ self,
+ task: Any,
+ resources: NamedResources,
+ rollout: Any,
+ ) -> Any:
+ reward_score = await self._validation_with_save(task, resources, rollout)
+ return reward_score
+
+ async def _validation_with_save(
+ self,
+ task: Any,
+ resources: NamedResources,
+ rollout: Any,
+ ) -> float:
+ """Run validation and save results."""
+ import json
+ import os
+ from pathlib import Path
+ from datetime import datetime
+
+ question = task["question"]
+ answer_list: List[str] = cast(List[str], task["golden_answers"])
+ llm: LLM = cast(LLM, resources.get("main_llm"))
+
+ result = await self._act_with_ultrarag(
+ question, llm.endpoint, llm.model, temperature=0.0
+ )
+
+ pred_answer = extract_answer_from_response(result["response"])
+
+ reward_score = await eval(pred_answer, answer_list)
+
+ try:
+ checkpoint_dir = os.environ.get("CHECKPOINT_DIR", "checkpoints/ultrarag_agl_checkpoints/ultrarag_agl")
+ step = int(os.environ.get("CURRENT_STEP", "0"))
+ is_val_before_train = (step == 0)
+
+ if is_val_before_train:
+ val_dir = Path(checkpoint_dir) / "val_before_train"
+ else:
+ val_dir = Path(checkpoint_dir) / f"validation_step_{step}"
+ val_dir.mkdir(parents=True, exist_ok=True)
+
+ result_file = val_dir / "results.jsonl"
+ validation_result = {
+ "question": question,
+ "golden_answers": answer_list,
+ "prediction": pred_answer, # extracted final answer
+ "rollout_content": result["response"], # full reasoning trace
+ "steps": result.get("steps", []), # reasoning steps
+ "retrieved_nodes": result.get("retrieved_nodes", []), # retrieved nodes
+ "reward": float(reward_score),
+ "step": step,
+ "timestamp": datetime.now().isoformat(),
+ }
+
+ with open(result_file, "a", encoding="utf-8") as f:
+ f.write(json.dumps(validation_result, ensure_ascii=False) + "\n")
+ except Exception as e:
+ print(f"Error while saving validation results: {e}")
+
+ return reward_score
+
+
+if __name__ == "__main__":
+ Trainer(n_workers=128).fit_v0(
+ UltraRAGAgent(use_simplified_interface=False),
+ "http://localhost:9999/"
+ )
+