diff --git a/Justfile b/Justfile index 93c3702..49e2442 100644 --- a/Justfile +++ b/Justfile @@ -23,9 +23,9 @@ typecheck-pyright: PYRIGHT_PYTHON_IGNORE_WARNINGS=1 uv run pyright test-services/ typecheck-mypy: - uv run -m mypy --check-untyped-defs --ignore-missing-imports python/ - uv run -m mypy --check-untyped-defs --ignore-missing-imports examples/ - uv run -m mypy --check-untyped-defs --ignore-missing-imports tests/ + uv run -m mypy --check-untyped-defs --ignore-missing-imports --implicit-optional python/ + uv run -m mypy --check-untyped-defs --ignore-missing-imports --implicit-optional examples/ + uv run -m mypy --check-untyped-defs --ignore-missing-imports --implicit-optional tests/ typecheck: typecheck-pyright typecheck-mypy diff --git a/pyproject.toml b/pyproject.toml index 140c233..d01b149 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ serde = ["dacite", "pydantic", "msgspec"] client = ["httpx[http2]"] adk = ["google-adk>=1.20.0"] openai = ["openai-agents>=0.6.1"] +pydantic_ai = ["pydantic-ai-slim>=1.35.0"] [build-system] requires = ["maturin>=1.6,<2.0"] diff --git a/python/restate/ext/pydantic/__init__.py b/python/restate/ext/pydantic/__init__.py new file mode 100644 index 0000000..3f54d32 --- /dev/null +++ b/python/restate/ext/pydantic/__init__.py @@ -0,0 +1,12 @@ +from ._agent import RestateAgent +from ._model import RestateModelWrapper +from ._serde import PydanticTypeAdapter +from ._toolset import RestateContextRunToolSet + + +__all__ = [ + "RestateModelWrapper", + "RestateAgent", + "PydanticTypeAdapter", + "RestateContextRunToolSet", +] diff --git a/python/restate/ext/pydantic/_agent.py b/python/restate/ext/pydantic/_agent.py new file mode 100644 index 0000000..10baa6b --- /dev/null +++ b/python/restate/ext/pydantic/_agent.py @@ -0,0 +1,380 @@ +from __future__ import annotations + +from collections.abc import AsyncIterable, AsyncIterator, Iterator, Sequence +from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager +from typing import Any, overload + +from restate import TerminalError +from restate.ext.pydantic._utils import state_context +from restate.extensions import current_context + +from pydantic_ai import models +from pydantic_ai._run_context import AgentDepsT +from pydantic_ai.agent.abstract import AbstractAgent, EventStreamHandler, RunOutputDataT, Instructions +from pydantic_ai.agent.wrapper import WrapperAgent +from pydantic_ai.builtin_tools import AbstractBuiltinTool +from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import AgentStreamEvent, ModelMessage, UserContent +from pydantic_ai.models import Model +from pydantic_ai.output import OutputDataT, OutputSpec +from pydantic_ai.result import StreamedRunResult +from pydantic_ai.run import AgentRunResult +from pydantic_ai.settings import ModelSettings +from pydantic_ai.tools import DeferredToolResults, RunContext, BuiltinToolFunc +from pydantic_ai.toolsets.abstract import AbstractToolset +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.usage import RunUsage, UsageLimits +from ._model import RestateModelWrapper +from ._toolset import RestateContextRunToolSet + + +class RestateAgent(WrapperAgent[AgentDepsT, OutputDataT]): + """An agent that integrates with Restate framework for building resilient applications. + + This agent wraps an existing agent with Restate context capabilities, providing + automatic retries and durable execution for all operations. By default, tool calls + are automatically wrapped with Restate's execution model. + + Example: + ... + + weather = restate.Service('weather') + + agent = RestateAgent(weather_agent) + + @weather.handler() + async def get_weather(ctx: restate.Context, city: str): + result = await agent.run(f'What is the weather in {city}?') + return result.output + ... + + For advanced scenarios, you can disable automatic tool wrapping by setting + `disable_auto_wrapping_tools=True`. This allows direct usage of Restate context + within your tools for features like RPC calls, timers, and multi-step operations. + + When automatic wrapping is disabled, function tools will NOT be automatically executed + within Restate's `ctx.run()` context, giving you full control over how the + Restate context is used within your tool implementations. + But model calls, and MCP tool calls will still be automatically wrapped. + + Example: + ... + + @dataclass + WeatherDeps: + ... + restate_context: Context + + weather_agent = Agent(..., deps_type=WeatherDeps, ...) + + @weather_agent.tool + async def get_lat_lng(ctx: RunContext[WeatherDeps], location_description: str) -> LatLng: + restate_context = ctx.deps.restate_context + lat = await restate_context.run(...) # <---- note the direct usage of the restate context + lng = await restate_context.run(...) + return LatLng(lat, lng) + + + agent = RestateAgent(weather_agent) + + weather = restate.Service('weather') + + @weather.handler() + async def get_weather(ctx: restate.Context, city: str): + result = await agent.run(f'What is the weather in {city}?', deps=WeatherDeps(restate_context=ctx, ...)) + return result.output + ... + + """ + + def __init__( + self, + wrapped: AbstractAgent[AgentDepsT, OutputDataT], + *, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + disable_auto_wrapping_tools: bool = False, + ): + super().__init__(wrapped) + if not isinstance(wrapped.model, Model): + raise TerminalError( + "An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time." + ) + + self._event_stream_handler = event_stream_handler + self._disable_auto_wrapping_tools = disable_auto_wrapping_tools + self._model = RestateModelWrapper(wrapped.model, event_stream_handler=event_stream_handler, max_attempts=3) + + def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]: + """Set the Restate context for the toolset, wrapping tools if needed.""" + if isinstance(toolset, FunctionToolset) and not disable_auto_wrapping_tools: + return RestateContextRunToolSet(toolset) + try: + from pydantic_ai.mcp import MCPServer + + from ._toolset import RestateMCPServer + except ImportError: + pass + else: + if isinstance(toolset, MCPServer): + return RestateMCPServer(toolset) + + return toolset + + self._toolsets = [toolset.visit_and_replace(set_context) for toolset in wrapped.toolsets] + + @contextmanager + def _restate_overrides(self) -> Iterator[None]: + with ( + super().override(model=self._model, toolsets=self._toolsets, tools=[]), + ): + with state_context(): + yield + + @property + def model(self) -> models.Model | models.KnownModelName | str | None: + return self._model + + @property + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + handler = self._event_stream_handler or super().event_stream_handler + if handler is None: + return None + if self._disable_auto_wrapping_tools: + return handler + return self.wrapped_event_stream_handler + + async def wrapped_event_stream_handler( + self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[AgentStreamEvent] + ) -> None: + fn = self._event_stream_handler + if fn is None: + return + context = current_context() + if context is None: + raise UserError("No Restate context found for RestateAgent event stream handler.") + + async for event in stream: + + async def single_event(): + yield event + + await context.run_typed("run event", lambda: fn(ctx, single_event())) + + @property + def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]: + with self._restate_overrides(): + return super().toolsets + + @overload + async def run( + self, + user_prompt: str | Sequence[UserContent] | None = None, + *, + output_type: None = None, + message_history: Sequence[ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + instructions: Instructions[AgentDepsT] = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ) -> AgentRunResult[OutputDataT]: ... + + @overload + async def run( + self, + user_prompt: str | Sequence[UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: Sequence[ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + instructions: Instructions[AgentDepsT] = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ) -> AgentRunResult[RunOutputDataT]: ... + + async def run( + self, + user_prompt: str | Sequence[UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: Sequence[ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + instructions: Instructions[AgentDepsT] = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ) -> AgentRunResult[Any]: + """Run the agent with a user prompt in async mode. + + This method builds an internal agent graph (using system prompts, tools and output schemas) and then + runs the graph to completion. The result of the run is returned. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + agent_run = await agent.run('What is the capital of France?') + print(agent_run.output) + #> The capital of France is Paris. + ``` + + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + deferred_tool_results: Optional results for deferred tool calls in the message history. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + instructions: Optional additional instructions to use for this run. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run. + builtin_tools: Optional additional builtin tools for this run. + + Returns: + The result of the run. + """ + if model is not None: + raise TerminalError( + "An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time." + ) + with self._restate_overrides(): + return await super(WrapperAgent, self).run( + user_prompt=user_prompt, + output_type=output_type, + message_history=message_history, + instructions=instructions, + deferred_tool_results=deferred_tool_results, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + builtin_tools=builtin_tools, + event_stream_handler=event_stream_handler, + ) + + @overload + def run_stream( + self, + user_prompt: str | Sequence[UserContent] | None = None, + *, + output_type: None = None, + message_history: Sequence[ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + instructions: Instructions[AgentDepsT] = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ... + + @overload + def run_stream( + self, + user_prompt: str | Sequence[UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: Sequence[ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + instructions: Instructions[AgentDepsT] = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... + + @asynccontextmanager + async def run_stream( + self, + user_prompt: str | Sequence[UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: Sequence[ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + instructions: Instructions[AgentDepsT] = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + **_deprecated_kwargs: Any, + ) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]]: + """Run the agent with a user prompt in async mode, returning a streamed response. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + async with agent.run_stream('What is the capital of the UK?') as response: + print(await response.get_output()) + #> The capital of the UK is London. + ``` + + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + deferred_tool_results: Optional results for deferred tool calls in the message history. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + instructions: Optional additional instructions to use for this run. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + builtin_tools: Optional additional builtin tools for this run. + event_stream_handler: Optional event stream handler to use for this run. It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager. + + Returns: + The result of the run. + """ + raise UserError("RestateAgent does not support run_stream currently.") + yield diff --git a/python/restate/ext/pydantic/_model.py b/python/restate/ext/pydantic/_model.py new file mode 100644 index 0000000..83a3d3e --- /dev/null +++ b/python/restate/ext/pydantic/_model.py @@ -0,0 +1,117 @@ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from datetime import datetime +from typing import Any + +from restate import RunOptions, SdkInternalBaseException +from restate.ext.pydantic._utils import current_state +from restate.extensions import current_context +from restate.ext.turnstile import Turnstile + +from pydantic_ai.agent.abstract import EventStreamHandler +from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import ModelMessage, ModelResponse, ModelResponseStreamEvent +from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse +from pydantic_ai.models.wrapper import WrapperModel +from pydantic_ai.settings import ModelSettings +from pydantic_ai.tools import AgentDepsT, RunContext +from pydantic_ai.usage import RequestUsage + + +from ._serde import PydanticTypeAdapter + +MODEL_RESPONSE_SERDE = PydanticTypeAdapter(ModelResponse) + + +class RestateStreamedResponse(StreamedResponse): + def __init__(self, model_request_parameters: ModelRequestParameters, response: ModelResponse): + super().__init__(model_request_parameters) + self.response = response + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + return + # noinspection PyUnreachableCode + yield + + def get(self) -> ModelResponse: + return self.response + + def usage(self) -> RequestUsage: + return self.response.usage # pragma: no cover + + @property + def model_name(self) -> str: + return self.response.model_name or "" # pragma: no cover + + @property + def provider_name(self) -> str: + return self.response.provider_name or "" # pragma: no cover + + @property + def timestamp(self) -> datetime: + return self.response.timestamp # pragma: no cover + + @property + def provider_url(self) -> str | None: + return None + + +class RestateModelWrapper(WrapperModel): + def __init__( + self, + wrapped: Model, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + max_attempts: int | None = None, + ): + super().__init__(wrapped) + self._options = RunOptions(serde=MODEL_RESPONSE_SERDE, max_attempts=max_attempts) + self._event_stream_handler = event_stream_handler + + async def request(self, *args: Any, **kwargs: Any) -> ModelResponse: + context = current_context() + if context is None: + raise UserError( + "A model cannot be used without a Restate context. Make sure to run it within an agent or a run context." + ) + try: + res = await context.run_typed("Model call", self.wrapped.request, self._options, *args, **kwargs) + ids = [c.tool_call_id for c in res.tool_calls] + current_state().turnstile = Turnstile(ids) + return res + except SdkInternalBaseException as e: + raise Exception("Internal error during model call") from e + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + run_context: RunContext[AgentDepsT] | None = None, + ) -> AsyncIterator[StreamedResponse]: + if run_context is None: + raise UserError( + "A model cannot be used with `pydantic_ai.direct.model_request_stream()` as it requires a `run_context`. Set an `event_stream_handler` on the agent and use `agent.run()` instead." + ) + + fn = self._event_stream_handler + assert fn is not None + + async def request_stream_run(): + async with self.wrapped.request_stream( + messages, + model_settings, + model_request_parameters, + run_context, + ) as streamed_response: + await fn(run_context, streamed_response) # type: ignore[arg-type] + + async for _ in streamed_response: + pass + return streamed_response.get() + + try: + response = await self._context.run_typed("Model stream call", request_stream_run, self._options) + yield RestateStreamedResponse(model_request_parameters, response) + except SdkInternalBaseException as e: + raise Exception("Internal error during model stream call") from e diff --git a/python/restate/ext/pydantic/_serde.py b/python/restate/ext/pydantic/_serde.py new file mode 100644 index 0000000..028a57f --- /dev/null +++ b/python/restate/ext/pydantic/_serde.py @@ -0,0 +1,45 @@ +import typing + +from pydantic import TypeAdapter +from restate.serde import Serde + +T = typing.TypeVar("T") + + +class PydanticTypeAdapter(Serde[T]): + """A serializer/deserializer for Pydantic models.""" + + def __init__(self, model_type: type[T]): + """Initializes a new instance of the PydanticTypeAdaptorSerde class. + + Args: + model_type (typing.Type[T]): The Pydantic model type to serialize/deserialize. + """ + self._model_type = TypeAdapter(model_type) + + def deserialize(self, buf: bytes) -> T | None: + """Deserializes a bytearray to a Pydantic model. + + Args: + buf (bytearray): The bytearray to deserialize. + + Returns: + typing.Optional[T]: The deserialized Pydantic model. + """ + if not buf: + return None + return self._model_type.validate_json(buf.decode("utf-8")) # raises if invalid + + def serialize(self, obj: T | None) -> bytes: + """Serializes a Pydantic model to a bytearray. + + Args: + obj (typing.Optional[T]): The Pydantic model to serialize. + + Returns: + bytes: The serialized bytearray. + """ + if obj is None: + return b"" + tpe = TypeAdapter(type(obj)) + return tpe.dump_json(obj) diff --git a/python/restate/ext/pydantic/_toolset.py b/python/restate/ext/pydantic/_toolset.py new file mode 100644 index 0000000..12fe239 --- /dev/null +++ b/python/restate/ext/pydantic/_toolset.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Literal + +from restate import RunOptions, SdkInternalBaseException, TerminalError +from restate.exceptions import SdkInternalException +from restate.extensions import current_context + +from pydantic_ai import ToolDefinition +from pydantic_ai._run_context import AgentDepsT +from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError +from pydantic_ai.mcp import MCPServer, ToolResult +from pydantic_ai.tools import RunContext +from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool +from pydantic_ai.toolsets.wrapper import WrapperToolset + +from ._serde import PydanticTypeAdapter +from ._utils import current_state + + +@dataclass +class RestateContextRunResult: + """A simple wrapper for tool results to be used with Restate's run_typed.""" + + kind: Literal["output", "call_deferred", "approval_required", "model_retry"] + output: Any + error: str | None = None + + +CONTEXT_RUN_SERDE = PydanticTypeAdapter(RestateContextRunResult) + + +@dataclass +class RestateMCPGetToolsContextRunResult: + """A simple wrapper for tool results to be used with Restate's run_typed.""" + + output: dict[str, ToolDefinition] + + +MCP_GET_TOOLS_SERDE = PydanticTypeAdapter(RestateMCPGetToolsContextRunResult) + + +@dataclass +class RestateMCPToolRunResult: + """A simple wrapper for tool results to be used with Restate's run_typed.""" + + output: ToolResult + + +MCP_RUN_SERDE = PydanticTypeAdapter(RestateMCPToolRunResult) + + +class RestateContextRunToolSet(WrapperToolset[AgentDepsT]): + """A toolset that automatically wraps tool calls with restate's `ctx.run_typed()`.""" + + def __init__(self, wrapped: AbstractToolset[AgentDepsT]): + super().__init__(wrapped) + self.options = RunOptions[RestateContextRunResult](serde=CONTEXT_RUN_SERDE) + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + async def action() -> RestateContextRunResult: + try: + # A tool may raise ModelRetry, CallDeferred, ApprovalRequired, or UserError + # to signal special conditions to the caller. + # Since, restate ctx.run() will retry this exception we need to convert these exceptions + # to a return value and handle them outside of the ctx.run(). + output = await self.wrapped.call_tool(name, tool_args, ctx, tool) + return RestateContextRunResult(kind="output", output=output, error=None) + except ModelRetry as e: + return RestateContextRunResult(kind="model_retry", output=None, error=e.message) + except CallDeferred: + return RestateContextRunResult(kind="call_deferred", output=None, error=None) + except ApprovalRequired: + return RestateContextRunResult(kind="approval_required", output=None, error=None) + except UserError as e: + raise TerminalError(str(e)) from e + + id = ctx.tool_call_id + if id is None: + raise TerminalError("Tool call ID is required for turnstile synchronization.") + context = current_context() + if context is None: + raise UserError( + "A tool cannot be used without a Restate context. Make sure to run it within an agent or a run context." + ) + turnstile = current_state().turnstile + try: + await turnstile.wait_for(id) + res = await context.run_typed(f"Calling {name}", action, self.options) + + if res.kind == "call_deferred": + raise CallDeferred() + elif res.kind == "approval_required": + raise ApprovalRequired() + elif res.kind == "model_retry": + assert res.error is not None + raise ModelRetry(res.error) + else: + assert res.kind == "output" + turnstile.allow_next_after(id) + return res.output + except Exception as e: + turnstile.cancel_all_after(id) + raise e from None + except SdkInternalException as e: + turnstile.cancel_all_after(id) + raise RuntimeError() from e + + def visit_and_replace( + self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] + ) -> AbstractToolset[AgentDepsT]: + return visitor(self) + + +class RestateMCPServer(WrapperToolset[AgentDepsT]): + """A wrapper for MCPServer that integrates with restate.""" + + def __init__(self, wrapped: MCPServer): + super().__init__(wrapped) + self._wrapped = wrapped + + def visit_and_replace( + self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] + ) -> AbstractToolset[AgentDepsT]: + return visitor(self) + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + async def get_tools_in_context() -> RestateMCPGetToolsContextRunResult: + res = await self._wrapped.get_tools(ctx) + # ToolsetTool is not serializable as it holds a SchemaValidator + # (which is also the same for every MCP tool so unnecessary to pass along the wire every time), + # so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity. + return RestateMCPGetToolsContextRunResult(output={name: tool.tool_def for name, tool in res.items()}) + + options = RunOptions(serde=MCP_GET_TOOLS_SERDE) + + context = current_context() + if context is None: + raise UserError( + "A toolset cannot be used without a Restate context. Make sure to run it within an agent or a run context." + ) + + try: + tool_defs = await context.run_typed("get mcp tools", get_tools_in_context, options) + return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.output.items()} + except SdkInternalBaseException as e: + raise Exception("Internal error during get_tools call") from e + + def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]: + assert isinstance(self.wrapped, MCPServer) + return self.wrapped.tool_for_tool_def(tool_def) + + async def call_tool( + self, + name: str, + tool_args: dict[str, Any], + ctx: RunContext[AgentDepsT], + tool: ToolsetTool[AgentDepsT], + ) -> ToolResult: + async def call_tool_in_context() -> RestateMCPToolRunResult: + res = await self._wrapped.call_tool(name, tool_args, ctx, tool) + return RestateMCPToolRunResult(output=res) + + options = RunOptions(serde=MCP_RUN_SERDE) + context = current_context() + if context is None: + raise UserError( + "A toolset cannot be used without a Restate context. Make sure to run it within an agent or a run context." + ) + try: + res = await context.run_typed(f"Calling mcp tool {name}", call_tool_in_context, options) + except SdkInternalBaseException as e: + raise Exception("Internal error during tool call") from e + + return res.output diff --git a/python/restate/ext/pydantic/_utils.py b/python/restate/ext/pydantic/_utils.py new file mode 100644 index 0000000..46503b4 --- /dev/null +++ b/python/restate/ext/pydantic/_utils.py @@ -0,0 +1,31 @@ +from contextlib import contextmanager +from contextvars import ContextVar + +from restate.ext.turnstile import Turnstile + + +class State: + __slots__ = ("turnstile",) + + def __init__(self): + self.turnstile = Turnstile([]) + + +restate_state_var = ContextVar("restate_state_var", default=State()) + + +def current_state() -> State: + return restate_state_var.get() + + +def set_current_state(state: State): + restate_state_var.set(state) + + +@contextmanager +def state_context(): + token = restate_state_var.set(State()) + try: + yield + finally: + restate_state_var.reset(token) diff --git a/uv.lock b/uv.lock index 9d45c2c..5121b40 100644 --- a/uv.lock +++ b/uv.lock @@ -407,6 +407,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/24/04/2f9e8a965f4214883258a6f716fea324d1b81e97bce6346cfbafffe6b86c/fastapi-0.118.3-py3-none-any.whl", hash = "sha256:8b9673dc083b4b9d3d295d49ba1c0a2abbfb293d34ba210fd9b0a90d5f39981e", size = 97957, upload-time = "2025-10-10T10:40:16.118Z" }, ] +[[package]] +name = "genai-prices" +version = "0.0.49" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0d/aa/81f76b90f8d1a7dcd9297bd8bf664927ae2a1efe40fe5d1a8856dc721359/genai_prices-0.0.49.tar.gz", hash = "sha256:a7f98f1537e6f89ed54f1cd8f560806e187033dcb42554fbecd4d635567120c5", size = 57852, upload-time = "2025-12-17T10:47:29.345Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/1e/1d51238dd164dde10c4e3be6ad2d8f26dd34dd262117c277440e2b5dc7c0/genai_prices-0.0.49-py3-none-any.whl", hash = "sha256:dd3efbebcd865d89cd849793530729e7f7e1ca59d2b17a091ad1aa6aa76daf0d", size = 60433, upload-time = "2025-12-17T10:47:28.3Z" }, +] + [[package]] name = "google-adk" version = "1.20.0" @@ -1340,6 +1353,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, ] +[[package]] +name = "logfire-api" +version = "4.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/d9/d5f5e276371d5c8cde559d558de44b8378641231a23f3a632ebfe4b05c9b/logfire_api-4.16.0.tar.gz", hash = "sha256:0efa62f5e73abdea670b5e9384c841b544474207110a089536a0fa8704f9e386", size = 57702, upload-time = "2025-12-04T16:16:40.725Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/6e/6d500ce6352c54566d03c65d92a8f3fc7045645814de046707b105dda2a6/logfire_api-4.16.0-py3-none-any.whl", hash = "sha256:7351153c35cb61f0f89d2d4123ebf99b5469d70ef34c613a5ce56f85bf1b14fb", size = 95247, upload-time = "2025-12-04T16:16:38.007Z" }, +] + [[package]] name = "mako" version = "1.3.10" @@ -2077,6 +2099,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d", size = 463580, upload-time = "2025-11-26T15:11:44.605Z" }, ] +[[package]] +name = "pydantic-ai-slim" +version = "1.35.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "genai-prices" }, + { name = "griffe" }, + { name = "httpx" }, + { name = "opentelemetry-api" }, + { name = "pydantic" }, + { name = "pydantic-graph" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dc/fd/4919262b1645ad3870ff7c59f9f8aa4aedb04803e60903a4ef503a49d00d/pydantic_ai_slim-1.35.0.tar.gz", hash = "sha256:5cd2184ecc2799a5f378abea1e0f1846dd6487b800c5be84a0b84a18e4213d20", size = 348494, upload-time = "2025-12-18T00:15:05.586Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/74/7a152ff16b5b7924217217c256b0c9de99dea8b97157e42c7a4cd807a8c1/pydantic_ai_slim-1.35.0-py3-none-any.whl", hash = "sha256:d0c1d9ea4de0e13ad2918811719cd36989794439ce33ee87cb3530c2c058d5ca", size = 454728, upload-time = "2025-12-18T00:14:58.251Z" }, +] + [[package]] name = "pydantic-core" version = "2.41.5" @@ -2195,6 +2236,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" }, ] +[[package]] +name = "pydantic-graph" +version = "1.35.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "logfire-api" }, + { name = "pydantic" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/36/f9/15d6169190d49ffb9da3337556e63dec46b3a7f17584425b7a67bd855c0b/pydantic_graph-1.35.0.tar.gz", hash = "sha256:775a4ff0d650e158bc42e97d0ab121a59b0748efe1a32760c28f0ab53f14d4da", size = 58455, upload-time = "2025-12-18T00:15:07.805Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/49/b615dd668cc48cfddb8f8b403e23762ae69a083b6db48e348d7549df2161/pydantic_graph-1.35.0-py3-none-any.whl", hash = "sha256:24c0653a47d21a51e76776b3ddbfcff28d3934d2e895b985f2638b0214acd3d4", size = 72325, upload-time = "2025-12-18T00:15:01.5Z" }, +] + [[package]] name = "pydantic-settings" version = "2.12.0" @@ -2441,6 +2497,9 @@ lint = [ openai = [ { name = "openai-agents" }, ] +pydantic-ai = [ + { name = "pydantic-ai-slim" }, +] serde = [ { name = "dacite" }, { name = "msgspec" }, @@ -2465,12 +2524,13 @@ requires-dist = [ { name = "mypy", marker = "extra == 'lint'", specifier = ">=1.11.2" }, { name = "openai-agents", marker = "extra == 'openai'", specifier = ">=0.6.1" }, { name = "pydantic", marker = "extra == 'serde'" }, + { name = "pydantic-ai-slim", marker = "extra == 'pydantic-ai'", specifier = ">=1.35.0" }, { name = "pyright", marker = "extra == 'lint'", specifier = ">=1.1.390" }, { name = "pytest", marker = "extra == 'test'" }, { name = "ruff", marker = "extra == 'lint'", specifier = ">=0.6.9" }, { name = "testcontainers", marker = "extra == 'harness'" }, ] -provides-extras = ["test", "lint", "harness", "serde", "client", "adk", "openai"] +provides-extras = ["test", "lint", "harness", "serde", "client", "adk", "openai", "pydantic-ai"] [[package]] name = "rpds-py"