Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
31c5693
add allocate_session and feedback API
pan-x-c Sep 22, 2025
ece144e
fix comments
pan-x-c Sep 22, 2025
7ed876c
add api test
pan-x-c Sep 23, 2025
bc3d72e
merge main
pan-x-c Oct 14, 2025
20857a7
add tests
pan-x-c Oct 14, 2025
5295688
process group for serve model
pan-x-c Oct 15, 2025
fa6604a
fix tests
pan-x-c Oct 15, 2025
1154605
Merge branch 'main' into feature/user_feedback
pan-x-c Oct 15, 2025
ef7278e
clean code
pan-x-c Oct 15, 2025
f401491
Merge branch 'main' into feature/user_feedback
pan-x-c Oct 21, 2025
dc45cc0
merge main
pan-x-c Dec 11, 2025
be94f13
fix serve tests
pan-x-c Dec 16, 2025
ce01983
fix tests
pan-x-c Dec 16, 2025
1692a23
Merge branch 'main' into feature/user_feedback
pan-x-c Dec 16, 2025
0552b8d
add recorder
pan-x-c Dec 17, 2025
fabd757
Merge branch 'main' into feature/user_feedback
pan-x-c Dec 18, 2025
27eee07
fix proxy
pan-x-c Dec 18, 2025
7fa496c
fix server
pan-x-c Dec 18, 2025
02ce2dc
fix serve trainer tset
pan-x-c Dec 18, 2025
5507e5f
add tests
pan-x-c Dec 18, 2025
7c04e27
merge main
pan-x-c Dec 18, 2025
ed69403
fix tests
pan-x-c Dec 18, 2025
4050d1d
fix model version
pan-x-c Dec 18, 2025
73466c4
fix comments
pan-x-c Dec 18, 2025
c1cee23
fix comments
pan-x-c Dec 18, 2025
5122366
fix tests
pan-x-c Dec 18, 2025
4094490
fix replay
pan-x-c Dec 19, 2025
eb63d9a
fix synchronizer
pan-x-c Dec 19, 2025
3489b04
fix tests
pan-x-c Dec 19, 2025
e8b0dd5
fix pre-commit
pan-x-c Dec 19, 2025
c96593f
fix comments
pan-x-c Dec 19, 2025
f813128
fix serve mode
pan-x-c Dec 19, 2025
8101740
fix buffer test
pan-x-c Dec 19, 2025
fe33799
fix pre-commit
pan-x-c Dec 19, 2025
3df76a8
fix megatron training
pan-x-c Dec 19, 2025
63fa3be
fix vllm prefix caching
pan-x-c Dec 19, 2025
e0973a3
fix benchmark
pan-x-c Dec 22, 2025
ed8edef
add client side timeout
pan-x-c Dec 22, 2025
290e36f
fix comments
pan-x-c Dec 22, 2025
c0b7b67
update default test setting
pan-x-c Dec 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmark/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def check_taskset_path(dataset_name: str, taskset_path: str) -> str:
raise AttributeError(f"{script_filename} is missing 'DEFAULT_DATA_PATH'")
taskset_path = module.DEFAULT_DATA_PATH
taskset_path = os.path.realpath(taskset_path)
if os.path.exists(taskset_path):
return taskset_path

# For frozenlake, check if train.parquet and test.parquet already exist
if dataset_name == "frozenlake":
Expand Down
7 changes: 4 additions & 3 deletions tests/buffer/experience_storage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tests.tools import RayUnittestBaseAysnc
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import ExperienceBufferConfig
from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig
from trinity.common.constants import StorageType
from trinity.common.experience import EID, Experience

Expand Down Expand Up @@ -95,6 +95,7 @@ async def test_sql_experience_buffer(self):
max_read_timeout=3,
path=f"sqlite:///{DB_PATH}",
batch_size=self.train_batch_size,
replay_buffer=ReplayBufferConfig(enable=True),
)
config = config.to_storage_config()
writer = SQLWriter(config)
Expand All @@ -118,7 +119,7 @@ async def test_sql_experience_buffer(self):
exps = reader.read()
self.assertEqual(len(exps), self.train_batch_size)
for exp in exps:
self.assertEqual(exp.eid.task, cnt)
self.assertEqual(exp.eid.task, str(cnt))
cnt -= 1

