Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ typecheck-pyright:
PYRIGHT_PYTHON_IGNORE_WARNINGS=1 uv run pyright test-services/

typecheck-mypy:
uv run -m mypy --check-untyped-defs --ignore-missing-imports python/
uv run -m mypy --check-untyped-defs --ignore-missing-imports examples/
uv run -m mypy --check-untyped-defs --ignore-missing-imports tests/
uv run -m mypy --check-untyped-defs --ignore-missing-imports --implicit-optional python/
uv run -m mypy --check-untyped-defs --ignore-missing-imports --implicit-optional examples/
uv run -m mypy --check-untyped-defs --ignore-missing-imports --implicit-optional tests/

typecheck: typecheck-pyright typecheck-mypy

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ serde = ["dacite", "pydantic", "msgspec"]
client = ["httpx[http2]"]
adk = ["google-adk>=1.20.0"]
openai = ["openai-agents>=0.6.1"]
pydantic_ai = ["pydantic-ai-slim>=1.35.0"]

[build-system]
requires = ["maturin>=1.6,<2.0"]
Expand Down
12 changes: 12 additions & 0 deletions python/restate/ext/pydantic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from ._agent import RestateAgent
from ._model import RestateModelWrapper
from ._serde import PydanticTypeAdapter
from ._toolset import RestateContextRunToolSet


__all__ = [
"RestateModelWrapper",
"RestateAgent",
"PydanticTypeAdapter",
"RestateContextRunToolSet",
]
380 changes: 380 additions & 0 deletions python/restate/ext/pydantic/_agent.py

Large diffs are not rendered by default.

117 changes: 117 additions & 0 deletions python/restate/ext/pydantic/_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Any

from restate import RunOptions, SdkInternalBaseException
from restate.ext.pydantic._utils import current_state
from restate.extensions import current_context
from restate.ext.turnstile import Turnstile

from pydantic_ai.agent.abstract import EventStreamHandler
from pydantic_ai.exceptions import UserError
from pydantic_ai.messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
from pydantic_ai.models.wrapper import WrapperModel
from pydantic_ai.settings import ModelSettings
from pydantic_ai.tools import AgentDepsT, RunContext
from pydantic_ai.usage import RequestUsage


from ._serde import PydanticTypeAdapter

MODEL_RESPONSE_SERDE = PydanticTypeAdapter(ModelResponse)


class RestateStreamedResponse(StreamedResponse):
def __init__(self, model_request_parameters: ModelRequestParameters, response: ModelResponse):
super().__init__(model_request_parameters)
self.response = response

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
return
# noinspection PyUnreachableCode
yield

def get(self) -> ModelResponse:
return self.response

def usage(self) -> RequestUsage:
return self.response.usage # pragma: no cover

@property
def model_name(self) -> str:
return self.response.model_name or "" # pragma: no cover

@property
def provider_name(self) -> str:
return self.response.provider_name or "" # pragma: no cover

@property
def timestamp(self) -> datetime:
return self.response.timestamp # pragma: no cover

@property
def provider_url(self) -> str | None:
return None


class RestateModelWrapper(WrapperModel):
def __init__(
self,
wrapped: Model,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
max_attempts: int | None = None,
):
super().__init__(wrapped)
self._options = RunOptions(serde=MODEL_RESPONSE_SERDE, max_attempts=max_attempts)
self._event_stream_handler = event_stream_handler

async def request(self, *args: Any, **kwargs: Any) -> ModelResponse:
context = current_context()
if context is None:
raise UserError(
"A model cannot be used without a Restate context. Make sure to run it within an agent or a run context."
)
try:
res = await context.run_typed("Model call", self.wrapped.request, self._options, *args, **kwargs)
ids = [c.tool_call_id for c in res.tool_calls]
current_state().turnstile = Turnstile(ids)
return res
except SdkInternalBaseException as e:
raise Exception("Internal error during model call") from e

@asynccontextmanager
async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
run_context: RunContext[AgentDepsT] | None = None,
) -> AsyncIterator[StreamedResponse]:
if run_context is None:
raise UserError(
"A model cannot be used with `pydantic_ai.direct.model_request_stream()` as it requires a `run_context`. Set an `event_stream_handler` on the agent and use `agent.run()` instead."
)

fn = self._event_stream_handler
assert fn is not None

async def request_stream_run():
async with self.wrapped.request_stream(
messages,
model_settings,
model_request_parameters,
run_context,
) as streamed_response:
await fn(run_context, streamed_response) # type: ignore[arg-type]

async for _ in streamed_response:
pass
return streamed_response.get()

try:
response = await self._context.run_typed("Model stream call", request_stream_run, self._options)
yield RestateStreamedResponse(model_request_parameters, response)
except SdkInternalBaseException as e:
raise Exception("Internal error during model stream call") from e
45 changes: 45 additions & 0 deletions python/restate/ext/pydantic/_serde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import typing

from pydantic import TypeAdapter
from restate.serde import Serde

T = typing.TypeVar("T")


class PydanticTypeAdapter(Serde[T]):
"""A serializer/deserializer for Pydantic models."""

def __init__(self, model_type: type[T]):
"""Initializes a new instance of the PydanticTypeAdaptorSerde class.

Args:
model_type (typing.Type[T]): The Pydantic model type to serialize/deserialize.
"""
self._model_type = TypeAdapter(model_type)

def deserialize(self, buf: bytes) -> T | None:
"""Deserializes a bytearray to a Pydantic model.

Args:
buf (bytearray): The bytearray to deserialize.

Returns:
typing.Optional[T]: The deserialized Pydantic model.
"""
if not buf:
return None
return self._model_type.validate_json(buf.decode("utf-8")) # raises if invalid

