diff --git a/vlmrun/client/agent.py b/vlmrun/client/agent.py index 56f07cb..3c02411 100644 --- a/vlmrun/client/agent.py +++ b/vlmrun/client/agent.py @@ -3,7 +3,7 @@ from __future__ import annotations import warnings from functools import cached_property -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from pydantic import BaseModel @@ -16,6 +16,7 @@ AgentExecutionConfig, AgentCreationConfig, AgentCreationResponse, + AgentToolset, ) from vlmrun.client.exceptions import DependencyError @@ -169,6 +170,7 @@ def execute( metadata: Optional[RequestMetadata] = None, callback_url: Optional[str] = None, model: str = "vlmrun-orion-1:auto", + toolsets: Optional[List[AgentToolset]] = None, ) -> AgentExecutionResponse: """Execute an agent with the given arguments. @@ -180,6 +182,11 @@ def execute( metadata: Optional request metadata callback_url: Optional URL to call when execution is complete model: VLM Run Agent model to use for execution (default: "vlmrun-orion-1:auto") + toolsets: Optional list of tool categories to enable for this execution. + Available categories: core, image, image-gen, 3d_reconstruction, + viz, document, video, web. + When specified, only tools from these categories will be available. + If None, defaults to 'core' tools only. Returns: AgentExecutionResponse: Agent execution response @@ -203,6 +210,9 @@ def execute( if callback_url: data["callback_url"] = callback_url + if toolsets is not None: + data["toolsets"] = toolsets + response, status_code, headers = self._requestor.request( method="POST", url="agent/execute", diff --git a/vlmrun/client/types.py b/vlmrun/client/types.py index fb32e38..bc869c3 100644 --- a/vlmrun/client/types.py +++ b/vlmrun/client/types.py @@ -12,6 +12,18 @@ JobStatus = Literal["enqueued", "pending", "running", "completed", "failed", "paused"] +# AgentToolset type - tool categories available for agent execution +AgentToolset = Literal[ + "core", + "image", + "image-gen", + "3d_reconstruction", + "viz", + "document", + "video", + "web", +] + @dataclass class APIError(Exception):