# experience buffer support experience reuse
Expand All @@ -127,7 +128,7 @@ async def test_sql_experience_buffer(self):
exps = reader.read()
self.assertEqual(len(exps), self.train_batch_size)
for exp in exps:
self.assertEqual(exp.eid.task, cnt)
self.assertEqual(exp.eid.task, str(cnt))
cnt -= 1
self.assertEqual(await writer.release(), 0)

Expand Down
29 changes: 26 additions & 3 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,31 @@

import ray
import torch
from parameterized import parameterized

from tests.tools import RayUnittestBaseAysnc
from trinity.buffer import get_buffer_reader
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import ExperienceBufferConfig, TasksetConfig
from trinity.common.config import (
ExperienceBufferConfig,
ReplayBufferConfig,
TasksetConfig,
)
from trinity.common.constants import StorageType
from trinity.common.experience import Experience

db_path = os.path.join(os.path.dirname(__file__), "test.db")


class TestSQLBuffer(RayUnittestBaseAysnc):
async def test_sql_exp_buffer_read_write(self) -> None:
@parameterized.expand(
[
(True,),
(False,),
]
)
async def test_sql_exp_buffer_read_write(self, enable_replay: bool) -> None:
total_num = 8
put_batch_size = 2
read_batch_size = 4
Expand All @@ -25,7 +36,10 @@ async def test_sql_exp_buffer_read_write(self) -> None:
path=f"sqlite:///{db_path}",
storage_type=StorageType.SQL.value,
batch_size=read_batch_size,
max_read_timeout=3,
)
if enable_replay:
config.replay_buffer = ReplayBufferConfig(enable=True)
sql_writer = SQLWriter(config.to_storage_config())
sql_reader = SQLReader(config.to_storage_config())
exps = [
Expand Down Expand Up @@ -53,13 +67,22 @@ async def test_sql_exp_buffer_read_write(self) -> None:
reward=float(i),
logprobs=torch.tensor([0.1]),
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
info={"model_version": i},
info={"model_version": i + put_batch_size},
)
for i in range(1, put_batch_size * 2 + 1)
]
)
exps = sql_reader.read(batch_size=put_batch_size * 2)
self.assertEqual(len(exps), put_batch_size * 2)
for exp in exps:
self.assertTrue(exp.info["model_version"] > put_batch_size)
if enable_replay:
# support replay, so we can read all again
exps = sql_reader.read(batch_size=(put_batch_size * 2 + total_num))
self.assertEqual(len(exps), (put_batch_size * 2 + total_num))
# if read more than available, will wait until timeout
with self.assertRaises(StopIteration):
exps = sql_reader.read(batch_size=(put_batch_size * 3 + total_num))
db_wrapper = ray.get_actor("sql-test_buffer")
self.assertIsNotNone(db_wrapper)
self.assertEqual(await sql_writer.release(), 0)
Expand Down
47 changes: 24 additions & 23 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
import os
import random
import shutil
import unittest
from datetime import datetime

import httpx
import openai
import ray

