diff --git a/benchmark/bench.py b/benchmark/bench.py
index 07a481c312..0614a8f532 100644
--- a/benchmark/bench.py
+++ b/benchmark/bench.py
@@ -144,6 +144,8 @@ def check_taskset_path(dataset_name: str, taskset_path: str) -> str:
raise AttributeError(f"{script_filename} is missing 'DEFAULT_DATA_PATH'")
taskset_path = module.DEFAULT_DATA_PATH
taskset_path = os.path.realpath(taskset_path)
+ if os.path.exists(taskset_path):
+ return taskset_path
# For frozenlake, check if train.parquet and test.parquet already exist
if dataset_name == "frozenlake":
diff --git a/tests/buffer/experience_storage_test.py b/tests/buffer/experience_storage_test.py
index f308e1ee10..0071fabf39 100644
--- a/tests/buffer/experience_storage_test.py
+++ b/tests/buffer/experience_storage_test.py
@@ -10,7 +10,7 @@
from tests.tools import RayUnittestBaseAysnc
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
-from trinity.common.config import ExperienceBufferConfig
+from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig
from trinity.common.constants import StorageType
from trinity.common.experience import EID, Experience
@@ -95,6 +95,7 @@ async def test_sql_experience_buffer(self):
max_read_timeout=3,
path=f"sqlite:///{DB_PATH}",
batch_size=self.train_batch_size,
+ replay_buffer=ReplayBufferConfig(enable=True),
)
config = config.to_storage_config()
writer = SQLWriter(config)
@@ -118,7 +119,7 @@ async def test_sql_experience_buffer(self):
exps = reader.read()
self.assertEqual(len(exps), self.train_batch_size)
for exp in exps:
- self.assertEqual(exp.eid.task, cnt)
+ self.assertEqual(exp.eid.task, str(cnt))
cnt -= 1
# experience buffer support experience reuse
@@ -127,7 +128,7 @@ async def test_sql_experience_buffer(self):
exps = reader.read()
self.assertEqual(len(exps), self.train_batch_size)
for exp in exps:
- self.assertEqual(exp.eid.task, cnt)
+ self.assertEqual(exp.eid.task, str(cnt))
cnt -= 1
self.assertEqual(await writer.release(), 0)
diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py
index 1e742a54bc..b2cdd5712d 100644
--- a/tests/buffer/sql_test.py
+++ b/tests/buffer/sql_test.py
@@ -2,12 +2,17 @@
import ray
import torch
+from parameterized import parameterized
from tests.tools import RayUnittestBaseAysnc
from trinity.buffer import get_buffer_reader
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
-from trinity.common.config import ExperienceBufferConfig, TasksetConfig
+from trinity.common.config import (
+ ExperienceBufferConfig,
+ ReplayBufferConfig,
+ TasksetConfig,
+)
from trinity.common.constants import StorageType
from trinity.common.experience import Experience
@@ -15,7 +20,13 @@
class TestSQLBuffer(RayUnittestBaseAysnc):
- async def test_sql_exp_buffer_read_write(self) -> None:
+ @parameterized.expand(
+ [
+ (True,),
+ (False,),
+ ]
+ )
+ async def test_sql_exp_buffer_read_write(self, enable_replay: bool) -> None:
total_num = 8
put_batch_size = 2
read_batch_size = 4
@@ -25,7 +36,10 @@ async def test_sql_exp_buffer_read_write(self) -> None:
path=f"sqlite:///{db_path}",
storage_type=StorageType.SQL.value,
batch_size=read_batch_size,
+ max_read_timeout=3,
)
+ if enable_replay:
+ config.replay_buffer = ReplayBufferConfig(enable=True)
sql_writer = SQLWriter(config.to_storage_config())
sql_reader = SQLReader(config.to_storage_config())
exps = [
@@ -53,13 +67,22 @@ async def test_sql_exp_buffer_read_write(self) -> None:
reward=float(i),
logprobs=torch.tensor([0.1]),
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
- info={"model_version": i},
+ info={"model_version": i + put_batch_size},
)
for i in range(1, put_batch_size * 2 + 1)
]
)
exps = sql_reader.read(batch_size=put_batch_size * 2)
self.assertEqual(len(exps), put_batch_size * 2)
+ for exp in exps:
+ self.assertTrue(exp.info["model_version"] > put_batch_size)
+ if enable_replay:
+ # support replay, so we can read all again
+ exps = sql_reader.read(batch_size=(put_batch_size * 2 + total_num))
+ self.assertEqual(len(exps), (put_batch_size * 2 + total_num))
+ # if read more than available, will wait until timeout
+ with self.assertRaises(StopIteration):
+ exps = sql_reader.read(batch_size=(put_batch_size * 3 + total_num))
db_wrapper = ray.get_actor("sql-test_buffer")
self.assertIsNotNone(db_wrapper)
self.assertEqual(await sql_writer.release(), 0)
diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py
index b1282d7c7a..299bbf66b8 100644
--- a/tests/explorer/explorer_test.py
+++ b/tests/explorer/explorer_test.py
@@ -5,11 +5,9 @@
import os
import random
import shutil
-import unittest
from datetime import datetime
import httpx
-import openai
import ray
from tests.tools import (
@@ -27,6 +25,7 @@
from trinity.common.config import ExperienceBufferConfig, InferenceModelConfig
from trinity.common.constants import StorageType
from trinity.explorer.explorer import Explorer
+from trinity.explorer.proxy.client import TrinityClient
from trinity.manager.state_manager import StateManager
@@ -158,8 +157,9 @@ def run_serve(config):
run_stage(config)
-def run_agent(base_url, model_path: str):
- client = openai.Client(base_url=base_url, api_key="testkey")
+def run_agent(proxy_url, model_path: str):
+ proxy_client = TrinityClient(proxy_url=proxy_url)
+ openai_client = proxy_client.get_openai_client()
contents = [
"Hello, how are you?",
"What is the capital of China?",
@@ -172,10 +172,11 @@ def run_agent(base_url, model_path: str):
"What is the best way to learn programming?",
"Describe the process of photosynthesis.",
]
- response = client.chat.completions.create(
+ response = openai_client.chat.completions.create(
model=model_path,
messages=[{"role": "user", "content": random.choice(contents)}],
)
+ proxy_client.feedback(reward=2.0, msg_ids=[response.id])
return response.choices[0].message.content
@@ -191,7 +192,7 @@ def setUp(self):
self.config.explorer.rollout_model.engine_num = 4
self.config.explorer.rollout_model.enable_openai_api = True
self.config.checkpoint_root_dir = get_checkpoint_path()
- self.config.explorer.api_port = 8010
+ self.config.explorer.proxy_port = 8010
self.config.explorer.service_status_check_interval = 30
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="experience_buffer",
@@ -201,7 +202,6 @@ def setUp(self):
if multiprocessing.get_start_method(allow_none=True) != "spawn":
multiprocessing.set_start_method("spawn", force=True)
- @unittest.skip("Require improvement for agent mode")
async def test_serve(self): # noqa: C901
serve_process = multiprocessing.Process(target=run_serve, args=(self.config,))
serve_process.start()
@@ -238,7 +238,7 @@ async def test_serve(self): # noqa: C901
apps = []
for i in range(task_num):
app_process = multiprocessing.Process(
- target=run_agent, args=(server_url + "/v1", self.config.model.model_path)
+ target=run_agent, args=(server_url, self.config.model.model_path)
)
apps.append(app_process)
app_process.start()
@@ -248,22 +248,20 @@ async def test_serve(self): # noqa: C901
self.assertFalse(app.is_alive())
finish_step = None
-
+ proxy_client = TrinityClient(proxy_url=server_url)
for i in range(20):
- async with httpx.AsyncClient() as client:
- response = await client.get(f"{server_url}/metrics")
- self.assertEqual(response.status_code, 200)
- metrics = response.json()
- metrics_keys = list(metrics.keys())
- self.assertIn("explore_step_num", metrics_keys)
- self.assertIn("rollout/total_experience_count", metrics_keys)
- self.assertIn("rollout/model_0/total_request_count", metrics_keys)
- self.assertIn("rollout/model_3/model_version", metrics_keys)
- if not finish_step and metrics["rollout/total_experience_count"] == task_num:
- finish_step = metrics["explore_step_num"]
- if finish_step and metrics["explore_step_num"] >= finish_step + 1:
- # wait for one more step to ensure all data are written to buffer
- break
+ metrics = await proxy_client.get_metrics_async()
+ metrics_keys = list(metrics.keys())
+ self.assertIn("explore_step_num", metrics_keys)
+ self.assertIn("rollout/total_experience_count", metrics_keys)
+ self.assertIn("rollout/model_0/total_request_count", metrics_keys)
+ self.assertIn("rollout/model_3/model_version", metrics_keys)
+ if not finish_step and metrics["rollout/total_experience_count"] == task_num:
+ finish_step = metrics["explore_step_num"]
+ await proxy_client.commit_async()
+ if finish_step and metrics["explore_step_num"] >= finish_step + 1:
+ # wait for one more step to ensure all data are written to buffer
+ break
await asyncio.sleep(3)
serve_process.terminate()
@@ -277,6 +275,9 @@ async def test_serve(self): # noqa: C901
exps = await buffer_reader.read_async(batch_size=10)
for exp in exps:
self.assertTrue(len(exp.tokens) > 0)
+ self.assertTrue(len(exp.logprobs) > 0)
+ self.assertTrue(exp.prompt_length > 0)
+ self.assertTrue(exp.reward == 2.0)
self.assertEqual(len(exps), task_num)
def tearDown(self):
diff --git a/tests/explorer/proxy_test.py b/tests/explorer/proxy_test.py
new file mode 100644
index 0000000000..e819c246aa
--- /dev/null
+++ b/tests/explorer/proxy_test.py
@@ -0,0 +1,86 @@
+import os
+import unittest
+import uuid
+from typing import List
+
+import torch
+
+from trinity.common.experience import EID, Experience
+from trinity.explorer.proxy.recorder import HistoryRecorder
+
+
+def get_dummy_experience(num: int) -> List[Experience]:
+ return [
+ Experience(
+ eid=EID(suffix=uuid.uuid4().hex[:6]),
+ tokens=torch.zeros(5),
+ prompt_length=2,
+ info={
+ "model_version": 0,
+ },
+ )
+ for _ in range(num)
+ ]
+
+
+db_path = os.path.join(os.path.dirname(__file__), "test_recorder.db")
+
+
+class RecorderTest(unittest.TestCase):
+ def setUp(self) -> None:
+ if os.path.exists(db_path):
+ os.remove(db_path)
+
+ def tearDown(self) -> None:
+ if os.path.exists(db_path):
+ os.remove(db_path)
+
+ def test_recorder(self):
+ recorder = HistoryRecorder(
+ # in memory sqlite for testing
+ db_url="sqlite:///" + db_path,
+ table_name="experience",
+ )
+ self.assertIsInstance(recorder, HistoryRecorder)
+ # test record history
+
+ experiences_1 = get_dummy_experience(3)
+ recorder.record_history(experiences_1)
+ # test update reward
+ msg_ids_1 = [exp.eid.suffix for exp in experiences_1]
+ experiences_2 = get_dummy_experience(2)
+ recorder.record_history(experiences_2)
+ updated_experiences = recorder.update_reward(
+ reward=1.0, msg_ids=msg_ids_1, run_id=1, task_id="test_task"
+ )
+ self.assertEqual(len(updated_experiences), 3)
+ for exp in updated_experiences:
+ self.assertEqual(exp.reward, 1.0)
+ self.assertEqual(exp.eid.run, 1)
+ self.assertEqual(str(exp.eid.task), "test_task")
+ # test update reward with non-existing msg_ids
+ updated_experiences_empty = recorder.update_reward(
+ reward=2.0, msg_ids=["non_existing_id"], run_id=1, task_id="test_task"
+ )
+ self.assertEqual(len(updated_experiences_empty), 0)
+ # test record history with empty experiences
+ recorder.record_history([]) # should not raise any exception
+ # test update reward multiple times
+ updated_experiences_2 = recorder.update_reward(
+ reward=3.0,
+ msg_ids=[exp.eid.suffix for exp in experiences_2],
+ run_id=2,
+ task_id="test_task_2",
+ )
+ self.assertEqual(len(updated_experiences_2), 2)
+ for exp in updated_experiences_2:
+ self.assertEqual(exp.reward, 3.0)
+ self.assertEqual(exp.eid.run, 2)
+ self.assertEqual(str(exp.eid.task), "test_task_2")
+ updated_experiences_3 = recorder.update_reward(
+ reward=4.0,
+ msg_ids=[exp.eid.suffix for exp in experiences_2],
+ run_id=3,
+ task_id="test_task_3",
+ )
+ self.assertEqual(len(updated_experiences_3), 0) # already consumed
diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py
index 16bb0b7bbb..a28e71ca4d 100644
--- a/tests/explorer/workflow_test.py
+++ b/tests/explorer/workflow_test.py
@@ -553,7 +553,7 @@ async def test_adapter(self):
try:
from agentscope.model import TrinityChatModel
except ImportError:
- self.skipTest("agentscope >= 0.1.6 is not installed")
+ self.skipTest("agentscope >= 1.0.9 is not installed")
async def as_workflow_func(task, model) -> float:
self.assertIsInstance(task, dict)
diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py
index 114e396aca..6ab488986e 100644
--- a/tests/manager/synchronizer_test.py
+++ b/tests/manager/synchronizer_test.py
@@ -80,32 +80,23 @@ async def new_finish_explore_step(self, step: int, model_version: int) -> None:
def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
- try:
- trainer_monkey_patch(config, max_steps, intervals)
- train(config)
- finally:
- ray.shutdown(_exiting_interpreter=True)
+ trainer_monkey_patch(config, max_steps, intervals)
+ train(config)
def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
- try:
- explorer_monkey_patch(config, max_steps, intervals)
- explore(config)
- finally:
- ray.shutdown(_exiting_interpreter=True)
+ explorer_monkey_patch(config, max_steps, intervals)
+ explore(config)
def run_both(
config: Config, max_steps: int, trainer_intervals: List[int], explorer_intervals: List[int]
) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
- try:
- trainer_monkey_patch(config, max_steps, trainer_intervals)
- explorer_monkey_patch(config, max_steps, explorer_intervals)
- both(config)
- finally:
- ray.shutdown(_exiting_interpreter=True)
+ trainer_monkey_patch(config, max_steps, trainer_intervals)
+ explorer_monkey_patch(config, max_steps, explorer_intervals)
+ both(config)
class BaseTestSynchronizer(unittest.TestCase):
@@ -115,6 +106,7 @@ def setUp(self):
def tearDown(self):
checkpoint_path = get_checkpoint_path()
+ ray.shutdown(_exiting_interpreter=True)
shutil.rmtree(os.path.join(checkpoint_path, "unittest"))
diff --git a/tests/template/config.yaml b/tests/template/config.yaml
index 745dcc50a1..0bf1c2eaf8 100644
--- a/tests/template/config.yaml
+++ b/tests/template/config.yaml
@@ -38,8 +38,6 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
dtype: bfloat16
seed: 42
trainer:
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index d6d736d73d..972071b99c 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -1,5 +1,6 @@
"""Tests for trainer."""
+import asyncio
import json
import multiprocessing
import os
@@ -15,6 +16,7 @@
from tests.tools import (
RayUnittestBase,
+ RayUnittestBaseAysnc,
TensorBoardParser,
get_checkpoint_path,
get_lora_config,
@@ -23,7 +25,8 @@
get_unittest_dataset_config,
get_vision_language_model_path,
)
-from trinity.cli.launcher import bench, both, explore, run, train
+from trinity.buffer import get_buffer_reader
+from trinity.cli.launcher import bench, both, explore, run, serve, train
from trinity.common.config import (
AlgorithmConfig,
BufferConfig,
@@ -42,6 +45,7 @@
SyncStyle,
)
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
+from trinity.explorer.proxy.client import TrinityClient
from trinity.manager.state_manager import StateManager
@@ -91,10 +95,11 @@ def test_trainer(self):
eval_tasksets[1].repeat_times = 4
self.config.trainer.save_interval = 4
self.config.trainer.save_hf_checkpoint = "always"
+ if self.strategy == "megatron":
+ self.config.trainer.trainer_strategy = "megatron"
self.config.check_and_update()
_trainer_config = self.config.trainer.trainer_config
if self.strategy == "megatron":
- _trainer_config.actor_rollout_ref.actor.strategy = "megatron"
_trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2
_trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2
_trainer_config.critic.strategy = "megatron"
@@ -502,6 +507,19 @@ def run_both(config: Config) -> None:
both(config)
+def run_serve(config: Config) -> None:
+ ray.init(
+ namespace=config.ray_namespace,
+ runtime_env={
+ "env_vars": {
+ LOG_DIR_ENV_VAR: config.log.save_dir,
+ LOG_LEVEL_ENV_VAR: "INFO",
+ }
+ },
+ )
+ serve(config)
+
+
@parameterized_class(
("use_priority_queue", "strategy"),
[(False, "fsdp"), (True, "fsdp"), (True, "megatron")],
@@ -534,10 +552,11 @@ def test_fully_async_mode(self):
trainer_config = deepcopy(config)
trainer_config.mode = "train"
trainer_config.buffer.train_batch_size = 4
+ if self.strategy == "megatron":
+ trainer_config.trainer.trainer_strategy = "megatron"
trainer_config.check_and_update()
if self.strategy == "megatron":
_trainer_config = trainer_config.trainer.trainer_config
- _trainer_config.actor_rollout_ref.actor.strategy = "megatron"
_trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2
_trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2
_trainer_config.critic.strategy = "megatron"
@@ -892,6 +911,175 @@ def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)
+async def run_math_workflow(serve_url: str, task: dict):
+ from trinity.common.rewards.math_reward import MathRewardFn
+
+ proxy_client = TrinityClient(serve_url)
+ openai_client = proxy_client.get_openai_async_client()
+
+ query = task["question"]
+ truth = task["answer"]
+
+ reward_fn = MathRewardFn()
+
+ system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e.,
+ reasoning process here
+ answer here .
+"""
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": query},
+ ]
+
+ models = await openai_client.models.list()
+ model = models.data[0].id
+
+ response = await openai_client.chat.completions.create(
+ model=model,
+ messages=messages,
+ )
+ answer = response.choices[0].message.content
+ reward = reward_fn(response=answer, truth=truth, prompt=query)
+ await proxy_client.feedback_async(sum(reward.values()), [response.id])
+
+
+class TestServeWithTrainer(RayUnittestBaseAysnc):
+ def setUp(self):
+ if multiprocessing.get_start_method(allow_none=True) != "spawn":
+ multiprocessing.set_start_method("spawn", force=True)
+ checkpoint_path = get_checkpoint_path()
+ shutil.rmtree(os.path.join(checkpoint_path, "unittest"), ignore_errors=True)
+
+ async def test_serve_with_trainer(self): # noqa: C901
+ config = get_template_config()
+ config.project = "unittest"
+ config.name = f"serve_with_trainer_{datetime.now().strftime('%Y%m%d%H%M%S')}"
+ config.checkpoint_root_dir = get_checkpoint_path()
+ config.model.model_path = get_model_path()
+ config.buffer.batch_size = 4
+ config.buffer.train_batch_size = 4
+ config.algorithm.algorithm_type = "ppo"
+ config.algorithm.repeat_times = 1
+ config.cluster.gpu_per_node = 2
+ config.cluster.node_num = 1
+ config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
+ name="exp_buffer",
+ storage_type=StorageType.SQL.value,
+ schema_type="experience",
+ )
+ config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
+ config.buffer.total_steps = 3
+ config.trainer.save_interval = 1
+ config.synchronizer.sync_interval = 1
+ config.synchronizer.sync_method = SyncMethod.CHECKPOINT
+ config.explorer.rollout_model.engine_num = 2
+ config.explorer.rollout_model.enable_openai_api = True
+ config.explorer.rollout_model.tensor_parallel_size = 1
+ config.explorer.service_status_check_interval = 5
+
+ trainer_config = deepcopy(config)
+ trainer_config.mode = "train"
+ trainer_config.check_and_update()
+
+ trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,))
+ trainer_process.start()
+
+ await asyncio.sleep(5)
+ serve_config = deepcopy(config)
+ serve_config.mode = "serve"
+ serve_config.check_and_update()
+ serve_process = multiprocessing.Process(target=run_serve, args=(serve_config,))
+ serve_process.start()
+
+ ray.init(ignore_reinit_error=True)
+ while True:
+ try:
+ ray.get_actor("sql-exp_buffer", namespace=trainer_config.ray_namespace)
+ break
+ except ValueError:
+ print("waiting for trainer to start.")
+ await asyncio.sleep(5)
+
+ state_manager = StateManager(
+ path=serve_config.checkpoint_job_dir,
+ explorer_name=serve_config.explorer.name,
+ )
+
+ # wait for explorer initialization
+ for i in range(30):
+ try:
+ server_url = state_manager.load_explorer_server_url()
+ except Exception:
+ server_url = None
+ if server_url:
+ break
+ await asyncio.sleep(3)
+ if not server_url:
+ raise RuntimeError("Explorer server URL not found.")
+ proxy_client = TrinityClient(server_url)
+ # wait for server setup
+ for i in range(10):
+ if proxy_client.alive():
+ print("Proxy server is alive.")
+ break
+ await asyncio.sleep(2)
+
+ config.buffer.explorer_input.taskset.batch_size = 4
+ reader = get_buffer_reader(config.buffer.explorer_input.taskset)
+
+ try:
+ for i in range(3):
+ tasks = reader.read()
+ await asyncio.gather(
+ *(run_math_workflow(server_url, task.raw_task) for task in tasks)
+ )
+ await proxy_client.commit_async()
+ # wait for synchronizer started
+ end_time = time.time()
+ find_checkpoint = False
+ while time.time() - end_time < 100:
+ _, step_num = get_checkpoint_dir_with_step_num(
+ checkpoint_root_path=serve_config.checkpoint_job_dir,
+ raise_error=False,
+ )
+ if step_num >= i + 1: # checkpoint has been generated
+ find_checkpoint = True
+ break
+ await asyncio.sleep(1)
+ self.assertTrue(find_checkpoint, f"Checkpoint at step {i + 1} not found in time.")
+ metrics = await proxy_client.get_metrics_async()
+ self.assertTrue(metrics["rollout/total_experience_count"] == 4 * (i + 1))
+ self.assertTrue(metrics["rollout/ready_experience_count"] == 4 * (i + 1))
+ self.assertTrue(metrics["rollout/model_0/total_request_count"] > 0)
+ self.assertTrue(metrics["rollout/model_1/total_request_count"] > 0)
+ if i > 1:
+ self.assertTrue(metrics["rollout/model_0/model_version"] > 0)
+ self.assertTrue(metrics["rollout/model_1/model_version"] > 0)
+ metrics = await proxy_client.get_metrics_async()
+ self.assertEqual(metrics["rollout/total_experience_count"], 12)
+ self.assertEqual(metrics["rollout/ready_experience_count"], 12)
+ self.assertTrue(metrics["rollout/model_0/total_request_count"] > 0)
+ self.assertTrue(metrics["rollout/model_1/total_request_count"] > 0)
+ self.assertEqual(
+ metrics["rollout/model_0/total_request_count"]
+ + metrics["rollout/model_1/total_request_count"],
+ metrics["rollout/total_experience_count"],
+ )
+ # at least updated to version 2
+ self.assertTrue(metrics["rollout/model_0/model_version"] >= 2)
+ self.assertTrue(metrics["rollout/model_1/model_version"] >= 2)
+ # check final checkpoint
+ _, step_num = get_checkpoint_dir_with_step_num(
+ checkpoint_root_path=serve_config.checkpoint_job_dir,
+ step_num=3,
+ )
+ finally:
+ serve_process.terminate()
+ trainer_process.terminate()
+ serve_process.join()
+ trainer_process.join()
+
+
class TestMultiModalGRPO(BaseTrainerCase):
@unittest.skip("Require specific vllm/transformers version")
def test_trainer(self):
diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py
index 997c661a23..7f22e5370e 100644
--- a/trinity/buffer/schema/sql_schema.py
+++ b/trinity/buffer/schema/sql_schema.py
@@ -2,7 +2,17 @@
from typing import Dict, Optional, Tuple
-from sqlalchemy import JSON, Column, Float, Integer, LargeBinary, Text, create_engine
+from sqlalchemy import (
+ JSON,
+ Column,
+ DateTime,
+ Float,
+ Integer,
+ LargeBinary,
+ String,
+ create_engine,
+ func,
+)
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import declarative_base
from sqlalchemy.pool import NullPool
@@ -32,32 +42,36 @@ class ExperienceModel(Base): # type: ignore
__abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True)
- # for single turn
- prompt = Column(Text, nullable=True)
- response = Column(Text, nullable=True)
- # for multi turn
- message_list = Column(JSON, nullable=True)
- reward = Column(Float, nullable=True)
- # for step info
- model_version = Column(Integer, nullable=True, index=True)
+ timestamp = Column(DateTime, server_default=func.now())
+ task_id = Column(String(64), nullable=True, index=True) # associated task id
+ run_id = Column(Integer, nullable=True, index=True) # associated run id
+ msg_id = Column(String(64), nullable=True, index=True) # associated message id
# serialized experience object
+ model_version = Column(Integer, nullable=True, index=True)
experience_bytes = Column(LargeBinary, nullable=True)
+ reward = Column(Float, nullable=True)
consumed = Column(Integer, default=0, index=True)
def to_experience(self) -> Experience:
"""Load the experience from the database."""
- return Experience.deserialize(self.experience_bytes)
+ exp = Experience.deserialize(self.experience_bytes)
+ exp.eid.task = self.task_id
+ exp.eid.run = self.run_id
+ exp.eid.suffix = self.msg_id
+ exp.reward = self.reward
+ exp.info["model_version"] = self.model_version
+ return exp
@classmethod
def from_experience(cls, experience: Experience):
"""Save the experience to database."""
return cls(
- prompt=experience.prompt_text,
- response=experience.response_text,
- message_list=experience.messages,
+ experience_bytes=experience.serialize(),
reward=experience.reward,
+ task_id=str(experience.eid.task),
+ run_id=experience.eid.run,
+ msg_id=str(experience.eid.suffix),
model_version=experience.info["model_version"],
- experience_bytes=experience.serialize(),
)
@@ -107,7 +121,7 @@ def from_experience(cls, experience: Experience):
)
-def init_engine(db_url: str, table_name, schema_type: Optional[str]) -> Tuple:
+def init_engine(db_url: str, table_name: str, schema_type: Optional[str]) -> Tuple:
"""Get the sqlalchemy engine."""
logger = get_logger(__name__)
engine = create_engine(db_url, poolclass=NullPool)
@@ -128,6 +142,7 @@ def init_engine(db_url: str, table_name, schema_type: Optional[str]) -> Tuple:
try:
Base.metadata.create_all(engine, checkfirst=True)
+ logger.info(f"Created table {table_name} for schema type {schema_type}.")
except OperationalError:
logger.warning(f"Failed to create table {table_name}, assuming it already exists.")
diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py
index 08ff06fb8c..04f0c20bda 100644
--- a/trinity/buffer/storage/sql.py
+++ b/trinity/buffer/storage/sql.py
@@ -97,6 +97,7 @@ def __init__(self, config: StorageConfig) -> None:
super().__init__(config)
self.max_timeout = config.max_read_timeout
self.batch_size = config.batch_size
+ self.enable_replay = config.replay_buffer is not None and config.replay_buffer.enable
# TODO: optimize the following logic
if config.schema_type == "experience":
# NOTE: consistent with the old version of experience buffer
@@ -161,6 +162,8 @@ def _read_priority(self, batch_size: int, min_model_version: int = 0) -> List[Ex
query = session.query(self.table_model_cls)
if min_model_version > 0:
query = query.filter(self.table_model_cls.model_version >= min_model_version)
+ if not self.enable_replay:
+ query = query.filter(self.table_model_cls.consumed == 0)
experiences = (
query.order_by(
asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)
diff --git a/trinity/buffer/utils.py b/trinity/buffer/utils.py
index b27ba662bc..689a8cce3d 100644
--- a/trinity/buffer/utils.py
+++ b/trinity/buffer/utils.py
@@ -5,7 +5,7 @@
@contextmanager
-def retry_session(session_maker, max_retry_times: int, max_retry_interval: float):
+def retry_session(session_maker, max_retry_times: int = 2, max_retry_interval: float = 1.0):
"""A Context manager for retrying session."""
logger = get_logger(__name__)
for attempt in range(max_retry_times):
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 883e71752d..aee3a44fae 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -483,7 +483,7 @@ class InferenceModelConfig:
tensor_parallel_size: int = 1
use_v1: bool = True
enforce_eager: bool = False
- enable_prefix_caching: bool = False
+ enable_prefix_caching: bool = True
enable_chunked_prefill: bool = True
gpu_memory_utilization: float = 0.9
dtype: str = "bfloat16"
@@ -669,14 +669,17 @@ class ExplorerConfig:
# for benchmark
bench_on_latest_checkpoint: bool = False # only benchmark the latest checkpoint
- # for serve mode
- api_port: int = 8010
+ # for serve mode proxy
+ proxy_port: int = 8010
# listen on all interfaces by default
listen_address: str = "0.0.0.0"
# check the running status of the server every 60 seconds
service_status_check_interval: int = 60
# keep at least 1 model in running status
min_running_model_num: int = 1
+ # db url for proxy history recorder, if not set, use proxy_history.db in buffer cache dir
+ db_url: Optional[str] = None
+
# Experimental feature
over_rollout: OverRolloutConfig = field(default_factory=OverRolloutConfig)
dynamic_timeout: DynamicTimeoutConfig = field(default_factory=DynamicTimeoutConfig)
@@ -688,7 +691,7 @@ class ExplorerConfig:
class TrainerConfig:
name: str = TRAINER_NAME
trainer_type: str = "verl"
- trainer_strategy: str = "fsdp"
+ trainer_strategy: str = "fsdp" # "fsdp", "fsdp2" or "megatron"
save_interval: int = 0
enable_preview: bool = True # enable rollout preview in wandb
total_steps: Optional[
@@ -1396,7 +1399,7 @@ def check_and_update(self) -> Config: # noqa: C901
# check buffer
self._check_buffer()
# check and update trainer
- if self.mode in ["train", "both", "bench"]:
+ if self.mode in ["train", "both", "bench"] or self.trainer.trainer_strategy == "megatron":
if self.trainer.trainer_type == "verl":
if self.trainer.trainer_config:
from trinity.common.verl_config import veRLConfig
diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py
index b140848903..190be581cb 100644
--- a/trinity/common/models/__init__.py
+++ b/trinity/common/models/__init__.py
@@ -84,6 +84,9 @@ def create_inference_models(
allocator = _BundleAllocator(node_bundle_map)
namespace = ray.get_runtime_context().namespace
# create rollout models
+ # in 'serve' mode, we always enable openai api for rollout model
+ if config.mode == "serve":
+ config.explorer.rollout_model.enable_openai_api = True
for i in range(config.explorer.rollout_model.engine_num):
bundles_for_engine = allocator.allocate(config.explorer.rollout_model.tensor_parallel_size)
config.explorer.rollout_model.bundle_indices = ",".join(
diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py
index 110db0a3b7..3fe6f2bf37 100644
--- a/trinity/common/models/model.py
+++ b/trinity/common/models/model.py
@@ -379,7 +379,7 @@ async def get_current_load(self) -> int:
raise ValueError(
"API server is not enabled for this model. Load metrics is unavailable."
)
- with httpx.AsyncClient() as client:
+ async with httpx.AsyncClient() as client:
response = await client.get(f"{self.api_address}/load")
data = response.json()
return data["server_load"]
diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py
index d33a2b424e..445f1f63ba 100644
--- a/trinity/common/models/utils.py
+++ b/trinity/common/models/utils.py
@@ -165,6 +165,34 @@ def get_checkpoint_dir_with_step_num(
raise NotImplementedError(f"Unsupported trainer type {trainer_type}")
+def get_latest_state_dict(
+ checkpoint_root_path: str,
+ trainer_type: str = "verl",
+) -> Tuple[str, int]:
+ """Get the latest state dict from a root checkpoint directory.
+
+ Args:
+ checkpoint_root_path (str): The root checkpoint directory.
+
+ Returns:
+ Tuple[str, int]: The state dict path and the iteration of the state dict.
+ If the state dict does not exist, return (None, 0).
+ """
+ if trainer_type != "verl":
+ raise NotImplementedError(f"Unsupported trainer type {trainer_type}")
+ latest_state_dict_iteration_path = os.path.join(
+ checkpoint_root_path, "latest_state_dict_iteration.txt"
+ )
+ if os.path.exists(latest_state_dict_iteration_path):
+ with open(latest_state_dict_iteration_path, "r", encoding="utf-8") as f:
+ iteration = f.read().strip()
+ state_dict_path = os.path.join(
+ checkpoint_root_path, f"global_step_{iteration}", "actor"
+ )
+ return state_dict_path, int(iteration)
+ return None, 0 # type: ignore
+
+
def load_state_dict(checkpoint_dir: str, config: TrainerConfig) -> Union[dict, Tuple[str, str]]:
"""Load state dict from a checkpoint dir.
@@ -173,11 +201,11 @@ def load_state_dict(checkpoint_dir: str, config: TrainerConfig) -> Union[dict, T
trainer_type (str): The trainer type. Only support "verl" for now.
"""
if config.trainer_type == "verl":
- actor_config = config.trainer_config.actor_rollout_ref.actor
- strategy = actor_config.strategy
+ strategy = config.trainer_strategy
if strategy in {"fsdp", "fsdp2"}:
return load_fsdp_state_dict_from_verl_checkpoint(checkpoint_dir)
elif strategy == "megatron":
+ actor_config = config.trainer_config.actor_rollout_ref.actor
if (
actor_config.megatron.use_dist_checkpointing
or not actor_config.megatron.use_mbridge
diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py
index d1f1d3f61f..c6f85b48f8 100644
--- a/trinity/common/models/vllm_model.py
+++ b/trinity/common/models/vllm_model.py
@@ -109,11 +109,11 @@ def __init__(
distributed_executor_backend=("uni" if config.tensor_parallel_size == 1 else "ray"),
max_model_len=max_model_len,
enable_prefix_caching=config.enable_prefix_caching,
+ enable_chunked_prefill=config.enable_chunked_prefill,
dtype=config.dtype,
trust_remote_code=True,
task="generate",
gpu_memory_utilization=config.gpu_memory_utilization,
- # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage
override_generation_config={ # TODO: find a way to unittest this
"temperature": config.temperature,
"top_p": config.top_p,
@@ -522,6 +522,7 @@ async def sync_model(self, model_version: int) -> int:
await self.async_llm.add_lora(self.get_lora_request(self.default_lora_path))
self.model_version = model_version
return model_version
+ await self.async_llm.reset_prefix_cache()
await self._collective_rpc("update_weight")
self.logger.info("Sync model weights to vLLM successfully.")
self.model_version = model_version
diff --git a/trinity/explorer/api/api.py b/trinity/explorer/api/api.py
deleted file mode 100644
index b702c39f5a..0000000000
--- a/trinity/explorer/api/api.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import traceback
-
-import httpx
-import uvicorn
-from fastapi import FastAPI, Request
-from fastapi.responses import JSONResponse, Response
-
-app = FastAPI()
-
-
-# Forward openAI requests to a model instance
-
-
-@app.post("/v1/chat/completions")
-async def chat_completions(request: Request):
- # Currently, we do not support streaming chat completions
- body = await request.json()
- if "return_token_ids" not in body:
- body["return_token_ids"] = True
- url = await request.app.state.service.allocate_model()
- try:
- async with httpx.AsyncClient(timeout=request.app.state.inference_timeout) as client:
- resp = await client.post(f"{url}/v1/chat/completions", json=body)
- except Exception:
- return Response(
- status_code=500,
- content=f"Error forwarding request to model at {url}: {traceback.format_exc()}",
- )
- resp_data = resp.json()
- await request.app.state.service.record_experience(resp_data)
- return JSONResponse(content=resp_data)
-
-
-@app.get("/v1/models")
-async def show_available_models(request: Request):
- body = await request.json()
- url = await request.app.state.service.allocate_model(increase_count=False)
- async with httpx.AsyncClient() as client:
- resp = await client.get(f"{url}/v1/models", json=body)
- return JSONResponse(content=resp.json())
-
-
-@app.get("/health")
-async def health(request: Request) -> Response:
- """Health check."""
- return Response(status_code=200)
-
-
-@app.get("/metrics")
-async def metrics(request: Request):
- """Get the metrics of the service."""
- metrics = request.app.state.service.collect_metrics()
- metrics["explore_step_num"] = request.app.state.service.explorer.explore_step_num
- return JSONResponse(content=metrics)
-
-
-async def serve_http(app: FastAPI, host: str, port: int = None):
- config = uvicorn.Config(app, host=host, port=port)
- server = uvicorn.Server(config)
- await server.serve()
-
-
-async def run_app(service, listen_address: str, port: int = None) -> FastAPI:
- app.state.service = service
- app.state.inference_timeout = service.explorer.config.synchronizer.sync_timeout
- print(f"API server running on {listen_address}:{port}")
- await serve_http(app, listen_address, port)
diff --git a/trinity/explorer/api/service.py b/trinity/explorer/api/service.py
deleted file mode 100644
index ffdf2cfd9a..0000000000
--- a/trinity/explorer/api/service.py
+++ /dev/null
@@ -1,160 +0,0 @@
-import asyncio
-import time
-from collections import deque
-from typing import Dict, List
-
-import torch
-
-from trinity.common.constants import RunningStatus
-from trinity.common.experience import Experience
-from trinity.common.models.model import ModelWrapper
-from trinity.explorer.explorer import Explorer
-from trinity.utils.log import get_logger
-
-
-class ExplorerService:
- def __init__(self, explorer: Explorer, listen_address: str = "localhost", port: int = 8010):
- self.logger = get_logger(__name__)
- self.explorer = explorer
- self.app = None
- self.port = port
- self.listen_address = listen_address
- self.running = False
- self.models: List[ModelWrapper] = [ModelWrapper(model) for model in explorer.models]
- self.min_running_model_num = explorer.config.explorer.min_running_model_num
- self.check_interval = explorer.config.explorer.service_status_check_interval
- self.max_timeout = explorer.config.explorer.max_timeout
- self.running_models: deque[int] = deque() # indices of running models
- self.sync_task_map: Dict[asyncio.Future, int] = {} # sync task -> model index
- self.latest_model_version = 0
- self.experience_queue = asyncio.Queue()
- self.experience_count = 0
-
- async def serve(self):
- from trinity.explorer.api.api import run_app
-
- if self.running:
- self.logger.warning("Server is already running.")
- return
-
- self.running = True
- await asyncio.gather(*[model.prepare() for model in self.models])
-
- for i, _ in enumerate(self.models):
- self.running_models.append(i)
-
- self.serve_task = asyncio.create_task(
- run_app(service=self, listen_address=self.listen_address, port=self.port)
- )
- self.sync_model_weights_task = asyncio.create_task(self.model_weights_sync_loop())
-
- async def model_weights_sync_loop(self):
- self.logger.info("Starting model weights synchronization loop.")
- while self.running:
- for idx in list(self.running_models):
- if (
- len(self.running_models) > self.explorer.config.explorer.min_running_model_num
- and self.models[idx].model_version < self.latest_model_version
- ):
- self.running_models.remove(idx)
- self.models[idx].status = RunningStatus.REQUIRE_SYNC
- self.logger.info(f"Model {idx} scheduled for synchronization.")
- future = asyncio.create_task(self._wait_for_sync_start(idx))
- self.sync_task_map[future] = idx
- future.add_done_callback(self._sync_model_weights)
- # wait half interval
- await asyncio.sleep(self.check_interval / 2)
- self.logger.info("Model weights synchronization loop stopped.")
-
- def set_latest_model_version(self, version: int) -> None:
- if version > self.latest_model_version:
- self.latest_model_version = version
- self.logger.info(f"Updated latest model version to {version}.")
-
- async def _wait_for_sync_start(self, index: int):
- start_time = time.time()
- while time.time() - start_time < self.max_timeout:
- current_load = await self.models[index].get_current_load()
- if current_load == 0:
- self.models[index].status = RunningStatus.WAITING_SYNC
- self.logger.info(f"Model {index} begins synchronization.")
- return
- else:
- await asyncio.sleep(2)
- raise asyncio.TimeoutError(
- f"Timeout waiting for model {index} to be free for synchronization. Current load: {current_load}"
- )
-
- async def _sync_model_weights(self, task: asyncio.Future):
- index = self.sync_task_map.pop(task)
- latest_version = self.latest_model_version # capture the latest version
- if task.cancelled():
- self.logger.warning(f"Synchronization of model {index} was cancelled.")
- elif task.exception():
- self.logger.error(f"Error during synchronization of model {index}: {task.exception()}")
- else:
- await self.models[index].sync_model_weights(latest_version)
- self.logger.info(f"Model {index} synchronized to version {latest_version}.")
- self.running_models.append(index)
- self.models[index].status = RunningStatus.RUNNING
-
- async def allocate_model(self, increase_count: bool = True) -> str:
- model = self.models[self.running_models[0]]
- if increase_count:
- model.request_count += 1
- self.running_models.rotate(-1)
- return model.api_address
-
- def collect_metrics(self) -> Dict:
- metrics = {}
- for i, model in enumerate(self.models):
- metrics[f"rollout/model_{i}/total_request_count"] = model.request_count
- metrics[f"rollout/model_{i}/model_version"] = model.model_version
- metrics["rollout/total_experience_count"] = self.experience_count
- return metrics
-
- async def check_requiring_sync_models(self):
- if not self.running:
- self.logger.warning("Server is not running.")
- return
- await asyncio.gather(
- *[self._sync_model_weights(idx) for idx in list(self.requiring_sync_models)]
- )
-
- async def record_experience(self, response):
- experiences = []
- for choice in response["choices"]:
- exp = Experience(
- tokens=torch.cat(
- (
- torch.tensor(response["prompt_token_ids"], dtype=torch.int32),
- torch.tensor(choice["token_ids"], dtype=torch.int32),
- )
- ),
- logprobs=choice.get("logprobs", None),
- prompt_length=len(response["prompt_token_ids"]),
- response_text=choice.get("message", {}).get("content", ""),
- )
- experiences.append(exp)
- self.experience_count += len(experiences)
- for exp in experiences:
- await self.experience_queue.put(exp)
-
- async def get_all_experiences(self) -> List:
- experiences = []
- while not self.experience_queue.empty():
- experiences.append(await self.experience_queue.get())
- return experiences
-
- async def shutdown(self):
- if not self.running:
- self.logger.warning("Server is not running.")
- return
- self.sync_model_weights_task.cancel()
- self.serve_task.cancel()
- try:
- await self.serve_task
- except asyncio.CancelledError:
- pass
- self.running = False
- self.logger.info("API server shut down.")
diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py
index c690bc407d..767d30a5ed 100644
--- a/trinity/explorer/explorer.py
+++ b/trinity/explorer/explorer.py
@@ -25,7 +25,6 @@
SyncStyle,
)
from trinity.common.models import create_inference_models
-from trinity.common.models.utils import get_checkpoint_dir_with_step_num
from trinity.explorer.scheduler import Scheduler
from trinity.manager.state_manager import StateManager
from trinity.manager.synchronizer import Synchronizer
@@ -91,7 +90,6 @@ def __init__(self, config: Config):
async def setup_weight_sync_group(
self, master_address: str, master_port: int, state_dict_meta: List = None
):
- # In checkpoint mode, we use explorer to store the model weights which has no rank
base_offset = 1 if self.use_nccl_sync else 0
world_size = (
len(self.models) * self.config.explorer.rollout_model.tensor_parallel_size + base_offset
@@ -118,6 +116,30 @@ async def setup_weight_sync_group(
]
await asyncio.gather(*refs)
+ async def setup_model_level_weight_sync_group(self):
+ """Setup process group for each model, only used in serve mode."""
+ refs = []
+ world_size = self.config.explorer.rollout_model.tensor_parallel_size
+ for model in self.models:
+ master_address, master_port = await model.get_available_address.remote()
+ self.logger.info(
+ f"Initialize process group for model weight synchronization, "
+ f"master_address={master_address}, master_port={master_port}, "
+ f"world_size={world_size}"
+ )
+ refs.append(
+ model.init_process_group.remote(
+ master_address=master_address,
+ master_port=master_port,
+ rank_offset=0,
+ world_size=world_size,
+ group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
+ explorer_name=self.config.explorer.name,
+ timeout=self.config.synchronizer.sync_timeout,
+ )
+ )
+ await asyncio.gather(*refs)
+
async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int:
self.logger.info(f"Start to update model weights from checkpoint at step {step_num}.")
step_num = await self.synchronizer.set_model_state_dict_with_step_num.remote(step_num)
@@ -174,8 +196,14 @@ async def prepare(self) -> None:
self.logger.info("All models are ready.")
if not self.use_nccl_sync:
- master_address, master_port = await self.models[0].get_available_address.remote()
- await self.setup_weight_sync_group(master_address, master_port)
+ if self.config.mode == "serve":
+ # In serving mode, each engine will setup its own process group
+ await self.setup_model_level_weight_sync_group()
+ else:
+ master_address, master_port = await self.models[
+ 0
+ ].get_available_address.remote()
+ await self.setup_weight_sync_group(master_address, master_port)
if self.config.mode != "serve":
self.scheduler = Scheduler(self.config, self.models, self.auxiliary_models)
await self.scheduler.start()
@@ -233,9 +261,11 @@ async def explore_step(self) -> bool:
await self.save_checkpoint(sync_weight=False)
await self.synchronizer.set_explorer_status.remote(
RunningStatus.STOPPED,
- old_status=RunningStatus.RUNNING
- if self.last_sync_successful
- else RunningStatus.REQUIRE_SYNC,
+ old_status=(
+ RunningStatus.RUNNING
+ if self.last_sync_successful
+ else RunningStatus.REQUIRE_SYNC
+ ),
)
await self.shutdown()
return False
@@ -466,32 +496,26 @@ async def serve(self) -> None:
messages=[{"role": "user", "content": "Hello!"}]
)
"""
- from trinity.explorer.api.service import ExplorerService
+ from trinity.explorer.proxy.service import ExplorerService
self.service = ExplorerService(
self,
listen_address=self.config.explorer.listen_address,
- port=self.config.explorer.api_port,
+ port=self.config.explorer.proxy_port,
)
await self.service.serve()
self.server_url = f"http://{ray.util.get_node_ip_address()}:{self.service.port}"
self.logger.info(
- f"Explorer API Server is started on {self.server_url} and listening to {self.service.listen_address}."
+ "======================================================\n"
+ f"Starting Trinity Service on {self.server_url}\n"
+ "======================================================"
)
self.state.save_explorer_server_url(self.server_url)
while True:
- self.explore_step_num += 1
await asyncio.sleep(self.config.explorer.service_status_check_interval)
- # process experiences generated in the last interval
- exps = await self.service.get_all_experiences()
- metrics = await self.experience_pipeline.process.remote(exps)
- metrics.update(self.service.collect_metrics())
- self.monitor.log(metrics, self.explore_step_num)
# get the latest checkpoint
- _, step_num = get_checkpoint_dir_with_step_num(
- self.config.checkpoint_job_dir, raise_error=False
- )
- self.service.set_latest_model_version(step_num)
+ model_version = await self.synchronizer.get_latest_model_version.remote()
+ self.service.set_latest_model_version(model_version)
@classmethod
def get_actor(cls, config: Config):
diff --git a/trinity/explorer/explorer_client.py b/trinity/explorer/explorer_client.py
deleted file mode 100644
index 311b310038..0000000000
--- a/trinity/explorer/explorer_client.py
+++ /dev/null
@@ -1,49 +0,0 @@
-from functools import partial
-
-import httpx
-import openai
-import requests
-
-
-class ExplorerClient:
- def __init__(self, base_url: str):
- self.base_url = base_url
- self.session_id = self.init_session()
-
- def init_session(self) -> str:
- response = requests.post(f"{self.base_url}/allocate")
- data = response.json()
- return data["session_id"]
-
- def get_openai_client(self) -> openai.OpenAI:
- client = openai.OpenAI(
- base_url=self.base_url + "/v1",
- api_key="EMPTY",
- )
- client.chat.completions.create = partial(
- client.chat.completions.create, extra_body={"session_id": self.session_id}
- )
- return client
-
- def get_openai_async_client(self) -> openai.AsyncOpenAI:
- client = openai.AsyncOpenAI(
- base_url=self.base_url + "/v1",
- api_key="EMPTY",
- )
- client.chat.completions.create = partial(
- client.chat.completions.create, extra_body={"session_id": self.session_id}
- )
- return client
-
- def feedback(self, reward: float):
- response = requests.post(
- f"{self.base_url}/feedback", json={"session_id": self.session_id, "reward": reward}
- )
- return response.json()
-
- async def feedback_async(self, reward: float):
- async with httpx.AsyncClient() as client:
- response = await client.post(
- f"{self.base_url}/feedback", json={"session_id": self.session_id, "reward": reward}
- )
- return response.json()
diff --git a/trinity/explorer/api/__init__.py b/trinity/explorer/proxy/__init__.py
similarity index 100%
rename from trinity/explorer/api/__init__.py
rename to trinity/explorer/proxy/__init__.py
diff --git a/trinity/explorer/proxy/app.py b/trinity/explorer/proxy/app.py
new file mode 100644
index 0000000000..84cb05e058
--- /dev/null
+++ b/trinity/explorer/proxy/app.py
@@ -0,0 +1,124 @@
+import traceback
+from contextlib import asynccontextmanager
+
+import httpx
+import uvicorn
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.responses import JSONResponse, Response
+
+http_client: httpx.AsyncClient = None
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ global http_client
+ http_client = httpx.AsyncClient(
+ timeout=httpx.Timeout(300.0, connect=10.0),
+ limits=httpx.Limits(max_keepalive_connections=20, max_connections=100),
+ )
+ yield
+ await http_client.aclose()
+
+
+app = FastAPI(lifespan=lifespan)
+
+
+# Forward OpenAI requests to a model instance
+@app.post("/v1/chat/completions")
+async def chat_completions(request: Request):
+ # Currently, we do not support streaming chat completions
+ try:
+ request_data = await request.json()
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
+
+ forward_headers = {
+ key: value
+ for key, value in request.headers.items()
+ if key.lower() not in ["host", "content-length", "transfer-encoding"]
+ }
+ # for experience data recording, we need to return token ids and logprobs
+ request_data["return_token_ids"] = True
+ request_data["logprobs"] = True
+ # temperature must be set from config, ignore user's input
+ request_data["temperature"] = request.app.state.temperature
+ url, model_version = await request.app.state.service.allocate_model()
+ try:
+ async with httpx.AsyncClient(timeout=request.app.state.inference_timeout) as client:
+ resp = await client.post(
+ f"{url}/v1/chat/completions", json=request_data, headers=forward_headers
+ )
+ except Exception:
+ return Response(
+ status_code=500,
+ content=f"Error forwarding request to model at {url}: {traceback.format_exc()}",
+ )
+ resp_data = resp.json()
+ await request.app.state.service.record_experience(resp_data, model_version)
+ return JSONResponse(content=resp_data)
+
+
+@app.get("/v1/models")
+async def show_available_models(request: Request):
+ if hasattr(request.app.state, "models"):
+ return JSONResponse(content=request.app.state.models)
+ url, _ = await request.app.state.service.allocate_model(increase_count=False)
+ async with httpx.AsyncClient() as client:
+ print(f"Fetching models from {url}/v1/models")
+ resp = await client.get(f"{url}/v1/models")
+ request.app.state.models = resp.json()
+ return JSONResponse(content=resp.json())
+
+
+@app.get("/health")
+async def health(request: Request) -> Response:
+ """Health check."""
+ return Response(status_code=200)
+
+
+@app.get("/metrics")
+async def metrics(request: Request):
+ """Get the metrics of the service."""
+ metrics = request.app.state.service.collect_metrics()
+ metrics["explore_step_num"] = request.app.state.service.explorer.explore_step_num
+ return JSONResponse(content=metrics)
+
+
+@app.post("/feedback")
+async def feedback(request: Request):
+ """Receive feedback for the current session."""
+ body = await request.json()
+ reward = body.get("reward")
+ msg_ids = body.get("msg_ids")
+ task_id = body.get("task_id")
+ run_id = body.get("run_id", 0)
+ if msg_ids is None or reward is None:
+ return JSONResponse(status_code=400, content={"error": "msg_ids and reward are required"})
+ if not isinstance(msg_ids, list) or not isinstance(reward, (int, float)):
+ return JSONResponse(
+ status_code=400, content={"error": "msg_ids must be a list and reward must be a number"}
+ )
+ await request.app.state.service.record_feedback(
+ reward=reward, msg_ids=msg_ids, task_id=task_id, run_id=run_id
+ )
+ return JSONResponse(content={"status": "success"})
+
+
+@app.post("/commit")
+async def commit(request: Request):
+ """Commit the current experiences."""
+ await request.app.state.service.submit_experiences()
+ return JSONResponse(content={"status": "success"})
+
+
+async def serve_http(app: FastAPI, host: str, port: int) -> None:
+ config = uvicorn.Config(app, host=host, port=port)
+ server = uvicorn.Server(config)
+ await server.serve()
+
+
+async def run_app(service, listen_address: str, port: int) -> None:
+ app.state.service = service
+ app.state.temperature = service.explorer.config.model.temperature
+ app.state.inference_timeout = service.explorer.config.synchronizer.sync_timeout
+ await serve_http(app, listen_address, port)
diff --git a/trinity/explorer/proxy/client.py b/trinity/explorer/proxy/client.py
new file mode 100644
index 0000000000..e3e5a594a8
--- /dev/null
+++ b/trinity/explorer/proxy/client.py
@@ -0,0 +1,71 @@
+import uuid
+
+import httpx
+import openai
+import requests
+
+
+class TrinityClient:
+ def __init__(self, proxy_url: str):
+ self.proxy_url = proxy_url
+ self.openai_base_url = f"{self.proxy_url}/v1"
+ self.feedback_url = f"{self.proxy_url}/feedback"
+ self.task_id = uuid.uuid4().hex[:6]
+
+ def alive(self) -> bool:
+ try:
+ response = requests.get(f"{self.proxy_url}/health", timeout=2)
+ return response.status_code == 200
+ except requests.RequestException:
+ return False
+
+ def get_openai_client(self) -> openai.OpenAI:
+ client = openai.OpenAI(
+ base_url=self.openai_base_url,
+ api_key="EMPTY",
+ )
+ return client
+
+ def get_openai_async_client(self) -> openai.AsyncOpenAI:
+ client = openai.AsyncOpenAI(
+ base_url=self.openai_base_url,
+ api_key="EMPTY",
+ )
+ return client
+
+ def feedback(self, reward: float, msg_ids: list[str], timeout: float = 10) -> dict:
+ response = requests.post(
+ self.feedback_url,
+ json={"reward": reward, "msg_ids": msg_ids, "task_id": self.task_id},
+ timeout=timeout,
+ )
+ return response.json()
+
+ async def feedback_async(self, reward: float, msg_ids: list[str], timeout: float = 10) -> dict:
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ self.feedback_url,
+ json={"reward": reward, "msg_ids": msg_ids, "task_id": self.task_id},
+ timeout=timeout,
+ )
+ return response.json()
+
+ def commit(self, timeout: float = 10) -> dict:
+ response = requests.post(f"{self.proxy_url}/commit", timeout=timeout)
+ return response.json()
+
+ async def commit_async(self, timeout: float = 10) -> dict:
+ async with httpx.AsyncClient() as client:
+ response = await client.post(f"{self.proxy_url}/commit", timeout=timeout)
+ return response.json()
+
+ def get_metrics(self, timeout: float = 5) -> dict:
+ response = requests.get(f"{self.proxy_url}/metrics", timeout=timeout)
+ return response.json()
+
+ async def get_metrics_async(self, timeout: float = 5) -> dict:
+ async with httpx.AsyncClient() as client:
+ response = await client.get(f"{self.proxy_url}/metrics", timeout=timeout)
+ if response.status_code != 200:
+ raise ValueError(f"Failed to get metrics: {response.text}")
+ return response.json()
diff --git a/trinity/explorer/proxy/recorder.py b/trinity/explorer/proxy/recorder.py
new file mode 100644
index 0000000000..d5eaded27e
--- /dev/null
+++ b/trinity/explorer/proxy/recorder.py
@@ -0,0 +1,75 @@
+from typing import List
+
+from sqlalchemy.orm import sessionmaker
+
+from trinity.buffer.schema import init_engine
+from trinity.buffer.utils import retry_session
+from trinity.common.experience import Experience
+from trinity.utils.log import get_logger
+
+
+# TODO: implement an async version in the future
+class HistoryRecorder:
+ """Record chat history into the database."""
+
+ def __init__(self, db_url: str, table_name: str):
+ self.logger = get_logger()
+ self.engine, self.table_model_cls = init_engine(
+ db_url=db_url,
+ table_name=table_name,
+ schema_type="experience",
+ )
+ self.logger.info(f"Init SQL storage at {db_url}")
+ self.session = sessionmaker(bind=self.engine)
+
+ def record_history(self, experiences: List[Experience]) -> None:
+ """Save experience to the database."""
+ with retry_session(self.session) as db:
+ exps = [self.table_model_cls.from_experience(exp) for exp in experiences]
+ db.add_all(exps)
+
+ def update_reward(
+ self, reward: float, msg_ids: list, run_id: int, task_id: str
+ ) -> List[Experience]:
+ """Update reward for given response IDs and return the updated experiences.
+
+ Args:
+ reward (float): The reward value to be updated.
+ msg_ids (list): List of message IDs to update.
+ run_id (int): The run ID associated with the experiences.
+ task_id (str): The task ID associated with the experiences.
+
+ Returns:
+ List[Experience]: List of updated experiences.
+
+ Note:
+ Only experiences that have not been consumed (consumed == 0) will be returned.
+ For example, if you call this method multiple times with the same msg_ids, only
+ the first call will return the updated experiences; subsequent calls will return
+ an empty list.
+ """
+ with retry_session(self.session) as db:
+ # Lock and retrieve records that have not been consumed yet.
+ records = (
+ db.query(self.table_model_cls)
+ .filter(
+ self.table_model_cls.msg_id.in_(msg_ids),
+ self.table_model_cls.consumed == 0,
+ )
+ .with_for_update()
+ .all()
+ )
+
+ if not records:
+ return []
+
+ # Update records in memory
+ for record in records:
+ record.reward = reward
+ record.run_id = run_id
+ record.task_id = task_id
+ record.consumed += 1
+
+ # The session commit is handled by the `retry_session` context manager.
+ updated_experiences = [record.to_experience() for record in records]
+ return updated_experiences
diff --git a/trinity/explorer/proxy/service.py b/trinity/explorer/proxy/service.py
new file mode 100644
index 0000000000..868c733c45
--- /dev/null
+++ b/trinity/explorer/proxy/service.py
@@ -0,0 +1,199 @@
+import asyncio
+import time
+from collections import deque
+from typing import Dict, List, Tuple
+
+import torch
+
+from trinity.common.constants import RunningStatus
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.explorer.explorer import Explorer
+from trinity.explorer.proxy.recorder import HistoryRecorder
+from trinity.utils.log import get_logger
+
+
+class ExplorerService:
+ """Manages the lifecycle and operations of the Explorer API service."""
+
+ def __init__(self, explorer: Explorer, listen_address: str = "localhost", port: int = 8010):
+ self.logger = get_logger(__name__)
+ self.explorer = explorer
+ self.app = None
+ self.port = port
+ self.listen_address = listen_address
+ self.running = False
+ self.models: List[ModelWrapper] = [ModelWrapper(model) for model in explorer.models]
+ self.min_running_model_num = explorer.config.explorer.min_running_model_num
+ self.check_interval = explorer.config.explorer.service_status_check_interval
+ self.max_timeout = explorer.config.explorer.max_timeout
+ self.running_model_ids: deque[int] = deque() # indices of running models
+ self.model_version_map: Dict[int, int] = {} # model index -> model version
+ self.sync_task_map: Dict[asyncio.Future, int] = {} # sync task -> model index
+ self.latest_model_version = 0
+ self.session_level_experience_queue: Dict[int, deque[Experience]] = {}
+ self.commit_lock = asyncio.Lock()
+ self.ready_experiences = deque()
+ self.recorder = HistoryRecorder(
+ db_url=explorer.config.explorer.db_url
+ or f"sqlite:///{explorer.config.buffer.cache_dir}/proxy_history.db",
+ table_name="proxy_history",
+ )
+ self.total_experience_count = 0
+ self.ready_experience_count = 0
+
+ async def serve(self) -> None:
+ from trinity.explorer.proxy.app import run_app
+
+ if self.running:
+ self.logger.warning("Server is already running.")
+ return
+
+ self.running = True
+ await asyncio.gather(*[model.prepare() for model in self.models])
+
+ for i, _ in enumerate(self.models):
+ self.running_model_ids.append(i)
+
+ self.serve_task = asyncio.create_task(
+ run_app(service=self, listen_address=self.listen_address, port=self.port)
+ )
+ self.sync_model_weights_task = asyncio.create_task(self.model_weights_sync_loop())
+
+ async def model_weights_sync_loop(self) -> None:
+ self.logger.info("Starting model weights synchronization loop.")
+ while self.running:
+ for idx in list(self.running_model_ids):
+ self.model_version_map[idx] = await self.models[idx].model_version_async
+ if (
+ len(self.running_model_ids)
+ > self.explorer.config.explorer.min_running_model_num
+ and self.model_version_map[idx] < self.latest_model_version
+ ):
+ self.logger.info(f"Model {idx} scheduled for synchronization.")
+ self.models[idx].status = RunningStatus.REQUIRE_SYNC
+ self.running_model_ids.remove(idx)
+ asyncio.create_task(self._sync_model_weights(idx))
+ # wait half interval
+ await asyncio.sleep(self.check_interval / 2)
+ self.logger.info("Model weights synchronization loop stopped.")
+
+ def set_latest_model_version(self, version: int) -> None:
+ if version > self.latest_model_version:
+ self.latest_model_version = version
+ self.logger.info(f"Updated latest model version to {version}.")
+
+ async def _sync_model_weights(self, index: int) -> None:
+ """Synchronize model weights for the given model index."""
+ # wait until the model is free
+ start_time = time.time()
+ timeout_flag = True
+ current_load = -1
+ while time.time() - start_time < self.max_timeout:
+ current_load = await self.models[index].get_current_load()
+ if current_load == 0:
+ self.models[index].status = RunningStatus.WAITING_SYNC
+ self.logger.info(f"Model {index} begins synchronization.")
+ timeout_flag = False
+ break
+ else:
+ self.logger.info(
+ "Waiting for model %d to be free. Current load: %d", index, current_load
+ )
+ await asyncio.sleep(1)
+ if timeout_flag:
+ raise asyncio.TimeoutError(
+ f"Timeout waiting for model {index} to be free for synchronization. Current load: {current_load}"
+ )
+ latest_version = self.latest_model_version # capture the latest version
+ # perform synchronization
+ await self.models[index].sync_model_weights(latest_version)
+ self.logger.info(f"Model {index} synchronized to version {latest_version}.")
+ self.model_version_map[index] = await self.models[index].model_version_async
+ self.models[index].status = RunningStatus.RUNNING
+ self.running_model_ids.append(index)
+
+ async def allocate_model(self, increase_count: bool = True) -> Tuple[str, int]:
+ """Allocate a model for handling a request.
+
+ Returns:
+ A tuple of (model_api_address, model_version).
+ """
+ model_id = self.running_model_ids[0]
+ model = self.models[model_id]
+ if increase_count:
+ model.request_count += 1
+ self.running_model_ids.rotate(-1)
+ if model.api_address is None:
+ raise ValueError(
+ "Model does not have a valid API address, please set `enable_openai_api` to `True`."
+ )
+ return model.api_address, self.model_version_map[model_id]
+
+ def collect_metrics(self) -> Dict:
+ metrics = {}
+ for i, model in enumerate(self.models):
+ metrics[f"rollout/model_{i}/total_request_count"] = model.request_count
+ metrics[f"rollout/model_{i}/model_version"] = model.model_version
+ metrics["rollout/total_experience_count"] = self.total_experience_count
+ metrics["rollout/ready_experience_count"] = self.ready_experience_count
+ return metrics
+
+ async def record_experience(self, response, model_version: int) -> None:
+ experiences = []
+ for choice in response["choices"]:
+ exp = Experience(
+ tokens=torch.cat(
+ (
+ torch.tensor(response["prompt_token_ids"], dtype=torch.int32),
+ torch.tensor(choice["token_ids"], dtype=torch.int32),
+ )
+ ),
+ logprobs=(
+ torch.tensor(
+ [logprob["logprob"] for logprob in choice["logprobs"]["content"]],
+ dtype=torch.float32,
+ )
+ if "logprobs" in choice and choice["logprobs"] is not None
+ else torch.tensor([], dtype=torch.float32)
+ ),
+ prompt_length=len(response["prompt_token_ids"]),
+ )
+ exp.eid.suffix = response["id"]
+ exp.info["model_version"] = model_version
+ experiences.append(exp)
+
+ self.total_experience_count += len(experiences)
+ self.recorder.record_history(experiences)
+
+ async def submit_experiences(self) -> None:
+ async with self.commit_lock:
+ experiences = list(self.ready_experiences)
+ self.ready_experiences.clear()
+ metrics = await self.explorer.experience_pipeline.process.remote(experiences)
+ metrics.update(self.collect_metrics())
+ self.explorer.explore_step_num += 1
+ self.explorer.monitor.log(metrics, self.explorer.explore_step_num)
+
+ async def record_feedback(self, reward: float, msg_ids: List[str], task_id: str, run_id: int):
+ exps = self.recorder.update_reward(
+ reward=reward,
+ msg_ids=msg_ids,
+ task_id=task_id,
+ run_id=run_id,
+ )
+ self.ready_experience_count += len(exps)
+ self.ready_experiences.extend(exps)
+
+ async def shutdown(self):
+ if not self.running:
+ self.logger.warning("Server is not running.")
+ return
+ self.sync_model_weights_task.cancel()
+ self.serve_task.cancel()
+ try:
+ await self.serve_task
+ except asyncio.CancelledError:
+ pass
+ self.running = False
+ self.logger.info("API server shutdown.")
diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py
index e5adb86f1c..8157ad088d 100644
--- a/trinity/manager/synchronizer.py
+++ b/trinity/manager/synchronizer.py
@@ -274,6 +274,16 @@ async def wait_new_model_state_dict(self, current_version: int, no_wait: bool =
)
return self.model_version
+ async def get_latest_model_version(self) -> int:
+ """
+ Get the latest model version available in the synchronizer.
+
+ Returns:
+ The current model version.
+ """
+ async with self._ready_condition:
+ return self.model_version
+
async def ready_to_nccl_sync(
self, module: str, trainer_step: Optional[int] = None
) -> Union[int, None]: