diff --git a/python/packages/autogen-core/src/autogen_core/logging.py b/python/packages/autogen-core/src/autogen_core/logging.py
index 3f371a6f3bca..2f31dbabe55a 100644
--- a/python/packages/autogen-core/src/autogen_core/logging.py
+++ b/python/packages/autogen-core/src/autogen_core/logging.py
@@ -1,6 +1,6 @@
import json
from enum import Enum
-from typing import Any, Dict, List, cast
+from typing import Any, Dict, List, Optional, cast
from ._agent_id import AgentId
from ._message_handler_context import MessageHandlerContext
@@ -15,6 +15,8 @@ def __init__(
response: Dict[str, Any],
prompt_tokens: int,
completion_tokens: int,
+ latency_ms: Optional[float] = None,
+ tokens_per_second: Optional[float] = None,
**kwargs: Any,
) -> None:
"""To be used by model clients to log the call to the LLM.
@@ -24,6 +26,8 @@ def __init__(
response (Dict[str, Any]): The response of the call. Must be json serializable.
prompt_tokens (int): Number of tokens used in the prompt.
completion_tokens (int): Number of tokens used in the completion.
+ latency_ms (Optional[float]): Total call duration in milliseconds.
+ tokens_per_second (Optional[float]): Completion tokens divided by latency in seconds.
Example:
@@ -45,6 +49,10 @@ def __init__(
self.kwargs["response"] = response
self.kwargs["prompt_tokens"] = prompt_tokens
self.kwargs["completion_tokens"] = completion_tokens
+ if latency_ms is not None:
+ self.kwargs["latency_ms"] = latency_ms
+ if tokens_per_second is not None:
+ self.kwargs["tokens_per_second"] = tokens_per_second
try:
agent_id = MessageHandlerContext.agent_id()
except RuntimeError:
@@ -59,6 +67,14 @@ def prompt_tokens(self) -> int:
def completion_tokens(self) -> int:
return cast(int, self.kwargs["completion_tokens"])
+ @property
+ def latency_ms(self) -> Optional[float]:
+ return self.kwargs.get("latency_ms")
+
+ @property
+ def tokens_per_second(self) -> Optional[float]:
+ return self.kwargs.get("tokens_per_second")
+
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)
@@ -111,6 +127,9 @@ def __init__(
response: Dict[str, Any],
prompt_tokens: int,
completion_tokens: int,
+ latency_ms: Optional[float] = None,
+ tokens_per_second: Optional[float] = None,
+ ttft_ms: Optional[float] = None,
**kwargs: Any,
) -> None:
"""To be used by model clients to log the end of a stream.
@@ -119,6 +138,9 @@ def __init__(
response (Dict[str, Any]): The response of the call. Must be json serializable.
prompt_tokens (int): Number of tokens used in the prompt.
completion_tokens (int): Number of tokens used in the completion.
+ latency_ms (Optional[float]): Total stream duration in milliseconds.
+ tokens_per_second (Optional[float]): Completion tokens divided by latency in seconds.
+ ttft_ms (Optional[float]): Time to first token in milliseconds.
Example:
@@ -138,6 +160,12 @@ def __init__(
self.kwargs["response"] = response
self.kwargs["prompt_tokens"] = prompt_tokens
self.kwargs["completion_tokens"] = completion_tokens
+ if latency_ms is not None:
+ self.kwargs["latency_ms"] = latency_ms
+ if tokens_per_second is not None:
+ self.kwargs["tokens_per_second"] = tokens_per_second
+ if ttft_ms is not None:
+ self.kwargs["ttft_ms"] = ttft_ms
try:
agent_id = MessageHandlerContext.agent_id()
except RuntimeError:
@@ -152,6 +180,18 @@ def prompt_tokens(self) -> int:
def completion_tokens(self) -> int:
return cast(int, self.kwargs["completion_tokens"])
+ @property
+ def latency_ms(self) -> Optional[float]:
+ return self.kwargs.get("latency_ms")
+
+ @property
+ def tokens_per_second(self) -> Optional[float]:
+ return self.kwargs.get("tokens_per_second")
+
+ @property
+ def ttft_ms(self) -> Optional[float]:
+ return self.kwargs.get("ttft_ms")
+
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)
diff --git a/python/packages/autogen-core/tests/test_logging_events.py b/python/packages/autogen-core/tests/test_logging_events.py
new file mode 100644
index 000000000000..3fd56c503861
--- /dev/null
+++ b/python/packages/autogen-core/tests/test_logging_events.py
@@ -0,0 +1,107 @@
+"""Tests for LLMCallEvent and LLMStreamEndEvent timing fields.
+
+These tests verify the performance telemetry fields added in Issue #5790.
+"""
+
+import json
+
+from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent
+
+
+def test_llm_call_event_timing_fields() -> None:
+ """Test that LLMCallEvent correctly stores and serializes timing fields."""
+ event = LLMCallEvent(
+ messages=[{"role": "user", "content": "Hello"}],
+ response={"content": "Hi there!"},
+ prompt_tokens=10,
+ completion_tokens=20,
+ latency_ms=150.5,
+ tokens_per_second=133.33,
+ )
+
+ # Check property accessors
+ assert event.prompt_tokens == 10
+ assert event.completion_tokens == 20
+ assert event.latency_ms == 150.5
+ assert event.tokens_per_second == 133.33
+
+ # Check JSON serialization includes timing fields
+ json_str = str(event)
+ data = json.loads(json_str)
+ assert data["latency_ms"] == 150.5
+ assert data["tokens_per_second"] == 133.33
+
+
+def test_llm_call_event_timing_fields_optional() -> None:
+ """Test that timing fields are optional and not included when not provided."""
+ event = LLMCallEvent(
+ messages=[{"role": "user", "content": "Hello"}],
+ response={"content": "Hi there!"},
+ prompt_tokens=10,
+ completion_tokens=20,
+ )
+
+ # Check that missing fields return None
+ assert event.latency_ms is None
+ assert event.tokens_per_second is None
+
+ # Check JSON serialization excludes missing timing fields
+ json_str = str(event)
+ data = json.loads(json_str)
+ assert "latency_ms" not in data
+ assert "tokens_per_second" not in data
+
+
+def test_llm_stream_end_event_timing_fields() -> None:
+ """Test that LLMStreamEndEvent correctly stores and serializes timing fields including TTFT."""
+ event = LLMStreamEndEvent(
+ response={"content": "Hello, world!"},
+ prompt_tokens=10,
+ completion_tokens=25,
+ latency_ms=200.0,
+ tokens_per_second=125.0,
+ ttft_ms=50.5,
+ )
+
+ # Check property accessors
+ assert event.prompt_tokens == 10
+ assert event.completion_tokens == 25
+ assert event.latency_ms == 200.0
+ assert event.tokens_per_second == 125.0
+ assert event.ttft_ms == 50.5
+
+ # Check JSON serialization includes all timing fields
+ json_str = str(event)
+ data = json.loads(json_str)
+ assert data["latency_ms"] == 200.0
+ assert data["tokens_per_second"] == 125.0
+ assert data["ttft_ms"] == 50.5
+
+
+def test_llm_stream_end_event_timing_fields_optional() -> None:
+ """Test that streaming timing fields are optional."""
+ event = LLMStreamEndEvent(
+ response={"content": "Hello, world!"},
+ prompt_tokens=10,
+ completion_tokens=25,
+ )
+
+ # Check that missing fields return None
+ assert event.latency_ms is None
+ assert event.tokens_per_second is None
+ assert event.ttft_ms is None
+
+ # Check JSON serialization excludes missing timing fields
+ json_str = str(event)
+ data = json.loads(json_str)
+ assert "latency_ms" not in data
+ assert "tokens_per_second" not in data
+ assert "ttft_ms" not in data
+
+
+if __name__ == "__main__":
+ test_llm_call_event_timing_fields()
+ test_llm_call_event_timing_fields_optional()
+ test_llm_stream_end_event_timing_fields()
+ test_llm_stream_end_event_timing_fields_optional()
+ print("All tests passed!")
diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py
index 16d56b57f956..08c19201aca5 100644
--- a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py
+++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py
@@ -1,6 +1,7 @@
import asyncio
import logging
import re
+import time
from asyncio import Task
from inspect import getfullargspec
from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Union, cast
@@ -65,7 +66,9 @@
from .._utils.parse_r1_content import parse_r1_content
create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
-AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage]
+AzureMessage = Union[
+ AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage
+]
logger = logging.getLogger(EVENT_LOGGER_NAME)
@@ -74,7 +77,9 @@ def _is_github_model(endpoint: str) -> bool:
return endpoint == GITHUB_MODELS_ENDPOINT
-def convert_tools(tools: Sequence[Tool | ToolSchema]) -> List[ChatCompletionsToolDefinition]:
+def convert_tools(
+ tools: Sequence[Tool | ToolSchema],
+) -> List[ChatCompletionsToolDefinition]:
result: List[ChatCompletionsToolDefinition] = []
for tool in tools:
if isinstance(tool, Tool):
@@ -125,7 +130,13 @@ def _user_message_to_azure(message: UserMessage) -> AzureUserMessage:
elif isinstance(part, Image):
# TODO: support url based images
# TODO: support specifying details
- parts.append(ImageContentItem(image_url=ImageUrl(url=part.data_uri, detail=ImageDetailLevel.AUTO)))
+ parts.append(
+ ImageContentItem(
+ image_url=ImageUrl(
+ url=part.data_uri, detail=ImageDetailLevel.AUTO
+ )
+ )
+ )
else:
raise ValueError(f"Unknown content type: {message.content}")
return AzureUserMessage(content=parts)
@@ -141,8 +152,13 @@ def _assistant_message_to_azure(message: AssistantMessage) -> AzureAssistantMess
return AzureAssistantMessage(content=message.content)
-def _tool_message_to_azure(message: FunctionExecutionResultMessage) -> Sequence[AzureToolMessage]:
- return [AzureToolMessage(content=x.content, tool_call_id=x.call_id) for x in message.content]
+def _tool_message_to_azure(
+ message: FunctionExecutionResultMessage,
+) -> Sequence[AzureToolMessage]:
+ return [
+ AzureToolMessage(content=x.content, tool_call_id=x.call_id)
+ for x in message.content
+ ]
def to_azure_message(message: LLMMessage) -> Sequence[AzureMessage]:
@@ -172,7 +188,9 @@ def assert_valid_name(name: str) -> str:
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
"""
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
- raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
+ raise ValueError(
+ f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed."
+ )
if len(name) > 64:
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
return name
@@ -306,11 +324,15 @@ def _validate_config(config: Dict[str, Any]) -> AzureAIChatCompletionClientConfi
raise ValueError("model_info is required for AzureAIChatCompletionClient")
validate_model_info(config["model_info"])
if _is_github_model(config["endpoint"]) and "model" not in config:
- raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
+ raise ValueError(
+ "model is required for when using a Github model with AzureAIChatCompletionClient"
+ )
return cast(AzureAIChatCompletionClientConfig, config)
@staticmethod
- def _create_client(config: AzureAIChatCompletionClientConfig) -> ChatCompletionsClient:
+ def _create_client(
+ config: AzureAIChatCompletionClientConfig,
+ ) -> ChatCompletionsClient:
# Only pass the parameters that ChatCompletionsClient accepts
# Remove 'model_info' and other client-specific parameters
client_config = {k: v for k, v in config.items() if k not in ("model_info",)}
@@ -337,8 +359,12 @@ def _validate_model_info(
if self.model_info["vision"] is False:
for message in messages:
if isinstance(message, UserMessage):
- if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
- raise ValueError("Model does not support vision and image was provided")
+ if isinstance(message.content, list) and any(
+ isinstance(x, Image) for x in message.content
+ ):
+ raise ValueError(
+ "Model does not support vision and image was provided"
+ )
if json_output is not None:
if self.model_info["json_output"] is False and json_output is True:
@@ -346,7 +372,9 @@ def _validate_model_info(
if isinstance(json_output, type):
# TODO: we should support this in the future.
- raise ValueError("Structured output is not currently supported for AzureAIChatCompletionClient")
+ raise ValueError(
+ "Structured output is not currently supported for AzureAIChatCompletionClient"
+ )
if json_output is True and "response_format" not in create_args:
create_args["response_format"] = "json_object"
@@ -368,7 +396,9 @@ async def create(
) -> CreateResult:
extra_create_args_keys = set(extra_create_args.keys())
if not create_kwargs.issuperset(extra_create_args_keys):
- raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
+ raise ValueError(
+ f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}"
+ )
# Copy the create args and overwrite anything in extra_create_args
create_args = self._create_args.copy()
@@ -380,11 +410,14 @@ async def create(
azure_messages = [item for sublist in azure_messages_nested for item in sublist]
task: Task[ChatCompletions]
+ start_time = time.perf_counter()
if len(tools) > 0:
if isinstance(tool_choice, Tool):
create_args["tool_choice"] = ChatCompletionsNamedToolChoice(
- function=ChatCompletionsNamedToolChoiceFunction(name=tool_choice.name)
+ function=ChatCompletionsNamedToolChoiceFunction(
+ name=tool_choice.name
+ )
)
else:
create_args["tool_choice"] = tool_choice
@@ -404,18 +437,30 @@ async def create(
cancellation_token.link_future(task)
result: ChatCompletions = await task
+ end_time = time.perf_counter()
usage = RequestUsage(
prompt_tokens=result.usage.prompt_tokens if result.usage else 0,
completion_tokens=result.usage.completion_tokens if result.usage else 0,
)
+ # Calculate performance metrics
+ latency_seconds = end_time - start_time
+ latency_ms = latency_seconds * 1000
+ tokens_per_second = (
+ usage.completion_tokens / latency_seconds
+ if latency_seconds > 0 and usage.completion_tokens > 0
+ else None
+ )
+
logger.info(
LLMCallEvent(
messages=[m.as_dict() for m in azure_messages],
response=result.as_dict(),
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
+ latency_ms=latency_ms,
+ tokens_per_second=tokens_per_second,
)
)
@@ -470,7 +515,9 @@ async def create_stream(
) -> AsyncGenerator[Union[str, CreateResult], None]:
extra_create_args_keys = set(extra_create_args.keys())
if not create_kwargs.issuperset(extra_create_args_keys):
- raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
+ raise ValueError(
+ f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}"
+ )
create_args: Dict[str, Any] = self._create_args.copy()
create_args.update(extra_create_args)
@@ -484,16 +531,27 @@ async def create_stream(
if len(tools) > 0:
if isinstance(tool_choice, Tool):
create_args["tool_choice"] = ChatCompletionsNamedToolChoice(
- function=ChatCompletionsNamedToolChoiceFunction(name=tool_choice.name)
+ function=ChatCompletionsNamedToolChoiceFunction(
+ name=tool_choice.name
+ )
)
else:
create_args["tool_choice"] = tool_choice
converted_tools = convert_tools(tools)
task = asyncio.create_task(
- self._client.complete(messages=azure_messages, tools=converted_tools, stream=True, **create_args)
+ self._client.complete(
+ messages=azure_messages,
+ tools=converted_tools,
+ stream=True,
+ **create_args,
+ )
)
else:
- task = asyncio.create_task(self._client.complete(messages=azure_messages, stream=True, **create_args))
+ task = asyncio.create_task(
+ self._client.complete(
+ messages=azure_messages, stream=True, **create_args
+ )
+ )
if cancellation_token is not None:
cancellation_token.link_future(task)
@@ -509,6 +567,10 @@ async def create_stream(
first_chunk = True
thought = None
+ # Performance timing variables
+ start_time = time.perf_counter()
+ first_token_time: Optional[float] = None
+
async for chunk in await task: # type: ignore
if first_chunk:
first_chunk = False
@@ -527,17 +589,29 @@ async def create_stream(
if choice.finish_reason is CompletionsFinishReason.TOOL_CALLS:
finish_reason = "function_calls"
else:
- if choice.finish_reason in ["stop", "length", "function_calls", "content_filter", "unknown"]:
+ if choice.finish_reason in [
+ "stop",
+ "length",
+ "function_calls",
+ "content_filter",
+ "unknown",
+ ]:
finish_reason = choice.finish_reason # type: ignore
else:
- raise ValueError(f"Unexpected finish reason: {choice.finish_reason}")
+ raise ValueError(
+ f"Unexpected finish reason: {choice.finish_reason}"
+ )
# We first try to load the content
if choice and choice.delta.content is not None:
+ if first_token_time is None:
+ first_token_time = time.perf_counter()
content_deltas.append(choice.delta.content)
yield choice.delta.content
# Otherwise, we try to load the tool calls
if choice and choice.delta.tool_calls is not None:
+ if first_token_time is None:
+ first_token_time = time.perf_counter()
for tool_call_chunk in choice.delta.tool_calls:
# print(tool_call_chunk)
if "index" in tool_call_chunk:
@@ -545,7 +619,9 @@ async def create_stream(
else:
idx = tool_call_chunk.id
if idx not in full_tool_calls:
- full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
+ full_tool_calls[idx] = FunctionCall(
+ id="", arguments="", name=""
+ )
full_tool_calls[idx].id += tool_call_chunk.id
full_tool_calls[idx].name += tool_call_chunk.function.name
@@ -587,12 +663,30 @@ async def create_stream(
thought=thought,
)
+ # Calculate performance metrics
+ end_time = time.perf_counter()
+ latency_seconds = end_time - start_time
+ latency_ms = latency_seconds * 1000
+ ttft_ms = (
+ (first_token_time - start_time) * 1000
+ if first_token_time is not None
+ else None
+ )
+ tokens_per_second = (
+ usage.completion_tokens / latency_seconds
+ if latency_seconds > 0 and usage.completion_tokens > 0
+ else None
+ )
+
# Log the end of the stream.
logger.info(
LLMStreamEndEvent(
response=result.model_dump(),
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
+ latency_ms=latency_ms,
+ tokens_per_second=tokens_per_second,
+ ttft_ms=ttft_ms,
)
)
@@ -609,10 +703,14 @@ def actual_usage(self) -> RequestUsage:
def total_usage(self) -> RequestUsage:
return self._total_usage
- def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
+ def count_tokens(
+ self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []
+ ) -> int:
return 0
- def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
+ def remaining_tokens(
+ self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []
+ ) -> int:
return 0
@property
diff --git a/python/packages/autogen-ext/src/autogen_ext/models/nvidia/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/__init__.py
new file mode 100644
index 000000000000..7b07868e3c73
--- /dev/null
+++ b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""NVIDIA NIM Speculative Reasoning Execution (SRE) Client for AutoGen 0.4.
+
+This package provides a ChatCompletionClient that bridges NVIDIA NIM
+high-performance inference with Microsoft AutoGen orchestration, enabling
+parallel tool execution during LLM reasoning to reduce "Time to Action" latency.
+"""
+
+from ._nvidia_speculative_client import NvidiaSpeculativeClient
+from ._reasoning_sniffer import ReasoningSniffer, ToolIntent
+from ._speculative_cache import SpeculativeCache
+
+__all__ = [
+ "NvidiaSpeculativeClient",
+ "ReasoningSniffer",
+ "ToolIntent",
+ "SpeculativeCache",
+]
diff --git a/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_nvidia_speculative_client.py b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_nvidia_speculative_client.py
new file mode 100644
index 000000000000..a66f331991fd
--- /dev/null
+++ b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_nvidia_speculative_client.py
@@ -0,0 +1,480 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""NVIDIA NIM Speculative Reasoning Execution Client for AutoGen 0.4.
+
+This module provides the main ChatCompletionClient implementation that bridges
+NVIDIA NIM inference with AutoGen orchestration. It enables parallel tool
+execution during LLM reasoning to dramatically reduce "Time to Action" latency.
+
+Key Features:
+- Streams DeepSeek-R1 style ... reasoning blocks
+- Uses ReasoningSniffer to detect tool intents during reasoning
+- Fires speculative prewarm tasks via asyncio.create_task (non-blocking)
+- Logs high-precision latency metrics for benchmarking
+
+The "Shark" Architecture:
+ Model thinks → Sniffer detects intent → Prewarm fires (background)
+ → Model keeps thinking
+ → Tool actually called → Cache HIT
+"""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import time
+import warnings
+from dataclasses import dataclass, field
+from typing import (
+ Any,
+ AsyncGenerator,
+ Awaitable,
+ Callable,
+ Dict,
+ List,
+ Literal,
+ Mapping,
+ Optional,
+ Sequence,
+ Union,
+)
+
+from autogen_core import CancellationToken
+from autogen_core.models import (
+ AssistantMessage,
+ ChatCompletionClient,
+ CreateResult,
+ LLMMessage,
+ ModelInfo,
+ RequestUsage,
+)
+from autogen_core.models._model_client import ModelFamily
+from autogen_core.tools import Tool, ToolSchema
+
+from ._reasoning_sniffer import ReasoningSniffer, ToolIntent
+from ._speculative_cache import SpeculativeCache
+
+# Event logger for LLM events
+EVENT_LOGGER_NAME = "autogen_core.events"
+logger = logging.getLogger(EVENT_LOGGER_NAME)
+trace_logger = logging.getLogger(__name__)
+
+
+@dataclass
+class SpeculativePrewarmEvent:
+ """Event emitted when the ReasoningSniffer detects a tool intent.
+
+ This event is logged to the event bus so that orchestrators can
+ trigger speculative tool execution.
+ """
+
+ tool_type: str
+ """The type/name of the tool detected."""
+
+ query_hint: str
+ """Extracted query or argument hint from the reasoning."""
+
+ confidence: float
+ """Confidence score (0.0 to 1.0)."""
+
+ detected_at_ms: float
+ """Timestamp when the intent was detected."""
+
+ reasoning_context: str
+ """The reasoning text that triggered detection."""
+
+ def __str__(self) -> str:
+ return (
+ f"SpeculativePrewarm: tool={self.tool_type}, "
+ f"hint='{self.query_hint[:50]}...', confidence={self.confidence:.2f}"
+ )
+
+
+@dataclass
+class SpeculativeHitEvent:
+ """Event emitted when a cached speculative result is used."""
+
+ tool_name: str
+ """Name of the tool that had a cache hit."""
+
+ latency_saved_ms: float
+ """Estimated latency saved by using the cached result."""
+
+ def __str__(self) -> str:
+ return (
+ f"SpeculativeHit: tool={self.tool_name}, "
+ f"latency_saved={self.latency_saved_ms:.1f}ms"
+ )
+
+
+@dataclass
+class PerformanceMetrics:
+ """Performance tracking for the speculative execution pipeline."""
+
+ stream_start_time: float = 0.0
+ first_token_time: Optional[float] = None
+ reasoning_start_time: Optional[float] = None
+ reasoning_end_time: Optional[float] = None
+ first_intent_detected_time: Optional[float] = None
+ tool_call_time: Optional[float] = None
+ stream_end_time: float = 0.0
+
+ intents_detected: int = 0
+ prewarms_triggered: int = 0
+
+ @property
+ def ttft_ms(self) -> Optional[float]:
+ """Time to first token in milliseconds."""
+ if self.first_token_time is None:
+ return None
+ return (self.first_token_time - self.stream_start_time) * 1000
+
+ @property
+ def reasoning_duration_ms(self) -> Optional[float]:
+ """Total reasoning duration in milliseconds."""
+ if self.reasoning_start_time is None or self.reasoning_end_time is None:
+ return None
+ return (self.reasoning_end_time - self.reasoning_start_time) * 1000
+
+ @property
+ def speculative_delta_ms(self) -> Optional[float]:
+ """The 'speculative delta' - how early we detected intent vs tool call.
+
+ Negative = we detected before the model called (good!)
+ Positive = we missed the window (speculation opportunity lost)
+ """
+ if self.first_intent_detected_time is None or self.tool_call_time is None:
+ return None
+ return (self.tool_call_time - self.first_intent_detected_time) * 1000
+
+ def summary(self) -> Dict[str, Any]:
+ """Get a summary of all performance metrics."""
+ return {
+ "ttft_ms": self.ttft_ms,
+ "reasoning_duration_ms": self.reasoning_duration_ms,
+ "speculative_delta_ms": self.speculative_delta_ms,
+ "intents_detected": self.intents_detected,
+ "prewarms_triggered": self.prewarms_triggered,
+ "total_duration_ms": (self.stream_end_time - self.stream_start_time) * 1000,
+ }
+
+
+# Type alias for the prewarm callback
+PrewarmCallback = Callable[[str, str, Dict[str, Any]], Awaitable[Any]]
+
+
+class NvidiaSpeculativeClient(ChatCompletionClient):
+ """NVIDIA NIM-compatible ChatCompletionClient with Speculative Reasoning Execution.
+
+ This client wraps an existing OpenAI-compatible client (which connects to
+ NVIDIA NIM, vLLM, or any OpenAI-compatible endpoint) and adds speculative
+ execution capabilities.
+
+ The key innovation is that during the model's reasoning phase ( block),
+ the ReasoningSniffer monitors for tool-call intents. When detected, it triggers
+ a prewarm callback to speculatively execute the tool in the background.
+
+ Args:
+ inner_client: An existing ChatCompletionClient (e.g., OpenAIChatCompletionClient
+ configured for NIM endpoint).
+ sniffer: Optional custom ReasoningSniffer. Uses default patterns if not provided.
+ prewarm_callback: Async function called when tool intent is detected.
+ Signature: async def callback(tool_type: str, query_hint: str, context: dict) -> Any
+ enable_speculation: Whether to enable speculative execution (default: True).
+ min_confidence: Minimum confidence threshold to trigger prewarm (default: 0.7).
+
+ Example:
+ >>> from autogen_ext.models.openai import OpenAIChatCompletionClient
+ >>> from autogen_ext.models.nvidia import NvidiaSpeculativeClient
+ >>>
+ >>> # Create inner client pointing to NIM
+ >>> inner = OpenAIChatCompletionClient(
+ ... model="deepseek-r1",
+ ... base_url="http://wulver:8000/v1",
+ ... api_key="token"
+ ... )
+ >>>
+ >>> # Wrap with speculative execution
+ >>> client = NvidiaSpeculativeClient(
+ ... inner_client=inner,
+ ... prewarm_callback=my_prewarm_function
+ ... )
+ """
+
+ component_type = "model"
+ component_config_schema = None # TODO: Add proper config schema
+
+ def __init__(
+ self,
+ inner_client: ChatCompletionClient,
+ *,
+ sniffer: Optional[ReasoningSniffer] = None,
+ prewarm_callback: Optional[PrewarmCallback] = None,
+ enable_speculation: bool = True,
+ min_confidence: float = 0.7,
+ sniff_all_content: bool = False,
+ ) -> None:
+ self._inner_client = inner_client
+ self._sniffer = sniffer or ReasoningSniffer()
+ self._prewarm_callback = prewarm_callback
+ self._enable_speculation = enable_speculation
+ self._min_confidence = min_confidence
+ self._sniff_all_content = sniff_all_content # For distilled models without tags
+ self._cache = SpeculativeCache.get_instance()
+
+ # Track running prewarm tasks
+ self._prewarm_tasks: List[asyncio.Task[Any]] = []
+
+ # DEDUPLICATION: Track triggered intents to avoid firing 78 times
+ self._triggered_intents: set[str] = set()
+
+ # Performance tracking
+ self._last_metrics: Optional[PerformanceMetrics] = None
+
+ @property
+ def model_info(self) -> ModelInfo:
+ """Get model info from the inner client."""
+ return self._inner_client.model_info
+
+ @property
+ def capabilities(self) -> Any:
+ """Deprecated. Use model_info instead."""
+ warnings.warn(
+ "capabilities is deprecated, use model_info instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self._inner_client.capabilities
+
+ def actual_usage(self) -> RequestUsage:
+ """Get actual token usage."""
+ return self._inner_client.actual_usage()
+
+ def total_usage(self) -> RequestUsage:
+ """Get total token usage."""
+ return self._inner_client.total_usage()
+
+ def count_tokens(
+ self,
+ messages: Sequence[LLMMessage],
+ *,
+ tools: Sequence[Tool | ToolSchema] = [],
+ ) -> int:
+ """Count tokens for the given messages."""
+ return self._inner_client.count_tokens(messages, tools=tools)
+
+ def remaining_tokens(
+ self,
+ messages: Sequence[LLMMessage],
+ *,
+ tools: Sequence[Tool | ToolSchema] = [],
+ ) -> int:
+ """Get remaining tokens for the model's context."""
+ return self._inner_client.remaining_tokens(messages, tools=tools)
+
+ async def close(self) -> None:
+ """Close the client and cancel any pending prewarm tasks."""
+ # Cancel any running prewarm tasks
+ for task in self._prewarm_tasks:
+ if not task.done():
+ task.cancel()
+ self._prewarm_tasks.clear()
+
+ await self._inner_client.close()
+
+ async def create(
+ self,
+ messages: Sequence[LLMMessage],
+ *,
+ tools: Sequence[Tool | ToolSchema] = [],
+ tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
+ json_output: Optional[bool | type] = None,
+ extra_create_args: Mapping[str, Any] = {},
+ cancellation_token: Optional[CancellationToken] = None,
+ ) -> CreateResult:
+ """Create a non-streaming response.
+
+ Note: Speculative execution is most effective with streaming.
+ This method delegates directly to the inner client.
+ """
+ return await self._inner_client.create(
+ messages,
+ tools=tools,
+ tool_choice=tool_choice,
+ json_output=json_output,
+ extra_create_args=extra_create_args,
+ cancellation_token=cancellation_token,
+ )
+
+ async def create_stream(
+ self,
+ messages: Sequence[LLMMessage],
+ *,
+ tools: Sequence[Tool | ToolSchema] = [],
+ tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
+ json_output: Optional[bool | type] = None,
+ extra_create_args: Mapping[str, Any] = {},
+ cancellation_token: Optional[CancellationToken] = None,
+ ) -> AsyncGenerator[Union[str, CreateResult], None]:
+ """Create a streaming response with speculative reasoning execution.
+
+ This is the core method that implements the SRE pattern:
+ 1. Stream tokens from the inner client
+ 2. Monitor reasoning content with the sniffer
+ 3. Fire prewarm tasks when intents are detected (non-blocking)
+ 4. Continue streaming without interruption
+
+ Yields:
+ String chunks during streaming, ending with a CreateResult.
+ """
+ # Initialize performance metrics
+ metrics = PerformanceMetrics()
+ metrics.stream_start_time = time.perf_counter()
+
+ # Reset sniffer and triggered intents for fresh session
+ self._sniffer.reset()
+ self._triggered_intents.clear()
+
+ # Track reasoning state
+ is_in_reasoning = False
+ reasoning_buffer = ""
+
+ # Get the inner stream
+ inner_stream = self._inner_client.create_stream(
+ messages,
+ tools=tools,
+ tool_choice=tool_choice,
+ json_output=json_output,
+ extra_create_args=extra_create_args,
+ cancellation_token=cancellation_token,
+ )
+
+ async for chunk in inner_stream:
+ # Track first token time
+ if metrics.first_token_time is None:
+ metrics.first_token_time = time.perf_counter()
+
+ # If this is the final CreateResult, process it
+ if isinstance(chunk, CreateResult):
+ metrics.stream_end_time = time.perf_counter()
+
+ # Log performance summary
+ self._last_metrics = metrics
+ trace_logger.info(f"SpeculativeClient metrics: {metrics.summary()}")
+
+ # Check if there were tool calls - record timing
+ if chunk.content and isinstance(chunk.content, list):
+ metrics.tool_call_time = time.perf_counter()
+
+ yield chunk
+ return
+
+ # Process string chunks
+ chunk_str = str(chunk)
+
+ # Detect reasoning block boundaries
+ if "" in chunk_str:
+ is_in_reasoning = True
+ metrics.reasoning_start_time = time.perf_counter()
+
+ if "" in chunk_str:
+ is_in_reasoning = False
+ metrics.reasoning_end_time = time.perf_counter()
+ reasoning_buffer = ""
+
+ # Run sniffer on content
+ # For distilled models without tags, sniff all content
+ # For full R1 models, only sniff inside blocks
+ should_sniff = self._enable_speculation and (
+ self._sniff_all_content or is_in_reasoning
+ )
+
+ if should_sniff:
+ reasoning_buffer += chunk_str
+
+ # Sniff for tool intents
+ intent = self._sniffer.sniff(chunk_str)
+
+ if intent and intent.confidence >= self._min_confidence:
+ metrics.intents_detected += 1
+
+ if metrics.first_intent_detected_time is None:
+ metrics.first_intent_detected_time = time.perf_counter()
+
+ # DEDUPLICATION: Only fire once per tool type per session
+ # Use tool_type as the dedup key (could also include query_hint for finer granularity)
+ dedup_key = intent.tool_type
+
+ if dedup_key not in self._triggered_intents:
+ self._triggered_intents.add(dedup_key)
+
+ # Log the prewarm event
+ event = SpeculativePrewarmEvent(
+ tool_type=intent.tool_type,
+ query_hint=intent.query_hint,
+ confidence=intent.confidence,
+ detected_at_ms=intent.detected_at_ms,
+ reasoning_context=intent.reasoning_context,
+ )
+ logger.info(event)
+
+ # SHARK MOVE: Fire and forget the prewarm task
+ if self._prewarm_callback is not None:
+ task = asyncio.create_task(self._trigger_prewarm(intent))
+ self._prewarm_tasks.append(task)
+ metrics.prewarms_triggered += 1
+
+ yield chunk
+
+ async def _trigger_prewarm(self, intent: ToolIntent) -> None:
+ """Trigger speculative tool prewarm in the background.
+
+ This runs as a fire-and-forget task. If the speculation is wrong,
+ the result is simply not used. If it's right, the cache will have
+ the result ready.
+ """
+ if self._prewarm_callback is None:
+ return
+
+ try:
+ trace_logger.debug(
+ f"Triggering prewarm for {intent.tool_type}: {intent.query_hint}"
+ )
+
+ context = {
+ "confidence": intent.confidence,
+ "reasoning_context": intent.reasoning_context,
+ }
+
+ result = await self._prewarm_callback(
+ intent.tool_type,
+ intent.query_hint,
+ context,
+ )
+
+ # Store result in cache
+ if result is not None:
+ self._cache.store(
+ tool_name=intent.tool_type,
+ args={"query_hint": intent.query_hint},
+ result=result,
+ ttl=30.0, # 30 second TTL
+ )
+ trace_logger.info(
+ f"Prewarm complete for {intent.tool_type}, result cached"
+ )
+
+ except Exception as e:
+ # Speculation failure is silent - don't crash the main stream
+ trace_logger.warning(f"Prewarm failed for {intent.tool_type}: {e}")
+
+ @property
+ def last_metrics(self) -> Optional[PerformanceMetrics]:
+ """Get the performance metrics from the last create_stream call."""
+ return self._last_metrics
+
+ @property
+ def cache(self) -> SpeculativeCache:
+ """Access the speculative cache for inspection or manual operations."""
+ return self._cache
diff --git a/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_reasoning_sniffer.py b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_reasoning_sniffer.py
new file mode 100644
index 000000000000..67768f1d1ad5
--- /dev/null
+++ b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_reasoning_sniffer.py
@@ -0,0 +1,172 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""High-speed regex-based Intent Sniffer for Speculative Reasoning Execution.
+
+This module provides zero-latency heuristic detection of tool-call intents
+within streaming reasoning content from DeepSeek-R1 and similar models.
+The sniffer runs on every chunk without blocking the stream.
+"""
+
+import re
+import time
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Pattern
+
+
+@dataclass
+class ToolIntent:
+ """Represents a detected tool-call intent from reasoning content."""
+
+ tool_type: str
+ """The type/name of the tool detected (e.g., 'web_search', 'database_query')."""
+
+ query_hint: str
+ """Extracted query or argument hint from the reasoning."""
+
+ confidence: float
+ """Confidence score (0.0 to 1.0) based on pattern match strength."""
+
+ detected_at_ms: float
+ """Timestamp when the intent was detected (perf_counter * 1000)."""
+
+ reasoning_context: str
+ """The chunk of reasoning text that triggered the detection."""
+
+
+@dataclass
+class ReasoningSniffer:
+ """Detects tool-call intents in streaming reasoning content.
+
+ Uses high-speed regex patterns to identify when the model is planning
+ to use a specific tool. Designed to run on every streaming chunk
+ without introducing latency.
+
+ Example:
+ >>> sniffer = ReasoningSniffer()
+ >>> intent = sniffer.sniff("I will search for Python documentation")
+ >>> if intent:
+ ... print(f"Detected {intent.tool_type}: {intent.query_hint}")
+ Detected web_search: Python documentation
+ """
+
+ # High-signal patterns for DeepSeek-R1's "Action Intent"
+ # These are tuned for common reasoning patterns in R1-style models
+ # Enhanced to catch contractions (I'll, I'd) and more natural phrasings
+ PATTERNS: Dict[str, str] = field(
+ default_factory=lambda: {
+ # Web search intents - extended with contractions and more verbs
+ "web_search": r"(?:I(?:'ll| will| need to| should| must| have to)|Let me|I'd like to|need to)?\s*(?:search|look up|find|query|google|browse|look for|check online|research|find out|look into)\s+(?:for\s+|about\s+|the\s+|on\s+)?(?:information\s+(?:about|on)\s+)?['\"]?(.+?)(?:['\"]|\\.|,|;|$)",
+ # Database query intents
+ "database_query": r"(?:I(?:'ll| will| need to| should)|Let me)\s*(?:query|check|access|look up in|fetch from|retrieve from)\s+(?:the\s+)?(?:database|db|table|records?)\s+(?:for\s+)?['\"]?(.+?)(?:['\"]|\\.|,|$)",
+ # Calculation intents
+ "calculate": r"(?:I(?:'ll| will| need to| should)|Let me)\s*(?:calculate|compute|evaluate|work out|figure out|determine)\s+(.+?)(?:\\.|,|$)",
+ # API call intents
+ "api_call": r"(?:I(?:'ll| will| need to| should)|Let me)\s*(?:call|invoke|use|hit|query|fetch from)\s+(?:the\s+)?(?:API|endpoint|service|REST)\s+(?:for\s+)?['\"]?(.+?)(?:['\"]|\\.|,|$)",
+ # File/document lookup intents
+ "file_lookup": r"(?:I(?:'ll| will| need to| should)|Let me)\s*(?:read|open|check|look at|load|parse)\s+(?:the\s+)?(?:file|document|config|data)\s+['\"]?(.+?)(?:['\"]|\\.|,|$)",
+ # Generic tool usage
+ "tool_use": r"(?:I(?:'ll| will| need to| should)|Let me)\s*(?:use|invoke|call|execute)\s+(?:the\s+)?[`'\"]?(\w+)[`'\"]?\s+(?:tool|function)",
+ }
+ )
+
+ # Compiled patterns for performance
+ _compiled_patterns: Dict[str, Pattern[str]] = field(default_factory=dict)
+
+ # Buffer for accumulating reasoning context
+ _context_buffer: List[str] = field(default_factory=list)
+ _max_buffer_size: int = 500 # Characters to keep for context
+
+ def __post_init__(self) -> None:
+ """Compile regex patterns for maximum performance."""
+ self._compiled_patterns = {
+ tool_type: re.compile(pattern, re.IGNORECASE | re.DOTALL)
+ for tool_type, pattern in self.PATTERNS.items()
+ }
+
+ def add_pattern(self, tool_type: str, pattern: str) -> None:
+ """Add a custom pattern for detecting a specific tool type.
+
+ Args:
+ tool_type: The name/type of the tool to detect.
+ pattern: Regex pattern with a capture group for the query hint.
+ """
+ self.PATTERNS[tool_type] = pattern
+ self._compiled_patterns[tool_type] = re.compile(
+ pattern, re.IGNORECASE | re.DOTALL
+ )
+
+ def sniff(self, text: str) -> Optional[ToolIntent]:
+ """Analyze a chunk of reasoning text for tool-call intents.
+
+ This method is designed to be called on every streaming chunk.
+ It maintains an internal buffer to provide context for pattern matching.
+
+ Args:
+ text: A chunk of reasoning text from the model's thought stream.
+
+ Returns:
+ ToolIntent if an intent was detected, None otherwise.
+ """
+ if not text or not text.strip():
+ return None
+
+ # Add to context buffer
+ self._context_buffer.append(text)
+
+ # Keep buffer size manageable
+ full_context = "".join(self._context_buffer)
+ if len(full_context) > self._max_buffer_size:
+ # Trim from the beginning
+ excess = len(full_context) - self._max_buffer_size
+ self._context_buffer = [full_context[excess:]]
+
+ # Scan the current chunk and recent context
+ search_text = full_context[-self._max_buffer_size :]
+
+ for tool_type, compiled_pattern in self._compiled_patterns.items():
+ match = compiled_pattern.search(search_text)
+ if match:
+ query_hint = match.group(1).strip() if match.lastindex else ""
+ # Calculate confidence based on match quality
+ confidence = self._calculate_confidence(match, search_text)
+
+ return ToolIntent(
+ tool_type=tool_type,
+ query_hint=query_hint,
+ confidence=confidence,
+ detected_at_ms=time.perf_counter() * 1000,
+ reasoning_context=search_text[-200:], # Last 200 chars for context
+ )
+
+ return None
+
+ def _calculate_confidence(self, match: re.Match[str], text: str) -> float:
+ """Calculate confidence score based on match characteristics.
+
+ Args:
+ match: The regex match object.
+ text: The full text being searched.
+
+ Returns:
+ Confidence score between 0.0 and 1.0.
+ """
+ confidence = 0.7 # Base confidence for any match
+
+ # Boost if match is recent (in the last 100 chars)
+ if match.end() > len(text) - 100:
+ confidence += 0.15
+
+ # Boost if query hint is substantial
+ if match.lastindex and len(match.group(1).strip()) > 5:
+ confidence += 0.1
+
+ # Boost if explicit tool mention
+ if "tool" in text.lower() or "function" in text.lower():
+ confidence += 0.05
+
+ return min(confidence, 1.0)
+
+ def reset(self) -> None:
+ """Reset the context buffer. Call this between separate reasoning sessions."""
+ self._context_buffer.clear()
diff --git a/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_speculative_cache.py b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_speculative_cache.py
new file mode 100644
index 000000000000..0694f5c33c5b
--- /dev/null
+++ b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_speculative_cache.py
@@ -0,0 +1,263 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""Thread-safe Speculative Cache for pre-executed tool results.
+
+This module provides a singleton cache that stores results from speculatively
+executed tools. When the model formally requests a tool, the cache is checked
+first to provide near-instantaneous results.
+"""
+
+import asyncio
+import hashlib
+import json
+import logging
+import time
+from dataclasses import dataclass, field
+from typing import Any, Dict, Optional
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class CacheEntry:
+ """A single cached result with metadata."""
+
+ tool_name: str
+ args_hash: str
+ result: Any
+ created_at: float
+ ttl: float
+ hit_count: int = 0
+
+ def is_expired(self) -> bool:
+ """Check if this entry has expired."""
+ return time.time() - self.created_at > self.ttl
+
+
+class SpeculativeCache:
+ """Thread-safe singleton cache for speculative tool execution results.
+
+ This cache stores pre-executed tool results that were triggered by the
+ ReasoningSniffer during the model's reasoning phase. When a tool is
+ formally called, the cache is checked first.
+
+ The cache uses argument hashing to match pre-warmed results with actual
+ tool calls, allowing for partial matches when full arguments aren't known
+ during speculation.
+
+ Example:
+ >>> cache = SpeculativeCache.get_instance()
+ >>> cache.store("web_search", {"query": "python docs"}, "search results")
+ >>> result = cache.get("web_search", {"query": "python docs"})
+ >>> print(result)
+ search results
+ """
+
+ _instance: Optional["SpeculativeCache"] = None
+ _lock: asyncio.Lock = asyncio.Lock()
+
+ def __new__(cls) -> "SpeculativeCache":
+ """Ensure singleton pattern."""
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ cls._instance._cache: Dict[str, CacheEntry] = {}
+ cls._instance._stats = {
+ "hits": 0,
+ "misses": 0,
+ "stores": 0,
+ "evictions": 0,
+ }
+ return cls._instance
+
+ @classmethod
+ def get_instance(cls) -> "SpeculativeCache":
+ """Get the singleton cache instance."""
+ return cls()
+
+ @classmethod
+ def reset_instance(cls) -> None:
+ """Reset the singleton instance. Mainly for testing."""
+ cls._instance = None
+
+ @staticmethod
+ def _hash_args(args: Dict[str, Any]) -> str:
+ """Create a consistent hash for tool arguments.
+
+ Args:
+ args: Dictionary of tool arguments.
+
+ Returns:
+ A hex digest hash of the arguments.
+ """
+ # Sort keys for consistent ordering
+ serialized = json.dumps(args, sort_keys=True, default=str)
+ return hashlib.md5(serialized.encode()).hexdigest()
+
+ def store(
+ self,
+ tool_name: str,
+ args: Dict[str, Any],
+ result: Any,
+ ttl: float = 30.0,
+ ) -> None:
+ """Store a speculatively executed tool result.
+
+ Args:
+ tool_name: Name of the tool that was executed.
+ args: Arguments used for execution.
+ result: The result of the tool execution.
+ ttl: Time-to-live in seconds (default: 30s).
+ """
+ args_hash = self._hash_args(args)
+ cache_key = f"{tool_name}:{args_hash}"
+
+ entry = CacheEntry(
+ tool_name=tool_name,
+ args_hash=args_hash,
+ result=result,
+ created_at=time.time(),
+ ttl=ttl,
+ )
+
+ self._cache[cache_key] = entry
+ self._stats["stores"] += 1
+
+ logger.debug(
+ f"SpeculativeCache: Stored {tool_name} (hash={args_hash[:8]}..., ttl={ttl}s)"
+ )
+
+ def get(self, tool_name: str, args: Dict[str, Any]) -> Optional[Any]:
+ """Retrieve a cached result if available and not expired.
+
+ Args:
+ tool_name: Name of the tool being called.
+ args: Arguments for the tool call.
+
+ Returns:
+ The cached result if found and valid, None otherwise.
+ """
+ args_hash = self._hash_args(args)
+ cache_key = f"{tool_name}:{args_hash}"
+
+ entry = self._cache.get(cache_key)
+
+ if entry is None:
+ self._stats["misses"] += 1
+ return None
+
+ if entry.is_expired():
+ # Clean up expired entry
+ del self._cache[cache_key]
+ self._stats["evictions"] += 1
+ self._stats["misses"] += 1
+ logger.debug(f"SpeculativeCache: Expired entry for {tool_name}")
+ return None
+
+ # Cache hit!
+ entry.hit_count += 1
+ self._stats["hits"] += 1
+ logger.info(
+ f"SpeculativeCache: HIT for {tool_name} (latency saved, hit #{entry.hit_count})"
+ )
+ return entry.result
+
+ def get_fuzzy(self, tool_name: str, query_hint: str) -> Optional[Any]:
+ """Attempt fuzzy matching for speculative results.
+
+ This is used when the exact arguments aren't known but we have a
+ query hint from the ReasoningSniffer.
+
+ Args:
+ tool_name: Name of the tool.
+ query_hint: Partial query string detected during reasoning.
+
+ Returns:
+ The cached result if a fuzzy match is found, None otherwise.
+ """
+ query_hint_lower = query_hint.lower()
+
+ for cache_key, entry in self._cache.items():
+ if not cache_key.startswith(f"{tool_name}:"):
+ continue
+
+ if entry.is_expired():
+ continue
+
+ # Check if the query hint might match this entry
+ # This is a heuristic - the args might contain the hint
+ if hasattr(entry, "result") and isinstance(entry.result, str):
+ if query_hint_lower in entry.result.lower():
+ entry.hit_count += 1
+ self._stats["hits"] += 1
+ logger.info(
+ f"SpeculativeCache: FUZZY HIT for {tool_name} "
+ f"(hint='{query_hint[:30]}...')"
+ )
+ return entry.result
+
+ self._stats["misses"] += 1
+ return None
+
+ def invalidate(self, tool_name: str) -> int:
+ """Invalidate all cached entries for a specific tool.
+
+ Args:
+ tool_name: Name of the tool to invalidate.
+
+ Returns:
+ Number of entries invalidated.
+ """
+ keys_to_remove = [key for key in self._cache if key.startswith(f"{tool_name}:")]
+
+ for key in keys_to_remove:
+ del self._cache[key]
+
+ self._stats["evictions"] += len(keys_to_remove)
+ return len(keys_to_remove)
+
+ def clear(self) -> None:
+ """Clear all cached entries."""
+ count = len(self._cache)
+ self._cache.clear()
+ self._stats["evictions"] += count
+ logger.info(f"SpeculativeCache: Cleared {count} entries")
+
+ def cleanup_expired(self) -> int:
+ """Remove all expired entries from the cache.
+
+ Returns:
+ Number of entries removed.
+ """
+ expired_keys = [key for key, entry in self._cache.items() if entry.is_expired()]
+
+ for key in expired_keys:
+ del self._cache[key]
+
+ self._stats["evictions"] += len(expired_keys)
+ return len(expired_keys)
+
+ @property
+ def stats(self) -> Dict[str, int]:
+ """Get cache statistics.
+
+ Returns:
+ Dictionary with hits, misses, stores, and evictions counts.
+ """
+ return dict(self._stats)
+
+ @property
+ def hit_rate(self) -> float:
+ """Calculate the cache hit rate.
+
+ Returns:
+ Hit rate as a float between 0.0 and 1.0.
+ """
+ total = self._stats["hits"] + self._stats["misses"]
+ if total == 0:
+ return 0.0
+ return self._stats["hits"] / total
+
+ def __len__(self) -> int:
+ """Return the number of cached entries."""
+ return len(self._cache)
diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py
index a80e912534ab..e70e2c432425 100644
--- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py
+++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py
@@ -5,6 +5,7 @@
import math
import os
import re
+import time
import warnings
from asyncio import Task
from dataclasses import dataclass
@@ -93,9 +94,9 @@
openai_init_kwargs = set(inspect.getfullargspec(AsyncOpenAI.__init__).kwonlyargs)
aopenai_init_kwargs = set(inspect.getfullargspec(AsyncAzureOpenAI.__init__).kwonlyargs)
-create_kwargs = set(completion_create_params.CompletionCreateParamsBase.__annotations__.keys()) | set(
- ("timeout", "stream", "extra_body")
-)
+create_kwargs = set(
+ completion_create_params.CompletionCreateParamsBase.__annotations__.keys()
+) | set(("timeout", "stream", "extra_body"))
# Only single choice allowed
disallowed_create_args = set(["stream", "messages", "function_call", "functions", "n"])
required_create_args: Set[str] = set(["model"])
@@ -138,9 +139,13 @@ def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]:
create_args = {k: v for k, v in config.items() if k in create_kwargs}
create_args_keys = set(create_args.keys())
if not required_create_args.issubset(create_args_keys):
- raise ValueError(f"Required create args are missing: {required_create_args - create_args_keys}")
+ raise ValueError(
+ f"Required create args are missing: {required_create_args - create_args_keys}"
+ )
if disallowed_create_args.intersection(create_args_keys):
- raise ValueError(f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}")
+ raise ValueError(
+ f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}"
+ )
return create_args
@@ -175,12 +180,14 @@ def to_oai_type(
}
transformers = get_transformer("openai", model, model_family)
- def raise_value_error(message: LLMMessage, context: Dict[str, Any]) -> Sequence[ChatCompletionMessageParam]:
+ def raise_value_error(
+ message: LLMMessage, context: Dict[str, Any]
+ ) -> Sequence[ChatCompletionMessageParam]:
raise ValueError(f"Unknown message type: {type(message)}")
- transformer: Callable[[LLMMessage, Dict[str, Any]], Sequence[ChatCompletionMessageParam]] = transformers.get(
- type(message), raise_value_error
- )
+ transformer: Callable[
+ [LLMMessage, Dict[str, Any]], Sequence[ChatCompletionMessageParam]
+ ] = transformers.get(type(message), raise_value_error)
result = transformer(message, context)
return result
@@ -257,11 +264,19 @@ def convert_tools(
type="function",
function=FunctionDefinition(
name=tool_schema["name"],
- description=(tool_schema["description"] if "description" in tool_schema else ""),
+ description=(
+ tool_schema["description"]
+ if "description" in tool_schema
+ else ""
+ ),
parameters=(
- cast(FunctionParameters, tool_schema["parameters"]) if "parameters" in tool_schema else {}
+ cast(FunctionParameters, tool_schema["parameters"])
+ if "parameters" in tool_schema
+ else {}
+ ),
+ strict=(
+ tool_schema["strict"] if "strict" in tool_schema else False
),
- strict=(tool_schema["strict"] if "strict" in tool_schema else False),
),
)
)
@@ -293,7 +308,9 @@ def convert_tool_choice(tool_choice: Tool | Literal["auto", "required", "none"])
if isinstance(tool_choice, Tool):
return {"type": "function", "function": {"name": tool_choice.schema["name"]}}
else:
- raise ValueError(f"tool_choice must be a Tool object, 'auto', 'required', or 'none', got {type(tool_choice)}")
+ raise ValueError(
+ f"tool_choice must be a Tool object, 'auto', 'required', or 'none', got {type(tool_choice)}"
+ )
def normalize_name(name: str) -> str:
@@ -339,14 +356,18 @@ def count_tokens_openai(
continue
if isinstance(message, UserMessage) and isinstance(value, list):
- typed_message_value = cast(List[ChatCompletionContentPartParam], value)
+ typed_message_value = cast(
+ List[ChatCompletionContentPartParam], value
+ )
assert len(typed_message_value) == len(
message.content
), "Mismatch in message content and typed message value"
# We need image properties that are only in the original message
- for part, content_part in zip(typed_message_value, message.content, strict=False):
+ for part, content_part in zip(
+ typed_message_value, message.content, strict=False
+ ):
if isinstance(content_part, Image):
# TODO: add detail parameter
num_tokens += calculate_vision_tokens(content_part)
@@ -357,13 +378,17 @@ def count_tokens_openai(
serialized_part = json.dumps(part)
num_tokens += len(encoding.encode(serialized_part))
except TypeError:
- trace_logger.warning(f"Could not convert {part} to string, skipping.")
+ trace_logger.warning(
+ f"Could not convert {part} to string, skipping."
+ )
else:
if not isinstance(value, str):
try:
value = json.dumps(value)
except TypeError:
- trace_logger.warning(f"Could not convert {value} to string, skipping.")
+ trace_logger.warning(
+ f"Could not convert {value} to string, skipping."
+ )
continue
num_tokens += len(encoding.encode(value))
if key == "name":
@@ -389,26 +414,38 @@ def count_tokens_openai(
for field in v: # pyright: ignore
if field == "type":
tool_tokens += 2
- tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore
+ tool_tokens += len(
+ encoding.encode(v["type"])
+ ) # pyright: ignore
elif field == "description":
tool_tokens += 2
- tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
+ tool_tokens += len(
+ encoding.encode(v["description"])
+ ) # pyright: ignore
elif field == "anyOf":
tool_tokens -= 3
for o in v["anyOf"]: # type: ignore
tool_tokens += 3
- tool_tokens += len(encoding.encode(str(o["type"]))) # pyright: ignore
+ tool_tokens += len(
+ encoding.encode(str(o["type"]))
+ ) # pyright: ignore
elif field == "default":
tool_tokens += 2
- tool_tokens += len(encoding.encode(json.dumps(v["default"])))
+ tool_tokens += len(
+ encoding.encode(json.dumps(v["default"]))
+ )
elif field == "title":
tool_tokens += 2
- tool_tokens += len(encoding.encode(str(v["title"]))) # pyright: ignore
+ tool_tokens += len(
+ encoding.encode(str(v["title"]))
+ ) # pyright: ignore
elif field == "enum":
tool_tokens -= 3
for o in v["enum"]: # pyright: ignore
tool_tokens += 3
- tool_tokens += len(encoding.encode(o)) # pyright: ignore
+ tool_tokens += len(
+ encoding.encode(o)
+ ) # pyright: ignore
else:
trace_logger.warning(f"Not supported field {field}")
tool_tokens += 11
@@ -447,7 +484,9 @@ def __init__(
try:
self._model_info = _model_info.get_info(create_args["model"])
except KeyError as err:
- raise ValueError("model_info is required when model name is not a valid OpenAI model") from err
+ raise ValueError(
+ "model_info is required when model name is not a valid OpenAI model"
+ ) from err
elif model_capabilities is not None and model_info is not None:
raise ValueError("model_capabilities and model_info are mutually exclusive")
elif model_capabilities is not None and model_info is None:
@@ -487,7 +526,9 @@ def __init__(
def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient:
return OpenAIChatCompletionClient(**config)
- def _rstrip_last_assistant_message(self, messages: Sequence[LLMMessage]) -> Sequence[LLMMessage]:
+ def _rstrip_last_assistant_message(
+ self, messages: Sequence[LLMMessage]
+ ) -> Sequence[LLMMessage]:
"""
Remove the last assistant message if it is empty.
"""
@@ -509,7 +550,9 @@ def _process_create_args(
# Make sure all extra_create_args are valid
extra_create_args_keys = set(extra_create_args.keys())
if not create_kwargs.issuperset(extra_create_args_keys):
- raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
+ raise ValueError(
+ f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}"
+ )
# Copy the create args and overwrite anything in extra_create_args
create_args = self._create_args.copy()
@@ -541,7 +584,9 @@ def _process_create_args(
raise ValueError("Model does not support JSON output.")
if json_output is True:
# JSON mode.
- create_args["response_format"] = ResponseFormatJSONObject(type="json_object")
+ create_args["response_format"] = ResponseFormatJSONObject(
+ type="json_object"
+ )
elif json_output is False:
# Text mode.
create_args["response_format"] = ResponseFormatText(type="text")
@@ -555,7 +600,9 @@ def _process_create_args(
# Beta client mode with Pydantic model class.
response_format_value = json_output
else:
- raise ValueError(f"json_output must be a boolean or a Pydantic model class, got {type(json_output)}")
+ raise ValueError(
+ f"json_output must be a boolean or a Pydantic model class, got {type(json_output)}"
+ )
if response_format_value is not None and "response_format" in create_args:
warnings.warn(
@@ -573,8 +620,12 @@ def _process_create_args(
if self.model_info["vision"] is False:
for message in messages:
if isinstance(message, UserMessage):
- if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
- raise ValueError("Model does not support vision and image was provided")
+ if isinstance(message.content, list) and any(
+ isinstance(x, Image) for x in message.content
+ ):
+ raise ValueError(
+ "Model does not support vision and image was provided"
+ )
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.")
@@ -646,7 +697,9 @@ def _process_create_args(
# tool_choice is a single Tool object
tool_name = tool_choice.schema["name"]
if tool_name not in tool_names_available:
- raise ValueError(f"tool_choice references '{tool_name}' but it's not in the provided tools")
+ raise ValueError(
+ f"tool_choice references '{tool_name}' but it's not in the provided tools"
+ )
if len(converted_tools) > 0:
# Convert to OpenAI format and add to create_args
@@ -678,12 +731,17 @@ async def create(
extra_create_args,
)
future: Union[Task[ParsedChatCompletion[BaseModel]], Task[ChatCompletion]]
+ start_time = time.perf_counter()
if create_params.response_format is not None:
# Use beta client if response_format is not None
future = asyncio.ensure_future(
self._client.beta.chat.completions.parse(
messages=create_params.messages,
- tools=(create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN),
+ tools=(
+ create_params.tools
+ if len(create_params.tools) > 0
+ else NOT_GIVEN
+ ),
response_format=create_params.response_format,
**create_params.create_args,
)
@@ -694,7 +752,11 @@ async def create(
self._client.chat.completions.create(
messages=create_params.messages,
stream=False,
- tools=(create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN),
+ tools=(
+ create_params.tools
+ if len(create_params.tools) > 0
+ else NOT_GIVEN
+ ),
**create_params.create_args,
)
)
@@ -702,6 +764,7 @@ async def create(
if cancellation_token is not None:
cancellation_token.link_future(future)
result: Union[ParsedChatCompletion[BaseModel], ChatCompletion] = await future
+ end_time = time.perf_counter()
if create_params.response_format is not None:
result = cast(ParsedChatCompletion[Any], result)
@@ -709,8 +772,25 @@ async def create(
# even when result.usage is not None
usage = RequestUsage(
# TODO backup token counting
- prompt_tokens=getattr(result.usage, "prompt_tokens", 0) if result.usage is not None else 0,
- completion_tokens=getattr(result.usage, "completion_tokens", 0) if result.usage is not None else 0,
+ prompt_tokens=(
+ getattr(result.usage, "prompt_tokens", 0)
+ if result.usage is not None
+ else 0
+ ),
+ completion_tokens=(
+ getattr(result.usage, "completion_tokens", 0)
+ if result.usage is not None
+ else 0
+ ),
+ )
+
+ # Calculate performance metrics
+ latency_seconds = end_time - start_time
+ latency_ms = latency_seconds * 1000
+ tokens_per_second = (
+ usage.completion_tokens / latency_seconds
+ if latency_seconds > 0 and usage.completion_tokens > 0
+ else None
)
logger.info(
@@ -719,6 +799,8 @@ async def create(
response=result.model_dump(),
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
+ latency_ms=latency_ms,
+ tokens_per_second=tokens_per_second,
tools=create_params.tools,
)
)
@@ -733,15 +815,21 @@ async def create(
)
# Limited to a single choice currently.
- choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], Choice] = result.choices[0]
+ choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], Choice] = (
+ result.choices[0]
+ )
# Detect whether it is a function call or not.
# We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
content: Union[str, List[FunctionCall]]
thought: str | None = None
if choice.message.function_call is not None:
- raise ValueError("function_call is deprecated and is not supported by this model client.")
- elif choice.message.tool_calls is not None and len(choice.message.tool_calls) > 0:
+ raise ValueError(
+ "function_call is deprecated and is not supported by this model client."
+ )
+ elif (
+ choice.message.tool_calls is not None and len(choice.message.tool_calls) > 0
+ ):
if choice.finish_reason != "tool_calls":
warnings.warn(
f"Finish reason mismatch: {choice.finish_reason} != tool_calls "
@@ -763,7 +851,9 @@ async def create(
stacklevel=2,
)
if isinstance(tool_call.function.arguments, dict):
- tool_call.function.arguments = json.dumps(tool_call.function.arguments)
+ tool_call.function.arguments = json.dumps(
+ tool_call.function.arguments
+ )
content.append(
FunctionCall(
id=tool_call.id,
@@ -788,14 +878,21 @@ async def create(
ChatCompletionTokenLogprob(
token=x.token,
logprob=x.logprob,
- top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs],
+ top_logprobs=[
+ TopLogprob(logprob=y.logprob, bytes=y.bytes)
+ for y in x.top_logprobs
+ ],
bytes=x.bytes,
)
for x in choice.logprobs.content
]
# This is for local R1 models.
- if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1 and thought is None:
+ if (
+ isinstance(content, str)
+ and self._model_info["family"] == ModelFamily.R1
+ and thought is None
+ ):
thought, content = parse_r1_content(content)
response = CreateResult(
@@ -858,7 +955,10 @@ async def create_stream(
if include_usage is not None:
if "stream_options" in create_params.create_args:
stream_options = create_params.create_args["stream_options"]
- if "include_usage" in stream_options and stream_options["include_usage"] != include_usage:
+ if (
+ "include_usage" in stream_options
+ and stream_options["include_usage"] != include_usage
+ ):
raise ValueError(
"include_usage and extra_create_args['stream_options']['include_usage'] are both set, but differ in value."
)
@@ -904,6 +1004,10 @@ async def create_stream(
first_chunk = True
is_reasoning = False
+ # Performance timing variables
+ start_time = time.perf_counter()
+ first_token_time: Optional[float] = None
+
# Process the stream of chunks.
async for chunk in chunks:
if first_chunk:
@@ -922,7 +1026,10 @@ async def create_stream(
# https://github.com/microsoft/autogen/issues/4213
if len(chunk.choices) == 0:
empty_chunk_count += 1
- if not empty_chunk_warning_has_been_issued and empty_chunk_count >= empty_chunk_warning_threshold:
+ if (
+ not empty_chunk_warning_has_been_issued
+ and empty_chunk_count >= empty_chunk_warning_threshold
+ ):
empty_chunk_warning_has_been_issued = True
warnings.warn(
f"Received more than {empty_chunk_warning_threshold} consecutive empty chunks. Empty chunks are being ignored.",
@@ -945,11 +1052,18 @@ async def create_stream(
# for liteLLM chunk usage, do the following hack keeping the pervious chunk.stop_reason (if set).
# set the stop_reason for the usage chunk to the prior stop_reason
- stop_reason = choice.finish_reason if chunk.usage is None and stop_reason is None else stop_reason
+ stop_reason = (
+ choice.finish_reason
+ if chunk.usage is None and stop_reason is None
+ else stop_reason
+ )
maybe_model = chunk.model
reasoning_content: str | None = None
- if choice.delta.model_extra is not None and "reasoning_content" in choice.delta.model_extra:
+ if (
+ choice.delta.model_extra is not None
+ and "reasoning_content" in choice.delta.model_extra
+ ):
# If there is a reasoning_content field, then we populate the thought field. This is for models such as R1.
reasoning_content = choice.delta.model_extra.get("reasoning_content")
@@ -969,6 +1083,8 @@ async def create_stream(
# First try get content
if choice.delta.content:
+ if first_token_time is None:
+ first_token_time = time.perf_counter()
content_deltas.append(choice.delta.content)
if len(choice.delta.content) > 0:
yield choice.delta.content
@@ -977,11 +1093,15 @@ async def create_stream(
continue
# Otherwise, get tool calls
if choice.delta.tool_calls is not None:
+ if first_token_time is None:
+ first_token_time = time.perf_counter()
for tool_call_chunk in choice.delta.tool_calls:
idx = tool_call_chunk.index
if idx not in full_tool_calls:
# We ignore the type hint here because we want to fill in type when the delta provides it
- full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
+ full_tool_calls[idx] = FunctionCall(
+ id="", arguments="", name=""
+ )
if tool_call_chunk.id is not None:
full_tool_calls[idx].id += tool_call_chunk.id
@@ -990,13 +1110,18 @@ async def create_stream(
if tool_call_chunk.function.name is not None:
full_tool_calls[idx].name += tool_call_chunk.function.name
if tool_call_chunk.function.arguments is not None:
- full_tool_calls[idx].arguments += tool_call_chunk.function.arguments
+ full_tool_calls[
+ idx
+ ].arguments += tool_call_chunk.function.arguments
if choice.logprobs and choice.logprobs.content:
logprobs = [
ChatCompletionTokenLogprob(
token=x.token,
logprob=x.logprob,
- top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs],
+ top_logprobs=[
+ TopLogprob(logprob=y.logprob, bytes=y.bytes)
+ for y in x.top_logprobs
+ ],
bytes=x.bytes,
)
for x in choice.logprobs.content
@@ -1050,7 +1175,11 @@ async def create_stream(
thought = "".join(thought_deltas).lstrip("").rstrip("")
# This is for local R1 models whose reasoning content is within the content string.
- if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1 and thought is None:
+ if (
+ isinstance(content, str)
+ and self._model_info["family"] == ModelFamily.R1
+ and thought is None
+ ):
thought, content = parse_r1_content(content)
# Create the result.
@@ -1063,12 +1192,30 @@ async def create_stream(
thought=thought,
)
+ # Calculate performance metrics
+ end_time = time.perf_counter()
+ latency_seconds = end_time - start_time
+ latency_ms = latency_seconds * 1000
+ ttft_ms = (
+ (first_token_time - start_time) * 1000
+ if first_token_time is not None
+ else None
+ )
+ tokens_per_second = (
+ usage.completion_tokens / latency_seconds
+ if latency_seconds > 0 and usage.completion_tokens > 0
+ else None
+ )
+
# Log the end of the stream.
logger.info(
LLMStreamEndEvent(
response=result.model_dump(),
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
+ latency_ms=latency_ms,
+ tokens_per_second=tokens_per_second,
+ ttft_ms=ttft_ms,
)
)
@@ -1118,7 +1265,9 @@ async def _create_stream_chunks_beta_client(
async with self._client.beta.chat.completions.stream(
messages=oai_messages,
tools=tool_params if len(tool_params) > 0 else NOT_GIVEN,
- response_format=(response_format if response_format is not None else NOT_GIVEN),
+ response_format=(
+ response_format if response_format is not None else NOT_GIVEN
+ ),
**create_args_no_response_format,
) as stream:
while True:
@@ -1148,7 +1297,9 @@ def actual_usage(self) -> RequestUsage:
def total_usage(self) -> RequestUsage:
return self._total_usage
- def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
+ def count_tokens(
+ self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []
+ ) -> int:
return count_tokens_openai(
messages,
self._create_args["model"],
@@ -1158,7 +1309,9 @@ def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool |
include_name_in_message=self._include_name_in_message,
)
- def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
+ def remaining_tokens(
+ self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []
+ ) -> int:
token_limit = _model_info.get_token_limit(self._create_args["model"])
return token_limit - self.count_tokens(messages, tools=tools)
@@ -1176,7 +1329,9 @@ def model_info(self) -> ModelInfo:
return self._model_info
-class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenAIClientConfigurationConfigModel]):
+class OpenAIChatCompletionClient(
+ BaseOpenAIChatCompletionClient, Component[OpenAIClientConfigurationConfigModel]
+):
"""Chat completion client for OpenAI hosted models.
To use this client, you must install the `openai` extra:
@@ -1675,7 +1830,9 @@ class AzureOpenAIChatCompletionClient(
component_type = "model"
component_config_schema = AzureOpenAIClientConfigurationConfigModel
- component_provider_override = "autogen_ext.models.openai.AzureOpenAIChatCompletionClient"
+ component_provider_override = (
+ "autogen_ext.models.openai.AzureOpenAIChatCompletionClient"
+ )
def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]):
model_capabilities: Optional[ModelCapabilities] = None # type: ignore
@@ -1723,11 +1880,17 @@ def _to_config(self) -> AzureOpenAIClientConfigurationConfigModel:
copied_config = self._raw_config.copy()
if "azure_ad_token_provider" in copied_config:
- if not isinstance(copied_config["azure_ad_token_provider"], AzureTokenProvider):
- raise ValueError("azure_ad_token_provider must be a AzureTokenProvider to be component serialized")
+ if not isinstance(
+ copied_config["azure_ad_token_provider"], AzureTokenProvider
+ ):
+ raise ValueError(
+ "azure_ad_token_provider must be a AzureTokenProvider to be component serialized"
+ )
copied_config["azure_ad_token_provider"] = (
- copied_config["azure_ad_token_provider"].dump_component().model_dump(exclude_none=True)
+ copied_config["azure_ad_token_provider"]
+ .dump_component()
+ .model_dump(exclude_none=True)
)
return AzureOpenAIClientConfigurationConfigModel(**copied_config)
@@ -1743,8 +1906,10 @@ def _from_config(cls, config: AzureOpenAIClientConfigurationConfigModel) -> Self
copied_config["api_key"] = config.api_key.get_secret_value()
if "azure_ad_token_provider" in copied_config:
- copied_config["azure_ad_token_provider"] = AzureTokenProvider.load_component(
- copied_config["azure_ad_token_provider"]
+ copied_config["azure_ad_token_provider"] = (
+ AzureTokenProvider.load_component(
+ copied_config["azure_ad_token_provider"]
+ )
)
return cls(**copied_config)
diff --git a/python/packages/autogen-ext/tests/models/test_nvidia_speculative_client.py b/python/packages/autogen-ext/tests/models/test_nvidia_speculative_client.py
new file mode 100644
index 000000000000..81b21074c2b2
--- /dev/null
+++ b/python/packages/autogen-ext/tests/models/test_nvidia_speculative_client.py
@@ -0,0 +1,501 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""Unit tests for NVIDIA Speculative Reasoning Execution components.
+
+Tests the ReasoningSniffer, SpeculativeCache, and NvidiaSpeculativeClient
+for detecting tool-call intents in LLM reasoning streams and managing
+speculative pre-warming of tools.
+"""
+
+import asyncio
+import time
+from typing import Any, AsyncGenerator, Dict, List, Optional, Union
+from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+from autogen_core import CancellationToken
+from autogen_core.models import (
+ CreateResult,
+ RequestUsage,
+ UserMessage,
+)
+
+from autogen_ext.models.nvidia import (
+ NvidiaSpeculativeClient,
+ ReasoningSniffer,
+ SpeculativeCache,
+ ToolIntent,
+)
+
+
+class TestReasoningSniffer:
+ """Tests for the ReasoningSniffer component."""
+
+ def test_sniffer_initialization(self) -> None:
+ """Test that sniffer initializes with default patterns."""
+ sniffer = ReasoningSniffer()
+ assert sniffer is not None
+ assert len(sniffer.PATTERNS) > 0
+ assert "web_search" in sniffer.PATTERNS
+
+ def test_sniffer_detects_web_search_intent(self) -> None:
+ """Test detection of web search intent patterns."""
+ sniffer = ReasoningSniffer()
+
+ intent = sniffer.sniff("I will search for Python documentation")
+
+ assert intent is not None
+ assert intent.tool_type == "web_search"
+ assert "Python documentation" in intent.query_hint
+ assert intent.confidence > 0.5
+
+ def test_sniffer_detects_database_query_intent(self) -> None:
+ """Test detection of database query intent patterns."""
+ sniffer = ReasoningSniffer()
+
+ intent = sniffer.sniff("I need to check the database for user records")
+
+ assert intent is not None
+ assert intent.tool_type == "database_query"
+
+ def test_sniffer_detects_calculate_intent(self) -> None:
+ """Test detection of calculation intent patterns."""
+ sniffer = ReasoningSniffer()
+
+ intent = sniffer.sniff("I will calculate the total revenue")
+
+ assert intent is not None
+ assert intent.tool_type == "calculate"
+
+ def test_sniffer_returns_none_for_non_matching_text(self) -> None:
+ """Test that sniffer returns None when no intent is detected."""
+ sniffer = ReasoningSniffer()
+
+ intent = sniffer.sniff("The weather is nice today")
+
+ assert intent is None
+
+ def test_sniffer_accumulates_context(self) -> None:
+ """Test that sniffer accumulates context across chunks."""
+ sniffer = ReasoningSniffer()
+
+ # Send partial text across chunks
+ sniffer.sniff("I will ")
+ intent = sniffer.sniff("search for documentation")
+
+ assert intent is not None
+ assert intent.tool_type == "web_search"
+
+ def test_sniffer_reset_clears_context(self) -> None:
+ """Test that reset clears the accumulated context."""
+ sniffer = ReasoningSniffer()
+
+ sniffer.sniff("I will search for ")
+ sniffer.reset()
+ intent = sniffer.sniff("documentation")
+
+ # Should not match after reset since "I will search for" was cleared
+ assert intent is None
+
+ def test_sniffer_custom_pattern(self) -> None:
+ """Test adding and using custom patterns."""
+ sniffer = ReasoningSniffer()
+
+ sniffer.add_pattern(
+ "custom_tool",
+ r"(?:I will|Let me)\s+run\s+custom\s+analysis\s+on\s+(.+)"
+ )
+
+ intent = sniffer.sniff("I will run custom analysis on the dataset")
+
+ assert intent is not None
+ assert intent.tool_type == "custom_tool"
+
+
+class TestSpeculativeCache:
+ """Tests for the SpeculativeCache component."""
+
+ def setup_method(self) -> None:
+ """Reset cache singleton before each test."""
+ SpeculativeCache.reset_instance()
+
+ def test_cache_singleton_pattern(self) -> None:
+ """Test that cache follows singleton pattern."""
+ cache1 = SpeculativeCache.get_instance()
+ cache2 = SpeculativeCache.get_instance()
+
+ assert cache1 is cache2
+
+ def test_cache_reset_instance(self) -> None:
+ """Test that reset_instance creates a new singleton."""
+ cache1 = SpeculativeCache.get_instance()
+ cache1.store("web_search", {"query": "test"}, "value")
+
+ SpeculativeCache.reset_instance()
+ cache2 = SpeculativeCache.get_instance()
+
+ assert cache1 is not cache2
+ assert cache2.get("web_search", {"query": "test"}) is None
+
+ def test_cache_store_and_retrieve(self) -> None:
+ """Test basic store and retrieve functionality."""
+ cache = SpeculativeCache.get_instance()
+
+ cache.store("web_search", {"query": "test"}, "test_value")
+ result = cache.get("web_search", {"query": "test"})
+
+ assert result == "test_value"
+
+ def test_cache_returns_none_for_missing_key(self) -> None:
+ """Test that get returns None for missing keys."""
+ cache = SpeculativeCache.get_instance()
+
+ result = cache.get("nonexistent_tool", {"query": "missing"})
+
+ assert result is None
+
+ def test_cache_tracks_hits_and_misses(self) -> None:
+ """Test that cache tracks hit and miss statistics."""
+ cache = SpeculativeCache.get_instance()
+
+ cache.store("web_search", {"query": "test"}, "value")
+ cache.get("web_search", {"query": "test"}) # Hit
+ cache.get("web_search", {"query": "other"}) # Miss
+
+ assert cache.stats["hits"] == 1
+ assert cache.stats["misses"] == 1
+ assert cache.stats["stores"] == 1
+
+ def test_cache_argument_hashing(self) -> None:
+ """Test that cache generates consistent hashes from arguments."""
+ cache = SpeculativeCache.get_instance()
+
+ # Store two entries with same args - should overwrite
+ cache.store("web_search", {"query": "test"}, "result1")
+ cache.store("web_search", {"query": "test"}, "result2") # Overwrites
+ cache.store("web_search", {"query": "different"}, "result3")
+
+ # Should get latest value for same args
+ assert cache.get("web_search", {"query": "test"}) == "result2"
+ assert cache.get("web_search", {"query": "different"}) == "result3"
+
+ def test_cache_len(self) -> None:
+ """Test that len returns number of cached entries."""
+ cache = SpeculativeCache.get_instance()
+
+ assert len(cache) == 0
+
+ cache.store("tool1", {"arg": "a"}, "value1")
+ cache.store("tool2", {"arg": "b"}, "value2")
+
+ assert len(cache) == 2
+
+
+class TestNvidiaSpeculativeClient:
+ """Tests for the NvidiaSpeculativeClient wrapper."""
+
+ @pytest.fixture
+ def mock_inner_client(self) -> MagicMock:
+ """Create a mock inner client for testing."""
+ client = MagicMock()
+ client.model_info = {
+ "vision": False,
+ "function_calling": True,
+ "json_output": True,
+ "family": "r1",
+ }
+ client.capabilities = MagicMock()
+ client.capabilities.vision = False
+ client.capabilities.function_calling = True
+ client.capabilities.json_output = True
+
+ async def mock_close() -> None:
+ pass
+
+ client.close = AsyncMock(side_effect=mock_close)
+ return client
+
+ def test_client_initialization(self, mock_inner_client: MagicMock) -> None:
+ """Test that client initializes correctly."""
+ SpeculativeCache.reset_instance()
+
+ client = NvidiaSpeculativeClient(inner_client=mock_inner_client)
+
+ assert client is not None
+ assert client._inner_client is mock_inner_client
+ assert client._enable_speculation is True
+
+ def test_client_initialization_with_custom_sniffer(
+ self, mock_inner_client: MagicMock
+ ) -> None:
+ """Test initialization with custom sniffer."""
+ SpeculativeCache.reset_instance()
+ custom_sniffer = ReasoningSniffer()
+
+ client = NvidiaSpeculativeClient(
+ inner_client=mock_inner_client,
+ sniffer=custom_sniffer,
+ )
+
+ assert client._sniffer is custom_sniffer
+
+ def test_client_initialization_with_speculation_disabled(
+ self, mock_inner_client: MagicMock
+ ) -> None:
+ """Test that speculation can be disabled."""
+ SpeculativeCache.reset_instance()
+
+ client = NvidiaSpeculativeClient(
+ inner_client=mock_inner_client,
+ enable_speculation=False,
+ )
+
+ assert client._enable_speculation is False
+
+ def test_client_exposes_cache(self, mock_inner_client: MagicMock) -> None:
+ """Test that cache is accessible via property."""
+ SpeculativeCache.reset_instance()
+
+ client = NvidiaSpeculativeClient(inner_client=mock_inner_client)
+
+ assert client.cache is not None
+ assert isinstance(client.cache, SpeculativeCache)
+
+ @pytest.mark.asyncio
+ async def test_client_create_delegates_to_inner(
+ self, mock_inner_client: MagicMock
+ ) -> None:
+ """Test that create() delegates to inner client."""
+ SpeculativeCache.reset_instance()
+
+ expected_result = CreateResult(
+ content="Test response",
+ usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
+ finish_reason="stop",
+ cached=False,
+ )
+ mock_inner_client.create = AsyncMock(return_value=expected_result)
+
+ client = NvidiaSpeculativeClient(inner_client=mock_inner_client)
+
+ result = await client.create(
+ messages=[UserMessage(content="Hello", source="user")]
+ )
+
+ assert result == expected_result
+ mock_inner_client.create.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_client_close_closes_inner_client(
+ self, mock_inner_client: MagicMock
+ ) -> None:
+ """Test that close() closes the inner client."""
+ SpeculativeCache.reset_instance()
+
+ client = NvidiaSpeculativeClient(inner_client=mock_inner_client)
+
+ await client.close()
+
+ mock_inner_client.close.assert_called_once()
+
+
+class TestNvidiaSpeculativeClientStreaming:
+ """Tests for streaming functionality of NvidiaSpeculativeClient."""
+
+ @pytest.fixture
+ def mock_inner_client(self) -> MagicMock:
+ """Create a mock inner client with streaming support."""
+ client = MagicMock()
+ client.model_info = {
+ "vision": False,
+ "function_calling": True,
+ "json_output": True,
+ "family": "r1",
+ }
+
+ async def mock_close() -> None:
+ pass
+
+ client.close = AsyncMock(side_effect=mock_close)
+ return client
+
+ @pytest.mark.asyncio
+ async def test_stream_yields_chunks_and_final_result(
+ self, mock_inner_client: MagicMock
+ ) -> None:
+ """Test that create_stream yields chunks and final CreateResult."""
+ SpeculativeCache.reset_instance()
+
+ async def mock_stream(
+ messages: Any, **kwargs: Any
+ ) -> AsyncGenerator[Union[str, CreateResult], None]:
+ yield "Hello"
+ yield " World"
+ yield CreateResult(
+ content="Hello World",
+ usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
+ finish_reason="stop",
+ cached=False,
+ )
+
+ mock_inner_client.create_stream = mock_stream
+
+ client = NvidiaSpeculativeClient(inner_client=mock_inner_client)
+
+ chunks: List[Any] = []
+ async for chunk in client.create_stream(
+ messages=[UserMessage(content="Hi", source="user")]
+ ):
+ chunks.append(chunk)
+
+ assert len(chunks) == 3
+ assert chunks[0] == "Hello"
+ assert chunks[1] == " World"
+ assert isinstance(chunks[2], CreateResult)
+
+ @pytest.mark.asyncio
+ async def test_stream_triggers_prewarm_on_intent_detection(
+ self, mock_inner_client: MagicMock
+ ) -> None:
+ """Test that prewarm callback is triggered when intent is detected."""
+ SpeculativeCache.reset_instance()
+
+ prewarm_called = False
+ prewarm_tool_type: Optional[str] = None
+
+ async def prewarm_callback(
+ tool_type: str, query_hint: str, context: Dict[str, Any]
+ ) -> Optional[str]:
+ nonlocal prewarm_called, prewarm_tool_type
+ prewarm_called = True
+ prewarm_tool_type = tool_type
+ return "mock_result"
+
+ async def mock_stream(
+ messages: Any, **kwargs: Any
+ ) -> AsyncGenerator[Union[str, CreateResult], None]:
+ # Simulate reasoning content with tool intent
+ yield ""
+ yield "I will search for Python documentation"
+ yield ""
+ yield CreateResult(
+ content="Here's the documentation",
+ usage=RequestUsage(prompt_tokens=10, completion_tokens=20),
+ finish_reason="stop",
+ cached=False,
+ )
+
+ mock_inner_client.create_stream = mock_stream
+
+ client = NvidiaSpeculativeClient(
+ inner_client=mock_inner_client,
+ prewarm_callback=prewarm_callback,
+ min_confidence=0.5,
+ )
+
+ chunks: List[Any] = []
+ async for chunk in client.create_stream(
+ messages=[UserMessage(content="Find docs", source="user")]
+ ):
+ chunks.append(chunk)
+
+ # Allow background tasks to complete
+ await asyncio.sleep(0.2)
+
+ assert prewarm_called is True
+ assert prewarm_tool_type == "web_search"
+
+ @pytest.mark.asyncio
+ async def test_stream_respects_min_confidence(
+ self, mock_inner_client: MagicMock
+ ) -> None:
+ """Test that prewarm is only triggered above min_confidence."""
+ SpeculativeCache.reset_instance()
+
+ prewarm_called = False
+
+ async def prewarm_callback(
+ tool_type: str, query_hint: str, context: Dict[str, Any]
+ ) -> Optional[str]:
+ nonlocal prewarm_called
+ prewarm_called = True
+ return None
+
+ async def mock_stream(
+ messages: Any, **kwargs: Any
+ ) -> AsyncGenerator[Union[str, CreateResult], None]:
+ yield "Maybe I should search" # Low confidence phrase
+ yield CreateResult(
+ content="Done",
+ usage=RequestUsage(prompt_tokens=5, completion_tokens=2),
+ finish_reason="stop",
+ cached=False,
+ )
+
+ mock_inner_client.create_stream = mock_stream
+
+ client = NvidiaSpeculativeClient(
+ inner_client=mock_inner_client,
+ prewarm_callback=prewarm_callback,
+ min_confidence=0.99, # Very high threshold
+ )
+
+ async for _ in client.create_stream(
+ messages=[UserMessage(content="Test", source="user")]
+ ):
+ pass
+
+ await asyncio.sleep(0.1)
+
+ # Should not trigger because confidence is below threshold
+ assert prewarm_called is False
+
+ @pytest.mark.asyncio
+ async def test_stream_deduplicates_intents(
+ self, mock_inner_client: MagicMock
+ ) -> None:
+ """Test that same intent is only fired once per session."""
+ SpeculativeCache.reset_instance()
+
+ prewarm_count = 0
+
+ async def prewarm_callback(
+ tool_type: str, query_hint: str, context: Dict[str, Any]
+ ) -> Optional[str]:
+ nonlocal prewarm_count
+ prewarm_count += 1
+ return "result"
+
+ async def mock_stream(
+ messages: Any, **kwargs: Any
+ ) -> AsyncGenerator[Union[str, CreateResult], None]:
+ # Multiple search intents in same stream
+ yield "I will search for Python docs. "
+ yield "Let me search for more info. "
+ yield "I need to search again. "
+ yield CreateResult(
+ content="Done",
+ usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
+ finish_reason="stop",
+ cached=False,
+ )
+
+ mock_inner_client.create_stream = mock_stream
+
+ client = NvidiaSpeculativeClient(
+ inner_client=mock_inner_client,
+ prewarm_callback=prewarm_callback,
+ min_confidence=0.5,
+ sniff_all_content=True,
+ )
+
+ async for _ in client.create_stream(
+ messages=[UserMessage(content="Test", source="user")]
+ ):
+ pass
+
+ await asyncio.sleep(0.2)
+
+ # Should only fire once due to deduplication
+ assert prewarm_count == 1
diff --git a/python/packages/autogen-ext/tests/models/test_sre_integration.py b/python/packages/autogen-ext/tests/models/test_sre_integration.py
new file mode 100644
index 000000000000..0b7ba75803b2
--- /dev/null
+++ b/python/packages/autogen-ext/tests/models/test_sre_integration.py
@@ -0,0 +1,344 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""Integration tests for NVIDIA Speculative Reasoning Execution.
+
+These tests verify the end-to-end flow of speculative execution with
+simulated DeepSeek-R1 style reasoning streams. They demonstrate the
+latency savings achievable by parallelizing tool execution with
+model reasoning.
+"""
+
+import asyncio
+import time
+from typing import Any, AsyncGenerator, Dict, List, Optional, Union
+from unittest.mock import MagicMock
+
+import pytest
+from autogen_core.models import (
+ CreateResult,
+ RequestUsage,
+ UserMessage,
+)
+
+from autogen_ext.models.nvidia import (
+ NvidiaSpeculativeClient,
+ ReasoningSniffer,
+ SpeculativeCache,
+)
+
+
+# Simulated DeepSeek-R1 reasoning stream with tool intent
+MOCK_R1_STREAM_CONTENT = [
+ "",
+ "Let me analyze this problem step by step.",
+ " First, I need to understand what the user is asking.",
+ " The user wants information about Python documentation.",
+ " I will search for Python 3.12 documentation to find the answer.",
+ " Let me also check for any related tutorials.",
+ "",
+ "\n\nBased on my research, here are the key Python 3.12 features:\n",
+ "1. Pattern matching improvements\n",
+ "2. Type parameter syntax\n",
+ "3. Performance optimizations",
+]
+
+
+class TestSpeculativeExecutionIntegration:
+ """Integration tests for the full speculative execution flow."""
+
+ @pytest.fixture(autouse=True)
+ def reset_cache(self) -> None:
+ """Reset the singleton cache before each test."""
+ SpeculativeCache.reset_instance()
+
+ @pytest.fixture
+ def mock_client_with_r1_stream(self) -> MagicMock:
+ """Create a mock client that simulates R1-style reasoning stream."""
+ client = MagicMock()
+ client.model_info = {
+ "vision": False,
+ "function_calling": True,
+ "json_output": True,
+ "family": "r1",
+ }
+
+ async def mock_close() -> None:
+ pass
+
+ client.close = MagicMock(side_effect=mock_close)
+
+ async def mock_stream(
+ messages: Any, **kwargs: Any
+ ) -> AsyncGenerator[Union[str, CreateResult], None]:
+ for chunk in MOCK_R1_STREAM_CONTENT:
+ await asyncio.sleep(0.05) # Simulate token generation latency
+ yield chunk
+
+ yield CreateResult(
+ content="".join(MOCK_R1_STREAM_CONTENT),
+ usage=RequestUsage(prompt_tokens=50, completion_tokens=100),
+ finish_reason="stop",
+ cached=False,
+ )
+
+ client.create_stream = mock_stream
+ return client
+
+ @pytest.mark.asyncio
+ async def test_end_to_end_speculative_execution(
+ self, mock_client_with_r1_stream: MagicMock
+ ) -> None:
+ """Test complete speculative execution flow with timing verification."""
+ prewarm_events: List[Dict[str, Any]] = []
+ stream_start_time: Optional[float] = None
+
+ async def prewarm_callback(
+ tool_type: str, query_hint: str, context: Dict[str, Any]
+ ) -> Optional[str]:
+ prewarm_time = time.perf_counter()
+ prewarm_events.append({
+ "tool_type": tool_type,
+ "query_hint": query_hint,
+ "time": prewarm_time,
+ "offset_ms": (prewarm_time - stream_start_time) * 1000
+ if stream_start_time
+ else 0,
+ })
+ # Simulate tool execution
+ await asyncio.sleep(0.1)
+ return f"Mock result for {tool_type}"
+
+ client = NvidiaSpeculativeClient(
+ inner_client=mock_client_with_r1_stream,
+ prewarm_callback=prewarm_callback,
+ min_confidence=0.6,
+ )
+
+ stream_start_time = time.perf_counter()
+ chunks: List[Any] = []
+
+ async for chunk in client.create_stream(
+ messages=[UserMessage(content="Tell me about Python 3.12", source="user")]
+ ):
+ chunks.append(chunk)
+
+ stream_end_time = time.perf_counter()
+
+ # Allow background tasks to complete
+ await asyncio.sleep(0.2)
+ await client.close()
+
+ # Verify prewarm was triggered
+ assert len(prewarm_events) > 0, "Expected at least one prewarm event"
+
+ # Verify prewarm fired before stream ended
+ first_prewarm = prewarm_events[0]
+ assert first_prewarm["tool_type"] == "web_search"
+
+ # Verify stream completed successfully
+ assert isinstance(chunks[-1], CreateResult)
+
+ @pytest.mark.asyncio
+ async def test_speculative_delta_timing(
+ self, mock_client_with_r1_stream: MagicMock
+ ) -> None:
+ """Test that speculative execution provides timing advantage."""
+ prewarm_time: Optional[float] = None
+ reasoning_end_time: Optional[float] = None
+
+ async def prewarm_callback(
+ tool_type: str, query_hint: str, context: Dict[str, Any]
+ ) -> Optional[str]:
+ nonlocal prewarm_time
+ prewarm_time = time.perf_counter()
+ await asyncio.sleep(0.05)
+ return "result"
+
+ client = NvidiaSpeculativeClient(
+ inner_client=mock_client_with_r1_stream,
+ prewarm_callback=prewarm_callback,
+ min_confidence=0.6,
+ )
+
+ stream_start = time.perf_counter()
+
+ async for chunk in client.create_stream(
+ messages=[UserMessage(content="Test", source="user")]
+ ):
+ chunk_str = str(chunk)
+ if "" in chunk_str:
+ reasoning_end_time = time.perf_counter()
+
+ await asyncio.sleep(0.2)
+ await client.close()
+
+ # Verify timing relationship
+ assert prewarm_time is not None, "Prewarm should have been triggered"
+ assert reasoning_end_time is not None, "Should have detected reasoning end"
+
+ # Prewarm should fire before reasoning ends (negative delta = savings)
+ prewarm_offset = (prewarm_time - stream_start) * 1000
+ reasoning_end_offset = (reasoning_end_time - stream_start) * 1000
+ speculative_delta = prewarm_offset - reasoning_end_offset
+
+ assert speculative_delta < 0, (
+ f"Speculative execution should provide timing advantage. "
+ f"Delta: {speculative_delta:.0f}ms"
+ )
+
+ @pytest.mark.asyncio
+ async def test_no_prewarm_when_speculation_disabled(
+ self, mock_client_with_r1_stream: MagicMock
+ ) -> None:
+ """Test that prewarming is skipped when speculation is disabled."""
+ prewarm_called = False
+
+ async def prewarm_callback(
+ tool_type: str, query_hint: str, context: Dict[str, Any]
+ ) -> Optional[str]:
+ nonlocal prewarm_called
+ prewarm_called = True
+ return None
+
+ client = NvidiaSpeculativeClient(
+ inner_client=mock_client_with_r1_stream,
+ prewarm_callback=prewarm_callback,
+ enable_speculation=False, # Disabled
+ )
+
+ async for _ in client.create_stream(
+ messages=[UserMessage(content="Test", source="user")]
+ ):
+ pass
+
+ await asyncio.sleep(0.1)
+ await client.close()
+
+ assert prewarm_called is False
+
+ @pytest.mark.asyncio
+ async def test_cache_stores_prewarm_results(
+ self, mock_client_with_r1_stream: MagicMock
+ ) -> None:
+ """Test that prewarm results are stored in cache."""
+ async def prewarm_callback(
+ tool_type: str, query_hint: str, context: Dict[str, Any]
+ ) -> Optional[str]:
+ return f"cached_result_for_{tool_type}"
+
+ client = NvidiaSpeculativeClient(
+ inner_client=mock_client_with_r1_stream,
+ prewarm_callback=prewarm_callback,
+ min_confidence=0.5,
+ )
+
+ async for _ in client.create_stream(
+ messages=[UserMessage(content="Test", source="user")]
+ ):
+ pass
+
+ await asyncio.sleep(0.2)
+ await client.close()
+
+ # Verify cache has entries
+ cache = client.cache
+ assert len(cache) > 0
+ assert cache.stats["stores"] > 0
+
+
+class TestSnifferPatternMatching:
+ """Integration tests for sniffer pattern matching accuracy."""
+
+ @pytest.fixture(autouse=True)
+ def reset_cache(self) -> None:
+ """Reset cache before each test."""
+ SpeculativeCache.reset_instance()
+
+ @pytest.mark.parametrize(
+ "text,expected_tool,should_match",
+ [
+ ("I will search for Python documentation", "web_search", True),
+ ("Let me look up the latest news", "web_search", True),
+ ("I need to check the database for user records", "database_query", True),
+ ("I'll calculate the total revenue", "calculate", True),
+ ("The weather is nice", None, False),
+ ("Just thinking about the problem", None, False),
+ ],
+ )
+ def test_intent_detection_patterns(
+ self, text: str, expected_tool: Optional[str], should_match: bool
+ ) -> None:
+ """Test various intent detection patterns."""
+ sniffer = ReasoningSniffer()
+
+ intent = sniffer.sniff(text)
+
+ if should_match:
+ assert intent is not None, f"Expected match for: {text}"
+ assert intent.tool_type == expected_tool
+ else:
+ assert intent is None, f"Expected no match for: {text}"
+
+ def test_contraction_patterns(self) -> None:
+ """Test that contractions like I'll are properly detected."""
+ sniffer = ReasoningSniffer()
+
+ # Test I'll patterns
+ intent = sniffer.sniff("I'll search for the latest updates")
+ assert intent is not None
+ assert intent.tool_type == "web_search"
+
+ sniffer.reset()
+
+ intent = sniffer.sniff("I'll need to look up some information")
+ assert intent is not None
+
+
+class TestPerformanceMetrics:
+ """Tests for performance metric tracking."""
+
+ @pytest.fixture(autouse=True)
+ def reset_cache(self) -> None:
+ """Reset cache before each test."""
+ SpeculativeCache.reset_instance()
+
+ @pytest.mark.asyncio
+ async def test_metrics_are_tracked(self) -> None:
+ """Test that performance metrics are properly tracked."""
+ client_mock = MagicMock()
+ client_mock.model_info = {"family": "r1"}
+
+ async def mock_close() -> None:
+ pass
+
+ client_mock.close = MagicMock(side_effect=mock_close)
+
+ async def mock_stream(
+ messages: Any, **kwargs: Any
+ ) -> AsyncGenerator[Union[str, CreateResult], None]:
+ yield "Test response"
+ yield CreateResult(
+ content="Test",
+ usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
+ finish_reason="stop",
+ cached=False,
+ )
+
+ client_mock.create_stream = mock_stream
+
+ client = NvidiaSpeculativeClient(inner_client=client_mock)
+
+ async for _ in client.create_stream(
+ messages=[UserMessage(content="Test", source="user")]
+ ):
+ pass
+
+ await client.close()
+
+ # Verify metrics were captured
+ metrics = client.last_metrics
+ assert metrics is not None
+ assert metrics.stream_start_time is not None
+ assert metrics.stream_end_time is not None
+ assert metrics.first_token_time is not None