From 31c56930bd3aced8a25a757ad81d9f105396ee53 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 22 Sep 2025 20:21:27 +0800 Subject: [PATCH 01/33] add allocate_session and feedback API --- trinity/explorer/api/api.py | 24 ++++++++++++- trinity/explorer/api/service.py | 56 ++++++++++++++++++++++------- trinity/explorer/explorer_client.py | 20 ++++++----- 3 files changed, 77 insertions(+), 23 deletions(-) diff --git a/trinity/explorer/api/api.py b/trinity/explorer/api/api.py index 66b5e2e97a..f8914e481b 100644 --- a/trinity/explorer/api/api.py +++ b/trinity/explorer/api/api.py @@ -25,7 +25,9 @@ async def chat_completions(request: Request): content=f"Error forwarding request to model at {url}: {traceback.format_exc()}", ) resp_data = resp.json() - await request.app.state.service.record_experience(resp_data) + await request.app.state.service.record_experience( + resp_data, session_id=body.get("session_id", None) + ) return JSONResponse(content=resp_data) @@ -52,6 +54,26 @@ async def metrics(request: Request): return JSONResponse(content=metrics) +@app.get("/allocate") +async def allocate(request: Request): + """Allocate a new session.""" + return JSONResponse(content={"session_id": request.app.state.service.allocate_session()}) + + +@app.post("/feedback") +async def feedback(request: Request): + """Receive feedback for the current session.""" + body = await request.json() + session_id = body.get("session_id", None) + reward = body.get("reward", None) + if session_id is None or reward is None: + return JSONResponse( + status_code=400, content={"error": "session_id and reward are required"} + ) + await request.app.state.service.explorer.record_feedback(session_id, reward) + return JSONResponse(content={"status": "success"}) + + async def serve_http(app: FastAPI, host: str, port: int = None): config = uvicorn.Config(app, host=host, port=port) server = uvicorn.Server(config) diff --git a/trinity/explorer/api/service.py b/trinity/explorer/api/service.py index ffdf2cfd9a..b489b9cb87 100644 --- a/trinity/explorer/api/service.py +++ b/trinity/explorer/api/service.py @@ -1,7 +1,7 @@ import asyncio import time from collections import deque -from typing import Dict, List +from typing import Dict, List, Optional import torch @@ -13,6 +13,8 @@ class ExplorerService: + """Manages the lifecycle and operations of the Explorer API service.""" + def __init__(self, explorer: Explorer, listen_address: str = "localhost", port: int = 8010): self.logger = get_logger(__name__) self.explorer = explorer @@ -27,10 +29,13 @@ def __init__(self, explorer: Explorer, listen_address: str = "localhost", port: self.running_models: deque[int] = deque() # indices of running models self.sync_task_map: Dict[asyncio.Future, int] = {} # sync task -> model index self.latest_model_version = 0 - self.experience_queue = asyncio.Queue() + self.experience_queue: deque[Experience] = deque() + self.session_level_experience_queue: Dict[int, deque[Experience]] = {} + self.queue_lock = asyncio.Lock() self.experience_count = 0 + self.session_count = 0 - async def serve(self): + async def serve(self) -> None: from trinity.explorer.api.api import run_app if self.running: @@ -48,7 +53,7 @@ async def serve(self): ) self.sync_model_weights_task = asyncio.create_task(self.model_weights_sync_loop()) - async def model_weights_sync_loop(self): + async def model_weights_sync_loop(self) -> None: self.logger.info("Starting model weights synchronization loop.") while self.running: for idx in list(self.running_models): @@ -71,7 +76,7 @@ def set_latest_model_version(self, version: int) -> None: self.latest_model_version = version self.logger.info(f"Updated latest model version to {version}.") - async def _wait_for_sync_start(self, index: int): + async def _wait_for_sync_start(self, index: int) -> None: start_time = time.time() while time.time() - start_time < self.max_timeout: current_load = await self.models[index].get_current_load() @@ -85,7 +90,7 @@ async def _wait_for_sync_start(self, index: int): f"Timeout waiting for model {index} to be free for synchronization. Current load: {current_load}" ) - async def _sync_model_weights(self, task: asyncio.Future): + async def _sync_model_weights(self, task: asyncio.Future) -> None: index = self.sync_task_map.pop(task) latest_version = self.latest_model_version # capture the latest version if task.cancelled(): @@ -121,7 +126,7 @@ async def check_requiring_sync_models(self): *[self._sync_model_weights(idx) for idx in list(self.requiring_sync_models)] ) - async def record_experience(self, response): + async def record_experience(self, response, session_id: Optional[int] = None): experiences = [] for choice in response["choices"]: exp = Experience( @@ -137,14 +142,39 @@ async def record_experience(self, response): ) experiences.append(exp) self.experience_count += len(experiences) - for exp in experiences: - await self.experience_queue.put(exp) + + # Store experiences in session-level queue if session_id is provided + if session_id is not None: + async with self.queue_lock: + if session_id not in self.session_level_experience_queue: + self.session_level_experience_queue[session_id] = deque() + self.session_level_experience_queue[session_id].extend(experiences) + else: + async with self.queue_lock: + self.experience_queue.extend(experiences) async def get_all_experiences(self) -> List: - experiences = [] - while not self.experience_queue.empty(): - experiences.append(await self.experience_queue.get()) - return experiences + async with self.queue_lock: + experiences = list(self.experience_queue) + self.experience_queue.clear() + return experiences + + def allocate_session(self) -> int: + self.session_count += 1 + return self.session_count + + async def record_feedback(self, session_id: int, reward: float): + exps = [] + async with self.queue_lock: + if session_id in self.session_level_experience_queue: + exps = list(self.session_level_experience_queue.pop(session_id)) + if not exps: + self.logger.warning(f"No experiences found for session_id {session_id}.") + return + for exp in exps: + exp.reward = reward + async with self.queue_lock: + self.experience_queue.extend(exps) async def shutdown(self): if not self.running: diff --git a/trinity/explorer/explorer_client.py b/trinity/explorer/explorer_client.py index 311b310038..a7ca0f5e85 100644 --- a/trinity/explorer/explorer_client.py +++ b/trinity/explorer/explorer_client.py @@ -6,18 +6,20 @@ class ExplorerClient: - def __init__(self, base_url: str): - self.base_url = base_url + def __init__(self, explorer_api_url: str): + self.explorer_api_url = explorer_api_url + self.openai_base_url = f"{self.explorer_api_url}/v1" + self.feedback_url = f"{self.explorer_api_url}/feedback" self.session_id = self.init_session() def init_session(self) -> str: - response = requests.post(f"{self.base_url}/allocate") + response = requests.post(f"{self.explorer_api_url}/allocate") data = response.json() return data["session_id"] def get_openai_client(self) -> openai.OpenAI: client = openai.OpenAI( - base_url=self.base_url + "/v1", + base_url=self.openai_base_url, api_key="EMPTY", ) client.chat.completions.create = partial( @@ -27,7 +29,7 @@ def get_openai_client(self) -> openai.OpenAI: def get_openai_async_client(self) -> openai.AsyncOpenAI: client = openai.AsyncOpenAI( - base_url=self.base_url + "/v1", + base_url=self.openai_base_url, api_key="EMPTY", ) client.chat.completions.create = partial( @@ -35,15 +37,15 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: ) return client - def feedback(self, reward: float): + def feedback(self, reward: float) -> dict: response = requests.post( - f"{self.base_url}/feedback", json={"session_id": self.session_id, "reward": reward} + self.feedback_url, json={"session_id": self.session_id, "reward": reward} ) return response.json() - async def feedback_async(self, reward: float): + async def feedback_async(self, reward: float) -> dict: async with httpx.AsyncClient() as client: response = await client.post( - f"{self.base_url}/feedback", json={"session_id": self.session_id, "reward": reward} + self.feedback_url, json={"session_id": self.session_id, "reward": reward} ) return response.json() From ece144e52f337f3d00d28d4dfa2d0dfa0b6c03f7 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 22 Sep 2025 20:32:01 +0800 Subject: [PATCH 02/33] fix comments --- trinity/explorer/api/api.py | 6 +++++- trinity/explorer/api/service.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/trinity/explorer/api/api.py b/trinity/explorer/api/api.py index f8914e481b..9cc435aa21 100644 --- a/trinity/explorer/api/api.py +++ b/trinity/explorer/api/api.py @@ -70,7 +70,11 @@ async def feedback(request: Request): return JSONResponse( status_code=400, content={"error": "session_id and reward are required"} ) - await request.app.state.service.explorer.record_feedback(session_id, reward) + if not isinstance(session_id, int) or not isinstance(reward, (int, float)): + return JSONResponse( + status_code=400, content={"error": "session_id must be int and reward must be float"} + ) + await request.app.state.service.record_feedback(session_id, reward) return JSONResponse(content={"status": "success"}) diff --git a/trinity/explorer/api/service.py b/trinity/explorer/api/service.py index b489b9cb87..2dcfe815ef 100644 --- a/trinity/explorer/api/service.py +++ b/trinity/explorer/api/service.py @@ -140,6 +140,8 @@ async def record_experience(self, response, session_id: Optional[int] = None): prompt_length=len(response["prompt_token_ids"]), response_text=choice.get("message", {}).get("content", ""), ) + if session_id is not None: + exp.eid.task = session_id experiences.append(exp) self.experience_count += len(experiences) From 7ed876c04271432c8982ea19a36087ea575b5b6d Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 23 Sep 2025 10:58:38 +0800 Subject: [PATCH 03/33] add api test --- tests/trainer/trainer_test.py | 24 ++++++++++++++++++++++++ trinity/explorer/api/api.py | 3 +++ 2 files changed, 27 insertions(+) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 3dc67b692f..4fffc33585 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -682,6 +682,30 @@ def tearDown(self): shutil.rmtree(self.config.checkpoint_job_dir) +class TestServeWithTrainer(unittest.TestCase): + + def setUp(self): + if multiprocessing.get_start_method(allow_none=True) != "spawn": + multiprocessing.set_start_method("spawn", force=True) + checkpoint_path = get_checkpoint_path() + shutil.rmtree(os.path.join(checkpoint_path, "unittest"), ignore_errors=True) + + + def test_serve_with_trainer(self): + config = get_template_config() + config.project = "unittest" + config.name = f"serve_with_trainer_{datetime.now().strftime('%Y%m%d%H%M%S')}" + config.checkpoint_root_dir = get_checkpoint_path() + config.buffer.batch_size = 4 + config.algorithm.algorithm_type = "ppo" + config.algorithm.repeat_times = 1 + config.trainer.save_interval = 1 + config.buffer.trainer_input.experience_buffer = StorageConfig( + name="exp_buffer", + storage_type=StorageType.SQL, + ) + config.synchronizer.sync_method = SyncMethod.CHECKPOINT + class TestTrainerMultiModal(BaseTrainerCase): @unittest.skip("Require specific vllm/transformers version") def test_trainer(self): diff --git a/trinity/explorer/api/api.py b/trinity/explorer/api/api.py index 9cc435aa21..b84a3c0680 100644 --- a/trinity/explorer/api/api.py +++ b/trinity/explorer/api/api.py @@ -33,10 +33,13 @@ async def chat_completions(request: Request): @app.get("/v1/models") async def show_available_models(request: Request): + if hasattr(request.app.state, "models"): + return JSONResponse(content=request.app.state.models) body = await request.json() url = await request.app.state.service.allocate_model(increase_count=False) async with httpx.AsyncClient() as client: resp = await client.get(f"{url}/v1/models", json=body) + request.app.state.models = resp.json() return JSONResponse(content=resp.json()) From 20857a7d3bf730e1297c605cd563d7859c88535c Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 14 Oct 2025 18:27:28 +0800 Subject: [PATCH 04/33] add tests --- tests/trainer/trainer_test.py | 124 ++++++++++++++++++++++++++- trinity/common/models/utils.py | 22 +++++ trinity/common/models/vllm_worker.py | 4 + trinity/explorer/api/service.py | 2 +- trinity/explorer/explorer.py | 7 +- 5 files changed, 150 insertions(+), 9 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 8854616d48..0de164689f 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -8,6 +8,8 @@ from copy import deepcopy from datetime import datetime from unittest import mock +import asyncio +import httpx import ray from parameterized import parameterized_class @@ -22,7 +24,7 @@ get_unittest_dataset_config, get_vision_language_model_path, ) -from trinity.cli.launcher import bench, both, explore, run, train +from trinity.cli.launcher import bench, both, explore, run, train, serve from trinity.common.config import ( AlgorithmConfig, BufferConfig, @@ -41,6 +43,8 @@ ) from trinity.common.models.utils import get_checkpoint_dir_with_step_num from trinity.manager.state_manager import StateManager +from trinity.buffer import get_buffer_reader +from trinity.explorer.explorer_client import ExplorerClient class BaseTrainerCase(RayUnittestBase): @@ -474,6 +478,18 @@ def run_both(config: Config) -> None: ) both(config) +def run_serve(config: Config) -> None: + ray.init( + namespace=config.ray_namespace, + runtime_env={ + "env_vars": { + LOG_DIR_ENV_VAR: config.log.save_dir, + LOG_LEVEL_ENV_VAR: "INFO", + } + }, + ) + serve(config) + @parameterized_class( ("use_priority_queue", "strategy"), @@ -841,7 +857,38 @@ def tearDown(self): shutil.rmtree(self.config.checkpoint_job_dir) -class TestServeWithTrainer(unittest.TestCase): +async def run_math_workflow(serve_url: str, task: dict): + from trinity.common.rewards.math_reward import MathRewardFn + + explorer_client = ExplorerClient(serve_url) + openai_client = explorer_client.get_openai_async_client() + + query = task["question"] + truth = task["answer"] + + reward_fn = MathRewardFn() + + system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., + reasoning process here + answer here . +""" + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": query}, + ] + + model = openai_client.models.list().data[0].id + + response = await openai_client.chat.completions.create( + model=model, + messages=messages, + ) + answer = response.choices[0].message.content + reward = reward_fn(response=answer, truth=truth, prompt=query) + await explorer_client.feedback_async(reward) + + +class TestServeWithTrainer(unittest.IsolatedAsyncioTestCase): def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": @@ -850,7 +897,7 @@ def setUp(self): shutil.rmtree(os.path.join(checkpoint_path, "unittest"), ignore_errors=True) - def test_serve_with_trainer(self): + async def test_serve_with_trainer(self): config = get_template_config() config.project = "unittest" config.name = f"serve_with_trainer_{datetime.now().strftime('%Y%m%d%H%M%S')}" @@ -858,12 +905,81 @@ def test_serve_with_trainer(self): config.buffer.batch_size = 4 config.algorithm.algorithm_type = "ppo" config.algorithm.repeat_times = 1 - config.trainer.save_interval = 1 config.buffer.trainer_input.experience_buffer = StorageConfig( name="exp_buffer", storage_type=StorageType.SQL, ) + config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + config.buffer.train_batch_size = 4 + config.trainer.total_steps = 4 + config.trainer.save_interval = 2 + config.synchronizer.sync_interval = 2 config.synchronizer.sync_method = SyncMethod.CHECKPOINT + config.explorer.rollout_model.engine_num = 2 + config.explorer.rollout_model.tensor_parallel_size = 1 + config.explorer.service_status_check_interval = 10 + + trainer_config = deepcopy(config) + trainer_config.mode = "train" + trainer_config.check_and_update() + + trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,)) + trainer_process.start() + + ray.init(ignore_reinit_error=True) + while True: + try: + ray.get_actor("sql-exp_buffer", namespace=trainer_config.ray_namespace) + break + except ValueError: + print("waiting for trainer to start.") + time.sleep(5) + + serve_config = deepcopy(config) + serve_config.mode = "serve" + serve_config.check_and_update() + serve_process = multiprocessing.Process(target=run_explorer, args=(serve_config,)) + serve_process.start() + + state_manager = StateManager( + path=serve_config.checkpoint_job_dir, + explorer_name=serve_config.explorer.name, + ) + + # wait for explorer initialization + for i in range(30): + try: + server_url = state_manager.load_explorer_server_url() + except Exception: + server_url = None + if server_url: + break + await asyncio.sleep(3) + if not server_url: + raise RuntimeError("Explorer server URL not found.") + # wait for server setup + for i in range(10): + try: + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/health") + if response.status_code == 200: + break + except Exception: + pass + await asyncio.sleep(2) + + reader = get_buffer_reader(serve_config.buffer.explorer_input.taskset, serve_config.buffer) + + for i in range(2): + # generate data for 2 trainer steps + tasks = reader.read(batch_size=8) + await asyncio.gather(*(run_math_workflow(server_url, task.raw_task) for task in tasks)) + + # wait for synchronizer started + end_time = time.time() + while time.time() - end_time < config.explorer.service_status_check_interval: + await asyncio.sleep(1) + class TestMultiModalGRPO(BaseTrainerCase): @unittest.skip("Require specific vllm/transformers version") diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index d33a2b424e..8476b697f2 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -164,6 +164,28 @@ def get_checkpoint_dir_with_step_num( else: raise NotImplementedError(f"Unsupported trainer type {trainer_type}") +def get_latest_state_dict( + checkpoint_root_path: str, + trainer_type: str = "verl", +) -> Tuple[str, int]: + """Get the latest state dict from a root checkpoint directory. + + Args: + checkpoint_root_path (str): The root checkpoint directory. + + Returns: + Tuple[str, int]: The state dict path and the iteration of the state dict. + If the state dict does not exist, return (None, 0). + """ + if trainer_type != "verl": + raise NotImplementedError(f"Unsupported trainer type {trainer_type}") + latest_state_dict_iteration_path = os.path.join(checkpoint_root_path, "latest_state_dict_iteration.txt") + if os.path.exists(latest_state_dict_iteration_path): + with open(latest_state_dict_iteration_path, "r", encoding="utf-8") as f: + iteration = f.read().strip() + state_dict_path = os.path.join(checkpoint_root_path, f"global_step_{iteration}", "actor") + return state_dict_path, int(iteration) + return None, 0 def load_state_dict(checkpoint_dir: str, config: TrainerConfig) -> Union[dict, Tuple[str, str]]: """Load state dict from a checkpoint dir. diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 93d9c0bd48..05fa5ec9d2 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -88,3 +88,7 @@ def update_weight(self): torch.distributed.barrier(group=self._model_update_group) torch.cuda.synchronize() torch.cuda.empty_cache() + + def update_weight_from_checkpoint(self): + """Update weight from checkpoint without broadcasting""" + pass \ No newline at end of file diff --git a/trinity/explorer/api/service.py b/trinity/explorer/api/service.py index 2dcfe815ef..c910eb093e 100644 --- a/trinity/explorer/api/service.py +++ b/trinity/explorer/api/service.py @@ -189,4 +189,4 @@ async def shutdown(self): except asyncio.CancelledError: pass self.running = False - self.logger.info("API server shut down.") + self.logger.info("API server shutdown.") diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index a10b523af2..4e0c9099aa 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -23,7 +23,7 @@ SyncStyle, ) from trinity.common.models import create_inference_models -from trinity.common.models.utils import get_checkpoint_dir_with_step_num +from trinity.common.models.utils import get_latest_state_dict from trinity.explorer.scheduler import Scheduler from trinity.manager.state_manager import StateManager from trinity.manager.synchronizer import Synchronizer @@ -81,7 +81,6 @@ def __init__(self, config: Config): async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None ): - # In checkpoint mode, we use explorer to store the model weights which has no rank base_offset = 1 if self.use_nccl_sync else 0 world_size = ( len(self.models) * self.config.explorer.rollout_model.tensor_parallel_size + base_offset @@ -445,8 +444,8 @@ async def serve(self) -> None: metrics.update(self.service.collect_metrics()) self.monitor.log(metrics, self.explore_step_num) # get the latest checkpoint - _, step_num = get_checkpoint_dir_with_step_num( - self.config.checkpoint_job_dir, raise_error=False + _, step_num = get_latest_state_dict( + self.config.checkpoint_job_dir, self.config.trainer.trainer_type, ) self.service.set_latest_model_version(step_num) From 5295688af4d2ee80982d2d5299927014862d24fc Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 15 Oct 2025 11:31:25 +0800 Subject: [PATCH 05/33] process group for serve model --- tests/trainer/trainer_test.py | 17 ++++++----- trinity/common/models/utils.py | 12 ++++++-- trinity/common/models/vllm_worker.py | 2 +- trinity/explorer/explorer.py | 45 ++++++++++++++++++++++++---- trinity/manager/synchronizer.py | 2 +- 5 files changed, 59 insertions(+), 19 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 0de164689f..0e2baf56d9 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1,5 +1,6 @@ """Tests for trainer.""" +import asyncio import multiprocessing import os import shutil @@ -8,9 +9,8 @@ from copy import deepcopy from datetime import datetime from unittest import mock -import asyncio -import httpx +import httpx import ray from parameterized import parameterized_class @@ -24,7 +24,8 @@ get_unittest_dataset_config, get_vision_language_model_path, ) -from trinity.cli.launcher import bench, both, explore, run, train, serve +from trinity.buffer import get_buffer_reader +from trinity.cli.launcher import bench, both, explore, run, serve, train from trinity.common.config import ( AlgorithmConfig, BufferConfig, @@ -42,9 +43,8 @@ SyncStyle, ) from trinity.common.models.utils import get_checkpoint_dir_with_step_num -from trinity.manager.state_manager import StateManager -from trinity.buffer import get_buffer_reader from trinity.explorer.explorer_client import ExplorerClient +from trinity.manager.state_manager import StateManager class BaseTrainerCase(RayUnittestBase): @@ -478,6 +478,7 @@ def run_both(config: Config) -> None: ) both(config) + def run_serve(config: Config) -> None: ray.init( namespace=config.ray_namespace, @@ -885,18 +886,16 @@ async def run_math_workflow(serve_url: str, task: dict): ) answer = response.choices[0].message.content reward = reward_fn(response=answer, truth=truth, prompt=query) - await explorer_client.feedback_async(reward) + await explorer_client.feedback_async(sum(reward.values())) class TestServeWithTrainer(unittest.IsolatedAsyncioTestCase): - def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": multiprocessing.set_start_method("spawn", force=True) checkpoint_path = get_checkpoint_path() shutil.rmtree(os.path.join(checkpoint_path, "unittest"), ignore_errors=True) - async def test_serve_with_trainer(self): config = get_template_config() config.project = "unittest" @@ -980,6 +979,8 @@ async def test_serve_with_trainer(self): while time.time() - end_time < config.explorer.service_status_check_interval: await asyncio.sleep(1) + # check for trainer new checkpoint + class TestMultiModalGRPO(BaseTrainerCase): @unittest.skip("Require specific vllm/transformers version") diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index 8476b697f2..10ea510e1b 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -164,6 +164,7 @@ def get_checkpoint_dir_with_step_num( else: raise NotImplementedError(f"Unsupported trainer type {trainer_type}") + def get_latest_state_dict( checkpoint_root_path: str, trainer_type: str = "verl", @@ -179,13 +180,18 @@ def get_latest_state_dict( """ if trainer_type != "verl": raise NotImplementedError(f"Unsupported trainer type {trainer_type}") - latest_state_dict_iteration_path = os.path.join(checkpoint_root_path, "latest_state_dict_iteration.txt") + latest_state_dict_iteration_path = os.path.join( + checkpoint_root_path, "latest_state_dict_iteration.txt" + ) if os.path.exists(latest_state_dict_iteration_path): with open(latest_state_dict_iteration_path, "r", encoding="utf-8") as f: iteration = f.read().strip() - state_dict_path = os.path.join(checkpoint_root_path, f"global_step_{iteration}", "actor") + state_dict_path = os.path.join( + checkpoint_root_path, f"global_step_{iteration}", "actor" + ) return state_dict_path, int(iteration) - return None, 0 + return None, 0 # type: ignore + def load_state_dict(checkpoint_dir: str, config: TrainerConfig) -> Union[dict, Tuple[str, str]]: """Load state dict from a checkpoint dir. diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 05fa5ec9d2..f1e5833a8f 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -91,4 +91,4 @@ def update_weight(self): def update_weight_from_checkpoint(self): """Update weight from checkpoint without broadcasting""" - pass \ No newline at end of file + pass diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 4e0c9099aa..78e2e768b0 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -107,6 +107,30 @@ async def setup_weight_sync_group( ] await asyncio.gather(*refs) + async def setup_model_level_weight_sync_group(self): + """Setup process group for each model, only used in serve mode.""" + refs = [] + world_size = self.config.explorer.rollout_model.tensor_parallel_size + for model in self.models: + master_address, master_port = await model.get_available_address.remote() + self.logger.info( + f"Initialize process group for model weight synchronization, " + f"master_address={master_address}, master_port={master_port}, " + f"world_size={world_size}" + ) + refs.append( + model.init_process_group.remote( + master_address=master_address, + master_port=master_port, + rank_offset=0, + world_size=world_size, + group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, + explorer_name=self.config.explorer.name, + timeout=self.config.synchronizer.sync_timeout, + ) + ) + await asyncio.gather(*refs) + async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int: step_num = await self.synchronizer.set_model_state_dict_with_step_num.remote(step_num) await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models]) @@ -156,8 +180,14 @@ async def prepare(self) -> None: self.logger.info("All rollout models are ready.") if not self.use_nccl_sync: - master_address, master_port = await self.models[0].get_available_address.remote() - await self.setup_weight_sync_group(master_address, master_port) + if self.config.mode == "serve": + # In serving mode, each engine will setup its own process group + await self.setup_model_level_weight_sync_group() + else: + master_address, master_port = await self.models[ + 0 + ].get_available_address.remote() + await self.setup_weight_sync_group(master_address, master_port) if self.config.mode != "serve": self.scheduler = Scheduler(self.config, self.models, self.auxiliary_models) await self.scheduler.start() @@ -213,9 +243,11 @@ async def explore_step(self) -> bool: await self.save_checkpoint(sync_weight=False) await self.synchronizer.set_explorer_status.remote( RunningStatus.STOPPED, - old_status=RunningStatus.RUNNING - if self.last_sync_successful - else RunningStatus.REQUIRE_SYNC, + old_status=( + RunningStatus.RUNNING + if self.last_sync_successful + else RunningStatus.REQUIRE_SYNC + ), ) await self.shutdown() return False @@ -445,7 +477,8 @@ async def serve(self) -> None: self.monitor.log(metrics, self.explore_step_num) # get the latest checkpoint _, step_num = get_latest_state_dict( - self.config.checkpoint_job_dir, self.config.trainer.trainer_type, + self.config.checkpoint_job_dir, + self.config.trainer.trainer_type, ) self.service.set_latest_model_version(step_num) diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 415ebda945..94d11ae7f8 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -77,7 +77,7 @@ async def _check_modules(self) -> None: async def _find_latest_state_dict(self) -> None: assert self.config.trainer.trainer_type == "verl" - default_local_dir = self.config.trainer.trainer_config.trainer.default_local_dir + default_local_dir = self.config.checkpoint_job_dir local_latest_state_dict_iteration = os.path.join( default_local_dir, "latest_state_dict_iteration.txt" ) From fa6604aeb03bca0de58287c978791f3a52fca247 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 15 Oct 2025 18:00:01 +0800 Subject: [PATCH 06/33] fix tests --- tests/template/config.yaml | 4 ++-- tests/trainer/trainer_test.py | 24 +++++++++++++----------- trinity/common/models/__init__.py | 3 +++ trinity/explorer/api/api.py | 4 ++-- trinity/explorer/api/service.py | 13 ++++++++++++- trinity/explorer/explorer_client.py | 2 +- 6 files changed, 33 insertions(+), 17 deletions(-) diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 29d2806c03..1796d28284 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -16,8 +16,8 @@ model: max_response_tokens: 2048 max_model_len: 4096 cluster: # 2 for explorer, 2 for trainer - node_num: 2 - gpu_per_node: 2 + node_num: 1 + gpu_per_node: 4 buffer: total_epochs: 1 batch_size: 4 diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 0e2baf56d9..8f6cd8a33d 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -878,7 +878,8 @@ async def run_math_workflow(serve_url: str, task: dict): {"role": "user", "content": query}, ] - model = openai_client.models.list().data[0].id + models = await openai_client.models.list() + model = models.data[0].id response = await openai_client.chat.completions.create( model=model, @@ -901,6 +902,7 @@ async def test_serve_with_trainer(self): config.project = "unittest" config.name = f"serve_with_trainer_{datetime.now().strftime('%Y%m%d%H%M%S')}" config.checkpoint_root_dir = get_checkpoint_path() + config.model.model_path = get_model_path() config.buffer.batch_size = 4 config.algorithm.algorithm_type = "ppo" config.algorithm.repeat_times = 1 @@ -911,10 +913,11 @@ async def test_serve_with_trainer(self): config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") config.buffer.train_batch_size = 4 config.trainer.total_steps = 4 - config.trainer.save_interval = 2 + config.trainer.save_interval = 4 config.synchronizer.sync_interval = 2 config.synchronizer.sync_method = SyncMethod.CHECKPOINT config.explorer.rollout_model.engine_num = 2 + config.explorer.rollout_model.enable_openai_api = True config.explorer.rollout_model.tensor_parallel_size = 1 config.explorer.service_status_check_interval = 10 @@ -925,6 +928,13 @@ async def test_serve_with_trainer(self): trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,)) trainer_process.start() + await asyncio.sleep(10) + serve_config = deepcopy(config) + serve_config.mode = "serve" + serve_config.check_and_update() + serve_process = multiprocessing.Process(target=run_serve, args=(serve_config,)) + serve_process.start() + ray.init(ignore_reinit_error=True) while True: try: @@ -932,13 +942,7 @@ async def test_serve_with_trainer(self): break except ValueError: print("waiting for trainer to start.") - time.sleep(5) - - serve_config = deepcopy(config) - serve_config.mode = "serve" - serve_config.check_and_update() - serve_process = multiprocessing.Process(target=run_explorer, args=(serve_config,)) - serve_process.start() + await asyncio.sleep(5) state_manager = StateManager( path=serve_config.checkpoint_job_dir, @@ -979,8 +983,6 @@ async def test_serve_with_trainer(self): while time.time() - end_time < config.explorer.service_status_check_interval: await asyncio.sleep(1) - # check for trainer new checkpoint - class TestMultiModalGRPO(BaseTrainerCase): @unittest.skip("Require specific vllm/transformers version") diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 9b12b48340..9a84f99b3f 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -84,6 +84,9 @@ def create_inference_models( allocator = _BundleAllocator(node_bundle_map) namespace = ray.get_runtime_context().namespace # create rollout models + # in 'serve' mode, we always enable openai api for rollout model + if config.mode == "serve": + config.explorer.rollout_model.enable_openai_api = True for i in range(config.explorer.rollout_model.engine_num): bundles_for_engine = allocator.allocate(config.explorer.rollout_model.tensor_parallel_size) config.explorer.rollout_model.bundle_indices = ",".join( diff --git a/trinity/explorer/api/api.py b/trinity/explorer/api/api.py index b84a3c0680..92a40f16d2 100644 --- a/trinity/explorer/api/api.py +++ b/trinity/explorer/api/api.py @@ -35,10 +35,10 @@ async def chat_completions(request: Request): async def show_available_models(request: Request): if hasattr(request.app.state, "models"): return JSONResponse(content=request.app.state.models) - body = await request.json() url = await request.app.state.service.allocate_model(increase_count=False) async with httpx.AsyncClient() as client: - resp = await client.get(f"{url}/v1/models", json=body) + print(f"Fetching models from {url}/v1/models") + resp = await client.get(f"{url}/v1/models") request.app.state.models = resp.json() return JSONResponse(content=resp.json()) diff --git a/trinity/explorer/api/service.py b/trinity/explorer/api/service.py index c910eb093e..3d348de248 100644 --- a/trinity/explorer/api/service.py +++ b/trinity/explorer/api/service.py @@ -108,6 +108,10 @@ async def allocate_model(self, increase_count: bool = True) -> str: if increase_count: model.request_count += 1 self.running_models.rotate(-1) + if model.api_address is None: + raise ValueError( + "Model does not have a valid API address, please set `enable_openai_api` to `True`." + ) return model.api_address def collect_metrics(self) -> Dict: @@ -136,7 +140,14 @@ async def record_experience(self, response, session_id: Optional[int] = None): torch.tensor(choice["token_ids"], dtype=torch.int32), ) ), - logprobs=choice.get("logprobs", None), + logprobs=( + torch.tensor( + [logprob["logprob"] for logprob in choice["logprobs"]["content"]], + dtype=torch.float32, + ) + if "logprobs" in choice and choice["logprobs"] is not None + else torch.tensor([], dtype=torch.float32) + ), prompt_length=len(response["prompt_token_ids"]), response_text=choice.get("message", {}).get("content", ""), ) diff --git a/trinity/explorer/explorer_client.py b/trinity/explorer/explorer_client.py index a7ca0f5e85..3d44d0dd67 100644 --- a/trinity/explorer/explorer_client.py +++ b/trinity/explorer/explorer_client.py @@ -13,7 +13,7 @@ def __init__(self, explorer_api_url: str): self.session_id = self.init_session() def init_session(self) -> str: - response = requests.post(f"{self.explorer_api_url}/allocate") + response = requests.get(f"{self.explorer_api_url}/allocate") data = response.json() return data["session_id"] From ef7278ed99016b6f1ae94a76369d826562d1ee49 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 15 Oct 2025 19:43:49 +0800 Subject: [PATCH 07/33] clean code --- trinity/common/models/vllm_worker.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index f1e5833a8f..93d9c0bd48 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -88,7 +88,3 @@ def update_weight(self): torch.distributed.barrier(group=self._model_update_group) torch.cuda.synchronize() torch.cuda.empty_cache() - - def update_weight_from_checkpoint(self): - """Update weight from checkpoint without broadcasting""" - pass From be94f13d096166def7cd09a6fb6bffb758f31458 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 16 Dec 2025 16:29:25 +0800 Subject: [PATCH 08/33] fix serve tests --- tests/trainer/trainer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 4c5f524394..9ec7a937c4 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -959,7 +959,7 @@ async def test_serve_with_trainer(self): config.algorithm.repeat_times = 1 config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", - storage_type=StorageType.SQL, + storage_type=StorageType.SQL.value(), schema_type="experience", ) config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") From ce01983d741ca495c126448ccbbf2706590021c1 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 16 Dec 2025 17:20:47 +0800 Subject: [PATCH 09/33] fix tests --- tests/trainer/trainer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 9ec7a937c4..3ed10bcd96 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -959,7 +959,7 @@ async def test_serve_with_trainer(self): config.algorithm.repeat_times = 1 config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", - storage_type=StorageType.SQL.value(), + storage_type=StorageType.SQL.value, schema_type="experience", ) config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") From 0552b8d2de10117c17006b8695be7afbb50bbc8f Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 17 Dec 2025 22:01:22 +0800 Subject: [PATCH 10/33] add recorder --- tests/explorer/explorer_test.py | 14 ++-- tests/trainer/trainer_test.py | 2 +- trinity/buffer/schema/sql_schema.py | 39 ++++++--- trinity/buffer/utils.py | 2 +- trinity/common/config.py | 7 +- trinity/explorer/explorer.py | 4 +- trinity/explorer/explorer_client.py | 51 ------------ trinity/explorer/{api => proxy}/__init__.py | 0 trinity/explorer/{api/api.py => proxy/app.py} | 42 ++++++---- trinity/explorer/proxy/client.py | 41 ++++++++++ trinity/explorer/proxy/recorder.py | 53 ++++++++++++ trinity/explorer/{api => proxy}/service.py | 81 +++++++------------ 12 files changed, 197 insertions(+), 139 deletions(-) delete mode 100644 trinity/explorer/explorer_client.py rename trinity/explorer/{api => proxy}/__init__.py (100%) rename trinity/explorer/{api/api.py => proxy/app.py} (71%) create mode 100644 trinity/explorer/proxy/client.py create mode 100644 trinity/explorer/proxy/recorder.py rename trinity/explorer/{api => proxy}/service.py (75%) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index c061099437..4f29bad501 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -8,7 +8,6 @@ from datetime import datetime import httpx -import openai import ray from tests.tools import ( @@ -26,6 +25,7 @@ from trinity.common.config import ExperienceBufferConfig, InferenceModelConfig from trinity.common.constants import StorageType from trinity.explorer.explorer import Explorer +from trinity.explorer.proxy.client import ProxyClient from trinity.manager.state_manager import StateManager @@ -157,8 +157,9 @@ def run_serve(config): run_stage(config) -def run_agent(base_url, model_path: str): - client = openai.Client(base_url=base_url, api_key="testkey") +def run_agent(proxy_url, model_path: str): + proxy_client = ProxyClient(proxy_url=proxy_url) + openai_client = proxy_client.get_openai_client() contents = [ "Hello, how are you?", "What is the capital of China?", @@ -171,10 +172,11 @@ def run_agent(base_url, model_path: str): "What is the best way to learn programming?", "Describe the process of photosynthesis.", ] - response = client.chat.completions.create( + response = openai_client.chat.completions.create( model=model_path, messages=[{"role": "user", "content": random.choice(contents)}], ) + proxy_client.feedback(reward=1.0, msg_ids=[response.id]) return response.choices[0].message.content @@ -190,7 +192,7 @@ def setUp(self): self.config.explorer.rollout_model.engine_num = 4 self.config.explorer.rollout_model.enable_openai_api = True self.config.checkpoint_root_dir = get_checkpoint_path() - self.config.explorer.api_port = 8010 + self.config.explorer.proxy_port = 8010 self.config.explorer.service_status_check_interval = 30 self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="experience_buffer", @@ -236,7 +238,7 @@ async def test_serve(self): # noqa: C901 apps = [] for i in range(task_num): app_process = multiprocessing.Process( - target=run_agent, args=(server_url + "/v1", self.config.model.model_path) + target=run_agent, args=(server_url, self.config.model.model_path) ) apps.append(app_process) app_process.start() diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 3ed10bcd96..4deef63bf2 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1023,7 +1023,7 @@ async def test_serve_with_trainer(self): pass await asyncio.sleep(2) - reader = get_buffer_reader(serve_config.buffer.explorer_input.taskset, serve_config.buffer) + reader = get_buffer_reader(serve_config.buffer.explorer_input.taskset) for i in range(2): # generate data for 2 trainer steps diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index e0df0b1e8e..2329f74973 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -2,7 +2,17 @@ from typing import Dict, Optional, Tuple -from sqlalchemy import JSON, Column, Float, Integer, LargeBinary, Text, create_engine +from sqlalchemy import ( + JSON, + Column, + DateTime, + Float, + Integer, + LargeBinary, + String, + create_engine, + func, +) from sqlalchemy.exc import OperationalError from sqlalchemy.orm import declarative_base from sqlalchemy.pool import NullPool @@ -37,19 +47,23 @@ class ExperienceModel(Base): # type: ignore __abstract__ = True id = Column(Integer, primary_key=True, autoincrement=True) - # for single turn - prompt = Column(Text, nullable=True) - response = Column(Text, nullable=True) - # for multi turn - message_list = Column(JSON, nullable=True) - reward = Column(Float, nullable=True) + timestamp = Column(DateTime, server_default=func.now()) + task_id = Column(String(64), nullable=True, index=True) # associated task id + run_id = Column(Integer, nullable=True, index=True) # associated run id + msg_id = Column(String(64), nullable=True, index=True) # associated message id # serialized experience object experience_bytes = Column(LargeBinary, nullable=True) + reward = Column(Float, nullable=True) consumed = Column(Integer, default=0, index=True) def to_experience(self) -> Experience: """Load the experience from the database.""" - return Experience.deserialize(self.experience_bytes) + exp = Experience.deserialize(self.experience_bytes) + exp.eid.task = self.task_id + exp.eid.run = self.run_id + exp.eid.suffix = self.msg_id + exp.reward = self.reward + return exp @classmethod def from_experience(cls, experience: Experience): @@ -57,9 +71,9 @@ def from_experience(cls, experience: Experience): return cls( experience_bytes=experience.serialize(), reward=experience.reward, - prompt=experience.prompt_text, - response=experience.response_text, - message_list=experience.messages, + task_id=str(experience.eid.task), + run_id=experience.eid.run, + msg_id=str(experience.eid.suffix), ) @@ -111,7 +125,7 @@ def from_experience(cls, experience: Experience): ) -def init_engine(db_url: str, table_name, schema_type: Optional[str]) -> Tuple: +def init_engine(db_url: str, table_name: str, schema_type: Optional[str]) -> Tuple: """Get the sqlalchemy engine.""" logger = get_logger(__name__) engine = create_engine(db_url, poolclass=NullPool) @@ -130,6 +144,7 @@ def init_engine(db_url: str, table_name, schema_type: Optional[str]) -> Tuple: try: Base.metadata.create_all(engine, checkfirst=True) + logger.info(f"Created table {table_name} for schema type {schema_type}.") except OperationalError: logger.warning(f"Failed to create table {table_name}, assuming it already exists.") diff --git a/trinity/buffer/utils.py b/trinity/buffer/utils.py index b27ba662bc..689a8cce3d 100644 --- a/trinity/buffer/utils.py +++ b/trinity/buffer/utils.py @@ -5,7 +5,7 @@ @contextmanager -def retry_session(session_maker, max_retry_times: int, max_retry_interval: float): +def retry_session(session_maker, max_retry_times: int = 2, max_retry_interval: float = 1.0): """A Context manager for retrying session.""" logger = get_logger(__name__) for attempt in range(max_retry_times): diff --git a/trinity/common/config.py b/trinity/common/config.py index f30893a251..8b8cbd241d 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -669,14 +669,17 @@ class ExplorerConfig: # for benchmark bench_on_latest_checkpoint: bool = False # only benchmark the latest checkpoint - # for serve mode - api_port: int = 8010 + # for serve mode proxy + proxy_port: int = 8010 # listen on all interfaces by default listen_address: str = "0.0.0.0" # check the running status of the server every 60 seconds service_status_check_interval: int = 60 # keep at least 1 model in running status min_running_model_num: int = 1 + # db url for proxy history recorder, if not set, use proxy_history.db in buffer cache dir + db_url: Optional[str] = None + # Experimental feature over_rollout: OverRolloutConfig = field(default_factory=OverRolloutConfig) dynamic_timeout: DynamicTimeoutConfig = field(default_factory=DynamicTimeoutConfig) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 98b2251d67..583ebd90f5 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -494,12 +494,12 @@ async def serve(self) -> None: messages=[{"role": "user", "content": "Hello!"}] ) """ - from trinity.explorer.api.service import ExplorerService + from trinity.explorer.proxy.service import ExplorerService self.service = ExplorerService( self, listen_address=self.config.explorer.listen_address, - port=self.config.explorer.api_port, + port=self.config.explorer.proxy_port, ) await self.service.serve() self.server_url = f"http://{ray.util.get_node_ip_address()}:{self.service.port}" diff --git a/trinity/explorer/explorer_client.py b/trinity/explorer/explorer_client.py deleted file mode 100644 index 3d44d0dd67..0000000000 --- a/trinity/explorer/explorer_client.py +++ /dev/null @@ -1,51 +0,0 @@ -from functools import partial - -import httpx -import openai -import requests - - -class ExplorerClient: - def __init__(self, explorer_api_url: str): - self.explorer_api_url = explorer_api_url - self.openai_base_url = f"{self.explorer_api_url}/v1" - self.feedback_url = f"{self.explorer_api_url}/feedback" - self.session_id = self.init_session() - - def init_session(self) -> str: - response = requests.get(f"{self.explorer_api_url}/allocate") - data = response.json() - return data["session_id"] - - def get_openai_client(self) -> openai.OpenAI: - client = openai.OpenAI( - base_url=self.openai_base_url, - api_key="EMPTY", - ) - client.chat.completions.create = partial( - client.chat.completions.create, extra_body={"session_id": self.session_id} - ) - return client - - def get_openai_async_client(self) -> openai.AsyncOpenAI: - client = openai.AsyncOpenAI( - base_url=self.openai_base_url, - api_key="EMPTY", - ) - client.chat.completions.create = partial( - client.chat.completions.create, extra_body={"session_id": self.session_id} - ) - return client - - def feedback(self, reward: float) -> dict: - response = requests.post( - self.feedback_url, json={"session_id": self.session_id, "reward": reward} - ) - return response.json() - - async def feedback_async(self, reward: float) -> dict: - async with httpx.AsyncClient() as client: - response = await client.post( - self.feedback_url, json={"session_id": self.session_id, "reward": reward} - ) - return response.json() diff --git a/trinity/explorer/api/__init__.py b/trinity/explorer/proxy/__init__.py similarity index 100% rename from trinity/explorer/api/__init__.py rename to trinity/explorer/proxy/__init__.py diff --git a/trinity/explorer/api/api.py b/trinity/explorer/proxy/app.py similarity index 71% rename from trinity/explorer/api/api.py rename to trinity/explorer/proxy/app.py index c35e9d4067..e323040020 100644 --- a/trinity/explorer/api/api.py +++ b/trinity/explorer/proxy/app.py @@ -1,20 +1,34 @@ import traceback +from contextlib import asynccontextmanager import httpx import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response -app = FastAPI() +http_client: httpx.AsyncClient = None -# Forward openAI requests to a model instance +@asynccontextmanager +async def lifespan(app: FastAPI): + global http_client + http_client = httpx.AsyncClient( + timeout=httpx.Timeout(300.0, connect=10.0), + limits=httpx.Limits(max_keepalive_connections=20, max_connections=100), + ) + yield + await http_client.aclose() + +app = FastAPI(lifespan=lifespan) + +# Forward OpenAI requests to a model instance @app.post("/v1/chat/completions") async def chat_completions(request: Request): # Currently, we do not support streaming chat completions body = await request.json() + if "return_token_ids" not in body: body["return_token_ids"] = True url = await request.app.state.service.allocate_model() @@ -27,9 +41,7 @@ async def chat_completions(request: Request): content=f"Error forwarding request to model at {url}: {traceback.format_exc()}", ) resp_data = resp.json() - await request.app.state.service.record_experience( - resp_data, session_id=body.get("session_id", None) - ) + await request.app.state.service.record_experience(resp_data) return JSONResponse(content=resp_data) @@ -69,17 +81,19 @@ async def allocate(request: Request): async def feedback(request: Request): """Receive feedback for the current session.""" body = await request.json() - session_id = body.get("session_id", None) - reward = body.get("reward", None) - if session_id is None or reward is None: - return JSONResponse( - status_code=400, content={"error": "session_id and reward are required"} - ) - if not isinstance(session_id, int) or not isinstance(reward, (int, float)): + reward = body.get("reward") + msg_ids = body.get("msg_ids") + task_id = body.get("task_id") + run_id = body.get("run_id", 0) + if msg_ids is None or reward is None: + return JSONResponse(status_code=400, content={"error": "msg_ids and reward are required"}) + if not isinstance(msg_ids, list) or not isinstance(reward, (int, float)): return JSONResponse( - status_code=400, content={"error": "session_id must be int and reward must be float"} + status_code=400, content={"error": "msg_ids must be a list and reward must be a number"} ) - await request.app.state.service.record_feedback(session_id, reward) + await request.app.state.service.record_feedback( + reward=reward, msg_ids=msg_ids, task_id=task_id, run_id=run_id + ) return JSONResponse(content={"status": "success"}) diff --git a/trinity/explorer/proxy/client.py b/trinity/explorer/proxy/client.py new file mode 100644 index 0000000000..43e8de0cab --- /dev/null +++ b/trinity/explorer/proxy/client.py @@ -0,0 +1,41 @@ +import uuid + +import httpx +import openai +import requests + + +class ProxyClient: + def __init__(self, proxy_url: str): + self.proxy_url = proxy_url + self.openai_base_url = f"{self.proxy_url}/v1" + self.feedback_url = f"{self.proxy_url}/feedback" + self.task_id = uuid.uuid4().hex[:6] + + def get_openai_client(self) -> openai.OpenAI: + client = openai.OpenAI( + base_url=self.openai_base_url, + api_key="EMPTY", + ) + return client + + def get_openai_async_client(self) -> openai.AsyncOpenAI: + client = openai.AsyncOpenAI( + base_url=self.openai_base_url, + api_key="EMPTY", + ) + return client + + def feedback(self, reward: float, msg_ids: list[str]) -> dict: + response = requests.post( + self.feedback_url, json={"reward": reward, "msg_ids": msg_ids, "task_id": self.task_id} + ) + return response.json() + + async def feedback_async(self, reward: float, msg_ids: list[str]) -> dict: + async with httpx.AsyncClient() as client: + response = await client.post( + self.feedback_url, + json={"reward": reward, "msg_ids": msg_ids, "task_id": self.task_id}, + ) + return response.json() diff --git a/trinity/explorer/proxy/recorder.py b/trinity/explorer/proxy/recorder.py new file mode 100644 index 0000000000..713f41b6ef --- /dev/null +++ b/trinity/explorer/proxy/recorder.py @@ -0,0 +1,53 @@ +from typing import List + +from sqlalchemy.orm import sessionmaker + +from trinity.buffer.schema import init_engine +from trinity.buffer.utils import retry_session +from trinity.common.experience import Experience +from trinity.utils.log import get_logger + + +class HistoryRecorder: + """Record chat history into the database.""" + + def __init__(self, db_url: str, table_name: str): + self.logger = get_logger() + self.engine, self.table_model_cls = init_engine( + db_url=db_url, + table_name=table_name, + schema_type="experience", + ) + self.logger.info(f"Init SQL storage at {db_url}") + self.session = sessionmaker(bind=self.engine) + + def record_history(self, experiences: List[Experience]) -> None: + """Save experience to the database.""" + with retry_session(self.session) as db: + exps = [self.table_model_cls.from_experience(exp) for exp in experiences] + db.add_all(exps) + + def update_reward( + self, reward: float, msg_ids: list, run_id: int, task_id: str + ) -> List[Experience]: + """Update reward for given response IDs and return the updated experiences.""" + with retry_session(self.session) as db: + db.execute( + self.table_model_cls.__table__.update() + .where((self.table_model_cls.msg_id.in_(msg_ids))) + .values( + reward=reward, + run_id=run_id, + task_id=task_id, + consumed=self.table_model_cls.consumed + 1, + ) + ) + + with retry_session(self.session) as db: + results = db.execute( + self.table_model_cls.__table__.select().where( + (self.table_model_cls.msg_id.in_(msg_ids) & self.table_model_cls.consumed == 1) + ) + ).all() + updated_experiences = [self.table_model_cls.to_experience(row) for row in results] + return updated_experiences diff --git a/trinity/explorer/api/service.py b/trinity/explorer/proxy/service.py similarity index 75% rename from trinity/explorer/api/service.py rename to trinity/explorer/proxy/service.py index 3d348de248..b695ef588c 100644 --- a/trinity/explorer/api/service.py +++ b/trinity/explorer/proxy/service.py @@ -1,7 +1,7 @@ import asyncio import time from collections import deque -from typing import Dict, List, Optional +from typing import Dict, List import torch @@ -9,6 +9,7 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.explorer.explorer import Explorer +from trinity.explorer.proxy.recorder import HistoryRecorder from trinity.utils.log import get_logger @@ -29,14 +30,19 @@ def __init__(self, explorer: Explorer, listen_address: str = "localhost", port: self.running_models: deque[int] = deque() # indices of running models self.sync_task_map: Dict[asyncio.Future, int] = {} # sync task -> model index self.latest_model_version = 0 - self.experience_queue: deque[Experience] = deque() self.session_level_experience_queue: Dict[int, deque[Experience]] = {} self.queue_lock = asyncio.Lock() - self.experience_count = 0 - self.session_count = 0 + self.ready_experiences = deque() + self.recorder = HistoryRecorder( + db_url=explorer.config.explorer.db_url + or f"sqlite:///{explorer.config.buffer.cache_dir}/proxy_history.db", + table_name="proxy_history", + ) + self.total_experience_count = 0 + self.ready_experience_count = 0 async def serve(self) -> None: - from trinity.explorer.api.api import run_app + from trinity.explorer.proxy.app import run_app if self.running: self.logger.warning("Server is already running.") @@ -119,18 +125,11 @@ def collect_metrics(self) -> Dict: for i, model in enumerate(self.models): metrics[f"rollout/model_{i}/total_request_count"] = model.request_count metrics[f"rollout/model_{i}/model_version"] = model.model_version - metrics["rollout/total_experience_count"] = self.experience_count + metrics["rollout/total_experience_count"] = self.total_experience_count + metrics["rollout/ready_experience_count"] = self.ready_experience_count return metrics - async def check_requiring_sync_models(self): - if not self.running: - self.logger.warning("Server is not running.") - return - await asyncio.gather( - *[self._sync_model_weights(idx) for idx in list(self.requiring_sync_models)] - ) - - async def record_experience(self, response, session_id: Optional[int] = None): + async def record_experience(self, response): experiences = [] for choice in response["choices"]: exp = Experience( @@ -149,45 +148,27 @@ async def record_experience(self, response, session_id: Optional[int] = None): else torch.tensor([], dtype=torch.float32) ), prompt_length=len(response["prompt_token_ids"]), - response_text=choice.get("message", {}).get("content", ""), ) - if session_id is not None: - exp.eid.task = session_id + exp.eid.suffix = response["id"] experiences.append(exp) - self.experience_count += len(experiences) - - # Store experiences in session-level queue if session_id is provided - if session_id is not None: - async with self.queue_lock: - if session_id not in self.session_level_experience_queue: - self.session_level_experience_queue[session_id] = deque() - self.session_level_experience_queue[session_id].extend(experiences) - else: - async with self.queue_lock: - self.experience_queue.extend(experiences) + + self.total_experience_count += len(experiences) + self.recorder.record_history(experiences) async def get_all_experiences(self) -> List: - async with self.queue_lock: - experiences = list(self.experience_queue) - self.experience_queue.clear() - return experiences - - def allocate_session(self) -> int: - self.session_count += 1 - return self.session_count - - async def record_feedback(self, session_id: int, reward: float): - exps = [] - async with self.queue_lock: - if session_id in self.session_level_experience_queue: - exps = list(self.session_level_experience_queue.pop(session_id)) - if not exps: - self.logger.warning(f"No experiences found for session_id {session_id}.") - return - for exp in exps: - exp.reward = reward - async with self.queue_lock: - self.experience_queue.extend(exps) + experiences = list(self.ready_experiences) + self.ready_experiences.clear() + return experiences + + async def record_feedback(self, reward: float, msg_ids: List[str], task_id: str, run_id: int): + exps = self.recorder.update_reward( + reward=reward, + msg_ids=msg_ids, + task_id=task_id, + run_id=run_id, + ) + self.ready_experience_count += len(exps) + self.ready_experiences.extend(exps) async def shutdown(self): if not self.running: From 27eee072d01bf1a0e8f477b61847e61569b9df43 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 18 Dec 2025 10:21:55 +0800 Subject: [PATCH 11/33] fix proxy --- tests/explorer/explorer_test.py | 5 +++- trinity/explorer/proxy/app.py | 25 +++++++++++++++----- trinity/explorer/proxy/service.py | 39 ++++++++++++++++++++----------- 3 files changed, 49 insertions(+), 20 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 4f29bad501..64b461bc62 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -176,7 +176,7 @@ def run_agent(proxy_url, model_path: str): model=model_path, messages=[{"role": "user", "content": random.choice(contents)}], ) - proxy_client.feedback(reward=1.0, msg_ids=[response.id]) + proxy_client.feedback(reward=2.0, msg_ids=[response.id]) return response.choices[0].message.content @@ -277,6 +277,9 @@ async def test_serve(self): # noqa: C901 exps = await buffer_reader.read_async(batch_size=10) for exp in exps: self.assertTrue(len(exp.tokens) > 0) + self.assertTrue(len(exp.logprobs) > 0) + self.assertTrue(exp.prompt_length > 0) + self.assertTrue(exp.reward == 2.0) self.assertEqual(len(exps), task_num) def tearDown(self): diff --git a/trinity/explorer/proxy/app.py b/trinity/explorer/proxy/app.py index e323040020..7fc815d732 100644 --- a/trinity/explorer/proxy/app.py +++ b/trinity/explorer/proxy/app.py @@ -3,7 +3,7 @@ import httpx import uvicorn -from fastapi import FastAPI, Request +from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, Response http_client: httpx.AsyncClient = None @@ -27,14 +27,27 @@ async def lifespan(app: FastAPI): @app.post("/v1/chat/completions") async def chat_completions(request: Request): # Currently, we do not support streaming chat completions - body = await request.json() - - if "return_token_ids" not in body: - body["return_token_ids"] = True + try: + request_data = await request.json() + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") + + forward_headers = { + key: value + for key, value in request.headers.items() + if key.lower() not in ["host", "content-length", "transfer-encoding"] + } + + if "return_token_ids" not in request_data: + request_data["return_token_ids"] = True + if "logprobs" not in request_data: + request_data["logprobs"] = True url = await request.app.state.service.allocate_model() try: async with httpx.AsyncClient(timeout=request.app.state.inference_timeout) as client: - resp = await client.post(f"{url}/v1/chat/completions", json=body) + resp = await client.post( + f"{url}/v1/chat/completions", json=request_data, headers=forward_headers + ) except Exception: return Response( status_code=500, diff --git a/trinity/explorer/proxy/service.py b/trinity/explorer/proxy/service.py index b695ef588c..ac7294a0dd 100644 --- a/trinity/explorer/proxy/service.py +++ b/trinity/explorer/proxy/service.py @@ -1,7 +1,7 @@ import asyncio import time from collections import deque -from typing import Dict, List +from typing import Dict, List, Tuple import torch @@ -27,7 +27,8 @@ def __init__(self, explorer: Explorer, listen_address: str = "localhost", port: self.min_running_model_num = explorer.config.explorer.min_running_model_num self.check_interval = explorer.config.explorer.service_status_check_interval self.max_timeout = explorer.config.explorer.max_timeout - self.running_models: deque[int] = deque() # indices of running models + self.running_model_ids: deque[int] = deque() # indices of running models + self.model_version_map: Dict[int, int] = {} # model index -> model version self.sync_task_map: Dict[asyncio.Future, int] = {} # sync task -> model index self.latest_model_version = 0 self.session_level_experience_queue: Dict[int, deque[Experience]] = {} @@ -52,7 +53,7 @@ async def serve(self) -> None: await asyncio.gather(*[model.prepare() for model in self.models]) for i, _ in enumerate(self.models): - self.running_models.append(i) + self.running_model_ids.append(i) self.serve_task = asyncio.create_task( run_app(service=self, listen_address=self.listen_address, port=self.port) @@ -62,12 +63,14 @@ async def serve(self) -> None: async def model_weights_sync_loop(self) -> None: self.logger.info("Starting model weights synchronization loop.") while self.running: - for idx in list(self.running_models): + for idx in list(self.running_model_ids): + self.model_version_map[idx] = await self.models[idx].model_version_async if ( - len(self.running_models) > self.explorer.config.explorer.min_running_model_num - and self.models[idx].model_version < self.latest_model_version + len(self.running_model_ids) + > self.explorer.config.explorer.min_running_model_num + and self.model_version_map[idx] < self.latest_model_version ): - self.running_models.remove(idx) + self.running_model_ids.remove(idx) self.models[idx].status = RunningStatus.REQUIRE_SYNC self.logger.info(f"Model {idx} scheduled for synchronization.") future = asyncio.create_task(self._wait_for_sync_start(idx)) @@ -83,6 +86,7 @@ def set_latest_model_version(self, version: int) -> None: self.logger.info(f"Updated latest model version to {version}.") async def _wait_for_sync_start(self, index: int) -> None: + """Wait until the model is free to start synchronization.""" start_time = time.time() while time.time() - start_time < self.max_timeout: current_load = await self.models[index].get_current_load() @@ -97,6 +101,7 @@ async def _wait_for_sync_start(self, index: int) -> None: ) async def _sync_model_weights(self, task: asyncio.Future) -> None: + """A callback to synchronize model weights after waiting.""" index = self.sync_task_map.pop(task) latest_version = self.latest_model_version # capture the latest version if task.cancelled(): @@ -106,19 +111,26 @@ async def _sync_model_weights(self, task: asyncio.Future) -> None: else: await self.models[index].sync_model_weights(latest_version) self.logger.info(f"Model {index} synchronized to version {latest_version}.") - self.running_models.append(index) + self.model_version_map[index] = await self.models[index].model_version_async + self.running_model_ids.append(index) self.models[index].status = RunningStatus.RUNNING - async def allocate_model(self, increase_count: bool = True) -> str: - model = self.models[self.running_models[0]] + async def allocate_model(self, increase_count: bool = True) -> Tuple[str, int]: + """Allocate a model for handling a request. + + Returns: + A tuple of (model_api_address, model_version). + """ + model_id = self.running_model_ids[0] + model = self.models[model_id] if increase_count: model.request_count += 1 - self.running_models.rotate(-1) + self.running_model_ids.rotate(-1) if model.api_address is None: raise ValueError( "Model does not have a valid API address, please set `enable_openai_api` to `True`." ) - return model.api_address + return model.api_address, self.model_version_map[model_id] def collect_metrics(self) -> Dict: metrics = {} @@ -129,7 +141,7 @@ def collect_metrics(self) -> Dict: metrics["rollout/ready_experience_count"] = self.ready_experience_count return metrics - async def record_experience(self, response): + async def record_experience(self, response, model_version: int) -> None: experiences = [] for choice in response["choices"]: exp = Experience( @@ -150,6 +162,7 @@ async def record_experience(self, response): prompt_length=len(response["prompt_token_ids"]), ) exp.eid.suffix = response["id"] + exp.info["model_version"] = model_version experiences.append(exp) self.total_experience_count += len(experiences) From 7fa496c1321502be5ce0958b505e5c30eca74dba Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 18 Dec 2025 10:28:22 +0800 Subject: [PATCH 12/33] fix server --- trinity/explorer/proxy/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trinity/explorer/proxy/app.py b/trinity/explorer/proxy/app.py index 7fc815d732..b9f9536a17 100644 --- a/trinity/explorer/proxy/app.py +++ b/trinity/explorer/proxy/app.py @@ -42,7 +42,7 @@ async def chat_completions(request: Request): request_data["return_token_ids"] = True if "logprobs" not in request_data: request_data["logprobs"] = True - url = await request.app.state.service.allocate_model() + url, model_version = await request.app.state.service.allocate_model() try: async with httpx.AsyncClient(timeout=request.app.state.inference_timeout) as client: resp = await client.post( @@ -54,7 +54,7 @@ async def chat_completions(request: Request): content=f"Error forwarding request to model at {url}: {traceback.format_exc()}", ) resp_data = resp.json() - await request.app.state.service.record_experience(resp_data) + await request.app.state.service.record_experience(resp_data, model_version) return JSONResponse(content=resp_data) From 02ce2dc66133cb15c143a4cccacba5075dc921d8 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 18 Dec 2025 11:44:59 +0800 Subject: [PATCH 13/33] fix serve trainer tset --- tests/explorer/explorer_test.py | 28 +++++++++++++--------------- tests/trainer/trainer_test.py | 20 ++++++++++++++------ trinity/explorer/explorer.py | 6 ------ trinity/explorer/proxy/app.py | 9 ++++++++- trinity/explorer/proxy/client.py | 20 ++++++++++++++++++++ trinity/explorer/proxy/service.py | 14 +++++++++----- 6 files changed, 64 insertions(+), 33 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 64b461bc62..e037b1ceb4 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -248,22 +248,20 @@ async def test_serve(self): # noqa: C901 self.assertFalse(app.is_alive()) finish_step = None - + proxy_client = ProxyClient(proxy_url=server_url) for i in range(20): - async with httpx.AsyncClient() as client: - response = await client.get(f"{server_url}/metrics") - self.assertEqual(response.status_code, 200) - metrics = response.json() - metrics_keys = list(metrics.keys()) - self.assertIn("explore_step_num", metrics_keys) - self.assertIn("rollout/total_experience_count", metrics_keys) - self.assertIn("rollout/model_0/total_request_count", metrics_keys) - self.assertIn("rollout/model_3/model_version", metrics_keys) - if not finish_step and metrics["rollout/total_experience_count"] == task_num: - finish_step = metrics["explore_step_num"] - if finish_step and metrics["explore_step_num"] >= finish_step + 1: - # wait for one more step to ensure all data are written to buffer - break + metrics = await proxy_client.get_metrics_async() + metrics_keys = list(metrics.keys()) + self.assertIn("explore_step_num", metrics_keys) + self.assertIn("rollout/total_experience_count", metrics_keys) + self.assertIn("rollout/model_0/total_request_count", metrics_keys) + self.assertIn("rollout/model_3/model_version", metrics_keys) + if not finish_step and metrics["rollout/total_experience_count"] == task_num: + finish_step = metrics["explore_step_num"] + await proxy_client.commit_async() + if finish_step and metrics["explore_step_num"] >= finish_step + 1: + # wait for one more step to ensure all data are written to buffer + break await asyncio.sleep(3) serve_process.terminate() diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 4deef63bf2..a42a5c9c7f 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -45,7 +45,7 @@ SyncStyle, ) from trinity.common.models.utils import get_checkpoint_dir_with_step_num -from trinity.explorer.explorer_client import ExplorerClient +from trinity.explorer.proxy.client import ProxyClient from trinity.manager.state_manager import StateManager @@ -912,8 +912,8 @@ def tearDown(self): async def run_math_workflow(serve_url: str, task: dict): from trinity.common.rewards.math_reward import MathRewardFn - explorer_client = ExplorerClient(serve_url) - openai_client = explorer_client.get_openai_async_client() + proxy_client = ProxyClient(serve_url) + openai_client = proxy_client.get_openai_async_client() query = task["question"] truth = task["answer"] @@ -938,7 +938,7 @@ async def run_math_workflow(serve_url: str, task: dict): ) answer = response.choices[0].message.content reward = reward_fn(response=answer, truth=truth, prompt=query) - await explorer_client.feedback_async(sum(reward.values())) + await proxy_client.feedback_async(sum(reward.values()), [response.id]) class TestServeWithTrainer(unittest.IsolatedAsyncioTestCase): @@ -957,6 +957,8 @@ async def test_serve_with_trainer(self): config.buffer.batch_size = 4 config.algorithm.algorithm_type = "ppo" config.algorithm.repeat_times = 1 + config.cluster.gpu_per_node = 2 + config.cluster.node_num = 1 config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", storage_type=StorageType.SQL.value, @@ -971,7 +973,7 @@ async def test_serve_with_trainer(self): config.explorer.rollout_model.engine_num = 2 config.explorer.rollout_model.enable_openai_api = True config.explorer.rollout_model.tensor_parallel_size = 1 - config.explorer.service_status_check_interval = 10 + config.explorer.service_status_check_interval = 5 trainer_config = deepcopy(config) trainer_config.mode = "train" @@ -1025,16 +1027,22 @@ async def test_serve_with_trainer(self): reader = get_buffer_reader(serve_config.buffer.explorer_input.taskset) + proxy_client = ProxyClient(server_url) for i in range(2): # generate data for 2 trainer steps tasks = reader.read(batch_size=8) await asyncio.gather(*(run_math_workflow(server_url, task.raw_task) for task in tasks)) - + await proxy_client.commit_async() # wait for synchronizer started end_time = time.time() while time.time() - end_time < config.explorer.service_status_check_interval: await asyncio.sleep(1) + serve_process.terminate() + trainer_process.terminate() + serve_process.join() + trainer_process.join() + class TestMultiModalGRPO(BaseTrainerCase): @unittest.skip("Require specific vllm/transformers version") diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 583ebd90f5..688a2cb29d 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -508,13 +508,7 @@ async def serve(self) -> None: ) self.state.save_explorer_server_url(self.server_url) while True: - self.explore_step_num += 1 await asyncio.sleep(self.config.explorer.service_status_check_interval) - # process experiences generated in the last interval - exps = await self.service.get_all_experiences() - metrics = await self.experience_pipeline.process.remote(exps) - metrics.update(self.service.collect_metrics()) - self.monitor.log(metrics, self.explore_step_num) # get the latest checkpoint _, step_num = get_latest_state_dict( self.config.checkpoint_job_dir, diff --git a/trinity/explorer/proxy/app.py b/trinity/explorer/proxy/app.py index b9f9536a17..d4a194d917 100644 --- a/trinity/explorer/proxy/app.py +++ b/trinity/explorer/proxy/app.py @@ -62,7 +62,7 @@ async def chat_completions(request: Request): async def show_available_models(request: Request): if hasattr(request.app.state, "models"): return JSONResponse(content=request.app.state.models) - url = await request.app.state.service.allocate_model(increase_count=False) + url, _ = await request.app.state.service.allocate_model(increase_count=False) async with httpx.AsyncClient() as client: print(f"Fetching models from {url}/v1/models") resp = await client.get(f"{url}/v1/models") @@ -110,6 +110,13 @@ async def feedback(request: Request): return JSONResponse(content={"status": "success"}) +@app.post("/commit") +async def commit(request: Request): + """Commit the current experiences.""" + await request.app.state.service.submit_experiences() + return JSONResponse(content={"status": "success"}) + + async def serve_http(app: FastAPI, host: str, port: int = None): config = uvicorn.Config(app, host=host, port=port) server = uvicorn.Server(config) diff --git a/trinity/explorer/proxy/client.py b/trinity/explorer/proxy/client.py index 43e8de0cab..e8a8b61860 100644 --- a/trinity/explorer/proxy/client.py +++ b/trinity/explorer/proxy/client.py @@ -39,3 +39,23 @@ async def feedback_async(self, reward: float, msg_ids: list[str]) -> dict: json={"reward": reward, "msg_ids": msg_ids, "task_id": self.task_id}, ) return response.json() + + def commit(self) -> dict: + response = requests.post(f"{self.proxy_url}/commit") + return response.json() + + async def commit_async(self) -> dict: + async with httpx.AsyncClient() as client: + response = await client.post(f"{self.proxy_url}/commit") + return response.json() + + def get_metrics(self) -> dict: + response = requests.get(f"{self.proxy_url}/metrics") + return response.json() + + async def get_metrics_async(self) -> dict: + async with httpx.AsyncClient() as client: + response = await client.get(f"{self.proxy_url}/metrics") + if response.status_code != 200: + raise ValueError(f"Failed to get metrics: {response.text}") + return response.json() diff --git a/trinity/explorer/proxy/service.py b/trinity/explorer/proxy/service.py index ac7294a0dd..e05e84d6f8 100644 --- a/trinity/explorer/proxy/service.py +++ b/trinity/explorer/proxy/service.py @@ -32,7 +32,7 @@ def __init__(self, explorer: Explorer, listen_address: str = "localhost", port: self.sync_task_map: Dict[asyncio.Future, int] = {} # sync task -> model index self.latest_model_version = 0 self.session_level_experience_queue: Dict[int, deque[Experience]] = {} - self.queue_lock = asyncio.Lock() + self.commit_lock = asyncio.Lock() self.ready_experiences = deque() self.recorder = HistoryRecorder( db_url=explorer.config.explorer.db_url @@ -168,10 +168,14 @@ async def record_experience(self, response, model_version: int) -> None: self.total_experience_count += len(experiences) self.recorder.record_history(experiences) - async def get_all_experiences(self) -> List: - experiences = list(self.ready_experiences) - self.ready_experiences.clear() - return experiences + async def submit_experiences(self) -> None: + async with self.commit_lock: + experiences = list(self.ready_experiences) + self.ready_experiences.clear() + metrics = await self.explorer.experience_pipeline.process.remote(experiences) + metrics.update(self.collect_metrics()) + self.explorer.explore_step_num += 1 + self.explorer.monitor.log(metrics, self.explorer.explore_step_num) async def record_feedback(self, reward: float, msg_ids: List[str], task_id: str, run_id: int): exps = self.recorder.update_reward( From 5507e5f34519e5a635a939122b16b9b0c97b64c8 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 18 Dec 2025 15:06:43 +0800 Subject: [PATCH 14/33] add tests --- tests/trainer/trainer_test.py | 59 +++++++++++++++++------------- trinity/common/models/model.py | 2 +- trinity/explorer/proxy/client.py | 7 ++++ trinity/explorer/proxy/recorder.py | 24 +++++++++++- trinity/explorer/proxy/service.py | 39 +++++++++----------- 5 files changed, 81 insertions(+), 50 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index a42a5c9c7f..98059c38a2 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -11,7 +11,6 @@ from datetime import datetime from unittest import mock -import httpx import ray from parameterized import parameterized_class @@ -966,8 +965,8 @@ async def test_serve_with_trainer(self): ) config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") config.buffer.train_batch_size = 4 - config.trainer.total_steps = 4 - config.trainer.save_interval = 4 + config.buffer.total_steps = 6 + config.trainer.save_interval = 2 config.synchronizer.sync_interval = 2 config.synchronizer.sync_method = SyncMethod.CHECKPOINT config.explorer.rollout_model.engine_num = 2 @@ -1014,34 +1013,44 @@ async def test_serve_with_trainer(self): await asyncio.sleep(3) if not server_url: raise RuntimeError("Explorer server URL not found.") + proxy_client = ProxyClient(server_url) # wait for server setup for i in range(10): - try: - async with httpx.AsyncClient() as client: - response = await client.get(f"{server_url}/health") - if response.status_code == 200: - break - except Exception: - pass + if proxy_client.alive(): + break await asyncio.sleep(2) reader = get_buffer_reader(serve_config.buffer.explorer_input.taskset) - proxy_client = ProxyClient(server_url) - for i in range(2): - # generate data for 2 trainer steps - tasks = reader.read(batch_size=8) - await asyncio.gather(*(run_math_workflow(server_url, task.raw_task) for task in tasks)) - await proxy_client.commit_async() - # wait for synchronizer started - end_time = time.time() - while time.time() - end_time < config.explorer.service_status_check_interval: - await asyncio.sleep(1) - - serve_process.terminate() - trainer_process.terminate() - serve_process.join() - trainer_process.join() + try: + for i in range(3): + # generate data for 2 trainer steps + tasks = reader.read(batch_size=4) + await asyncio.gather( + *(run_math_workflow(server_url, task.raw_task) for task in tasks) + ) + await proxy_client.commit_async() + # wait for synchronizer started + end_time = time.time() + find_checkpoint = False + while time.time() - end_time < 100: + checkpoint_step_dir, step_num = get_checkpoint_dir_with_step_num( + checkpoint_root_path=serve_config.checkpoint_job_dir, + raise_error=False, + ) + if step_num >= 2 * (i + 1): # checkpoint has been generated + find_checkpoint = True + print(f"Found checkpoint at step {step_num}.") + break + await asyncio.sleep(1) + self.assertTrue( + find_checkpoint, f"Checkpoint at step {2 * (i + 1)} not found in time." + ) + finally: + serve_process.terminate() + trainer_process.terminate() + serve_process.join() + trainer_process.join() class TestMultiModalGRPO(BaseTrainerCase): diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 110db0a3b7..3fe6f2bf37 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -379,7 +379,7 @@ async def get_current_load(self) -> int: raise ValueError( "API server is not enabled for this model. Load metrics is unavailable." ) - with httpx.AsyncClient() as client: + async with httpx.AsyncClient() as client: response = await client.get(f"{self.api_address}/load") data = response.json() return data["server_load"] diff --git a/trinity/explorer/proxy/client.py b/trinity/explorer/proxy/client.py index e8a8b61860..9da7870fe6 100644 --- a/trinity/explorer/proxy/client.py +++ b/trinity/explorer/proxy/client.py @@ -12,6 +12,13 @@ def __init__(self, proxy_url: str): self.feedback_url = f"{self.proxy_url}/feedback" self.task_id = uuid.uuid4().hex[:6] + def alive(self) -> bool: + try: + response = requests.get(f"{self.proxy_url}/health", timeout=2) + return response.status_code == 200 + except requests.RequestException: + return False + def get_openai_client(self) -> openai.OpenAI: client = openai.OpenAI( base_url=self.openai_base_url, diff --git a/trinity/explorer/proxy/recorder.py b/trinity/explorer/proxy/recorder.py index 713f41b6ef..0056a866f2 100644 --- a/trinity/explorer/proxy/recorder.py +++ b/trinity/explorer/proxy/recorder.py @@ -8,6 +8,7 @@ from trinity.utils.log import get_logger +# TODO: implement an async version in the future class HistoryRecorder: """Record chat history into the database.""" @@ -30,7 +31,23 @@ def record_history(self, experiences: List[Experience]) -> None: def update_reward( self, reward: float, msg_ids: list, run_id: int, task_id: str ) -> List[Experience]: - """Update reward for given response IDs and return the updated experiences.""" + """Update reward for given response IDs and return the updated experiences. + + Args: + reward (float): The reward value to be updated. + msg_ids (list): List of message IDs to update. + run_id (int): The run ID associated with the experiences. + task_id (str): The task ID associated with the experiences. + + Returns: + List[Experience]: List of updated experiences. + + Note: + Only experiences that have not been consumed (consumed == 0) will be returned. + For example, if you call this method multiple times with the same msg_ids, only + the first call will return the updated experiences; subsequent calls will return + an empty list. + """ with retry_session(self.session) as db: db.execute( self.table_model_cls.__table__.update() @@ -46,7 +63,10 @@ def update_reward( with retry_session(self.session) as db: results = db.execute( self.table_model_cls.__table__.select().where( - (self.table_model_cls.msg_id.in_(msg_ids) & self.table_model_cls.consumed == 1) + ( + self.table_model_cls.msg_id.in_(msg_ids) + & (self.table_model_cls.consumed == 1) + ) ) ).all() updated_experiences = [self.table_model_cls.to_experience(row) for row in results] diff --git a/trinity/explorer/proxy/service.py b/trinity/explorer/proxy/service.py index e05e84d6f8..e284e00882 100644 --- a/trinity/explorer/proxy/service.py +++ b/trinity/explorer/proxy/service.py @@ -70,12 +70,10 @@ async def model_weights_sync_loop(self) -> None: > self.explorer.config.explorer.min_running_model_num and self.model_version_map[idx] < self.latest_model_version ): - self.running_model_ids.remove(idx) - self.models[idx].status = RunningStatus.REQUIRE_SYNC self.logger.info(f"Model {idx} scheduled for synchronization.") - future = asyncio.create_task(self._wait_for_sync_start(idx)) - self.sync_task_map[future] = idx - future.add_done_callback(self._sync_model_weights) + self.models[idx].status = RunningStatus.REQUIRE_SYNC + self.running_model_ids.remove(idx) + asyncio.create_task(self._wait_for_sync_start(idx)) # wait half interval await asyncio.sleep(self.check_interval / 2) self.logger.info("Model weights synchronization loop stopped.") @@ -88,32 +86,29 @@ def set_latest_model_version(self, version: int) -> None: async def _wait_for_sync_start(self, index: int) -> None: """Wait until the model is free to start synchronization.""" start_time = time.time() + timeout_flag = True while time.time() - start_time < self.max_timeout: current_load = await self.models[index].get_current_load() if current_load == 0: self.models[index].status = RunningStatus.WAITING_SYNC self.logger.info(f"Model {index} begins synchronization.") - return + timeout_flag = False + break else: - await asyncio.sleep(2) - raise asyncio.TimeoutError( - f"Timeout waiting for model {index} to be free for synchronization. Current load: {current_load}" - ) - - async def _sync_model_weights(self, task: asyncio.Future) -> None: - """A callback to synchronize model weights after waiting.""" - index = self.sync_task_map.pop(task) + self.logger.info( + "Waiting for model %d to be free. Current load: %d", index, current_load + ) + await asyncio.sleep(1) + if timeout_flag: + raise asyncio.TimeoutError( + f"Timeout waiting for model {index} to be free for synchronization. Current load: {current_load}" + ) latest_version = self.latest_model_version # capture the latest version - if task.cancelled(): - self.logger.warning(f"Synchronization of model {index} was cancelled.") - elif task.exception(): - self.logger.error(f"Error during synchronization of model {index}: {task.exception()}") - else: - await self.models[index].sync_model_weights(latest_version) - self.logger.info(f"Model {index} synchronized to version {latest_version}.") + await self.models[index].sync_model_weights(latest_version) + self.logger.info(f"Model {index} synchronized to version {latest_version}.") self.model_version_map[index] = await self.models[index].model_version_async - self.running_model_ids.append(index) self.models[index].status = RunningStatus.RUNNING + self.running_model_ids.append(index) async def allocate_model(self, increase_count: bool = True) -> Tuple[str, int]: """Allocate a model for handling a request. From ed69403aad327ae93cfb61344b0ec4c86d72e4e8 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 18 Dec 2025 17:33:06 +0800 Subject: [PATCH 15/33] fix tests --- tests/explorer/proxy_test.py | 83 +++++++++++++++++++++++++++++++++++ tests/trainer/trainer_test.py | 54 +++++++++++++++++------ 2 files changed, 123 insertions(+), 14 deletions(-) create mode 100644 tests/explorer/proxy_test.py diff --git a/tests/explorer/proxy_test.py b/tests/explorer/proxy_test.py new file mode 100644 index 0000000000..18b3e96b18 --- /dev/null +++ b/tests/explorer/proxy_test.py @@ -0,0 +1,83 @@ +import os +import unittest +import uuid +from typing import List + +import torch + +from trinity.common.experience import EID, Experience +from trinity.explorer.proxy.recorder import HistoryRecorder + + +def get_dummy_experience(num: int) -> List[Experience]: + return [ + Experience( + eid=EID(suffix=uuid.uuid4().hex[:6]), + tokens=torch.zeros(5), + prompt_length=2, + ) + for _ in range(num) + ] + + +db_path = os.path.join(os.path.dirname(__file__), "test_recorder.db") + + +class RecorderTest(unittest.TestCase): + def setUp(self) -> None: + if os.path.exists(db_path): + os.remove(db_path) + + def tearDown(self) -> None: + if os.path.exists(db_path): + os.remove(db_path) + + def test_recorder(self): + recorder = HistoryRecorder( + # in memory sqlite for testing + db_url="sqlite:///" + db_path, + table_name="experience", + ) + self.assertIsInstance(recorder, HistoryRecorder) + # test record history + + experiences_1 = get_dummy_experience(3) + recorder.record_history(experiences_1) + # test update reward + msg_ids_1 = [exp.eid.suffix for exp in experiences_1] + experiences_2 = get_dummy_experience(2) + recorder.record_history(experiences_2) + updated_experiences = recorder.update_reward( + reward=1.0, msg_ids=msg_ids_1, run_id=1, task_id="test_task" + ) + self.assertEqual(len(updated_experiences), 3) + for exp in updated_experiences: + self.assertEqual(exp.reward, 1.0) + self.assertEqual(exp.eid.run, 1) + self.assertEqual(str(exp.eid.task), "test_task") + # test update reward with non-existing msg_ids + updated_experiences_empty = recorder.update_reward( + reward=2.0, msg_ids=["non_existing_id"], run_id=1, task_id="test_task" + ) + self.assertEqual(len(updated_experiences_empty), 0) + # test record history with empty experiences + recorder.record_history([]) # should not raise any exception + # test update reward multiple times + updated_experiences_2 = recorder.update_reward( + reward=3.0, + msg_ids=[exp.eid.suffix for exp in experiences_2], + run_id=2, + task_id="test_task_2", + ) + self.assertEqual(len(updated_experiences_2), 2) + for exp in updated_experiences_2: + self.assertEqual(exp.reward, 3.0) + self.assertEqual(exp.eid.run, 2) + self.assertEqual(str(exp.eid.task), "test_task_2") + updated_experiences_3 = recorder.update_reward( + reward=4.0, + msg_ids=[exp.eid.suffix for exp in experiences_2], + run_id=3, + task_id="test_task_3", + ) + self.assertEqual(len(updated_experiences_3), 0) # already consumed diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 98059c38a2..bbc215b8ce 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -947,15 +947,18 @@ def setUp(self): checkpoint_path = get_checkpoint_path() shutil.rmtree(os.path.join(checkpoint_path, "unittest"), ignore_errors=True) - async def test_serve_with_trainer(self): + async def test_serve_with_trainer(self): # noqa: C901 config = get_template_config() config.project = "unittest" config.name = f"serve_with_trainer_{datetime.now().strftime('%Y%m%d%H%M%S')}" config.checkpoint_root_dir = get_checkpoint_path() config.model.model_path = get_model_path() config.buffer.batch_size = 4 + config.buffer.train_batch_size = 4 config.algorithm.algorithm_type = "ppo" config.algorithm.repeat_times = 1 + config.algorithm.sample_strategy = "staleness_control" + config.algorithm.sample_strategy_args = {"max_staleness": 1} config.cluster.gpu_per_node = 2 config.cluster.node_num = 1 config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( @@ -964,10 +967,9 @@ async def test_serve_with_trainer(self): schema_type="experience", ) config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") - config.buffer.train_batch_size = 4 - config.buffer.total_steps = 6 - config.trainer.save_interval = 2 - config.synchronizer.sync_interval = 2 + config.buffer.total_steps = 3 + config.trainer.save_interval = 1 + config.synchronizer.sync_interval = 1 config.synchronizer.sync_method = SyncMethod.CHECKPOINT config.explorer.rollout_model.engine_num = 2 config.explorer.rollout_model.enable_openai_api = True @@ -981,7 +983,7 @@ async def test_serve_with_trainer(self): trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,)) trainer_process.start() - await asyncio.sleep(10) + await asyncio.sleep(5) serve_config = deepcopy(config) serve_config.mode = "serve" serve_config.check_and_update() @@ -1017,15 +1019,16 @@ async def test_serve_with_trainer(self): # wait for server setup for i in range(10): if proxy_client.alive(): + print("Proxy server is alive.") break await asyncio.sleep(2) - reader = get_buffer_reader(serve_config.buffer.explorer_input.taskset) + config.buffer.explorer_input.taskset.batch_size = 4 + reader = get_buffer_reader(config.buffer.explorer_input.taskset) try: for i in range(3): - # generate data for 2 trainer steps - tasks = reader.read(batch_size=4) + tasks = reader.read() await asyncio.gather( *(run_math_workflow(server_url, task.raw_task) for task in tasks) ) @@ -1034,18 +1037,41 @@ async def test_serve_with_trainer(self): end_time = time.time() find_checkpoint = False while time.time() - end_time < 100: - checkpoint_step_dir, step_num = get_checkpoint_dir_with_step_num( + _, step_num = get_checkpoint_dir_with_step_num( checkpoint_root_path=serve_config.checkpoint_job_dir, raise_error=False, ) - if step_num >= 2 * (i + 1): # checkpoint has been generated + if step_num >= i + 1: # checkpoint has been generated find_checkpoint = True - print(f"Found checkpoint at step {step_num}.") break await asyncio.sleep(1) - self.assertTrue( - find_checkpoint, f"Checkpoint at step {2 * (i + 1)} not found in time." + self.assertTrue(find_checkpoint, f"Checkpoint at step {i + 1} not found in time.") + metrics = await proxy_client.get_metrics_async() + self.assertTrue(metrics["rollout/total_experience_count"] == 4 * (i + 1)) + self.assertTrue(metrics["rollout/ready_experience_count"] == 4 * (i + 1)) + self.assertTrue(metrics["rollout/model_0/total_request_count"] > 0) + self.assertTrue(metrics["rollout/model_1/total_request_count"] > 0) + if i > 1: + self.assertTrue(metrics["rollout/model_0/model_version"] > 0) + self.assertTrue(metrics["rollout/model_1/model_version"] > 0) + metrics = await proxy_client.get_metrics_async() + self.assertTrue(metrics["rollout/total_experience_count"] == 12) + self.assertTrue(metrics["rollout/ready_experience_count"] == 12) + self.assertTrue( + abs( + metrics["rollout/model_0/total_request_count"] + - metrics["rollout/model_1/total_request_count"] ) + < 4 + ) # balanced requests + # at least updated to version 2 + self.assertTrue(metrics["rollout/model_0/model_version"] >= 2) + self.assertTrue(metrics["rollout/model_1/model_version"] >= 2) + # check final checkpoint + _, step_num = get_checkpoint_dir_with_step_num( + checkpoint_root_path=serve_config.checkpoint_job_dir, + step_num=3, + ) finally: serve_process.terminate() trainer_process.terminate() From 4050d1df6cedf53cfa25372cf8d7bf6e779abd4b Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 18 Dec 2025 17:55:15 +0800 Subject: [PATCH 16/33] fix model version --- tests/explorer/explorer_test.py | 1 - tests/explorer/proxy_test.py | 3 +++ tests/explorer/workflow_test.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 5dcd8fd44a..b368a38610 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -203,7 +203,6 @@ def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": multiprocessing.set_start_method("spawn", force=True) - @unittest.skip("Require improvement for agent mode") async def test_serve(self): # noqa: C901 serve_process = multiprocessing.Process(target=run_serve, args=(self.config,)) serve_process.start() diff --git a/tests/explorer/proxy_test.py b/tests/explorer/proxy_test.py index 18b3e96b18..7ba6f29398 100644 --- a/tests/explorer/proxy_test.py +++ b/tests/explorer/proxy_test.py @@ -15,6 +15,9 @@ def get_dummy_experience(num: int) -> List[Experience]: eid=EID(suffix=uuid.uuid4().hex[:6]), tokens=torch.zeros(5), prompt_length=2, + info={ + "model_version": 0, + } ) for _ in range(num) ] diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 16bb0b7bbb..a28e71ca4d 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -553,7 +553,7 @@ async def test_adapter(self): try: from agentscope.model import TrinityChatModel except ImportError: - self.skipTest("agentscope >= 0.1.6 is not installed") + self.skipTest("agentscope >= 1.0.9 is not installed") async def as_workflow_func(task, model) -> float: self.assertIsInstance(task, dict) From 73466c4879ed8094d3141de00e6b6a740f42f663 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 18 Dec 2025 18:04:00 +0800 Subject: [PATCH 17/33] fix comments --- trinity/explorer/proxy/app.py | 6 ----- trinity/explorer/proxy/recorder.py | 40 ++++++++++++++++-------------- trinity/explorer/proxy/service.py | 9 ++++--- 3 files changed, 27 insertions(+), 28 deletions(-) diff --git a/trinity/explorer/proxy/app.py b/trinity/explorer/proxy/app.py index d4a194d917..c521c4e456 100644 --- a/trinity/explorer/proxy/app.py +++ b/trinity/explorer/proxy/app.py @@ -84,12 +84,6 @@ async def metrics(request: Request): return JSONResponse(content=metrics) -@app.get("/allocate") -async def allocate(request: Request): - """Allocate a new session.""" - return JSONResponse(content={"session_id": request.app.state.service.allocate_session()}) - - @app.post("/feedback") async def feedback(request: Request): """Receive feedback for the current session.""" diff --git a/trinity/explorer/proxy/recorder.py b/trinity/explorer/proxy/recorder.py index 0056a866f2..d5eaded27e 100644 --- a/trinity/explorer/proxy/recorder.py +++ b/trinity/explorer/proxy/recorder.py @@ -49,25 +49,27 @@ def update_reward( an empty list. """ with retry_session(self.session) as db: - db.execute( - self.table_model_cls.__table__.update() - .where((self.table_model_cls.msg_id.in_(msg_ids))) - .values( - reward=reward, - run_id=run_id, - task_id=task_id, - consumed=self.table_model_cls.consumed + 1, + # Lock and retrieve records that have not been consumed yet. + records = ( + db.query(self.table_model_cls) + .filter( + self.table_model_cls.msg_id.in_(msg_ids), + self.table_model_cls.consumed == 0, ) + .with_for_update() + .all() ) - with retry_session(self.session) as db: - results = db.execute( - self.table_model_cls.__table__.select().where( - ( - self.table_model_cls.msg_id.in_(msg_ids) - & (self.table_model_cls.consumed == 1) - ) - ) - ).all() - updated_experiences = [self.table_model_cls.to_experience(row) for row in results] - return updated_experiences + if not records: + return [] + + # Update records in memory + for record in records: + record.reward = reward + record.run_id = run_id + record.task_id = task_id + record.consumed += 1 + + # The session commit is handled by the `retry_session` context manager. + updated_experiences = [record.to_experience() for record in records] + return updated_experiences diff --git a/trinity/explorer/proxy/service.py b/trinity/explorer/proxy/service.py index e284e00882..868c733c45 100644 --- a/trinity/explorer/proxy/service.py +++ b/trinity/explorer/proxy/service.py @@ -73,7 +73,7 @@ async def model_weights_sync_loop(self) -> None: self.logger.info(f"Model {idx} scheduled for synchronization.") self.models[idx].status = RunningStatus.REQUIRE_SYNC self.running_model_ids.remove(idx) - asyncio.create_task(self._wait_for_sync_start(idx)) + asyncio.create_task(self._sync_model_weights(idx)) # wait half interval await asyncio.sleep(self.check_interval / 2) self.logger.info("Model weights synchronization loop stopped.") @@ -83,10 +83,12 @@ def set_latest_model_version(self, version: int) -> None: self.latest_model_version = version self.logger.info(f"Updated latest model version to {version}.") - async def _wait_for_sync_start(self, index: int) -> None: - """Wait until the model is free to start synchronization.""" + async def _sync_model_weights(self, index: int) -> None: + """Synchronize model weights for the given model index.""" + # wait until the model is free start_time = time.time() timeout_flag = True + current_load = -1 while time.time() - start_time < self.max_timeout: current_load = await self.models[index].get_current_load() if current_load == 0: @@ -104,6 +106,7 @@ async def _wait_for_sync_start(self, index: int) -> None: f"Timeout waiting for model {index} to be free for synchronization. Current load: {current_load}" ) latest_version = self.latest_model_version # capture the latest version + # perform synchronization await self.models[index].sync_model_weights(latest_version) self.logger.info(f"Model {index} synchronized to version {latest_version}.") self.model_version_map[index] = await self.models[index].model_version_async From c1cee2379ab1c9899821ca8efe2c0433592af87d Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 18 Dec 2025 18:07:15 +0800 Subject: [PATCH 18/33] fix comments --- tests/explorer/explorer_test.py | 1 - tests/explorer/proxy_test.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index b368a38610..e037b1ceb4 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -5,7 +5,6 @@ import os import random import shutil -import unittest from datetime import datetime import httpx diff --git a/tests/explorer/proxy_test.py b/tests/explorer/proxy_test.py index 7ba6f29398..e819c246aa 100644 --- a/tests/explorer/proxy_test.py +++ b/tests/explorer/proxy_test.py @@ -17,7 +17,7 @@ def get_dummy_experience(num: int) -> List[Experience]: prompt_length=2, info={ "model_version": 0, - } + }, ) for _ in range(num) ] From 5122366545658b925e107025c5783430912fdb61 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 18 Dec 2025 20:50:35 +0800 Subject: [PATCH 19/33] fix tests --- tests/buffer/experience_storage_test.py | 4 ++-- tests/explorer/explorer_test.py | 6 +++--- tests/trainer/trainer_test.py | 9 +++++---- trinity/explorer/explorer.py | 4 +++- trinity/explorer/proxy/app.py | 16 ++++++++-------- trinity/explorer/proxy/client.py | 2 +- 6 files changed, 22 insertions(+), 19 deletions(-) diff --git a/tests/buffer/experience_storage_test.py b/tests/buffer/experience_storage_test.py index f308e1ee10..c94978c418 100644 --- a/tests/buffer/experience_storage_test.py +++ b/tests/buffer/experience_storage_test.py @@ -118,7 +118,7 @@ async def test_sql_experience_buffer(self): exps = reader.read() self.assertEqual(len(exps), self.train_batch_size) for exp in exps: - self.assertEqual(exp.eid.task, cnt) + self.assertEqual(exp.eid.task, str(cnt)) cnt -= 1 # experience buffer support experience reuse @@ -127,7 +127,7 @@ async def test_sql_experience_buffer(self): exps = reader.read() self.assertEqual(len(exps), self.train_batch_size) for exp in exps: - self.assertEqual(exp.eid.task, cnt) + self.assertEqual(exp.eid.task, str(cnt)) cnt -= 1 self.assertEqual(await writer.release(), 0) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index e037b1ceb4..299bbf66b8 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -25,7 +25,7 @@ from trinity.common.config import ExperienceBufferConfig, InferenceModelConfig from trinity.common.constants import StorageType from trinity.explorer.explorer import Explorer -from trinity.explorer.proxy.client import ProxyClient +from trinity.explorer.proxy.client import TrinityClient from trinity.manager.state_manager import StateManager @@ -158,7 +158,7 @@ def run_serve(config): def run_agent(proxy_url, model_path: str): - proxy_client = ProxyClient(proxy_url=proxy_url) + proxy_client = TrinityClient(proxy_url=proxy_url) openai_client = proxy_client.get_openai_client() contents = [ "Hello, how are you?", @@ -248,7 +248,7 @@ async def test_serve(self): # noqa: C901 self.assertFalse(app.is_alive()) finish_step = None - proxy_client = ProxyClient(proxy_url=server_url) + proxy_client = TrinityClient(proxy_url=server_url) for i in range(20): metrics = await proxy_client.get_metrics_async() metrics_keys = list(metrics.keys()) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index bbc215b8ce..6a96c4df74 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -16,6 +16,7 @@ from tests.tools import ( RayUnittestBase, + RayUnittestBaseAysnc, TensorBoardParser, get_checkpoint_path, get_lora_config, @@ -44,7 +45,7 @@ SyncStyle, ) from trinity.common.models.utils import get_checkpoint_dir_with_step_num -from trinity.explorer.proxy.client import ProxyClient +from trinity.explorer.proxy.client import TrinityClient from trinity.manager.state_manager import StateManager @@ -911,7 +912,7 @@ def tearDown(self): async def run_math_workflow(serve_url: str, task: dict): from trinity.common.rewards.math_reward import MathRewardFn - proxy_client = ProxyClient(serve_url) + proxy_client = TrinityClient(serve_url) openai_client = proxy_client.get_openai_async_client() query = task["question"] @@ -940,7 +941,7 @@ async def run_math_workflow(serve_url: str, task: dict): await proxy_client.feedback_async(sum(reward.values()), [response.id]) -class TestServeWithTrainer(unittest.IsolatedAsyncioTestCase): +class TestServeWithTrainer(RayUnittestBaseAysnc): def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": multiprocessing.set_start_method("spawn", force=True) @@ -1015,7 +1016,7 @@ async def test_serve_with_trainer(self): # noqa: C901 await asyncio.sleep(3) if not server_url: raise RuntimeError("Explorer server URL not found.") - proxy_client = ProxyClient(server_url) + proxy_client = TrinityClient(server_url) # wait for server setup for i in range(10): if proxy_client.alive(): diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 334ac28bad..0559d44a43 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -507,7 +507,9 @@ async def serve(self) -> None: await self.service.serve() self.server_url = f"http://{ray.util.get_node_ip_address()}:{self.service.port}" self.logger.info( - f"Explorer API Server is started on {self.server_url} and listening to {self.service.listen_address}." + "======================================================\n" + f"Starting Trinity Service on {self.server_url}\n" + "======================================================" ) self.state.save_explorer_server_url(self.server_url) while True: diff --git a/trinity/explorer/proxy/app.py b/trinity/explorer/proxy/app.py index c521c4e456..84cb05e058 100644 --- a/trinity/explorer/proxy/app.py +++ b/trinity/explorer/proxy/app.py @@ -37,11 +37,11 @@ async def chat_completions(request: Request): for key, value in request.headers.items() if key.lower() not in ["host", "content-length", "transfer-encoding"] } - - if "return_token_ids" not in request_data: - request_data["return_token_ids"] = True - if "logprobs" not in request_data: - request_data["logprobs"] = True + # for experience data recording, we need to return token ids and logprobs + request_data["return_token_ids"] = True + request_data["logprobs"] = True + # temperature must be set from config, ignore user's input + request_data["temperature"] = request.app.state.temperature url, model_version = await request.app.state.service.allocate_model() try: async with httpx.AsyncClient(timeout=request.app.state.inference_timeout) as client: @@ -111,14 +111,14 @@ async def commit(request: Request): return JSONResponse(content={"status": "success"}) -async def serve_http(app: FastAPI, host: str, port: int = None): +async def serve_http(app: FastAPI, host: str, port: int) -> None: config = uvicorn.Config(app, host=host, port=port) server = uvicorn.Server(config) await server.serve() -async def run_app(service, listen_address: str, port: int = None) -> FastAPI: +async def run_app(service, listen_address: str, port: int) -> None: app.state.service = service + app.state.temperature = service.explorer.config.model.temperature app.state.inference_timeout = service.explorer.config.synchronizer.sync_timeout - print(f"API server running on {listen_address}:{port}") await serve_http(app, listen_address, port) diff --git a/trinity/explorer/proxy/client.py b/trinity/explorer/proxy/client.py index 9da7870fe6..31177e6709 100644 --- a/trinity/explorer/proxy/client.py +++ b/trinity/explorer/proxy/client.py @@ -5,7 +5,7 @@ import requests -class ProxyClient: +class TrinityClient: def __init__(self, proxy_url: str): self.proxy_url = proxy_url self.openai_base_url = f"{self.proxy_url}/v1" From 4094490c1606567c14ccd2760748a95fd9861dba Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 19 Dec 2025 09:58:36 +0800 Subject: [PATCH 20/33] fix replay --- tests/buffer/sql_test.py | 29 ++++++++++++++++++++++++++--- tests/trainer/trainer_test.py | 2 -- trinity/buffer/storage/sql.py | 3 +++ 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 1e742a54bc..b2cdd5712d 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -2,12 +2,17 @@ import ray import torch +from parameterized import parameterized from tests.tools import RayUnittestBaseAysnc from trinity.buffer import get_buffer_reader from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter -from trinity.common.config import ExperienceBufferConfig, TasksetConfig +from trinity.common.config import ( + ExperienceBufferConfig, + ReplayBufferConfig, + TasksetConfig, +) from trinity.common.constants import StorageType from trinity.common.experience import Experience @@ -15,7 +20,13 @@ class TestSQLBuffer(RayUnittestBaseAysnc): - async def test_sql_exp_buffer_read_write(self) -> None: + @parameterized.expand( + [ + (True,), + (False,), + ] + ) + async def test_sql_exp_buffer_read_write(self, enable_replay: bool) -> None: total_num = 8 put_batch_size = 2 read_batch_size = 4 @@ -25,7 +36,10 @@ async def test_sql_exp_buffer_read_write(self) -> None: path=f"sqlite:///{db_path}", storage_type=StorageType.SQL.value, batch_size=read_batch_size, + max_read_timeout=3, ) + if enable_replay: + config.replay_buffer = ReplayBufferConfig(enable=True) sql_writer = SQLWriter(config.to_storage_config()) sql_reader = SQLReader(config.to_storage_config()) exps = [ @@ -53,13 +67,22 @@ async def test_sql_exp_buffer_read_write(self) -> None: reward=float(i), logprobs=torch.tensor([0.1]), action_mask=torch.tensor([j % 2 for j in range(i + 1)]), - info={"model_version": i}, + info={"model_version": i + put_batch_size}, ) for i in range(1, put_batch_size * 2 + 1) ] ) exps = sql_reader.read(batch_size=put_batch_size * 2) self.assertEqual(len(exps), put_batch_size * 2) + for exp in exps: + self.assertTrue(exp.info["model_version"] > put_batch_size) + if enable_replay: + # support replay, so we can read all again + exps = sql_reader.read(batch_size=(put_batch_size * 2 + total_num)) + self.assertEqual(len(exps), (put_batch_size * 2 + total_num)) + # if read more than available, will wait until timeout + with self.assertRaises(StopIteration): + exps = sql_reader.read(batch_size=(put_batch_size * 3 + total_num)) db_wrapper = ray.get_actor("sql-test_buffer") self.assertIsNotNone(db_wrapper) self.assertEqual(await sql_writer.release(), 0) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 6a96c4df74..68ff350680 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -958,8 +958,6 @@ async def test_serve_with_trainer(self): # noqa: C901 config.buffer.train_batch_size = 4 config.algorithm.algorithm_type = "ppo" config.algorithm.repeat_times = 1 - config.algorithm.sample_strategy = "staleness_control" - config.algorithm.sample_strategy_args = {"max_staleness": 1} config.cluster.gpu_per_node = 2 config.cluster.node_num = 1 config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index 08ff06fb8c..04f0c20bda 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -97,6 +97,7 @@ def __init__(self, config: StorageConfig) -> None: super().__init__(config) self.max_timeout = config.max_read_timeout self.batch_size = config.batch_size + self.enable_replay = config.replay_buffer is not None and config.replay_buffer.enable # TODO: optimize the following logic if config.schema_type == "experience": # NOTE: consistent with the old version of experience buffer @@ -161,6 +162,8 @@ def _read_priority(self, batch_size: int, min_model_version: int = 0) -> List[Ex query = session.query(self.table_model_cls) if min_model_version > 0: query = query.filter(self.table_model_cls.model_version >= min_model_version) + if not self.enable_replay: + query = query.filter(self.table_model_cls.consumed == 0) experiences = ( query.order_by( asc(self.table_model_cls.consumed), desc(self.table_model_cls.id) From eb63d9abaa1a240406fda582c51aae82ee03319a Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 19 Dec 2025 10:06:15 +0800 Subject: [PATCH 21/33] fix synchronizer --- trinity/explorer/explorer.py | 8 ++------ trinity/manager/synchronizer.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 0559d44a43..767d30a5ed 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -25,7 +25,6 @@ SyncStyle, ) from trinity.common.models import create_inference_models -from trinity.common.models.utils import get_latest_state_dict from trinity.explorer.scheduler import Scheduler from trinity.manager.state_manager import StateManager from trinity.manager.synchronizer import Synchronizer @@ -515,11 +514,8 @@ async def serve(self) -> None: while True: await asyncio.sleep(self.config.explorer.service_status_check_interval) # get the latest checkpoint - _, step_num = get_latest_state_dict( - self.config.checkpoint_job_dir, - self.config.trainer.trainer_type, - ) - self.service.set_latest_model_version(step_num) + model_version = await self.synchronizer.get_latest_model_version.remote() + self.service.set_latest_model_version(model_version) @classmethod def get_actor(cls, config: Config): diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index e5adb86f1c..8157ad088d 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -274,6 +274,16 @@ async def wait_new_model_state_dict(self, current_version: int, no_wait: bool = ) return self.model_version + async def get_latest_model_version(self) -> int: + """ + Get the latest model version available in the synchronizer. + + Returns: + The current model version. + """ + async with self._ready_condition: + return self.model_version + async def ready_to_nccl_sync( self, module: str, trainer_step: Optional[int] = None ) -> Union[int, None]: From 3489b044bb1180bc476030cc095f9eb17fa882fd Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 19 Dec 2025 10:16:16 +0800 Subject: [PATCH 22/33] fix tests --- tests/trainer/trainer_test.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 68ff350680..60c0470e8c 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1056,13 +1056,12 @@ async def test_serve_with_trainer(self): # noqa: C901 metrics = await proxy_client.get_metrics_async() self.assertTrue(metrics["rollout/total_experience_count"] == 12) self.assertTrue(metrics["rollout/ready_experience_count"] == 12) + self.assertTrue(metrics["rollout/model_0/total_request_count"] > 0) + self.assertTrue(metrics["rollout/model_1/total_request_count"] > 0) self.assertTrue( - abs( - metrics["rollout/model_0/total_request_count"] - - metrics["rollout/model_1/total_request_count"] - ) - < 4 - ) # balanced requests + metrics["rollout/model_0/total_request_count"] + metrics["rollout/model_1/total_request_count"] + == metrics["rollout/total_request_count"] + ) # at least updated to version 2 self.assertTrue(metrics["rollout/model_0/model_version"] >= 2) self.assertTrue(metrics["rollout/model_1/model_version"] >= 2) From e8b0dd52b7e7f4669146cac73536c443ed9aff93 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 19 Dec 2025 10:18:29 +0800 Subject: [PATCH 23/33] fix pre-commit --- tests/trainer/trainer_test.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 60c0470e8c..020f5e3a8c 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1054,13 +1054,14 @@ async def test_serve_with_trainer(self): # noqa: C901 self.assertTrue(metrics["rollout/model_0/model_version"] > 0) self.assertTrue(metrics["rollout/model_1/model_version"] > 0) metrics = await proxy_client.get_metrics_async() - self.assertTrue(metrics["rollout/total_experience_count"] == 12) - self.assertTrue(metrics["rollout/ready_experience_count"] == 12) + self.assertEqual(metrics["rollout/total_experience_count"], 12) + self.assertEqual(metrics["rollout/ready_experience_count"], 12) self.assertTrue(metrics["rollout/model_0/total_request_count"] > 0) self.assertTrue(metrics["rollout/model_1/total_request_count"] > 0) - self.assertTrue( - metrics["rollout/model_0/total_request_count"] + metrics["rollout/model_1/total_request_count"] - == metrics["rollout/total_request_count"] + self.assertEqual( + metrics["rollout/model_0/total_request_count"] + + metrics["rollout/model_1/total_request_count"], + metrics["rollout/total_experience_count"] ) # at least updated to version 2 self.assertTrue(metrics["rollout/model_0/model_version"] >= 2) From c96593f3aa620cd2cb32d0c8b9a1cdcecb01c425 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 19 Dec 2025 10:22:16 +0800 Subject: [PATCH 24/33] fix comments --- tests/trainer/trainer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 020f5e3a8c..b964d1a8c8 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1061,7 +1061,7 @@ async def test_serve_with_trainer(self): # noqa: C901 self.assertEqual( metrics["rollout/model_0/total_request_count"] + metrics["rollout/model_1/total_request_count"], - metrics["rollout/total_experience_count"] + metrics["rollout/total_experience_count"], ) # at least updated to version 2 self.assertTrue(metrics["rollout/model_0/model_version"] >= 2) From f813128ab489072e8b12a3e7e6a220bb69b66830 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 19 Dec 2025 11:26:21 +0800 Subject: [PATCH 25/33] fix serve mode --- trinity/common/config.py | 4 ++-- trinity/common/models/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index f3d2d8ad93..3bcba0f71a 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -691,7 +691,7 @@ class ExplorerConfig: class TrainerConfig: name: str = TRAINER_NAME trainer_type: str = "verl" - trainer_strategy: str = "fsdp" + trainer_strategy: str = "fsdp" # "fsdp", "fsdp2" or "megatron" save_interval: int = 0 enable_preview: bool = True # enable rollout preview in wandb total_steps: Optional[ @@ -1399,7 +1399,7 @@ def check_and_update(self) -> Config: # noqa: C901 # check buffer self._check_buffer() # check and update trainer - if self.mode in ["train", "both", "bench"]: + if self.mode in ["train", "both", "bench"] or self.trainer.trainer_strategy == "megatron": if self.trainer.trainer_type == "verl": if self.trainer.trainer_config: from trinity.common.verl_config import veRLConfig diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index 10ea510e1b..445f1f63ba 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -201,11 +201,11 @@ def load_state_dict(checkpoint_dir: str, config: TrainerConfig) -> Union[dict, T trainer_type (str): The trainer type. Only support "verl" for now. """ if config.trainer_type == "verl": - actor_config = config.trainer_config.actor_rollout_ref.actor - strategy = actor_config.strategy + strategy = config.trainer_strategy if strategy in {"fsdp", "fsdp2"}: return load_fsdp_state_dict_from_verl_checkpoint(checkpoint_dir) elif strategy == "megatron": + actor_config = config.trainer_config.actor_rollout_ref.actor if ( actor_config.megatron.use_dist_checkpointing or not actor_config.megatron.use_mbridge From 81017403c1eb6dd8b12dc97afc06e367549f563b Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 19 Dec 2025 11:56:40 +0800 Subject: [PATCH 26/33] fix buffer test --- tests/buffer/experience_storage_test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/buffer/experience_storage_test.py b/tests/buffer/experience_storage_test.py index c94978c418..5452ae300c 100644 --- a/tests/buffer/experience_storage_test.py +++ b/tests/buffer/experience_storage_test.py @@ -10,7 +10,7 @@ from tests.tools import RayUnittestBaseAysnc from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter -from trinity.common.config import ExperienceBufferConfig +from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig from trinity.common.constants import StorageType from trinity.common.experience import EID, Experience @@ -95,6 +95,9 @@ async def test_sql_experience_buffer(self): max_read_timeout=3, path=f"sqlite:///{DB_PATH}", batch_size=self.train_batch_size, + replay_buffer=ReplayBufferConfig( + enable=True + ) ) config = config.to_storage_config() writer = SQLWriter(config) From fe337993a5550258689e08f22ad6d5339d108ae9 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 19 Dec 2025 11:57:07 +0800 Subject: [PATCH 27/33] fix pre-commit --- tests/buffer/experience_storage_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/buffer/experience_storage_test.py b/tests/buffer/experience_storage_test.py index 5452ae300c..0071fabf39 100644 --- a/tests/buffer/experience_storage_test.py +++ b/tests/buffer/experience_storage_test.py @@ -95,9 +95,7 @@ async def test_sql_experience_buffer(self): max_read_timeout=3, path=f"sqlite:///{DB_PATH}", batch_size=self.train_batch_size, - replay_buffer=ReplayBufferConfig( - enable=True - ) + replay_buffer=ReplayBufferConfig(enable=True), ) config = config.to_storage_config() writer = SQLWriter(config) From 3df76a81462188e700d9f2035d95ce03ac4c58f5 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 19 Dec 2025 13:47:48 +0800 Subject: [PATCH 28/33] fix megatron training --- tests/trainer/trainer_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index b964d1a8c8..972071b99c 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -95,10 +95,11 @@ def test_trainer(self): eval_tasksets[1].repeat_times = 4 self.config.trainer.save_interval = 4 self.config.trainer.save_hf_checkpoint = "always" + if self.strategy == "megatron": + self.config.trainer.trainer_strategy = "megatron" self.config.check_and_update() _trainer_config = self.config.trainer.trainer_config if self.strategy == "megatron": - _trainer_config.actor_rollout_ref.actor.strategy = "megatron" _trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2 _trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2 _trainer_config.critic.strategy = "megatron" @@ -551,10 +552,11 @@ def test_fully_async_mode(self): trainer_config = deepcopy(config) trainer_config.mode = "train" trainer_config.buffer.train_batch_size = 4 + if self.strategy == "megatron": + trainer_config.trainer.trainer_strategy = "megatron" trainer_config.check_and_update() if self.strategy == "megatron": _trainer_config = trainer_config.trainer.trainer_config - _trainer_config.actor_rollout_ref.actor.strategy = "megatron" _trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2 _trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2 _trainer_config.critic.strategy = "megatron" From 63fa3be78ff63a767f5f1238b6fd46e9f858bdd1 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 19 Dec 2025 13:59:05 +0800 Subject: [PATCH 29/33] fix vllm prefix caching --- trinity/common/models/vllm_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index d1f1d3f61f..8d3c5c01a0 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -522,6 +522,7 @@ async def sync_model(self, model_version: int) -> int: await self.async_llm.add_lora(self.get_lora_request(self.default_lora_path)) self.model_version = model_version return model_version + await self.async_llm.reset_prefix_cache() await self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") self.model_version = model_version From e0973a30185a6e902414f109aae4205149ae4e88 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 22 Dec 2025 09:46:04 +0800 Subject: [PATCH 30/33] fix benchmark --- benchmark/bench.py | 2 ++ trinity/common/config.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmark/bench.py b/benchmark/bench.py index 07a481c312..0614a8f532 100644 --- a/benchmark/bench.py +++ b/benchmark/bench.py @@ -144,6 +144,8 @@ def check_taskset_path(dataset_name: str, taskset_path: str) -> str: raise AttributeError(f"{script_filename} is missing 'DEFAULT_DATA_PATH'") taskset_path = module.DEFAULT_DATA_PATH taskset_path = os.path.realpath(taskset_path) + if os.path.exists(taskset_path): + return taskset_path # For frozenlake, check if train.parquet and test.parquet already exist if dataset_name == "frozenlake": diff --git a/trinity/common/config.py b/trinity/common/config.py index 3bcba0f71a..aee3a44fae 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -483,7 +483,7 @@ class InferenceModelConfig: tensor_parallel_size: int = 1 use_v1: bool = True enforce_eager: bool = False - enable_prefix_caching: bool = False + enable_prefix_caching: bool = True enable_chunked_prefill: bool = True gpu_memory_utilization: float = 0.9 dtype: str = "bfloat16" From ed8edef9a4bc978ffe693987d6b6b95151acdf1d Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 22 Dec 2025 10:36:58 +0800 Subject: [PATCH 31/33] add client side timeout --- trinity/explorer/proxy/client.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/trinity/explorer/proxy/client.py b/trinity/explorer/proxy/client.py index 31177e6709..e3e5a594a8 100644 --- a/trinity/explorer/proxy/client.py +++ b/trinity/explorer/proxy/client.py @@ -33,36 +33,39 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: ) return client - def feedback(self, reward: float, msg_ids: list[str]) -> dict: + def feedback(self, reward: float, msg_ids: list[str], timeout: float = 10) -> dict: response = requests.post( - self.feedback_url, json={"reward": reward, "msg_ids": msg_ids, "task_id": self.task_id} + self.feedback_url, + json={"reward": reward, "msg_ids": msg_ids, "task_id": self.task_id}, + timeout=timeout, ) return response.json() - async def feedback_async(self, reward: float, msg_ids: list[str]) -> dict: + async def feedback_async(self, reward: float, msg_ids: list[str], timeout: float = 10) -> dict: async with httpx.AsyncClient() as client: response = await client.post( self.feedback_url, json={"reward": reward, "msg_ids": msg_ids, "task_id": self.task_id}, + timeout=timeout, ) return response.json() - def commit(self) -> dict: - response = requests.post(f"{self.proxy_url}/commit") + def commit(self, timeout: float = 10) -> dict: + response = requests.post(f"{self.proxy_url}/commit", timeout=timeout) return response.json() - async def commit_async(self) -> dict: + async def commit_async(self, timeout: float = 10) -> dict: async with httpx.AsyncClient() as client: - response = await client.post(f"{self.proxy_url}/commit") + response = await client.post(f"{self.proxy_url}/commit", timeout=timeout) return response.json() - def get_metrics(self) -> dict: - response = requests.get(f"{self.proxy_url}/metrics") + def get_metrics(self, timeout: float = 5) -> dict: + response = requests.get(f"{self.proxy_url}/metrics", timeout=timeout) return response.json() - async def get_metrics_async(self) -> dict: + async def get_metrics_async(self, timeout: float = 5) -> dict: async with httpx.AsyncClient() as client: - response = await client.get(f"{self.proxy_url}/metrics") + response = await client.get(f"{self.proxy_url}/metrics", timeout=timeout) if response.status_code != 200: raise ValueError(f"Failed to get metrics: {response.text}") return response.json() From 290e36f1a8c6ce23ae354772f697b74ff96c6a18 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 22 Dec 2025 13:30:12 +0800 Subject: [PATCH 32/33] fix comments --- tests/manager/synchronizer_test.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py index 114e396aca..6ab488986e 100644 --- a/tests/manager/synchronizer_test.py +++ b/tests/manager/synchronizer_test.py @@ -80,32 +80,23 @@ async def new_finish_explore_step(self, step: int, model_version: int) -> None: def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None: ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) - try: - trainer_monkey_patch(config, max_steps, intervals) - train(config) - finally: - ray.shutdown(_exiting_interpreter=True) + trainer_monkey_patch(config, max_steps, intervals) + train(config) def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None: ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) - try: - explorer_monkey_patch(config, max_steps, intervals) - explore(config) - finally: - ray.shutdown(_exiting_interpreter=True) + explorer_monkey_patch(config, max_steps, intervals) + explore(config) def run_both( config: Config, max_steps: int, trainer_intervals: List[int], explorer_intervals: List[int] ) -> None: ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) - try: - trainer_monkey_patch(config, max_steps, trainer_intervals) - explorer_monkey_patch(config, max_steps, explorer_intervals) - both(config) - finally: - ray.shutdown(_exiting_interpreter=True) + trainer_monkey_patch(config, max_steps, trainer_intervals) + explorer_monkey_patch(config, max_steps, explorer_intervals) + both(config) class BaseTestSynchronizer(unittest.TestCase): @@ -115,6 +106,7 @@ def setUp(self): def tearDown(self): checkpoint_path = get_checkpoint_path() + ray.shutdown(_exiting_interpreter=True) shutil.rmtree(os.path.join(checkpoint_path, "unittest")) From c0b7b67121bc70f3bd3decca4720fd61df526b65 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 23 Dec 2025 11:10:48 +0800 Subject: [PATCH 33/33] update default test setting --- tests/template/config.yaml | 2 -- trinity/common/models/vllm_model.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 745dcc50a1..0bf1c2eaf8 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -38,8 +38,6 @@ explorer: rollout_model: engine_num: 2 tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true dtype: bfloat16 seed: 42 trainer: diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 8d3c5c01a0..c6f85b48f8 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -109,11 +109,11 @@ def __init__( distributed_executor_backend=("uni" if config.tensor_parallel_size == 1 else "ray"), max_model_len=max_model_len, enable_prefix_caching=config.enable_prefix_caching, + enable_chunked_prefill=config.enable_chunked_prefill, dtype=config.dtype, trust_remote_code=True, task="generate", gpu_memory_utilization=config.gpu_memory_utilization, - # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage override_generation_config={ # TODO: find a way to unittest this "temperature": config.temperature, "top_p": config.top_p,