Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions src/orcapod/core/pods.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import hashlib
import logging
import sys
from abc import abstractmethod
from collections.abc import Callable, Collection, Iterable, Sequence
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast

from orcapod import contexts
from orcapod.core.datagrams import (
Expand Down Expand Up @@ -224,54 +223,50 @@ def track_invocation(self, *streams: cp.Stream, label: str | None = None) -> Non
self._tracker_manager.record_pod_invocation(self, streams, label=label)


class CallableWithPod(Protocol):
def __call__(self, *args, **kwargs) -> Any: ...

@property
def pod(self) -> "FunctionPod": ...


def function_pod(
output_keys: str | Collection[str] | None = None,
function_name: str | None = None,
version: str = "v0.0",
label: str | None = None,
**kwargs,
) -> Callable[..., "FunctionPod"]:
) -> Callable[..., CallableWithPod]:
"""
Decorator that wraps a function in a FunctionPod instance.
Decorator that attaches FunctionPod as pod attribute.

Args:
output_keys: Keys for the function output(s)
function_name: Name of the function pod; if None, defaults to the function name
**kwargs: Additional keyword arguments to pass to the FunctionPod constructor. Please refer to the FunctionPod documentation for details.

Returns:
FunctionPod instance wrapping the decorated function
CallableWithPod: Decorated function with `pod` attribute holding the FunctionPod instance
"""

def decorator(func) -> FunctionPod:
def decorator(func: Callable) -> CallableWithPod:
if func.__name__ == "<lambda>":
raise ValueError("Lambda functions cannot be used with function_pod")

if not hasattr(func, "__module__") or func.__module__ is None:
raise ValueError(
f"Function {func.__name__} must be defined at module level"
)

# Store the original function in the module for pickling purposes
# and make sure to change the name of the function
module = sys.modules[func.__module__]
base_function_name = func.__name__
new_function_name = f"_original_{func.__name__}"
setattr(module, new_function_name, func)
# rename the function to be consistent and make it pickleable
setattr(func, "__name__", new_function_name)
setattr(func, "__qualname__", new_function_name)

# Create a simple typed function pod
pod = FunctionPod(
function=func,
output_keys=output_keys,
function_name=function_name or base_function_name,
function_name=function_name or func.__name__,
version=version,
label=label,
**kwargs,
)
return pod
setattr(func, "pod", pod)
return cast(CallableWithPod, func)

return decorator

Expand Down
Loading