Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/chatan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@

from .dataset import dataset
from .evaluate import eval, evaluate
from .generator import generator
from .generator import generator, async_generator
from .sampler import sample
from .viewer import generate_with_viewer

__all__ = ["dataset", "generator", "sample", "generate_with_viewer", "evaluate", "eval"]
__all__ = [
"dataset",
"generator",
"async_generator",
"sample",
"generate_with_viewer",
"evaluate",
"eval",
]
83 changes: 78 additions & 5 deletions src/chatan/dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Dataset creation and manipulation."""

from typing import Any, Callable, Dict, List, Optional, Union
import asyncio
import inspect
from typing import Any, Dict, List, Optional, Union

import pandas as pd
from datasets import Dataset as HFDataset
from tqdm import tqdm

from .evaluate import DatasetEvaluator, EvaluationFunction
from .generator import GeneratorFunction
from .generator import GeneratorFunction, AsyncGeneratorFunction
from .sampler import SampleFunction


Expand Down Expand Up @@ -97,7 +99,7 @@ def _build_dependency_graph(self) -> Dict[str, List[str]]:

for column, func in self.schema.items():
deps = []
if isinstance(func, GeneratorFunction):
if isinstance(func, (GeneratorFunction, AsyncGeneratorFunction)):
# Extract column references from prompt template
import re

Expand Down Expand Up @@ -136,14 +138,85 @@ def _generate_value(self, column: str, context: Dict[str, Any]) -> Any:
"""Generate a single value for a column."""
func = self.schema[column]

if isinstance(func, AsyncGeneratorFunction):
return self._resolve_sync(func(context))

if isinstance(func, (GeneratorFunction, SampleFunction)):
return func(context)
return self._resolve_sync(func(context))
elif callable(func):
return func(context)
return self._resolve_sync(func(context))
else:
# Static value
return func

async def _generate_value_async(self, column: str, context: Dict[str, Any]) -> Any:
"""Generate a single value for a column within an async context."""
func = self.schema[column]

if isinstance(func, AsyncGeneratorFunction):
return await func(context)

if isinstance(func, (GeneratorFunction, SampleFunction)):
result = func(context)
elif callable(func):
result = func(context)
else:
return func

if inspect.isawaitable(result):
return await result
return result

@staticmethod
def _resolve_sync(value: Any) -> Any:
"""Resolve awaitables when running in synchronous context."""

if inspect.isawaitable(value):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(value)

raise RuntimeError(
"Encountered awaitable while generating dataset inside an active "
"event loop. Use `await Dataset.generate_async(...)` instead."
)

return value

async def generate_async(
self,
n: Optional[int] = None,
progress: bool = True,
) -> pd.DataFrame:
"""Asynchronously generate the dataset.

This method mirrors :meth:`generate` but awaits async generators and
callables directly, making it safe to invoke from within an existing
event loop.
"""

num_samples = n or self.n
show_progress = progress

dependencies = self._build_dependency_graph()
execution_order = self._topological_sort(dependencies)

data = []
iterator = range(num_samples)
if show_progress:
iterator = tqdm(iterator, desc="Generating", leave=False)

for _ in iterator:
row = {}
for column in execution_order:
value = await self._generate_value_async(column, row)
row[column] = value
data.append(row)

self._data = pd.DataFrame(data)
return self._data

def to_pandas(self) -> pd.DataFrame:
"""Convert to pandas DataFrame."""
if self._data is None:
Expand Down
194 changes: 192 additions & 2 deletions src/chatan/generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""LLM generators with CPU fallback and aggressive memory management."""
"""LLM generators with CPU fallback, async support, and aggressive memory management."""

import asyncio
import gc
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Iterable, Optional

import anthropic
import openai
Expand All @@ -26,6 +27,15 @@ def generate(self, prompt: str, **kwargs) -> str:
pass


class AsyncBaseGenerator(ABC):
"""Base class for async LLM generators."""

@abstractmethod
async def generate(self, prompt: str, **kwargs) -> str:
"""Asynchronously generate content from a prompt."""
pass


class OpenAIGenerator(BaseGenerator):
"""OpenAI GPT generator."""

Expand All @@ -46,6 +56,33 @@ def generate(self, prompt: str, **kwargs) -> str:
return response.choices[0].message.content.strip()


class AsyncOpenAIGenerator(AsyncBaseGenerator):
"""Async OpenAI GPT generator."""

def __init__(self, api_key: str, model: str = "gpt-3.5-turbo", **kwargs):
async_client_cls = getattr(openai, "AsyncOpenAI", None)
if async_client_cls is None:
raise ImportError(
"Async OpenAI client is not available. Upgrade the `openai` package "
"to a version that provides `AsyncOpenAI`."
)

self.client = async_client_cls(api_key=api_key)
self.model = model
self.default_kwargs = kwargs

async def generate(self, prompt: str, **kwargs) -> str:
"""Generate content using OpenAI API asynchronously."""
merged_kwargs = {**self.default_kwargs, **kwargs}

response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**merged_kwargs,
)
return response.choices[0].message.content.strip()


