Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,36 @@ async def main():
df = asyncio.run(main())
```

### Explicit Dependencies For Callables

For plain callables, declare dependencies explicitly when one column needs another:

```python
from chatan import call, dataset

ds = dataset({
"file_path": call(lambda: get_random_filepath()),
"file_content": call(
lambda ctx: get_file_content(ctx["file_path"]),
requires=["file_path"],
),
})
```

You can also use tuple syntax: `"file_content": (callable_fn, ["file_path"])`.

If your callable argument names match column names, dependencies are inferred automatically:

```python
def get_file_chunk(file_path):
return load_chunk(file_path)

ds = dataset({
"file_path": call(get_random_filepath),
"file_chunk": call(get_file_chunk), # infers dependency on file_path
})
```

## Generator Options

### API-based Generators (included in base install)
Expand Down
4 changes: 3 additions & 1 deletion src/chatan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

__version__ = "0.3.0"

from .dataset import dataset
from .dataset import call, dataset, depends_on
from .evaluate import eval, evaluate
from .generator import generator
from .sampler import sample
from .viewer import generate_with_viewer

__all__ = [
"dataset",
"call",
"depends_on",
"generator",
"sample",
"generate_with_viewer",
Expand Down
203 changes: 194 additions & 9 deletions src/chatan/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Dataset creation and manipulation with async generation."""

import asyncio
from typing import Any, Dict, List, Optional, Union
import inspect
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import pandas as pd
from datasets import Dataset as HFDataset
Expand All @@ -11,6 +12,8 @@
from .generator import GeneratorFunction
from .sampler import SampleFunction

CONTEXT_ARG_NAMES = {"ctx", "context"}


class Dataset:
"""Async dataset generator with dependency-aware execution."""
Expand Down Expand Up @@ -156,7 +159,7 @@ async def _generate_column_value(
await completion_events[dep].wait()

# Generate the value
func = self.schema[column]
func = self._resolve_column_callable(self.schema[column])

if isinstance(func, GeneratorFunction):
# Use async generator
Expand All @@ -165,11 +168,8 @@ async def _generate_column_value(
# Samplers are sync but fast
value = func(row)
elif callable(func):
# Check if it's an async callable
if asyncio.iscoroutinefunction(func):
value = await func(row)
else:
value = func(row)
result = _invoke_with_context(func, row)
value = await result if asyncio.iscoroutine(result) else result
else:
# Static value
value = func
Expand All @@ -188,19 +188,101 @@ def _build_dependency_graph(self) -> Dict[str, List[str]]:

for column, func in self.schema.items():
deps = []
explicit_deps = self._extract_explicit_dependencies(func)
if explicit_deps:
deps.extend(explicit_deps)
deps.extend(self._infer_signature_dependencies(func, column))

# Extract dependencies from generator functions
if hasattr(func, "prompt_template"):
import re

template = getattr(func, "prompt_template", "")
deps = re.findall(r"\{(\w+)\}", template)
deps.extend(re.findall(r"\{(\w+)\}", template))

# Keep dependency order stable and remove duplicates.
ordered_deps = []
for dep in deps:
if dep not in ordered_deps:
ordered_deps.append(dep)

# Only include dependencies that are in the schema
dependencies[column] = [dep for dep in deps if dep in self.schema]
dependencies[column] = [dep for dep in ordered_deps if dep in self.schema]

return dependencies

@staticmethod
def _resolve_column_callable(func: Any) -> Any:
"""Unwrap schema value to the executable callable/value."""
if isinstance(func, DependentCallable):
return func.func

if (
isinstance(func, tuple)
and len(func) == 2
and callable(func[0])
):
return func[0]

return func

@staticmethod
def _extract_explicit_dependencies(func: Any) -> List[str]:
"""Extract explicit dependency declarations from schema value."""
deps = []

if isinstance(func, DependentCallable):
deps = func.dependencies
elif (
isinstance(func, tuple)
and len(func) == 2
and callable(func[0])
):
deps = func[1]
elif hasattr(func, "dependencies"):
deps = getattr(func, "dependencies")
elif hasattr(func, "depends_on"):
deps = getattr(func, "depends_on")

if deps is None:
return []
if isinstance(deps, str):
return [deps]
if isinstance(deps, Iterable):
return [dep for dep in deps if isinstance(dep, str)]
return []

def _infer_signature_dependencies(self, func: Any, current_column: str) -> List[str]:
"""Infer dependencies from callable parameter names."""
target = self._resolve_column_callable(func)
if not callable(target):
return []
if isinstance(target, (GeneratorFunction, SampleFunction)):
return []

try:
signature = inspect.signature(target)
except (TypeError, ValueError):
return []

inferred = []
for param in signature.parameters.values():
if param.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
continue
if param.name in CONTEXT_ARG_NAMES:
continue
if param.default is not inspect.Parameter.empty:
continue
if param.name == current_column:
continue
if param.name in self.schema:
inferred.append(param.name)

return inferred

def _topological_sort(self, dependencies: Dict[str, List[str]]) -> List[str]:
"""Topologically sort columns by dependencies."""
visited = set()
Expand Down Expand Up @@ -272,3 +354,106 @@ def dataset(schema: Union[Dict[str, Any], str], n: int = 100) -> Dataset:
>>> df = asyncio.run(main())
"""
return Dataset(schema, n)


class DependentCallable:
"""Wrapper for callables with explicit column dependencies."""

def __init__(self, func: Callable[..., Any], dependencies: List[str]):
self.func = func
self.dependencies = dependencies

def __call__(self, context: Dict[str, Any]) -> Any:
return _invoke_with_context(self.func, context)


def call(
func: Callable[..., Any],
*dependencies: str,
requires: Optional[List[str]] = None,
with_: Optional[List[str]] = None,
**kwargs,
) -> DependentCallable:
"""Declare callable schema entries and optional explicit dependencies.

Example:
schema = {
"file_path": call(lambda: random_path()),
"file_content": call(
lambda ctx: load(ctx["file_path"]),
requires=["file_path"],
),
}
"""
with_deps = kwargs.pop("with", None)
if kwargs:
unexpected = ", ".join(kwargs.keys())
raise TypeError(f"Unexpected keyword argument(s): {unexpected}")

explicit = list(dependencies)
if requires:
if isinstance(requires, str):
explicit.append(requires)
else:
explicit.extend(requires)
if with_:
explicit.extend(with_)
if with_deps:
if isinstance(with_deps, str):
explicit.append(with_deps)
else:
explicit.extend(with_deps)

return DependentCallable(func, explicit)


def depends_on(func: Callable[..., Any], *dependencies: str) -> DependentCallable:
"""Backward-compatible alias for explicit callable dependencies."""
return call(func, *dependencies)


def _invoke_with_context(func: Callable[..., Any], context: Dict[str, Any]) -> Any:
"""Invoke callable using context-aware argument mapping."""
try:
signature = inspect.signature(func)
except (TypeError, ValueError):
# Keep legacy behavior for callables without inspect metadata.
return func(context)

params = list(signature.parameters.values())
if not params:
return func()

args = []
kwargs = {}
missing_required = False

for param in params:
if param.kind == inspect.Parameter.VAR_POSITIONAL:
continue
if param.kind == inspect.Parameter.VAR_KEYWORD:
continue

if param.name in CONTEXT_ARG_NAMES:
value = context
has_value = True
else:
has_value = param.name in context
value = context.get(param.name)

if not has_value and param.default is inspect.Parameter.empty:
missing_required = True
continue
if not has_value:
continue

if param.kind == inspect.Parameter.POSITIONAL_ONLY:
args.append(value)
else:
kwargs[param.name] = value

if not missing_required:
return func(*args, **kwargs)

# Backward compatibility for legacy callables that expect `ctx`.
return func(context)
Loading