Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/chatan/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,18 @@ async def _generate_column_value(
# Samplers are sync but fast
value = func(row)
elif callable(func):
# Check if it's an async callable
# Check if it's an async callable (function or __call__ method)
if asyncio.iscoroutinefunction(func):
value = await func(row)
elif hasattr(func, '__call__') and asyncio.iscoroutinefunction(func.__call__):
value = await func(row)
else:
value = func(row)
result = func(row)
# Handle case where callable returns a coroutine
if asyncio.iscoroutine(result):
value = await result
else:
value = result
else:
# Static value
value = func
Expand Down
3 changes: 2 additions & 1 deletion src/chatan/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ async def __call__(self, context: Dict[str, Any], **kwargs) -> str:
merged[key] = value(context) if callable(value) else value

prompt = self.prompt_template.format(**merged)
result = await self.generator.generate(prompt, **kwargs)
# Pass context for caching (e.g., verifiers integration)
result = await self.generator.generate(prompt, _context=context, **kwargs)
return result.strip() if isinstance(result, str) else result

async def stream(
Expand Down
7 changes: 7 additions & 0 deletions src/chatan/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Integrations with external libraries.

Available integrations:
- verifiers: Use `from chatan.integrations.verifiers import rollout_generator`
"""

__all__: list[str] = []
235 changes: 235 additions & 0 deletions src/chatan/integrations/verifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""Verifiers integration for using rollout environments as chatan generators."""

import asyncio
from contextlib import asynccontextmanager
from typing import Any, Dict, Literal

from openai import AsyncOpenAI
from tenacity import retry, stop_after_attempt, wait_random_exponential
from verifiers.envs.environment import Environment

from chatan.generator import BaseGenerator, GeneratorFunction


@asynccontextmanager
async def _null_semaphore():
"""A no-op async context manager that acts as an unlimited semaphore."""
yield


ExtractType = Literal["completion", "reward", "metrics", "trajectory", "full"]


class RolloutResult:
"""Lazy rollout result - runs once, extracts many."""

def __init__(
self,
env: Environment,
client: AsyncOpenAI,
model: str,
prompt_template: str,
answer_template: str | None = None,
**sampling_args,
):
self.env = env
self.client = client
self.model = model
self.prompt_template = prompt_template
self.answer_template = answer_template
self.sampling_args = sampling_args
self._cache: Dict[int, Dict[str, Any]] = {}
self._locks: Dict[int, asyncio.Lock] = {}
self._row_counter = 0

async def _get_result(self, context: Dict[str, Any]) -> Dict[str, Any]:
"""Get or compute the rollout result for this context."""
ctx_id = id(context)

if ctx_id in self._cache:
return self._cache[ctx_id]

# Get or create lock for this context
if ctx_id not in self._locks:
self._locks[ctx_id] = asyncio.Lock()

async with self._locks[ctx_id]:
# Double-check after acquiring lock
if ctx_id in self._cache:
return self._cache[ctx_id]

prompt = self.prompt_template.format(**context)
answer = self.answer_template.format(**context) if self.answer_template else None

input_data = {
"prompt": [{"role": "user", "content": prompt}],
"example_id": self._row_counter,
"task": self.env.env_id or "default",
}
if answer is not None:
input_data["answer"] = answer

self._row_counter += 1

result = await self._run_with_retry(input_data)
self._cache[ctx_id] = result
return result

@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(3))
async def _run_with_retry(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
"""Run rollout with retry on transient failures."""
return await self.env.run_rollout(
input_data,
self.client,
self.model,
gen_sampling_args=self.sampling_args,
gen_sem=_null_semaphore(),
)

def _extractor(self, extract: ExtractType) -> "RolloutExtractor":
"""Create an extractor for a specific field."""
return RolloutExtractor(self, extract)

@property
def completion(self) -> "RolloutExtractor":
return self._extractor("completion")

@property
def reward(self) -> "RolloutExtractor":
return self._extractor("reward")

@property
def metrics(self) -> "RolloutExtractor":
return self._extractor("metrics")

@property
def trajectory(self) -> "RolloutExtractor":
return self._extractor("trajectory")

@property
def full(self) -> "RolloutExtractor":
return self._extractor("full")


class RolloutExtractor(BaseGenerator):
"""Extracts a field from a shared RolloutResult."""

def __init__(self, rollout_result: RolloutResult, extract: ExtractType):
self._rollout_result = rollout_result
self._extract = extract

# Make it usable directly in dataset schema
prompt_template = ""

async def __call__(self, context: Dict[str, Any]) -> Any:
"""Called by dataset generation."""
return await self.generate("", _context=context)

async def generate(self, prompt: str, **kwargs) -> Any:
context = kwargs.get("_context", {})
result = await self._rollout_result._get_result(context)
return self._extract_field(result)

def _extract_field(self, result: Dict[str, Any]) -> Any:
if self._extract == "completion":
return self._extract_completion_text(result)
elif self._extract == "reward":
return result.get("reward", 0.0)
elif self._extract == "metrics":
return result.get("metrics", {})
elif self._extract == "trajectory":
return result.get("trajectory", [])
elif self._extract == "full":
return dict(result)
else:
return result.get(self._extract)

def _extract_completion_text(self, result: Dict[str, Any]) -> str:
completion = result.get("completion", [])
if isinstance(completion, list):
for msg in reversed(completion):
if msg.get("role") == "assistant":
content = msg.get("content", "")
return content if isinstance(content, str) else str(content)
return ""
return str(completion)


def rollout_generator(
env: Environment,
client: AsyncOpenAI,
model: str,
**sampling_args,
) -> "RolloutClient":
"""Create a rollout client.

Args:
env: Verifiers environment to use for rollouts.
client: AsyncOpenAI client.
model: Model name to use for generation.
**sampling_args: Additional sampling arguments passed to the model.

Returns:
RolloutClient that can be called to create rollout results.

Example:
>>> from chatan import dataset, sample
>>> from chatan.integrations.verifiers import rollout_generator
>>> from verifiers.utils.env_utils import load_environment
>>> from openai import AsyncOpenAI
>>>
>>> client = AsyncOpenAI(api_key="...")
>>> env = load_environment("gsm8k")
>>> rollout = rollout_generator(env, client, "gpt-4.1-mini")
>>>
>>> row = sample.row(env.get_eval_dataset())
>>> r = rollout("{question}", answer="{ground_truth}")
>>>
>>> ds = dataset({
... "question": row["prompt"],
... "ground_truth": row["answer"],
... "model_answer": r.completion,
... "trajectory": r.trajectory,
... "reward": r.reward,
... })
"""
return RolloutClient(env, client, model, **sampling_args)


class RolloutClient:
"""Factory for creating rollout results."""

def __init__(
self,
env: Environment,
client: AsyncOpenAI,
model: str,
**sampling_args,
):
self.env = env
self.client = client
self.model = model
self.sampling_args = sampling_args

def __call__(
self,
prompt_template: str,
answer: str | None = None,
) -> RolloutResult:
"""Create a rollout result that can be used to extract multiple fields.

Args:
prompt_template: Template string for the prompt.
answer: Template string for the ground truth answer.

Returns:
RolloutResult with .completion, .trajectory, .reward, .metrics, .full extractors.
"""
return RolloutResult(
env=self.env,
client=self.client,
model=self.model,
prompt_template=prompt_template,
answer_template=answer,
**self.sampling_args,
)
76 changes: 75 additions & 1 deletion src/chatan/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,62 @@ def __call__(self, context: Dict[str, Any] = None) -> Any:
return random.choice(self.values)


class RowSampler:
"""Sample aligned rows from a dataset - all columns from the same row."""

def __init__(self, dataset: Union[pd.DataFrame, HFDataset, Dict]):
if isinstance(dataset, pd.DataFrame):
self._data = dataset
self._len = len(dataset)
elif isinstance(dataset, HFDataset):
self._data = dataset
self._len = len(dataset)
elif isinstance(dataset, dict):
self._data = dataset
first_key = next(iter(dataset.keys()))
self._len = len(dataset[first_key])
else:
raise ValueError("Unsupported dataset type")

self._index_cache: Dict[int, int] = {}

def _get_index(self, context: Dict[str, Any]) -> int:
"""Get or create the sampled row index for this context."""
ctx_id = id(context) if context else 0
if ctx_id not in self._index_cache:
self._index_cache[ctx_id] = random.randint(0, self._len - 1)
return self._index_cache[ctx_id]

def __getitem__(self, column: str) -> "RowColumnSampler":
"""Get a sampler for a specific column."""
return RowColumnSampler(self, column)

def clear_cache(self) -> None:
"""Clear the index cache."""
self._index_cache.clear()


class RowColumnSampler(SampleFunction):
"""Sample a column value from an aligned row."""

def __init__(self, row_sampler: RowSampler, column: str):
self._row_sampler = row_sampler
self._column = column

def __call__(self, context: Dict[str, Any] = None) -> Any:
idx = self._row_sampler._get_index(context)
data = self._row_sampler._data

if isinstance(data, pd.DataFrame):
return data.iloc[idx][self._column]
elif isinstance(data, HFDataset):
return data[idx][self._column]
elif isinstance(data, dict):
return data[self._column][idx]
else:
raise ValueError("Unsupported dataset type")


# Factory functions for the sample namespace
class SampleNamespace:
"""Namespace for sampling functions."""
Expand Down Expand Up @@ -150,9 +206,27 @@ def from_dataset(
column: str,
default: Optional[SampleFunction] = None,
) -> DatasetSampler:
"""Sample from existing dataset."""
"""Sample from existing dataset (independent per column)."""
return DatasetSampler(dataset, column, default)

@staticmethod
def row(dataset: Union[pd.DataFrame, HFDataset, Dict]) -> RowSampler:
"""Sample aligned rows from a dataset.

Returns a RowSampler that ensures all column accesses within
the same row come from the same source row.

Example:
>>> eval_data = load_dataset("squad")
>>> row = sample.row(eval_data)
>>> ds = dataset({
... "question": row["question"],
... "context": row["context"],
... "answer": row["answer"],
... })
"""
return RowSampler(dataset)


# Export the sample namespace
sample = SampleNamespace()
Loading
Loading