From b297a50209f1fcaa1cdda49c0dc3490e3a43cfa2 Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Sat, 4 Oct 2025 23:06:10 -0700 Subject: [PATCH] reuse pydantic example for local model picking --- .../default_pydantic_ai_rollout_processor.py | 10 ++++++- tests/chinook/pydantic/agent.py | 1 + .../pydantic/test_pydantic_complex_queries.py | 26 ++++++++++++++++--- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py b/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py index 47b7b456..3b181c4a 100644 --- a/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py +++ b/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py @@ -46,12 +46,20 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> """Create agent rollout tasks and return them for external handling.""" semaphore = config.semaphore - agent = self._setup_agent(config) + # NOTE: Do not create the agent outside of the semaphore or multiple rows + # will initialize clients and start network calls concurrently. This can + # overwhelm local providers like Ollama where only one request should be + # active at a time. Instead, construct the agent within the semaphore-guarded + # section per row. async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row with agent rollout.""" start_time = time.perf_counter() + # Build the agent lazily inside the semaphore guard to ensure we fully + # respect max_concurrent_rollouts across both setup and run phases. + agent = self._setup_agent(config) + tools = [] for toolset in agent.toolsets: if isinstance(toolset, FunctionToolset): diff --git a/tests/chinook/pydantic/agent.py b/tests/chinook/pydantic/agent.py index 2b260fd4..594967a6 100644 --- a/tests/chinook/pydantic/agent.py +++ b/tests/chinook/pydantic/agent.py @@ -61,6 +61,7 @@ def execute_sql(ctx: RunContext, query: str) -> str: return "\n".join(table_lines) except Exception as e: + print("Show exception: ", e) connection.rollback() raise ModelRetry("Please try again with a different query. Here is the error: " + str(e)) diff --git a/tests/chinook/pydantic/test_pydantic_complex_queries.py b/tests/chinook/pydantic/test_pydantic_complex_queries.py index 583c90df..66b0b093 100644 --- a/tests/chinook/pydantic/test_pydantic_complex_queries.py +++ b/tests/chinook/pydantic/test_pydantic_complex_queries.py @@ -2,6 +2,7 @@ from pydantic import BaseModel from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIChatModel, OpenAIChatModelSettings +from pydantic_ai.providers.openai import OpenAIProvider import pytest from eval_protocol.models import EvaluateResult, EvaluationRow @@ -24,10 +25,23 @@ def agent_factory(config: RolloutProcessorConfig) -> Agent: model_name = config.completion_params["model"] - provider = config.completion_params.get("provider") + provider_param = config.completion_params.get("provider") reasoning = config.completion_params.get("reasoning") - settings = OpenAIChatModelSettings(openai_reasoning_effort=reasoning) - model = OpenAIChatModel(model_name, provider=provider or "openai", settings=settings) + # gpt-4o-mini does not support reasoning + if model_name == "gpt-4o-mini": + settings = OpenAIChatModelSettings() + else: + settings = OpenAIChatModelSettings(openai_reasoning_effort=reasoning) + base_url = config.completion_params.get("base_url") + api_key = config.completion_params.get("api_key") or os.getenv("OPENAI_API_KEY") or "dummy" + if base_url or provider_param == "ollama": + provider = OpenAIProvider( + api_key=api_key, + base_url=base_url or os.getenv("OLLAMA_OPENAI_BASE_URL", "http://localhost:11434/v1"), + ) + else: + provider = provider_param or "openai" + model = OpenAIChatModel(model_name, provider=provider, settings=settings) return setup_agent(model) @@ -51,7 +65,11 @@ def agent_factory(config: RolloutProcessorConfig) -> Agent: # "model": "accounts/fireworks/models/kimi-k2-instruct-0905", # "provider": "fireworks", # }, - {"model": "gpt-5"}, + # {"model": "gpt-4o-mini"}, + {"model": "gpt-5-nano-2025-08-07"}, + # {"model": "qwen3:4b", "provider": "ollama", "base_url": os.getenv("OLLAMA_OPENAI_BASE_URL", "http://localhost:11434/v1")}, + # {"model": "qwen3:8b", "provider": "ollama", "base_url": os.getenv("OLLAMA_OPENAI_BASE_URL", "http://localhost:11434/v1")}, + # {"model": "granite4:micro", "provider": "ollama", "base_url": os.getenv("OLLAMA_OPENAI_BASE_URL", "http://localhost:11434/v1")}, # {"model": "gpt-5", "reasoning": "high"}, ], rollout_processor=PydanticAgentRolloutProcessor(agent_factory),