class AnthropicGenerator(BaseGenerator):
"""Anthropic Claude generator."""

Expand All @@ -67,6 +104,39 @@ def generate(self, prompt: str, **kwargs) -> str:
return response.content[0].text.strip()


class AsyncAnthropicGenerator(AsyncBaseGenerator):
"""Async Anthropic Claude generator."""

def __init__(
self,
api_key: str,
model: str = "claude-3-sonnet-20240229",
**kwargs,
):
async_client_cls = getattr(anthropic, "AsyncAnthropic", None)
if async_client_cls is None:
raise ImportError(
"Async Anthropic client is not available. Upgrade the `anthropic` package "
"to a version that provides `AsyncAnthropic`."
)

self.client = async_client_cls(api_key=api_key)
self.model = model
self.default_kwargs = kwargs

async def generate(self, prompt: str, **kwargs) -> str:
"""Generate content using Anthropic API asynchronously."""
merged_kwargs = {**self.default_kwargs, **kwargs}

response = await self.client.messages.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=merged_kwargs.pop("max_tokens", 1000),
**merged_kwargs,
)
return response.content[0].text.strip()


class TransformersGenerator(BaseGenerator):
"""Local HuggingFace/transformers generator with aggressive memory management."""

Expand Down Expand Up @@ -202,6 +272,89 @@ def __call__(self, context: Dict[str, Any]) -> str:
return result.strip() if isinstance(result, str) else result


class AsyncGeneratorFunction:
"""Callable async generator function with high concurrency support."""

def __init__(
self,
generator: AsyncBaseGenerator,
prompt_template: str,
variables: Optional[Dict[str, Any]] = None,
):
self.generator = generator
self.prompt_template = prompt_template
self.variables = variables or {}

async def __call__(self, context: Dict[str, Any], **kwargs) -> str:
"""Generate content with context substitution asynchronously."""
merged = dict(context)
for key, value in self.variables.items():
merged[key] = value(context) if callable(value) else value

prompt = self.prompt_template.format(**merged)
result = await self.generator.generate(prompt, **kwargs)
return result.strip() if isinstance(result, str) else result

async def stream(
self,
contexts: Iterable[Dict[str, Any]],
*,
concurrency: int = 5,
return_exceptions: bool = False,
**kwargs,
):
"""Asynchronously yield results for many contexts with bounded concurrency."""

if concurrency < 1:
raise ValueError("concurrency must be at least 1")

contexts_list = list(contexts)
if not contexts_list:
return

semaphore = asyncio.Semaphore(concurrency)

async def worker(index: int, ctx: Dict[str, Any]):
async with semaphore:
try:
result = await self(ctx, **kwargs)
return index, result, None
except Exception as exc: # pragma: no cover - exercised via return_exceptions
return index, None, exc

tasks = [
asyncio.create_task(worker(index, ctx))
for index, ctx in enumerate(contexts_list)
]

next_index = 0
buffer: Dict[int, Any] = {}

try:
for coro in asyncio.as_completed(tasks):
index, value, error = await coro

if error is not None and not return_exceptions:
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise error

buffer[index] = error if error is not None else value

while next_index in buffer:
item = buffer.pop(next_index)
next_index += 1
yield item

finally:
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)


class GeneratorClient:
"""Main interface for creating generators."""

Expand Down Expand Up @@ -251,6 +404,34 @@ def __call__(self, prompt_template: str, **variables) -> GeneratorFunction:
return GeneratorFunction(self._generator, prompt_template, variables)


class AsyncGeneratorClient:
"""Async interface for creating generators with concurrent execution."""

def __init__(self, provider: str, api_key: Optional[str] = None, **kwargs):
provider_lower = provider.lower()
try:
if provider_lower == "openai":
if api_key is None:
raise ValueError("API key is required for OpenAI")
self._generator = AsyncOpenAIGenerator(api_key, **kwargs)
elif provider_lower == "anthropic":
if api_key is None:
raise ValueError("API key is required for Anthropic")
self._generator = AsyncAnthropicGenerator(api_key, **kwargs)
else:
raise ValueError(f"Unsupported provider for async generator: {provider}")

except Exception as e:
raise ValueError(
f"Failed to initialize async generator for provider '{provider}'. "
f"Check your configuration and try again. Original error: {str(e)}"
) from e

def __call__(self, prompt_template: str, **variables) -> AsyncGeneratorFunction:
"""Create an async generator function."""
return AsyncGeneratorFunction(self._generator, prompt_template, variables)


# Factory function
def generator(
provider: str = "openai", api_key: Optional[str] = None, **kwargs
Expand All @@ -259,3 +440,12 @@ def generator(
if provider.lower() in {"openai", "anthropic"} and api_key is None:
raise ValueError("API key is required")
return GeneratorClient(provider, api_key, **kwargs)


def async_generator(
provider: str = "openai", api_key: Optional[str] = None, **kwargs
) -> AsyncGeneratorClient:
"""Create an async generator client."""
if provider.lower() in {"openai", "anthropic"} and api_key is None:
raise ValueError("API key is required")
return AsyncGeneratorClient(provider, api_key, **kwargs)
Loading