from tests.tools import (
Expand All @@ -27,6 +25,7 @@
from trinity.common.config import ExperienceBufferConfig, InferenceModelConfig
from trinity.common.constants import StorageType
from trinity.explorer.explorer import Explorer
from trinity.explorer.proxy.client import TrinityClient
from trinity.manager.state_manager import StateManager


Expand Down Expand Up @@ -158,8 +157,9 @@ def run_serve(config):
run_stage(config)


def run_agent(base_url, model_path: str):
client = openai.Client(base_url=base_url, api_key="testkey")
def run_agent(proxy_url, model_path: str):
proxy_client = TrinityClient(proxy_url=proxy_url)
openai_client = proxy_client.get_openai_client()
contents = [
"Hello, how are you?",
"What is the capital of China?",
Expand All @@ -172,10 +172,11 @@ def run_agent(base_url, model_path: str):
"What is the best way to learn programming?",
"Describe the process of photosynthesis.",
]
response = client.chat.completions.create(
response = openai_client.chat.completions.create(
model=model_path,
messages=[{"role": "user", "content": random.choice(contents)}],
)
proxy_client.feedback(reward=2.0, msg_ids=[response.id])
return response.choices[0].message.content


Expand All @@ -191,7 +192,7 @@ def setUp(self):
self.config.explorer.rollout_model.engine_num = 4
self.config.explorer.rollout_model.enable_openai_api = True
self.config.checkpoint_root_dir = get_checkpoint_path()
self.config.explorer.api_port = 8010
self.config.explorer.proxy_port = 8010
self.config.explorer.service_status_check_interval = 30
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="experience_buffer",
Expand All @@ -201,7 +202,6 @@ def setUp(self):
if multiprocessing.get_start_method(allow_none=True) != "spawn":
multiprocessing.set_start_method("spawn", force=True)

@unittest.skip("Require improvement for agent mode")
async def test_serve(self): # noqa: C901
serve_process = multiprocessing.Process(target=run_serve, args=(self.config,))
serve_process.start()
Expand Down Expand Up @@ -238,7 +238,7 @@ async def test_serve(self): # noqa: C901
apps = []
for i in range(task_num):
app_process = multiprocessing.Process(
target=run_agent, args=(server_url + "/v1", self.config.model.model_path)
target=run_agent, args=(server_url, self.config.model.model_path)
)
apps.append(app_process)
app_process.start()
Expand All @@ -248,22 +248,20 @@ async def test_serve(self): # noqa: C901
self.assertFalse(app.is_alive())

finish_step = None

proxy_client = TrinityClient(proxy_url=server_url)
for i in range(20):
async with httpx.AsyncClient() as client:
response = await client.get(f"{server_url}/metrics")
self.assertEqual(response.status_code, 200)
metrics = response.json()
metrics_keys = list(metrics.keys())
self.assertIn("explore_step_num", metrics_keys)
self.assertIn("rollout/total_experience_count", metrics_keys)
self.assertIn("rollout/model_0/total_request_count", metrics_keys)
self.assertIn("rollout/model_3/model_version", metrics_keys)
if not finish_step and metrics["rollout/total_experience_count"] == task_num:
finish_step = metrics["explore_step_num"]
if finish_step and metrics["explore_step_num"] >= finish_step + 1:
# wait for one more step to ensure all data are written to buffer
break
metrics = await proxy_client.get_metrics_async()
metrics_keys = list(metrics.keys())
self.assertIn("explore_step_num", metrics_keys)
self.assertIn("rollout/total_experience_count", metrics_keys)
self.assertIn("rollout/model_0/total_request_count", metrics_keys)
self.assertIn("rollout/model_3/model_version", metrics_keys)
if not finish_step and metrics["rollout/total_experience_count"] == task_num:
finish_step = metrics["explore_step_num"]
await proxy_client.commit_async()
if finish_step and metrics["explore_step_num"] >= finish_step + 1:
# wait for one more step to ensure all data are written to buffer
break
await asyncio.sleep(3)

serve_process.terminate()
Expand All @@ -277,6 +275,9 @@ async def test_serve(self): # noqa: C901
exps = await buffer_reader.read_async(batch_size=10)
for exp in exps:
self.assertTrue(len(exp.tokens) > 0)
self.assertTrue(len(exp.logprobs) > 0)
self.assertTrue(exp.prompt_length > 0)
self.assertTrue(exp.reward == 2.0)
self.assertEqual(len(exps), task_num)

def tearDown(self):
Expand Down
86 changes: 86 additions & 0 deletions tests/explorer/proxy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
import unittest
import uuid
from typing import List

import torch

from trinity.common.experience import EID, Experience
from trinity.explorer.proxy.recorder import HistoryRecorder


def get_dummy_experience(num: int) -> List[Experience]:
return [
Experience(
eid=EID(suffix=uuid.uuid4().hex[:6]),
tokens=torch.zeros(5),
prompt_length=2,
info={
"model_version": 0,
},
)
for _ in range(num)
]


db_path = os.path.join(os.path.dirname(__file__), "test_recorder.db")


class RecorderTest(unittest.TestCase):
def setUp(self) -> None:
if os.path.exists(db_path):
os.remove(db_path)

def tearDown(self) -> None:
if os.path.exists(db_path):
os.remove(db_path)

def test_recorder(self):
recorder = HistoryRecorder(
# in memory sqlite for testing
db_url="sqlite:///" + db_path,
table_name="experience",
)
self.assertIsInstance(recorder, HistoryRecorder)
# test record history

experiences_1 = get_dummy_experience(3)
recorder.record_history(experiences_1)
# test update reward
msg_ids_1 = [exp.eid.suffix for exp in experiences_1]
experiences_2 = get_dummy_experience(2)
recorder.record_history(experiences_2)
updated_experiences = recorder.update_reward(
reward=1.0, msg_ids=msg_ids_1, run_id=1, task_id="test_task"
)
self.assertEqual(len(updated_experiences), 3)
for exp in updated_experiences:
self.assertEqual(exp.reward, 1.0)
self.assertEqual(exp.eid.run, 1)
self.assertEqual(str(exp.eid.task), "test_task")
# test update reward with non-existing msg_ids
updated_experiences_empty = recorder.update_reward(
reward=2.0, msg_ids=["non_existing_id"], run_id=1, task_id="test_task"
)
self.assertEqual(len(updated_experiences_empty), 0)
# test record history with empty experiences
recorder.record_history([]) # should not raise any exception
# test update reward multiple times
updated_experiences_2 = recorder.update_reward(
reward=3.0,
msg_ids=[exp.eid.suffix for exp in experiences_2],
run_id=2,
task_id="test_task_2",
)
self.assertEqual(len(updated_experiences_2), 2)
for exp in updated_experiences_2:
self.assertEqual(exp.reward, 3.0)
self.assertEqual(exp.eid.run, 2)
self.assertEqual(str(exp.eid.task), "test_task_2")
updated_experiences_3 = recorder.update_reward(
reward=4.0,
msg_ids=[exp.eid.suffix for exp in experiences_2],
run_id=3,
task_id="test_task_3",
)
self.assertEqual(len(updated_experiences_3), 0) # already consumed
2 changes: 1 addition & 1 deletion tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ async def test_adapter(self):
try:
from agentscope.model import TrinityChatModel
except ImportError:
self.skipTest("agentscope >= 0.1.6 is not installed")
self.skipTest("agentscope >= 1.0.9 is not installed")

async def as_workflow_func(task, model) -> float:
self.assertIsInstance(task, dict)
Expand Down
24 changes: 8 additions & 16 deletions tests/manager/synchronizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,32 +80,23 @@ async def new_finish_explore_step(self, step: int, model_version: int) -> None:

def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
try:
trainer_monkey_patch(config, max_steps, intervals)
train(config)
finally:
ray.shutdown(_exiting_interpreter=True)
trainer_monkey_patch(config, max_steps, intervals)
train(config)


def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
try:
explorer_monkey_patch(config, max_steps, intervals)
explore(config)
finally:
ray.shutdown(_exiting_interpreter=True)
explorer_monkey_patch(config, max_steps, intervals)
explore(config)


def run_both(
config: Config, max_steps: int, trainer_intervals: List[int], explorer_intervals: List[int]
) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
try:
trainer_monkey_patch(config, max_steps, trainer_intervals)
explorer_monkey_patch(config, max_steps, explorer_intervals)
both(config)
finally:
ray.shutdown(_exiting_interpreter=True)
trainer_monkey_patch(config, max_steps, trainer_intervals)
explorer_monkey_patch(config, max_steps, explorer_intervals)
both(config)


class BaseTestSynchronizer(unittest.TestCase):
Expand All @@ -115,6 +106,7 @@ def setUp(self):

def tearDown(self):
checkpoint_path = get_checkpoint_path()
ray.shutdown(_exiting_interpreter=True)
shutil.rmtree(os.path.join(checkpoint_path, "unittest"))


Expand Down
2 changes: 0 additions & 2 deletions tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
trainer:
Expand Down
Loading