From f732fe4053060be64fdb770185afb7ec107d3caa Mon Sep 17 00:00:00 2001 From: Christian Reetz Date: Sat, 7 Feb 2026 22:11:04 -0800 Subject: [PATCH 1/4] Support explicit dependencies for callable dataset columns --- README.md | 18 ++++++ src/chatan/__init__.py | 3 +- src/chatan/dataset.py | 90 ++++++++++++++++++++++++++--- tests/test_dataset_comprehensive.py | 46 ++++++++++++++- 4 files changed, 146 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index c1fad43..1189522 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,24 @@ async def main(): df = asyncio.run(main()) ``` +### Explicit Dependencies For Callables + +For plain callables, declare dependencies explicitly when one column needs another: + +```python +from chatan import dataset, depends_on + +ds = dataset({ + "file_path": lambda ctx: get_random_filepath(), + "file_content": depends_on( + lambda ctx: get_file_content(ctx["file_path"]), + "file_path", + ), +}) +``` + +You can also use tuple syntax: `"file_content": (callable_fn, ["file_path"])`. + ## Generator Options ### API-based Generators (included in base install) diff --git a/src/chatan/__init__.py b/src/chatan/__init__.py index defa38c..fdeeb5b 100644 --- a/src/chatan/__init__.py +++ b/src/chatan/__init__.py @@ -2,7 +2,7 @@ __version__ = "0.3.0" -from .dataset import dataset +from .dataset import dataset, depends_on from .evaluate import eval, evaluate from .generator import generator from .sampler import sample @@ -10,6 +10,7 @@ __all__ = [ "dataset", + "depends_on", "generator", "sample", "generate_with_viewer", diff --git a/src/chatan/dataset.py b/src/chatan/dataset.py index 31ecba9..3340060 100644 --- a/src/chatan/dataset.py +++ b/src/chatan/dataset.py @@ -1,7 +1,7 @@ """Dataset creation and manipulation with async generation.""" import asyncio -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union import pandas as pd from datasets import Dataset as HFDataset @@ -156,7 +156,7 @@ async def _generate_column_value( await completion_events[dep].wait() # Generate the value - func = self.schema[column] + func = self._resolve_column_callable(self.schema[column]) if isinstance(func, GeneratorFunction): # Use async generator @@ -165,11 +165,8 @@ async def _generate_column_value( # Samplers are sync but fast value = func(row) elif callable(func): - # Check if it's an async callable - if asyncio.iscoroutinefunction(func): - value = await func(row) - else: - value = func(row) + result = func(row) + value = await result if asyncio.iscoroutine(result) else result else: # Static value value = func @@ -188,19 +185,69 @@ def _build_dependency_graph(self) -> Dict[str, List[str]]: for column, func in self.schema.items(): deps = [] + explicit_deps = self._extract_explicit_dependencies(func) + if explicit_deps: + deps.extend(explicit_deps) # Extract dependencies from generator functions if hasattr(func, "prompt_template"): import re template = getattr(func, "prompt_template", "") - deps = re.findall(r"\{(\w+)\}", template) + deps.extend(re.findall(r"\{(\w+)\}", template)) + + # Keep dependency order stable and remove duplicates. + ordered_deps = [] + for dep in deps: + if dep not in ordered_deps: + ordered_deps.append(dep) # Only include dependencies that are in the schema - dependencies[column] = [dep for dep in deps if dep in self.schema] + dependencies[column] = [dep for dep in ordered_deps if dep in self.schema] return dependencies + @staticmethod + def _resolve_column_callable(func: Any) -> Any: + """Unwrap schema value to the executable callable/value.""" + if isinstance(func, DependentCallable): + return func.func + + if ( + isinstance(func, tuple) + and len(func) == 2 + and callable(func[0]) + ): + return func[0] + + return func + + @staticmethod + def _extract_explicit_dependencies(func: Any) -> List[str]: + """Extract explicit dependency declarations from schema value.""" + deps = [] + + if isinstance(func, DependentCallable): + deps = func.dependencies + elif ( + isinstance(func, tuple) + and len(func) == 2 + and callable(func[0]) + ): + deps = func[1] + elif hasattr(func, "dependencies"): + deps = getattr(func, "dependencies") + elif hasattr(func, "depends_on"): + deps = getattr(func, "depends_on") + + if deps is None: + return [] + if isinstance(deps, str): + return [deps] + if isinstance(deps, Iterable): + return [dep for dep in deps if isinstance(dep, str)] + return [] + def _topological_sort(self, dependencies: Dict[str, List[str]]) -> List[str]: """Topologically sort columns by dependencies.""" visited = set() @@ -272,3 +319,28 @@ def dataset(schema: Union[Dict[str, Any], str], n: int = 100) -> Dataset: >>> df = asyncio.run(main()) """ return Dataset(schema, n) + + +class DependentCallable: + """Wrapper for callables with explicit column dependencies.""" + + def __init__(self, func: Callable[[Dict[str, Any]], Any], dependencies: List[str]): + self.func = func + self.dependencies = dependencies + + def __call__(self, context: Dict[str, Any]) -> Any: + return self.func(context) + + +def depends_on( + func: Callable[[Dict[str, Any]], Any], *dependencies: str +) -> DependentCallable: + """Declare explicit column dependencies for callable schema entries. + + Example: + schema = { + "file_path": lambda ctx: "...", + "file_content": depends_on(lambda ctx: load(ctx["file_path"]), "file_path"), + } + """ + return DependentCallable(func, list(dependencies)) diff --git a/tests/test_dataset_comprehensive.py b/tests/test_dataset_comprehensive.py index 843aaa5..88b436d 100644 --- a/tests/test_dataset_comprehensive.py +++ b/tests/test_dataset_comprehensive.py @@ -7,7 +7,7 @@ from unittest.mock import Mock from datasets import Dataset as HFDataset -from chatan.dataset import Dataset, dataset +from chatan.dataset import Dataset, dataset, depends_on from chatan.generator import GeneratorFunction, BaseGenerator from chatan.sampler import ChoiceSampler, UUIDSampler @@ -132,6 +132,20 @@ def test_dependency_outside_schema(self): # external_col should be filtered out assert dependencies["col2"] == ["col1"] + def test_explicit_callable_dependencies(self): + """Test explicit dependencies for callables.""" + schema = { + "col1": ChoiceSampler(["A"]), + "col2": depends_on(lambda ctx: f"v:{ctx['col1']}", "col1"), + "col3": (lambda ctx: f"w:{ctx['col2']}", ["col2"]), + } + ds = Dataset(schema, n=2) + + dependencies = ds._build_dependency_graph() + assert dependencies["col1"] == [] + assert dependencies["col2"] == ["col1"] + assert dependencies["col3"] == ["col2"] + @pytest.mark.asyncio class TestDataGeneration: @@ -209,6 +223,36 @@ async def test_complex_dependency_chain(self): assert row["d"] == row["b"] + row["c"] assert row["e"] == row["a"] + row["d"] + async def test_callable_depends_on_wrapper(self): + """Test callable dependencies via depends_on wrapper.""" + schema = { + "file_path": lambda ctx: "src/main.ts", + "file_content": depends_on( + lambda ctx: f"content:{ctx['file_path']}", + "file_path", + ), + } + ds = Dataset(schema, n=5) + df = await ds.generate() + + assert len(df) == 5 + assert all(df["file_content"] == "content:src/main.ts") + + async def test_callable_tuple_dependency_spec(self): + """Test callable dependencies via tuple spec.""" + schema = { + "file_path": lambda ctx: "src/app.ts", + "file_content": ( + lambda ctx: f"content:{ctx['file_path']}", + ["file_path"], + ), + } + ds = Dataset(schema, n=5) + df = await ds.generate() + + assert len(df) == 5 + assert all(df["file_content"] == "content:src/app.ts") + async def test_override_sample_count(self): """Test overriding sample count in generate().""" schema = {"col": ChoiceSampler(["A"])} From a6b2eb82fe4e80330c339be774a4b90f23cd1bbf Mon Sep 17 00:00:00 2001 From: Christian Reetz Date: Sat, 7 Feb 2026 22:16:27 -0800 Subject: [PATCH 2/4] Add call() wrapper API for callable dataset columns --- README.md | 9 ++-- src/chatan/__init__.py | 3 +- src/chatan/dataset.py | 64 ++++++++++++++++++++++++----- tests/test_dataset_comprehensive.py | 44 ++++++++++++++++---- 4 files changed, 97 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 1189522..97e0591 100644 --- a/README.md +++ b/README.md @@ -50,18 +50,19 @@ df = asyncio.run(main()) For plain callables, declare dependencies explicitly when one column needs another: ```python -from chatan import dataset, depends_on +from chatan import call, dataset ds = dataset({ - "file_path": lambda ctx: get_random_filepath(), - "file_content": depends_on( + "file_path": call(lambda: get_random_filepath()), + "file_content": call( lambda ctx: get_file_content(ctx["file_path"]), - "file_path", + with_=["file_path"], ), }) ``` You can also use tuple syntax: `"file_content": (callable_fn, ["file_path"])`. +(`with` is a Python keyword, so use `with_` in normal calls.) ## Generator Options diff --git a/src/chatan/__init__.py b/src/chatan/__init__.py index fdeeb5b..8e865b9 100644 --- a/src/chatan/__init__.py +++ b/src/chatan/__init__.py @@ -2,7 +2,7 @@ __version__ = "0.3.0" -from .dataset import dataset, depends_on +from .dataset import call, dataset, depends_on from .evaluate import eval, evaluate from .generator import generator from .sampler import sample @@ -10,6 +10,7 @@ __all__ = [ "dataset", + "call", "depends_on", "generator", "sample", diff --git a/src/chatan/dataset.py b/src/chatan/dataset.py index 3340060..ed8d7a4 100644 --- a/src/chatan/dataset.py +++ b/src/chatan/dataset.py @@ -1,6 +1,7 @@ """Dataset creation and manipulation with async generation.""" import asyncio +import inspect from typing import Any, Callable, Dict, Iterable, List, Optional, Union import pandas as pd @@ -210,9 +211,6 @@ def _build_dependency_graph(self) -> Dict[str, List[str]]: @staticmethod def _resolve_column_callable(func: Any) -> Any: """Unwrap schema value to the executable callable/value.""" - if isinstance(func, DependentCallable): - return func.func - if ( isinstance(func, tuple) and len(func) == 2 @@ -324,23 +322,67 @@ def dataset(schema: Union[Dict[str, Any], str], n: int = 100) -> Dataset: class DependentCallable: """Wrapper for callables with explicit column dependencies.""" - def __init__(self, func: Callable[[Dict[str, Any]], Any], dependencies: List[str]): + def __init__(self, func: Callable[..., Any], dependencies: List[str]): self.func = func self.dependencies = dependencies + self._accepts_context = _callable_accepts_context(func) def __call__(self, context: Dict[str, Any]) -> Any: - return self.func(context) + if self._accepts_context: + return self.func(context) + return self.func() -def depends_on( - func: Callable[[Dict[str, Any]], Any], *dependencies: str +def call( + func: Callable[..., Any], *dependencies: str, with_: Optional[List[str]] = None, **kwargs ) -> DependentCallable: - """Declare explicit column dependencies for callable schema entries. + """Declare callable schema entries and optional explicit dependencies. Example: schema = { - "file_path": lambda ctx: "...", - "file_content": depends_on(lambda ctx: load(ctx["file_path"]), "file_path"), + "file_path": call(lambda: random_path()), + "file_content": call( + lambda ctx: load(ctx["file_path"]), + with_=["file_path"], + ), } """ - return DependentCallable(func, list(dependencies)) + with_deps = kwargs.pop("with", None) + if kwargs: + unexpected = ", ".join(kwargs.keys()) + raise TypeError(f"Unexpected keyword argument(s): {unexpected}") + + explicit = list(dependencies) + if with_: + explicit.extend(with_) + if with_deps: + if isinstance(with_deps, str): + explicit.append(with_deps) + else: + explicit.extend(with_deps) + + return DependentCallable(func, explicit) + + +def depends_on(func: Callable[..., Any], *dependencies: str) -> DependentCallable: + """Backward-compatible alias for explicit callable dependencies.""" + return call(func, *dependencies) + + +def _callable_accepts_context(func: Callable[..., Any]) -> bool: + """Return True when callable can accept a context argument.""" + try: + signature = inspect.signature(func) + except (TypeError, ValueError): + # Builtins/c-extensions without inspect metadata: keep legacy behavior. + return True + + params = list(signature.parameters.values()) + if not params: + return False + + for param in params: + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + return True + + return True diff --git a/tests/test_dataset_comprehensive.py b/tests/test_dataset_comprehensive.py index 88b436d..3a9608c 100644 --- a/tests/test_dataset_comprehensive.py +++ b/tests/test_dataset_comprehensive.py @@ -7,7 +7,7 @@ from unittest.mock import Mock from datasets import Dataset as HFDataset -from chatan.dataset import Dataset, dataset, depends_on +from chatan.dataset import Dataset, call, dataset, depends_on from chatan.generator import GeneratorFunction, BaseGenerator from chatan.sampler import ChoiceSampler, UUIDSampler @@ -136,7 +136,7 @@ def test_explicit_callable_dependencies(self): """Test explicit dependencies for callables.""" schema = { "col1": ChoiceSampler(["A"]), - "col2": depends_on(lambda ctx: f"v:{ctx['col1']}", "col1"), + "col2": call(lambda ctx: f"v:{ctx['col1']}", with_=["col1"]), "col3": (lambda ctx: f"w:{ctx['col2']}", ["col2"]), } ds = Dataset(schema, n=2) @@ -223,13 +223,13 @@ async def test_complex_dependency_chain(self): assert row["d"] == row["b"] + row["c"] assert row["e"] == row["a"] + row["d"] - async def test_callable_depends_on_wrapper(self): - """Test callable dependencies via depends_on wrapper.""" + async def test_callable_call_wrapper(self): + """Test callable dependencies via call wrapper.""" schema = { - "file_path": lambda ctx: "src/main.ts", - "file_content": depends_on( + "file_path": call(lambda: "src/main.ts"), + "file_content": call( lambda ctx: f"content:{ctx['file_path']}", - "file_path", + with_=["file_path"], ), } ds = Dataset(schema, n=5) @@ -238,6 +238,36 @@ async def test_callable_depends_on_wrapper(self): assert len(df) == 5 assert all(df["file_content"] == "content:src/main.ts") + async def test_callable_call_wrapper_with_keyword_alias(self): + """Test call wrapper supports 'with' keyword alias via kwargs expansion.""" + schema = { + "file_path": call(lambda: "src/alias.ts"), + "file_content": call( + lambda ctx: f"content:{ctx['file_path']}", + **{"with": ["file_path"]}, + ), + } + ds = Dataset(schema, n=3) + df = await ds.generate() + + assert len(df) == 3 + assert all(df["file_content"] == "content:src/alias.ts") + + async def test_depends_on_backwards_compatible(self): + """Test depends_on still works as alias.""" + schema = { + "file_path": lambda ctx: "src/legacy.ts", + "file_content": depends_on( + lambda ctx: f"content:{ctx['file_path']}", + "file_path", + ), + } + ds = Dataset(schema, n=3) + df = await ds.generate() + + assert len(df) == 3 + assert all(df["file_content"] == "content:src/legacy.ts") + async def test_callable_tuple_dependency_spec(self): """Test callable dependencies via tuple spec.""" schema = { From 7beb91960028904569c3dc174ae4952e6074b453 Mon Sep 17 00:00:00 2001 From: Christian Reetz Date: Sat, 7 Feb 2026 23:41:31 -0800 Subject: [PATCH 3/4] Use requires keyword for callable dependencies --- README.md | 3 +-- src/chatan/dataset.py | 13 +++++++++++-- tests/test_dataset_comprehensive.py | 21 ++++++++++++++++++--- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 97e0591..611cab3 100644 --- a/README.md +++ b/README.md @@ -56,13 +56,12 @@ ds = dataset({ "file_path": call(lambda: get_random_filepath()), "file_content": call( lambda ctx: get_file_content(ctx["file_path"]), - with_=["file_path"], + requires=["file_path"], ), }) ``` You can also use tuple syntax: `"file_content": (callable_fn, ["file_path"])`. -(`with` is a Python keyword, so use `with_` in normal calls.) ## Generator Options diff --git a/src/chatan/dataset.py b/src/chatan/dataset.py index ed8d7a4..d19953b 100644 --- a/src/chatan/dataset.py +++ b/src/chatan/dataset.py @@ -334,7 +334,11 @@ def __call__(self, context: Dict[str, Any]) -> Any: def call( - func: Callable[..., Any], *dependencies: str, with_: Optional[List[str]] = None, **kwargs + func: Callable[..., Any], + *dependencies: str, + requires: Optional[List[str]] = None, + with_: Optional[List[str]] = None, + **kwargs, ) -> DependentCallable: """Declare callable schema entries and optional explicit dependencies. @@ -343,7 +347,7 @@ def call( "file_path": call(lambda: random_path()), "file_content": call( lambda ctx: load(ctx["file_path"]), - with_=["file_path"], + requires=["file_path"], ), } """ @@ -353,6 +357,11 @@ def call( raise TypeError(f"Unexpected keyword argument(s): {unexpected}") explicit = list(dependencies) + if requires: + if isinstance(requires, str): + explicit.append(requires) + else: + explicit.extend(requires) if with_: explicit.extend(with_) if with_deps: diff --git a/tests/test_dataset_comprehensive.py b/tests/test_dataset_comprehensive.py index 3a9608c..d2f80c5 100644 --- a/tests/test_dataset_comprehensive.py +++ b/tests/test_dataset_comprehensive.py @@ -136,7 +136,7 @@ def test_explicit_callable_dependencies(self): """Test explicit dependencies for callables.""" schema = { "col1": ChoiceSampler(["A"]), - "col2": call(lambda ctx: f"v:{ctx['col1']}", with_=["col1"]), + "col2": call(lambda ctx: f"v:{ctx['col1']}", requires=["col1"]), "col3": (lambda ctx: f"w:{ctx['col2']}", ["col2"]), } ds = Dataset(schema, n=2) @@ -229,7 +229,7 @@ async def test_callable_call_wrapper(self): "file_path": call(lambda: "src/main.ts"), "file_content": call( lambda ctx: f"content:{ctx['file_path']}", - with_=["file_path"], + requires=["file_path"], ), } ds = Dataset(schema, n=5) @@ -238,7 +238,7 @@ async def test_callable_call_wrapper(self): assert len(df) == 5 assert all(df["file_content"] == "content:src/main.ts") - async def test_callable_call_wrapper_with_keyword_alias(self): + async def test_callable_call_wrapper_with_with_keyword_alias(self): """Test call wrapper supports 'with' keyword alias via kwargs expansion.""" schema = { "file_path": call(lambda: "src/alias.ts"), @@ -253,6 +253,21 @@ async def test_callable_call_wrapper_with_keyword_alias(self): assert len(df) == 3 assert all(df["file_content"] == "content:src/alias.ts") + async def test_callable_call_wrapper_supports_requires_string(self): + """Test call wrapper supports requires as a single string.""" + schema = { + "file_path": call(lambda: "src/single.ts"), + "file_content": call( + lambda ctx: f"content:{ctx['file_path']}", + requires="file_path", + ), + } + ds = Dataset(schema, n=2) + df = await ds.generate() + + assert len(df) == 2 + assert all(df["file_content"] == "content:src/single.ts") + async def test_depends_on_backwards_compatible(self): """Test depends_on still works as alias.""" schema = { From b9d8d356ec18498dd75d4706bf9d3705d3c52ded Mon Sep 17 00:00:00 2001 From: Christian Reetz Date: Sun, 8 Feb 2026 00:03:39 -0800 Subject: [PATCH 4/4] Infer callable dependencies from function signatures --- README.md | 12 ++++ src/chatan/dataset.py | 88 ++++++++++++++++++++++++----- tests/test_dataset_comprehensive.py | 40 +++++++++++++ 3 files changed, 127 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 611cab3..b36e905 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,18 @@ ds = dataset({ You can also use tuple syntax: `"file_content": (callable_fn, ["file_path"])`. +If your callable argument names match column names, dependencies are inferred automatically: + +```python +def get_file_chunk(file_path): + return load_chunk(file_path) + +ds = dataset({ + "file_path": call(get_random_filepath), + "file_chunk": call(get_file_chunk), # infers dependency on file_path +}) +``` + ## Generator Options ### API-based Generators (included in base install) diff --git a/src/chatan/dataset.py b/src/chatan/dataset.py index d19953b..4531f20 100644 --- a/src/chatan/dataset.py +++ b/src/chatan/dataset.py @@ -12,6 +12,8 @@ from .generator import GeneratorFunction from .sampler import SampleFunction +CONTEXT_ARG_NAMES = {"ctx", "context"} + class Dataset: """Async dataset generator with dependency-aware execution.""" @@ -166,7 +168,7 @@ async def _generate_column_value( # Samplers are sync but fast value = func(row) elif callable(func): - result = func(row) + result = _invoke_with_context(func, row) value = await result if asyncio.iscoroutine(result) else result else: # Static value @@ -189,6 +191,7 @@ def _build_dependency_graph(self) -> Dict[str, List[str]]: explicit_deps = self._extract_explicit_dependencies(func) if explicit_deps: deps.extend(explicit_deps) + deps.extend(self._infer_signature_dependencies(func, column)) # Extract dependencies from generator functions if hasattr(func, "prompt_template"): @@ -211,6 +214,9 @@ def _build_dependency_graph(self) -> Dict[str, List[str]]: @staticmethod def _resolve_column_callable(func: Any) -> Any: """Unwrap schema value to the executable callable/value.""" + if isinstance(func, DependentCallable): + return func.func + if ( isinstance(func, tuple) and len(func) == 2 @@ -246,6 +252,37 @@ def _extract_explicit_dependencies(func: Any) -> List[str]: return [dep for dep in deps if isinstance(dep, str)] return [] + def _infer_signature_dependencies(self, func: Any, current_column: str) -> List[str]: + """Infer dependencies from callable parameter names.""" + target = self._resolve_column_callable(func) + if not callable(target): + return [] + if isinstance(target, (GeneratorFunction, SampleFunction)): + return [] + + try: + signature = inspect.signature(target) + except (TypeError, ValueError): + return [] + + inferred = [] + for param in signature.parameters.values(): + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + continue + if param.name in CONTEXT_ARG_NAMES: + continue + if param.default is not inspect.Parameter.empty: + continue + if param.name == current_column: + continue + if param.name in self.schema: + inferred.append(param.name) + + return inferred + def _topological_sort(self, dependencies: Dict[str, List[str]]) -> List[str]: """Topologically sort columns by dependencies.""" visited = set() @@ -325,12 +362,9 @@ class DependentCallable: def __init__(self, func: Callable[..., Any], dependencies: List[str]): self.func = func self.dependencies = dependencies - self._accepts_context = _callable_accepts_context(func) def __call__(self, context: Dict[str, Any]) -> Any: - if self._accepts_context: - return self.func(context) - return self.func() + return _invoke_with_context(self.func, context) def call( @@ -378,20 +412,48 @@ def depends_on(func: Callable[..., Any], *dependencies: str) -> DependentCallabl return call(func, *dependencies) -def _callable_accepts_context(func: Callable[..., Any]) -> bool: - """Return True when callable can accept a context argument.""" +def _invoke_with_context(func: Callable[..., Any], context: Dict[str, Any]) -> Any: + """Invoke callable using context-aware argument mapping.""" try: signature = inspect.signature(func) except (TypeError, ValueError): - # Builtins/c-extensions without inspect metadata: keep legacy behavior. - return True + # Keep legacy behavior for callables without inspect metadata. + return func(context) params = list(signature.parameters.values()) if not params: - return False + return func() + + args = [] + kwargs = {} + missing_required = False for param in params: - if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): - return True + if param.kind == inspect.Parameter.VAR_POSITIONAL: + continue + if param.kind == inspect.Parameter.VAR_KEYWORD: + continue + + if param.name in CONTEXT_ARG_NAMES: + value = context + has_value = True + else: + has_value = param.name in context + value = context.get(param.name) + + if not has_value and param.default is inspect.Parameter.empty: + missing_required = True + continue + if not has_value: + continue + + if param.kind == inspect.Parameter.POSITIONAL_ONLY: + args.append(value) + else: + kwargs[param.name] = value + + if not missing_required: + return func(*args, **kwargs) - return True + # Backward compatibility for legacy callables that expect `ctx`. + return func(context) diff --git a/tests/test_dataset_comprehensive.py b/tests/test_dataset_comprehensive.py index d2f80c5..c6ef37d 100644 --- a/tests/test_dataset_comprehensive.py +++ b/tests/test_dataset_comprehensive.py @@ -146,6 +146,27 @@ def test_explicit_callable_dependencies(self): assert dependencies["col2"] == ["col1"] assert dependencies["col3"] == ["col2"] + def test_signature_inferred_callable_dependencies(self): + """Test dependencies inferred from callable argument names.""" + + def col2(col1): + return f"v:{col1}" + + def col3(col2): + return f"w:{col2}" + + schema = { + "col1": ChoiceSampler(["A"]), + "col2": call(col2), + "col3": call(col3), + } + ds = Dataset(schema, n=2) + + dependencies = ds._build_dependency_graph() + assert dependencies["col1"] == [] + assert dependencies["col2"] == ["col1"] + assert dependencies["col3"] == ["col2"] + @pytest.mark.asyncio class TestDataGeneration: @@ -283,6 +304,25 @@ async def test_depends_on_backwards_compatible(self): assert len(df) == 3 assert all(df["file_content"] == "content:src/legacy.ts") + async def test_callable_call_wrapper_infers_dependencies_from_signature(self): + """Test call wrapper infers dependencies from function signature names.""" + + def file_path(): + return "src/inferred.ts" + + def file_content(file_path): + return f"content:{file_path}" + + schema = { + "file_path": call(file_path), + "file_content": call(file_content), + } + ds = Dataset(schema, n=3) + df = await ds.generate() + + assert len(df) == 3 + assert all(df["file_content"] == "content:src/inferred.ts") + async def test_callable_tuple_dependency_spec(self): """Test callable dependencies via tuple spec.""" schema = {