From 6a92dcb4bf068f4baee8a572d59a0d7b2f6fdace Mon Sep 17 00:00:00 2001 From: Christian R <117322020+cdreetz@users.noreply.github.com> Date: Tue, 21 Oct 2025 17:26:08 -0700 Subject: [PATCH] Handle async generator functions in dataset --- src/chatan/__init__.py | 12 +- src/chatan/dataset.py | 83 +++++++++++- src/chatan/generator.py | 194 +++++++++++++++++++++++++++- tests/test_dataset_comprehensive.py | 47 ++++++- tests/test_generator.py | 181 +++++++++++++++++++++++++- 5 files changed, 504 insertions(+), 13 deletions(-) diff --git a/src/chatan/__init__.py b/src/chatan/__init__.py index b5c1fa7..a9fad00 100644 --- a/src/chatan/__init__.py +++ b/src/chatan/__init__.py @@ -4,8 +4,16 @@ from .dataset import dataset from .evaluate import eval, evaluate -from .generator import generator +from .generator import generator, async_generator from .sampler import sample from .viewer import generate_with_viewer -__all__ = ["dataset", "generator", "sample", "generate_with_viewer", "evaluate", "eval"] +__all__ = [ + "dataset", + "generator", + "async_generator", + "sample", + "generate_with_viewer", + "evaluate", + "eval", +] diff --git a/src/chatan/dataset.py b/src/chatan/dataset.py index 297a8a8..672419a 100644 --- a/src/chatan/dataset.py +++ b/src/chatan/dataset.py @@ -1,13 +1,15 @@ """Dataset creation and manipulation.""" -from typing import Any, Callable, Dict, List, Optional, Union +import asyncio +import inspect +from typing import Any, Dict, List, Optional, Union import pandas as pd from datasets import Dataset as HFDataset from tqdm import tqdm from .evaluate import DatasetEvaluator, EvaluationFunction -from .generator import GeneratorFunction +from .generator import GeneratorFunction, AsyncGeneratorFunction from .sampler import SampleFunction @@ -97,7 +99,7 @@ def _build_dependency_graph(self) -> Dict[str, List[str]]: for column, func in self.schema.items(): deps = [] - if isinstance(func, GeneratorFunction): + if isinstance(func, (GeneratorFunction, AsyncGeneratorFunction)): # Extract column references from prompt template import re @@ -136,14 +138,85 @@ def _generate_value(self, column: str, context: Dict[str, Any]) -> Any: """Generate a single value for a column.""" func = self.schema[column] + if isinstance(func, AsyncGeneratorFunction): + return self._resolve_sync(func(context)) + if isinstance(func, (GeneratorFunction, SampleFunction)): - return func(context) + return self._resolve_sync(func(context)) elif callable(func): - return func(context) + return self._resolve_sync(func(context)) else: # Static value return func + async def _generate_value_async(self, column: str, context: Dict[str, Any]) -> Any: + """Generate a single value for a column within an async context.""" + func = self.schema[column] + + if isinstance(func, AsyncGeneratorFunction): + return await func(context) + + if isinstance(func, (GeneratorFunction, SampleFunction)): + result = func(context) + elif callable(func): + result = func(context) + else: + return func + + if inspect.isawaitable(result): + return await result + return result + + @staticmethod + def _resolve_sync(value: Any) -> Any: + """Resolve awaitables when running in synchronous context.""" + + if inspect.isawaitable(value): + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(value) + + raise RuntimeError( + "Encountered awaitable while generating dataset inside an active " + "event loop. Use `await Dataset.generate_async(...)` instead." + ) + + return value + + async def generate_async( + self, + n: Optional[int] = None, + progress: bool = True, + ) -> pd.DataFrame: + """Asynchronously generate the dataset. + + This method mirrors :meth:`generate` but awaits async generators and + callables directly, making it safe to invoke from within an existing + event loop. + """ + + num_samples = n or self.n + show_progress = progress + + dependencies = self._build_dependency_graph() + execution_order = self._topological_sort(dependencies) + + data = [] + iterator = range(num_samples) + if show_progress: + iterator = tqdm(iterator, desc="Generating", leave=False) + + for _ in iterator: + row = {} + for column in execution_order: + value = await self._generate_value_async(column, row) + row[column] = value + data.append(row) + + self._data = pd.DataFrame(data) + return self._data + def to_pandas(self) -> pd.DataFrame: """Convert to pandas DataFrame.""" if self._data is None: diff --git a/src/chatan/generator.py b/src/chatan/generator.py index d8bdbee..802bd46 100644 --- a/src/chatan/generator.py +++ b/src/chatan/generator.py @@ -1,9 +1,10 @@ -"""LLM generators with CPU fallback and aggressive memory management.""" +"""LLM generators with CPU fallback, async support, and aggressive memory management.""" +import asyncio import gc import os from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Iterable, Optional import anthropic import openai @@ -26,6 +27,15 @@ def generate(self, prompt: str, **kwargs) -> str: pass +class AsyncBaseGenerator(ABC): + """Base class for async LLM generators.""" + + @abstractmethod + async def generate(self, prompt: str, **kwargs) -> str: + """Asynchronously generate content from a prompt.""" + pass + + class OpenAIGenerator(BaseGenerator): """OpenAI GPT generator.""" @@ -46,6 +56,33 @@ def generate(self, prompt: str, **kwargs) -> str: return response.choices[0].message.content.strip() +class AsyncOpenAIGenerator(AsyncBaseGenerator): + """Async OpenAI GPT generator.""" + + def __init__(self, api_key: str, model: str = "gpt-3.5-turbo", **kwargs): + async_client_cls = getattr(openai, "AsyncOpenAI", None) + if async_client_cls is None: + raise ImportError( + "Async OpenAI client is not available. Upgrade the `openai` package " + "to a version that provides `AsyncOpenAI`." + ) + + self.client = async_client_cls(api_key=api_key) + self.model = model + self.default_kwargs = kwargs + + async def generate(self, prompt: str, **kwargs) -> str: + """Generate content using OpenAI API asynchronously.""" + merged_kwargs = {**self.default_kwargs, **kwargs} + + response = await self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + **merged_kwargs, + ) + return response.choices[0].message.content.strip() + + class AnthropicGenerator(BaseGenerator): """Anthropic Claude generator.""" @@ -67,6 +104,39 @@ def generate(self, prompt: str, **kwargs) -> str: return response.content[0].text.strip() +class AsyncAnthropicGenerator(AsyncBaseGenerator): + """Async Anthropic Claude generator.""" + + def __init__( + self, + api_key: str, + model: str = "claude-3-sonnet-20240229", + **kwargs, + ): + async_client_cls = getattr(anthropic, "AsyncAnthropic", None) + if async_client_cls is None: + raise ImportError( + "Async Anthropic client is not available. Upgrade the `anthropic` package " + "to a version that provides `AsyncAnthropic`." + ) + + self.client = async_client_cls(api_key=api_key) + self.model = model + self.default_kwargs = kwargs + + async def generate(self, prompt: str, **kwargs) -> str: + """Generate content using Anthropic API asynchronously.""" + merged_kwargs = {**self.default_kwargs, **kwargs} + + response = await self.client.messages.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + max_tokens=merged_kwargs.pop("max_tokens", 1000), + **merged_kwargs, + ) + return response.content[0].text.strip() + + class TransformersGenerator(BaseGenerator): """Local HuggingFace/transformers generator with aggressive memory management.""" @@ -202,6 +272,89 @@ def __call__(self, context: Dict[str, Any]) -> str: return result.strip() if isinstance(result, str) else result +class AsyncGeneratorFunction: + """Callable async generator function with high concurrency support.""" + + def __init__( + self, + generator: AsyncBaseGenerator, + prompt_template: str, + variables: Optional[Dict[str, Any]] = None, + ): + self.generator = generator + self.prompt_template = prompt_template + self.variables = variables or {} + + async def __call__(self, context: Dict[str, Any], **kwargs) -> str: + """Generate content with context substitution asynchronously.""" + merged = dict(context) + for key, value in self.variables.items(): + merged[key] = value(context) if callable(value) else value + + prompt = self.prompt_template.format(**merged) + result = await self.generator.generate(prompt, **kwargs) + return result.strip() if isinstance(result, str) else result + + async def stream( + self, + contexts: Iterable[Dict[str, Any]], + *, + concurrency: int = 5, + return_exceptions: bool = False, + **kwargs, + ): + """Asynchronously yield results for many contexts with bounded concurrency.""" + + if concurrency < 1: + raise ValueError("concurrency must be at least 1") + + contexts_list = list(contexts) + if not contexts_list: + return + + semaphore = asyncio.Semaphore(concurrency) + + async def worker(index: int, ctx: Dict[str, Any]): + async with semaphore: + try: + result = await self(ctx, **kwargs) + return index, result, None + except Exception as exc: # pragma: no cover - exercised via return_exceptions + return index, None, exc + + tasks = [ + asyncio.create_task(worker(index, ctx)) + for index, ctx in enumerate(contexts_list) + ] + + next_index = 0 + buffer: Dict[int, Any] = {} + + try: + for coro in asyncio.as_completed(tasks): + index, value, error = await coro + + if error is not None and not return_exceptions: + for task in tasks: + if not task.done(): + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise error + + buffer[index] = error if error is not None else value + + while next_index in buffer: + item = buffer.pop(next_index) + next_index += 1 + yield item + + finally: + for task in tasks: + if not task.done(): + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + class GeneratorClient: """Main interface for creating generators.""" @@ -251,6 +404,34 @@ def __call__(self, prompt_template: str, **variables) -> GeneratorFunction: return GeneratorFunction(self._generator, prompt_template, variables) +class AsyncGeneratorClient: + """Async interface for creating generators with concurrent execution.""" + + def __init__(self, provider: str, api_key: Optional[str] = None, **kwargs): + provider_lower = provider.lower() + try: + if provider_lower == "openai": + if api_key is None: + raise ValueError("API key is required for OpenAI") + self._generator = AsyncOpenAIGenerator(api_key, **kwargs) + elif provider_lower == "anthropic": + if api_key is None: + raise ValueError("API key is required for Anthropic") + self._generator = AsyncAnthropicGenerator(api_key, **kwargs) + else: + raise ValueError(f"Unsupported provider for async generator: {provider}") + + except Exception as e: + raise ValueError( + f"Failed to initialize async generator for provider '{provider}'. " + f"Check your configuration and try again. Original error: {str(e)}" + ) from e + + def __call__(self, prompt_template: str, **variables) -> AsyncGeneratorFunction: + """Create an async generator function.""" + return AsyncGeneratorFunction(self._generator, prompt_template, variables) + + # Factory function def generator( provider: str = "openai", api_key: Optional[str] = None, **kwargs @@ -259,3 +440,12 @@ def generator( if provider.lower() in {"openai", "anthropic"} and api_key is None: raise ValueError("API key is required") return GeneratorClient(provider, api_key, **kwargs) + + +def async_generator( + provider: str = "openai", api_key: Optional[str] = None, **kwargs +) -> AsyncGeneratorClient: + """Create an async generator client.""" + if provider.lower() in {"openai", "anthropic"} and api_key is None: + raise ValueError("API key is required") + return AsyncGeneratorClient(provider, api_key, **kwargs) diff --git a/tests/test_dataset_comprehensive.py b/tests/test_dataset_comprehensive.py index 49b1e66..8b4d193 100644 --- a/tests/test_dataset_comprehensive.py +++ b/tests/test_dataset_comprehensive.py @@ -1,5 +1,6 @@ """Comprehensive tests for dataset module.""" +import asyncio import pytest import pandas as pd import tempfile @@ -8,7 +9,7 @@ from datasets import Dataset as HFDataset from chatan.dataset import Dataset, dataset -from chatan.generator import GeneratorFunction +from chatan.generator import GeneratorFunction, AsyncGeneratorFunction, AsyncBaseGenerator from chatan.sampler import ChoiceSampler, UUIDSampler @@ -194,6 +195,50 @@ def test_generator_function_with_context(self): assert all(df["content"].str.startswith("Generated: Create content for")) assert mock_generator.generate.call_count == 2 + +class DummyAsyncGenerator(AsyncBaseGenerator): + async def generate(self, prompt: str, **kwargs) -> str: + await asyncio.sleep(0) + return prompt.upper() + + +class TestAsyncIntegration: + """Tests for async generator integration within Dataset.""" + + def test_async_generator_function_sync_usage(self): + """Dataset.generate should resolve async generator results synchronously.""" + + async_func = AsyncGeneratorFunction(DummyAsyncGenerator(), "hello {name}") + schema = {"name": lambda _: "world", "greeting": async_func} + ds = Dataset(schema, n=1) + + df = ds.generate(progress=False) + + assert df.loc[0, "greeting"] == "HELLO WORLD" + + @pytest.mark.asyncio + async def test_generate_async(self): + """Dataset.generate_async should await async generator functions.""" + + async_func = AsyncGeneratorFunction(DummyAsyncGenerator(), "hello {name}") + schema = {"name": lambda _: "world", "greeting": async_func} + ds = Dataset(schema, n=2) + + df = await ds.generate_async(progress=False) + + assert list(df["greeting"]) == ["HELLO WORLD", "HELLO WORLD"] + + @pytest.mark.asyncio + async def test_generate_inside_event_loop_requires_async(self): + """Calling generate within a running loop should raise a helpful error.""" + + async_func = AsyncGeneratorFunction(DummyAsyncGenerator(), "hello {name}") + schema = {"name": lambda _: "world", "greeting": async_func} + ds = Dataset(schema, n=1) + + with pytest.raises(RuntimeError, match="generate_async"): + ds.generate(progress=False) + def test_lambda_function_generation(self): """Test generation with lambda functions.""" schema = { diff --git a/tests/test_generator.py b/tests/test_generator.py index 0727338..d83cf9d 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,14 +1,20 @@ """Comprehensive tests for generator module.""" +import asyncio import pytest -import sys from unittest.mock import Mock, patch, MagicMock from chatan.generator import ( OpenAIGenerator, AnthropicGenerator, + AsyncOpenAIGenerator, + AsyncAnthropicGenerator, + AsyncBaseGenerator, GeneratorFunction, + AsyncGeneratorFunction, GeneratorClient, - generator + AsyncGeneratorClient, + generator, + async_generator, ) # Conditional imports for torch-dependent tests @@ -97,6 +103,33 @@ def test_kwargs_override(self, mock_openai): assert call_args[1]["temperature"] == 0.9 +@pytest.mark.asyncio +class TestAsyncOpenAIGenerator: + """Test async OpenAI generator implementation.""" + + @patch('openai.AsyncOpenAI') + async def test_async_generate_basic(self, mock_async_openai): + """Test asynchronous content generation.""" + + mock_client = MagicMock() + mock_response = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = "Async content" + mock_response.choices = [mock_choice] + mock_client.chat.completions.create.return_value = asyncio.Future() + mock_client.chat.completions.create.return_value.set_result(mock_response) + mock_async_openai.return_value = mock_client + + gen = AsyncOpenAIGenerator("test-key") + result = await gen.generate("Prompt") + + assert result == "Async content" + mock_client.chat.completions.create.assert_called_once_with( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Prompt"}] + ) + + class TestAnthropicGenerator: """Test Anthropic generator implementation.""" @@ -147,6 +180,35 @@ def test_max_tokens_extraction(self, mock_anthropic): assert call_args[1]["temperature"] == 0.7 +@pytest.mark.asyncio +class TestAsyncAnthropicGenerator: + """Test async Anthropic generator implementation.""" + + @patch('anthropic.AsyncAnthropic') + async def test_async_generate_basic(self, mock_async_anthropic): + """Test asynchronous content generation.""" + + mock_client = MagicMock() + mock_response = MagicMock() + mock_content = MagicMock() + mock_content.text = "Async Claude" + mock_response.content = [mock_content] + future = asyncio.Future() + future.set_result(mock_response) + mock_client.messages.create.return_value = future + mock_async_anthropic.return_value = mock_client + + gen = AsyncAnthropicGenerator("test-key") + result = await gen.generate("Prompt") + + assert result == "Async Claude" + mock_client.messages.create.assert_called_once_with( + model="claude-3-sonnet-20240229", + messages=[{"role": "user", "content": "Prompt"}], + max_tokens=1000 + ) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available") class TestTransformersGenerator: """Test TransformersGenerator functionality (only when torch is available).""" @@ -202,11 +264,79 @@ def test_extra_context_variables(self): func = GeneratorFunction(mock_generator, "Write about {topic}") result = func({"topic": "AI", "extra": "ignored"}) - + assert result == "Generated" mock_generator.generate.assert_called_once_with("Write about AI") +@pytest.mark.asyncio +class TestAsyncGeneratorFunction: + """Test AsyncGeneratorFunction and its helpers.""" + + async def test_async_call(self): + """Ensure async call formats prompt and strips whitespace.""" + + class DummyAsyncGenerator(AsyncBaseGenerator): + async def generate(self, prompt: str, **kwargs) -> str: + return f" {prompt.upper()} " + + func = AsyncGeneratorFunction(DummyAsyncGenerator(), "Hello {name}") + result = await func({"name": "world"}) + assert result == "HELLO WORLD" + + async def test_stream_concurrency(self): + """Ensure stream runs with bounded concurrency and preserves order.""" + + class ConcurrentGenerator(AsyncBaseGenerator): + def __init__(self): + self.active = 0 + self.max_active = 0 + + async def generate(self, prompt: str, **kwargs) -> str: + self.active += 1 + self.max_active = max(self.max_active, self.active) + try: + await asyncio.sleep(0.01) + return prompt + finally: + self.active -= 1 + + generator = ConcurrentGenerator() + func = AsyncGeneratorFunction(generator, "item {value}") + contexts = [{"value": i} for i in range(4)] + + results = [] + async for value in func.stream(contexts, concurrency=2): + results.append(value) + + assert results == [f"item {i}" for i in range(4)] + assert generator.max_active == 2 + + async def test_stream_exceptions(self): + """Ensure exceptions can be captured or raised.""" + + class FailingGenerator(AsyncBaseGenerator): + async def generate(self, prompt: str, **kwargs) -> str: + if "fail" in prompt: + raise ValueError("boom") + return prompt + + func = AsyncGeneratorFunction(FailingGenerator(), "{value}") + contexts = [{"value": "ok"}, {"value": "fail"}, {"value": "later"}] + + results = [] + async for value in func.stream(contexts, return_exceptions=True): + results.append(value) + + assert isinstance(results[1], ValueError) + assert results[0] == "ok" + assert results[2] == "later" + + with pytest.raises(ValueError): + async for _ in func.stream(contexts): + pass + + class TestGeneratorClient: """Test GeneratorClient interface.""" @@ -251,6 +381,32 @@ def test_callable_returns_generator_function(self, mock_openai_gen): assert func.prompt_template == "Template {var}" +class TestAsyncGeneratorClient: + """Test AsyncGeneratorClient interface.""" + + @patch('chatan.generator.AsyncOpenAIGenerator') + def test_openai_async_client_creation(self, mock_openai_gen): + client = AsyncGeneratorClient("openai", "test-key", temperature=0.2) + mock_openai_gen.assert_called_once_with("test-key", temperature=0.2) + + @patch('chatan.generator.AsyncAnthropicGenerator') + def test_anthropic_async_client_creation(self, mock_anthropic_gen): + client = AsyncGeneratorClient("anthropic", "test-key", model="claude") + mock_anthropic_gen.assert_called_once_with("test-key", model="claude") + + def test_async_unsupported_provider(self): + with pytest.raises(ValueError, match="Unsupported provider"): + AsyncGeneratorClient("invalid", "key") + + @patch('chatan.generator.AsyncOpenAIGenerator') + def test_callable_returns_async_function(self, mock_openai_gen): + client = AsyncGeneratorClient("openai", "test-key") + func = client("Template {var}") + + assert isinstance(func, AsyncGeneratorFunction) + assert func.prompt_template == "Template {var}" + + class TestGeneratorFactory: """Test generator factory function.""" @@ -278,6 +434,25 @@ def test_transformers_provider_no_key(self, mock_client): mock_client.assert_called_once_with("transformers", None, model="gpt2") +class TestAsyncGeneratorFactory: + """Test async generator factory function.""" + + def test_missing_api_key(self): + with pytest.raises(ValueError, match="API key is required"): + async_generator("openai") + + @patch('chatan.generator.AsyncGeneratorClient') + def test_factory_creates_client(self, mock_client): + result = async_generator("openai", "test-key", temperature=0.5) + mock_client.assert_called_once_with("openai", "test-key", temperature=0.5) + assert result is mock_client.return_value + + @patch('chatan.generator.AsyncGeneratorClient') + def test_default_provider(self, mock_client): + async_generator(api_key="test-key") + mock_client.assert_called_once_with("openai", "test-key") + + class TestIntegration: """Integration tests for generator components."""