def serialize(self, obj: T | None) -> bytes:
"""Serializes a Pydantic model to a bytearray.

Args:
obj (typing.Optional[T]): The Pydantic model to serialize.

Returns:
bytes: The serialized bytearray.
"""
if obj is None:
return b""
tpe = TypeAdapter(type(obj))
return tpe.dump_json(obj)
179 changes: 179 additions & 0 deletions python/restate/ext/pydantic/_toolset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Literal

from restate import RunOptions, SdkInternalBaseException, TerminalError
from restate.exceptions import SdkInternalException
from restate.extensions import current_context

from pydantic_ai import ToolDefinition
from pydantic_ai._run_context import AgentDepsT
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
from pydantic_ai.mcp import MCPServer, ToolResult
from pydantic_ai.tools import RunContext
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
from pydantic_ai.toolsets.wrapper import WrapperToolset

from ._serde import PydanticTypeAdapter
from ._utils import current_state


@dataclass
class RestateContextRunResult:
"""A simple wrapper for tool results to be used with Restate's run_typed."""

kind: Literal["output", "call_deferred", "approval_required", "model_retry"]
output: Any
error: str | None = None


CONTEXT_RUN_SERDE = PydanticTypeAdapter(RestateContextRunResult)


@dataclass
class RestateMCPGetToolsContextRunResult:
"""A simple wrapper for tool results to be used with Restate's run_typed."""

output: dict[str, ToolDefinition]


MCP_GET_TOOLS_SERDE = PydanticTypeAdapter(RestateMCPGetToolsContextRunResult)


@dataclass
class RestateMCPToolRunResult:
"""A simple wrapper for tool results to be used with Restate's run_typed."""

output: ToolResult


MCP_RUN_SERDE = PydanticTypeAdapter(RestateMCPToolRunResult)


class RestateContextRunToolSet(WrapperToolset[AgentDepsT]):
"""A toolset that automatically wraps tool calls with restate's `ctx.run_typed()`."""

def __init__(self, wrapped: AbstractToolset[AgentDepsT]):
super().__init__(wrapped)
self.options = RunOptions[RestateContextRunResult](serde=CONTEXT_RUN_SERDE)

async def call_tool(
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
) -> Any:
async def action() -> RestateContextRunResult:
try:
# A tool may raise ModelRetry, CallDeferred, ApprovalRequired, or UserError
# to signal special conditions to the caller.
# Since, restate ctx.run() will retry this exception we need to convert these exceptions
# to a return value and handle them outside of the ctx.run().
output = await self.wrapped.call_tool(name, tool_args, ctx, tool)
return RestateContextRunResult(kind="output", output=output, error=None)
except ModelRetry as e:
return RestateContextRunResult(kind="model_retry", output=None, error=e.message)
except CallDeferred:
return RestateContextRunResult(kind="call_deferred", output=None, error=None)
except ApprovalRequired:
return RestateContextRunResult(kind="approval_required", output=None, error=None)
except UserError as e:
raise TerminalError(str(e)) from e

id = ctx.tool_call_id
if id is None:
raise TerminalError("Tool call ID is required for turnstile synchronization.")
context = current_context()
if context is None:
raise UserError(
"A tool cannot be used without a Restate context. Make sure to run it within an agent or a run context."
)
turnstile = current_state().turnstile
try:
await turnstile.wait_for(id)
res = await context.run_typed(f"Calling {name}", action, self.options)

if res.kind == "call_deferred":
raise CallDeferred()
elif res.kind == "approval_required":
raise ApprovalRequired()
elif res.kind == "model_retry":
assert res.error is not None
raise ModelRetry(res.error)
else:
assert res.kind == "output"
turnstile.allow_next_after(id)
return res.output
except Exception as e:
turnstile.cancel_all_after(id)
raise e from None
except SdkInternalException as e:
turnstile.cancel_all_after(id)
raise RuntimeError() from e

def visit_and_replace(
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
) -> AbstractToolset[AgentDepsT]:
return visitor(self)


class RestateMCPServer(WrapperToolset[AgentDepsT]):
"""A wrapper for MCPServer that integrates with restate."""

def __init__(self, wrapped: MCPServer):
super().__init__(wrapped)
self._wrapped = wrapped

def visit_and_replace(
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
) -> AbstractToolset[AgentDepsT]:
return visitor(self)

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
async def get_tools_in_context() -> RestateMCPGetToolsContextRunResult:
res = await self._wrapped.get_tools(ctx)
# ToolsetTool is not serializable as it holds a SchemaValidator
# (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
return RestateMCPGetToolsContextRunResult(output={name: tool.tool_def for name, tool in res.items()})

options = RunOptions(serde=MCP_GET_TOOLS_SERDE)

context = current_context()
if context is None:
raise UserError(
"A toolset cannot be used without a Restate context. Make sure to run it within an agent or a run context."
)

try:
tool_defs = await context.run_typed("get mcp tools", get_tools_in_context, options)
return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.output.items()}
except SdkInternalBaseException as e:
raise Exception("Internal error during get_tools call") from e

def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
assert isinstance(self.wrapped, MCPServer)
return self.wrapped.tool_for_tool_def(tool_def)

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> ToolResult:
async def call_tool_in_context() -> RestateMCPToolRunResult:
res = await self._wrapped.call_tool(name, tool_args, ctx, tool)
return RestateMCPToolRunResult(output=res)

options = RunOptions(serde=MCP_RUN_SERDE)
context = current_context()
if context is None:
raise UserError(
"A toolset cannot be used without a Restate context. Make sure to run it within an agent or a run context."
)
try:
res = await context.run_typed(f"Calling mcp tool {name}", call_tool_in_context, options)
except SdkInternalBaseException as e:
raise Exception("Internal error during tool call") from e

return res.output
Loading