diff --git a/python/packages/autogen-core/src/autogen_core/logging.py b/python/packages/autogen-core/src/autogen_core/logging.py index 3f371a6f3bca..2f31dbabe55a 100644 --- a/python/packages/autogen-core/src/autogen_core/logging.py +++ b/python/packages/autogen-core/src/autogen_core/logging.py @@ -1,6 +1,6 @@ import json from enum import Enum -from typing import Any, Dict, List, cast +from typing import Any, Dict, List, Optional, cast from ._agent_id import AgentId from ._message_handler_context import MessageHandlerContext @@ -15,6 +15,8 @@ def __init__( response: Dict[str, Any], prompt_tokens: int, completion_tokens: int, + latency_ms: Optional[float] = None, + tokens_per_second: Optional[float] = None, **kwargs: Any, ) -> None: """To be used by model clients to log the call to the LLM. @@ -24,6 +26,8 @@ def __init__( response (Dict[str, Any]): The response of the call. Must be json serializable. prompt_tokens (int): Number of tokens used in the prompt. completion_tokens (int): Number of tokens used in the completion. + latency_ms (Optional[float]): Total call duration in milliseconds. + tokens_per_second (Optional[float]): Completion tokens divided by latency in seconds. Example: @@ -45,6 +49,10 @@ def __init__( self.kwargs["response"] = response self.kwargs["prompt_tokens"] = prompt_tokens self.kwargs["completion_tokens"] = completion_tokens + if latency_ms is not None: + self.kwargs["latency_ms"] = latency_ms + if tokens_per_second is not None: + self.kwargs["tokens_per_second"] = tokens_per_second try: agent_id = MessageHandlerContext.agent_id() except RuntimeError: @@ -59,6 +67,14 @@ def prompt_tokens(self) -> int: def completion_tokens(self) -> int: return cast(int, self.kwargs["completion_tokens"]) + @property + def latency_ms(self) -> Optional[float]: + return self.kwargs.get("latency_ms") + + @property + def tokens_per_second(self) -> Optional[float]: + return self.kwargs.get("tokens_per_second") + # This must output the event in a json serializable format def __str__(self) -> str: return json.dumps(self.kwargs) @@ -111,6 +127,9 @@ def __init__( response: Dict[str, Any], prompt_tokens: int, completion_tokens: int, + latency_ms: Optional[float] = None, + tokens_per_second: Optional[float] = None, + ttft_ms: Optional[float] = None, **kwargs: Any, ) -> None: """To be used by model clients to log the end of a stream. @@ -119,6 +138,9 @@ def __init__( response (Dict[str, Any]): The response of the call. Must be json serializable. prompt_tokens (int): Number of tokens used in the prompt. completion_tokens (int): Number of tokens used in the completion. + latency_ms (Optional[float]): Total stream duration in milliseconds. + tokens_per_second (Optional[float]): Completion tokens divided by latency in seconds. + ttft_ms (Optional[float]): Time to first token in milliseconds. Example: @@ -138,6 +160,12 @@ def __init__( self.kwargs["response"] = response self.kwargs["prompt_tokens"] = prompt_tokens self.kwargs["completion_tokens"] = completion_tokens + if latency_ms is not None: + self.kwargs["latency_ms"] = latency_ms + if tokens_per_second is not None: + self.kwargs["tokens_per_second"] = tokens_per_second + if ttft_ms is not None: + self.kwargs["ttft_ms"] = ttft_ms try: agent_id = MessageHandlerContext.agent_id() except RuntimeError: @@ -152,6 +180,18 @@ def prompt_tokens(self) -> int: def completion_tokens(self) -> int: return cast(int, self.kwargs["completion_tokens"]) + @property + def latency_ms(self) -> Optional[float]: + return self.kwargs.get("latency_ms") + + @property + def tokens_per_second(self) -> Optional[float]: + return self.kwargs.get("tokens_per_second") + + @property + def ttft_ms(self) -> Optional[float]: + return self.kwargs.get("ttft_ms") + # This must output the event in a json serializable format def __str__(self) -> str: return json.dumps(self.kwargs) diff --git a/python/packages/autogen-core/tests/test_logging_events.py b/python/packages/autogen-core/tests/test_logging_events.py new file mode 100644 index 000000000000..3fd56c503861 --- /dev/null +++ b/python/packages/autogen-core/tests/test_logging_events.py @@ -0,0 +1,107 @@ +"""Tests for LLMCallEvent and LLMStreamEndEvent timing fields. + +These tests verify the performance telemetry fields added in Issue #5790. +""" + +import json + +from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent + + +def test_llm_call_event_timing_fields() -> None: + """Test that LLMCallEvent correctly stores and serializes timing fields.""" + event = LLMCallEvent( + messages=[{"role": "user", "content": "Hello"}], + response={"content": "Hi there!"}, + prompt_tokens=10, + completion_tokens=20, + latency_ms=150.5, + tokens_per_second=133.33, + ) + + # Check property accessors + assert event.prompt_tokens == 10 + assert event.completion_tokens == 20 + assert event.latency_ms == 150.5 + assert event.tokens_per_second == 133.33 + + # Check JSON serialization includes timing fields + json_str = str(event) + data = json.loads(json_str) + assert data["latency_ms"] == 150.5 + assert data["tokens_per_second"] == 133.33 + + +def test_llm_call_event_timing_fields_optional() -> None: + """Test that timing fields are optional and not included when not provided.""" + event = LLMCallEvent( + messages=[{"role": "user", "content": "Hello"}], + response={"content": "Hi there!"}, + prompt_tokens=10, + completion_tokens=20, + ) + + # Check that missing fields return None + assert event.latency_ms is None + assert event.tokens_per_second is None + + # Check JSON serialization excludes missing timing fields + json_str = str(event) + data = json.loads(json_str) + assert "latency_ms" not in data + assert "tokens_per_second" not in data + + +def test_llm_stream_end_event_timing_fields() -> None: + """Test that LLMStreamEndEvent correctly stores and serializes timing fields including TTFT.""" + event = LLMStreamEndEvent( + response={"content": "Hello, world!"}, + prompt_tokens=10, + completion_tokens=25, + latency_ms=200.0, + tokens_per_second=125.0, + ttft_ms=50.5, + ) + + # Check property accessors + assert event.prompt_tokens == 10 + assert event.completion_tokens == 25 + assert event.latency_ms == 200.0 + assert event.tokens_per_second == 125.0 + assert event.ttft_ms == 50.5 + + # Check JSON serialization includes all timing fields + json_str = str(event) + data = json.loads(json_str) + assert data["latency_ms"] == 200.0 + assert data["tokens_per_second"] == 125.0 + assert data["ttft_ms"] == 50.5 + + +def test_llm_stream_end_event_timing_fields_optional() -> None: + """Test that streaming timing fields are optional.""" + event = LLMStreamEndEvent( + response={"content": "Hello, world!"}, + prompt_tokens=10, + completion_tokens=25, + ) + + # Check that missing fields return None + assert event.latency_ms is None + assert event.tokens_per_second is None + assert event.ttft_ms is None + + # Check JSON serialization excludes missing timing fields + json_str = str(event) + data = json.loads(json_str) + assert "latency_ms" not in data + assert "tokens_per_second" not in data + assert "ttft_ms" not in data + + +if __name__ == "__main__": + test_llm_call_event_timing_fields() + test_llm_call_event_timing_fields_optional() + test_llm_stream_end_event_timing_fields() + test_llm_stream_end_event_timing_fields_optional() + print("All tests passed!") diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py index 16d56b57f956..08c19201aca5 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py @@ -1,6 +1,7 @@ import asyncio import logging import re +import time from asyncio import Task from inspect import getfullargspec from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Union, cast @@ -65,7 +66,9 @@ from .._utils.parse_r1_content import parse_r1_content create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs) -AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage] +AzureMessage = Union[ + AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage +] logger = logging.getLogger(EVENT_LOGGER_NAME) @@ -74,7 +77,9 @@ def _is_github_model(endpoint: str) -> bool: return endpoint == GITHUB_MODELS_ENDPOINT -def convert_tools(tools: Sequence[Tool | ToolSchema]) -> List[ChatCompletionsToolDefinition]: +def convert_tools( + tools: Sequence[Tool | ToolSchema], +) -> List[ChatCompletionsToolDefinition]: result: List[ChatCompletionsToolDefinition] = [] for tool in tools: if isinstance(tool, Tool): @@ -125,7 +130,13 @@ def _user_message_to_azure(message: UserMessage) -> AzureUserMessage: elif isinstance(part, Image): # TODO: support url based images # TODO: support specifying details - parts.append(ImageContentItem(image_url=ImageUrl(url=part.data_uri, detail=ImageDetailLevel.AUTO))) + parts.append( + ImageContentItem( + image_url=ImageUrl( + url=part.data_uri, detail=ImageDetailLevel.AUTO + ) + ) + ) else: raise ValueError(f"Unknown content type: {message.content}") return AzureUserMessage(content=parts) @@ -141,8 +152,13 @@ def _assistant_message_to_azure(message: AssistantMessage) -> AzureAssistantMess return AzureAssistantMessage(content=message.content) -def _tool_message_to_azure(message: FunctionExecutionResultMessage) -> Sequence[AzureToolMessage]: - return [AzureToolMessage(content=x.content, tool_call_id=x.call_id) for x in message.content] +def _tool_message_to_azure( + message: FunctionExecutionResultMessage, +) -> Sequence[AzureToolMessage]: + return [ + AzureToolMessage(content=x.content, tool_call_id=x.call_id) + for x in message.content + ] def to_azure_message(message: LLMMessage) -> Sequence[AzureMessage]: @@ -172,7 +188,9 @@ def assert_valid_name(name: str) -> str: For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API. """ if not re.match(r"^[a-zA-Z0-9_-]+$", name): - raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.") + raise ValueError( + f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed." + ) if len(name) > 64: raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.") return name @@ -306,11 +324,15 @@ def _validate_config(config: Dict[str, Any]) -> AzureAIChatCompletionClientConfi raise ValueError("model_info is required for AzureAIChatCompletionClient") validate_model_info(config["model_info"]) if _is_github_model(config["endpoint"]) and "model" not in config: - raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient") + raise ValueError( + "model is required for when using a Github model with AzureAIChatCompletionClient" + ) return cast(AzureAIChatCompletionClientConfig, config) @staticmethod - def _create_client(config: AzureAIChatCompletionClientConfig) -> ChatCompletionsClient: + def _create_client( + config: AzureAIChatCompletionClientConfig, + ) -> ChatCompletionsClient: # Only pass the parameters that ChatCompletionsClient accepts # Remove 'model_info' and other client-specific parameters client_config = {k: v for k, v in config.items() if k not in ("model_info",)} @@ -337,8 +359,12 @@ def _validate_model_info( if self.model_info["vision"] is False: for message in messages: if isinstance(message, UserMessage): - if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content): - raise ValueError("Model does not support vision and image was provided") + if isinstance(message.content, list) and any( + isinstance(x, Image) for x in message.content + ): + raise ValueError( + "Model does not support vision and image was provided" + ) if json_output is not None: if self.model_info["json_output"] is False and json_output is True: @@ -346,7 +372,9 @@ def _validate_model_info( if isinstance(json_output, type): # TODO: we should support this in the future. - raise ValueError("Structured output is not currently supported for AzureAIChatCompletionClient") + raise ValueError( + "Structured output is not currently supported for AzureAIChatCompletionClient" + ) if json_output is True and "response_format" not in create_args: create_args["response_format"] = "json_object" @@ -368,7 +396,9 @@ async def create( ) -> CreateResult: extra_create_args_keys = set(extra_create_args.keys()) if not create_kwargs.issuperset(extra_create_args_keys): - raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}") + raise ValueError( + f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}" + ) # Copy the create args and overwrite anything in extra_create_args create_args = self._create_args.copy() @@ -380,11 +410,14 @@ async def create( azure_messages = [item for sublist in azure_messages_nested for item in sublist] task: Task[ChatCompletions] + start_time = time.perf_counter() if len(tools) > 0: if isinstance(tool_choice, Tool): create_args["tool_choice"] = ChatCompletionsNamedToolChoice( - function=ChatCompletionsNamedToolChoiceFunction(name=tool_choice.name) + function=ChatCompletionsNamedToolChoiceFunction( + name=tool_choice.name + ) ) else: create_args["tool_choice"] = tool_choice @@ -404,18 +437,30 @@ async def create( cancellation_token.link_future(task) result: ChatCompletions = await task + end_time = time.perf_counter() usage = RequestUsage( prompt_tokens=result.usage.prompt_tokens if result.usage else 0, completion_tokens=result.usage.completion_tokens if result.usage else 0, ) + # Calculate performance metrics + latency_seconds = end_time - start_time + latency_ms = latency_seconds * 1000 + tokens_per_second = ( + usage.completion_tokens / latency_seconds + if latency_seconds > 0 and usage.completion_tokens > 0 + else None + ) + logger.info( LLMCallEvent( messages=[m.as_dict() for m in azure_messages], response=result.as_dict(), prompt_tokens=usage.prompt_tokens, completion_tokens=usage.completion_tokens, + latency_ms=latency_ms, + tokens_per_second=tokens_per_second, ) ) @@ -470,7 +515,9 @@ async def create_stream( ) -> AsyncGenerator[Union[str, CreateResult], None]: extra_create_args_keys = set(extra_create_args.keys()) if not create_kwargs.issuperset(extra_create_args_keys): - raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}") + raise ValueError( + f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}" + ) create_args: Dict[str, Any] = self._create_args.copy() create_args.update(extra_create_args) @@ -484,16 +531,27 @@ async def create_stream( if len(tools) > 0: if isinstance(tool_choice, Tool): create_args["tool_choice"] = ChatCompletionsNamedToolChoice( - function=ChatCompletionsNamedToolChoiceFunction(name=tool_choice.name) + function=ChatCompletionsNamedToolChoiceFunction( + name=tool_choice.name + ) ) else: create_args["tool_choice"] = tool_choice converted_tools = convert_tools(tools) task = asyncio.create_task( - self._client.complete(messages=azure_messages, tools=converted_tools, stream=True, **create_args) + self._client.complete( + messages=azure_messages, + tools=converted_tools, + stream=True, + **create_args, + ) ) else: - task = asyncio.create_task(self._client.complete(messages=azure_messages, stream=True, **create_args)) + task = asyncio.create_task( + self._client.complete( + messages=azure_messages, stream=True, **create_args + ) + ) if cancellation_token is not None: cancellation_token.link_future(task) @@ -509,6 +567,10 @@ async def create_stream( first_chunk = True thought = None + # Performance timing variables + start_time = time.perf_counter() + first_token_time: Optional[float] = None + async for chunk in await task: # type: ignore if first_chunk: first_chunk = False @@ -527,17 +589,29 @@ async def create_stream( if choice.finish_reason is CompletionsFinishReason.TOOL_CALLS: finish_reason = "function_calls" else: - if choice.finish_reason in ["stop", "length", "function_calls", "content_filter", "unknown"]: + if choice.finish_reason in [ + "stop", + "length", + "function_calls", + "content_filter", + "unknown", + ]: finish_reason = choice.finish_reason # type: ignore else: - raise ValueError(f"Unexpected finish reason: {choice.finish_reason}") + raise ValueError( + f"Unexpected finish reason: {choice.finish_reason}" + ) # We first try to load the content if choice and choice.delta.content is not None: + if first_token_time is None: + first_token_time = time.perf_counter() content_deltas.append(choice.delta.content) yield choice.delta.content # Otherwise, we try to load the tool calls if choice and choice.delta.tool_calls is not None: + if first_token_time is None: + first_token_time = time.perf_counter() for tool_call_chunk in choice.delta.tool_calls: # print(tool_call_chunk) if "index" in tool_call_chunk: @@ -545,7 +619,9 @@ async def create_stream( else: idx = tool_call_chunk.id if idx not in full_tool_calls: - full_tool_calls[idx] = FunctionCall(id="", arguments="", name="") + full_tool_calls[idx] = FunctionCall( + id="", arguments="", name="" + ) full_tool_calls[idx].id += tool_call_chunk.id full_tool_calls[idx].name += tool_call_chunk.function.name @@ -587,12 +663,30 @@ async def create_stream( thought=thought, ) + # Calculate performance metrics + end_time = time.perf_counter() + latency_seconds = end_time - start_time + latency_ms = latency_seconds * 1000 + ttft_ms = ( + (first_token_time - start_time) * 1000 + if first_token_time is not None + else None + ) + tokens_per_second = ( + usage.completion_tokens / latency_seconds + if latency_seconds > 0 and usage.completion_tokens > 0 + else None + ) + # Log the end of the stream. logger.info( LLMStreamEndEvent( response=result.model_dump(), prompt_tokens=usage.prompt_tokens, completion_tokens=usage.completion_tokens, + latency_ms=latency_ms, + tokens_per_second=tokens_per_second, + ttft_ms=ttft_ms, ) ) @@ -609,10 +703,14 @@ def actual_usage(self) -> RequestUsage: def total_usage(self) -> RequestUsage: return self._total_usage - def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: + def count_tokens( + self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = [] + ) -> int: return 0 - def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: + def remaining_tokens( + self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = [] + ) -> int: return 0 @property diff --git a/python/packages/autogen-ext/src/autogen_ext/models/nvidia/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/__init__.py new file mode 100644 index 000000000000..7b07868e3c73 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""NVIDIA NIM Speculative Reasoning Execution (SRE) Client for AutoGen 0.4. + +This package provides a ChatCompletionClient that bridges NVIDIA NIM +high-performance inference with Microsoft AutoGen orchestration, enabling +parallel tool execution during LLM reasoning to reduce "Time to Action" latency. +""" + +from ._nvidia_speculative_client import NvidiaSpeculativeClient +from ._reasoning_sniffer import ReasoningSniffer, ToolIntent +from ._speculative_cache import SpeculativeCache + +__all__ = [ + "NvidiaSpeculativeClient", + "ReasoningSniffer", + "ToolIntent", + "SpeculativeCache", +] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_nvidia_speculative_client.py b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_nvidia_speculative_client.py new file mode 100644 index 000000000000..a66f331991fd --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_nvidia_speculative_client.py @@ -0,0 +1,480 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""NVIDIA NIM Speculative Reasoning Execution Client for AutoGen 0.4. + +This module provides the main ChatCompletionClient implementation that bridges +NVIDIA NIM inference with AutoGen orchestration. It enables parallel tool +execution during LLM reasoning to dramatically reduce "Time to Action" latency. + +Key Features: +- Streams DeepSeek-R1 style ... reasoning blocks +- Uses ReasoningSniffer to detect tool intents during reasoning +- Fires speculative prewarm tasks via asyncio.create_task (non-blocking) +- Logs high-precision latency metrics for benchmarking + +The "Shark" Architecture: + Model thinks → Sniffer detects intent → Prewarm fires (background) + → Model keeps thinking + → Tool actually called → Cache HIT +""" + +from __future__ import annotations + +import asyncio +import logging +import time +import warnings +from dataclasses import dataclass, field +from typing import ( + Any, + AsyncGenerator, + Awaitable, + Callable, + Dict, + List, + Literal, + Mapping, + Optional, + Sequence, + Union, +) + +from autogen_core import CancellationToken +from autogen_core.models import ( + AssistantMessage, + ChatCompletionClient, + CreateResult, + LLMMessage, + ModelInfo, + RequestUsage, +) +from autogen_core.models._model_client import ModelFamily +from autogen_core.tools import Tool, ToolSchema + +from ._reasoning_sniffer import ReasoningSniffer, ToolIntent +from ._speculative_cache import SpeculativeCache + +# Event logger for LLM events +EVENT_LOGGER_NAME = "autogen_core.events" +logger = logging.getLogger(EVENT_LOGGER_NAME) +trace_logger = logging.getLogger(__name__) + + +@dataclass +class SpeculativePrewarmEvent: + """Event emitted when the ReasoningSniffer detects a tool intent. + + This event is logged to the event bus so that orchestrators can + trigger speculative tool execution. + """ + + tool_type: str + """The type/name of the tool detected.""" + + query_hint: str + """Extracted query or argument hint from the reasoning.""" + + confidence: float + """Confidence score (0.0 to 1.0).""" + + detected_at_ms: float + """Timestamp when the intent was detected.""" + + reasoning_context: str + """The reasoning text that triggered detection.""" + + def __str__(self) -> str: + return ( + f"SpeculativePrewarm: tool={self.tool_type}, " + f"hint='{self.query_hint[:50]}...', confidence={self.confidence:.2f}" + ) + + +@dataclass +class SpeculativeHitEvent: + """Event emitted when a cached speculative result is used.""" + + tool_name: str + """Name of the tool that had a cache hit.""" + + latency_saved_ms: float + """Estimated latency saved by using the cached result.""" + + def __str__(self) -> str: + return ( + f"SpeculativeHit: tool={self.tool_name}, " + f"latency_saved={self.latency_saved_ms:.1f}ms" + ) + + +@dataclass +class PerformanceMetrics: + """Performance tracking for the speculative execution pipeline.""" + + stream_start_time: float = 0.0 + first_token_time: Optional[float] = None + reasoning_start_time: Optional[float] = None + reasoning_end_time: Optional[float] = None + first_intent_detected_time: Optional[float] = None + tool_call_time: Optional[float] = None + stream_end_time: float = 0.0 + + intents_detected: int = 0 + prewarms_triggered: int = 0 + + @property + def ttft_ms(self) -> Optional[float]: + """Time to first token in milliseconds.""" + if self.first_token_time is None: + return None + return (self.first_token_time - self.stream_start_time) * 1000 + + @property + def reasoning_duration_ms(self) -> Optional[float]: + """Total reasoning duration in milliseconds.""" + if self.reasoning_start_time is None or self.reasoning_end_time is None: + return None + return (self.reasoning_end_time - self.reasoning_start_time) * 1000 + + @property + def speculative_delta_ms(self) -> Optional[float]: + """The 'speculative delta' - how early we detected intent vs tool call. + + Negative = we detected before the model called (good!) + Positive = we missed the window (speculation opportunity lost) + """ + if self.first_intent_detected_time is None or self.tool_call_time is None: + return None + return (self.tool_call_time - self.first_intent_detected_time) * 1000 + + def summary(self) -> Dict[str, Any]: + """Get a summary of all performance metrics.""" + return { + "ttft_ms": self.ttft_ms, + "reasoning_duration_ms": self.reasoning_duration_ms, + "speculative_delta_ms": self.speculative_delta_ms, + "intents_detected": self.intents_detected, + "prewarms_triggered": self.prewarms_triggered, + "total_duration_ms": (self.stream_end_time - self.stream_start_time) * 1000, + } + + +# Type alias for the prewarm callback +PrewarmCallback = Callable[[str, str, Dict[str, Any]], Awaitable[Any]] + + +class NvidiaSpeculativeClient(ChatCompletionClient): + """NVIDIA NIM-compatible ChatCompletionClient with Speculative Reasoning Execution. + + This client wraps an existing OpenAI-compatible client (which connects to + NVIDIA NIM, vLLM, or any OpenAI-compatible endpoint) and adds speculative + execution capabilities. + + The key innovation is that during the model's reasoning phase ( block), + the ReasoningSniffer monitors for tool-call intents. When detected, it triggers + a prewarm callback to speculatively execute the tool in the background. + + Args: + inner_client: An existing ChatCompletionClient (e.g., OpenAIChatCompletionClient + configured for NIM endpoint). + sniffer: Optional custom ReasoningSniffer. Uses default patterns if not provided. + prewarm_callback: Async function called when tool intent is detected. + Signature: async def callback(tool_type: str, query_hint: str, context: dict) -> Any + enable_speculation: Whether to enable speculative execution (default: True). + min_confidence: Minimum confidence threshold to trigger prewarm (default: 0.7). + + Example: + >>> from autogen_ext.models.openai import OpenAIChatCompletionClient + >>> from autogen_ext.models.nvidia import NvidiaSpeculativeClient + >>> + >>> # Create inner client pointing to NIM + >>> inner = OpenAIChatCompletionClient( + ... model="deepseek-r1", + ... base_url="http://wulver:8000/v1", + ... api_key="token" + ... ) + >>> + >>> # Wrap with speculative execution + >>> client = NvidiaSpeculativeClient( + ... inner_client=inner, + ... prewarm_callback=my_prewarm_function + ... ) + """ + + component_type = "model" + component_config_schema = None # TODO: Add proper config schema + + def __init__( + self, + inner_client: ChatCompletionClient, + *, + sniffer: Optional[ReasoningSniffer] = None, + prewarm_callback: Optional[PrewarmCallback] = None, + enable_speculation: bool = True, + min_confidence: float = 0.7, + sniff_all_content: bool = False, + ) -> None: + self._inner_client = inner_client + self._sniffer = sniffer or ReasoningSniffer() + self._prewarm_callback = prewarm_callback + self._enable_speculation = enable_speculation + self._min_confidence = min_confidence + self._sniff_all_content = sniff_all_content # For distilled models without tags + self._cache = SpeculativeCache.get_instance() + + # Track running prewarm tasks + self._prewarm_tasks: List[asyncio.Task[Any]] = [] + + # DEDUPLICATION: Track triggered intents to avoid firing 78 times + self._triggered_intents: set[str] = set() + + # Performance tracking + self._last_metrics: Optional[PerformanceMetrics] = None + + @property + def model_info(self) -> ModelInfo: + """Get model info from the inner client.""" + return self._inner_client.model_info + + @property + def capabilities(self) -> Any: + """Deprecated. Use model_info instead.""" + warnings.warn( + "capabilities is deprecated, use model_info instead", + DeprecationWarning, + stacklevel=2, + ) + return self._inner_client.capabilities + + def actual_usage(self) -> RequestUsage: + """Get actual token usage.""" + return self._inner_client.actual_usage() + + def total_usage(self) -> RequestUsage: + """Get total token usage.""" + return self._inner_client.total_usage() + + def count_tokens( + self, + messages: Sequence[LLMMessage], + *, + tools: Sequence[Tool | ToolSchema] = [], + ) -> int: + """Count tokens for the given messages.""" + return self._inner_client.count_tokens(messages, tools=tools) + + def remaining_tokens( + self, + messages: Sequence[LLMMessage], + *, + tools: Sequence[Tool | ToolSchema] = [], + ) -> int: + """Get remaining tokens for the model's context.""" + return self._inner_client.remaining_tokens(messages, tools=tools) + + async def close(self) -> None: + """Close the client and cancel any pending prewarm tasks.""" + # Cancel any running prewarm tasks + for task in self._prewarm_tasks: + if not task.done(): + task.cancel() + self._prewarm_tasks.clear() + + await self._inner_client.close() + + async def create( + self, + messages: Sequence[LLMMessage], + *, + tools: Sequence[Tool | ToolSchema] = [], + tool_choice: Tool | Literal["auto", "required", "none"] = "auto", + json_output: Optional[bool | type] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + ) -> CreateResult: + """Create a non-streaming response. + + Note: Speculative execution is most effective with streaming. + This method delegates directly to the inner client. + """ + return await self._inner_client.create( + messages, + tools=tools, + tool_choice=tool_choice, + json_output=json_output, + extra_create_args=extra_create_args, + cancellation_token=cancellation_token, + ) + + async def create_stream( + self, + messages: Sequence[LLMMessage], + *, + tools: Sequence[Tool | ToolSchema] = [], + tool_choice: Tool | Literal["auto", "required", "none"] = "auto", + json_output: Optional[bool | type] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + ) -> AsyncGenerator[Union[str, CreateResult], None]: + """Create a streaming response with speculative reasoning execution. + + This is the core method that implements the SRE pattern: + 1. Stream tokens from the inner client + 2. Monitor reasoning content with the sniffer + 3. Fire prewarm tasks when intents are detected (non-blocking) + 4. Continue streaming without interruption + + Yields: + String chunks during streaming, ending with a CreateResult. + """ + # Initialize performance metrics + metrics = PerformanceMetrics() + metrics.stream_start_time = time.perf_counter() + + # Reset sniffer and triggered intents for fresh session + self._sniffer.reset() + self._triggered_intents.clear() + + # Track reasoning state + is_in_reasoning = False + reasoning_buffer = "" + + # Get the inner stream + inner_stream = self._inner_client.create_stream( + messages, + tools=tools, + tool_choice=tool_choice, + json_output=json_output, + extra_create_args=extra_create_args, + cancellation_token=cancellation_token, + ) + + async for chunk in inner_stream: + # Track first token time + if metrics.first_token_time is None: + metrics.first_token_time = time.perf_counter() + + # If this is the final CreateResult, process it + if isinstance(chunk, CreateResult): + metrics.stream_end_time = time.perf_counter() + + # Log performance summary + self._last_metrics = metrics + trace_logger.info(f"SpeculativeClient metrics: {metrics.summary()}") + + # Check if there were tool calls - record timing + if chunk.content and isinstance(chunk.content, list): + metrics.tool_call_time = time.perf_counter() + + yield chunk + return + + # Process string chunks + chunk_str = str(chunk) + + # Detect reasoning block boundaries + if "" in chunk_str: + is_in_reasoning = True + metrics.reasoning_start_time = time.perf_counter() + + if "" in chunk_str: + is_in_reasoning = False + metrics.reasoning_end_time = time.perf_counter() + reasoning_buffer = "" + + # Run sniffer on content + # For distilled models without tags, sniff all content + # For full R1 models, only sniff inside blocks + should_sniff = self._enable_speculation and ( + self._sniff_all_content or is_in_reasoning + ) + + if should_sniff: + reasoning_buffer += chunk_str + + # Sniff for tool intents + intent = self._sniffer.sniff(chunk_str) + + if intent and intent.confidence >= self._min_confidence: + metrics.intents_detected += 1 + + if metrics.first_intent_detected_time is None: + metrics.first_intent_detected_time = time.perf_counter() + + # DEDUPLICATION: Only fire once per tool type per session + # Use tool_type as the dedup key (could also include query_hint for finer granularity) + dedup_key = intent.tool_type + + if dedup_key not in self._triggered_intents: + self._triggered_intents.add(dedup_key) + + # Log the prewarm event + event = SpeculativePrewarmEvent( + tool_type=intent.tool_type, + query_hint=intent.query_hint, + confidence=intent.confidence, + detected_at_ms=intent.detected_at_ms, + reasoning_context=intent.reasoning_context, + ) + logger.info(event) + + # SHARK MOVE: Fire and forget the prewarm task + if self._prewarm_callback is not None: + task = asyncio.create_task(self._trigger_prewarm(intent)) + self._prewarm_tasks.append(task) + metrics.prewarms_triggered += 1 + + yield chunk + + async def _trigger_prewarm(self, intent: ToolIntent) -> None: + """Trigger speculative tool prewarm in the background. + + This runs as a fire-and-forget task. If the speculation is wrong, + the result is simply not used. If it's right, the cache will have + the result ready. + """ + if self._prewarm_callback is None: + return + + try: + trace_logger.debug( + f"Triggering prewarm for {intent.tool_type}: {intent.query_hint}" + ) + + context = { + "confidence": intent.confidence, + "reasoning_context": intent.reasoning_context, + } + + result = await self._prewarm_callback( + intent.tool_type, + intent.query_hint, + context, + ) + + # Store result in cache + if result is not None: + self._cache.store( + tool_name=intent.tool_type, + args={"query_hint": intent.query_hint}, + result=result, + ttl=30.0, # 30 second TTL + ) + trace_logger.info( + f"Prewarm complete for {intent.tool_type}, result cached" + ) + + except Exception as e: + # Speculation failure is silent - don't crash the main stream + trace_logger.warning(f"Prewarm failed for {intent.tool_type}: {e}") + + @property + def last_metrics(self) -> Optional[PerformanceMetrics]: + """Get the performance metrics from the last create_stream call.""" + return self._last_metrics + + @property + def cache(self) -> SpeculativeCache: + """Access the speculative cache for inspection or manual operations.""" + return self._cache diff --git a/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_reasoning_sniffer.py b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_reasoning_sniffer.py new file mode 100644 index 000000000000..67768f1d1ad5 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_reasoning_sniffer.py @@ -0,0 +1,172 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""High-speed regex-based Intent Sniffer for Speculative Reasoning Execution. + +This module provides zero-latency heuristic detection of tool-call intents +within streaming reasoning content from DeepSeek-R1 and similar models. +The sniffer runs on every chunk without blocking the stream. +""" + +import re +import time +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Pattern + + +@dataclass +class ToolIntent: + """Represents a detected tool-call intent from reasoning content.""" + + tool_type: str + """The type/name of the tool detected (e.g., 'web_search', 'database_query').""" + + query_hint: str + """Extracted query or argument hint from the reasoning.""" + + confidence: float + """Confidence score (0.0 to 1.0) based on pattern match strength.""" + + detected_at_ms: float + """Timestamp when the intent was detected (perf_counter * 1000).""" + + reasoning_context: str + """The chunk of reasoning text that triggered the detection.""" + + +@dataclass +class ReasoningSniffer: + """Detects tool-call intents in streaming reasoning content. + + Uses high-speed regex patterns to identify when the model is planning + to use a specific tool. Designed to run on every streaming chunk + without introducing latency. + + Example: + >>> sniffer = ReasoningSniffer() + >>> intent = sniffer.sniff("I will search for Python documentation") + >>> if intent: + ... print(f"Detected {intent.tool_type}: {intent.query_hint}") + Detected web_search: Python documentation + """ + + # High-signal patterns for DeepSeek-R1's "Action Intent" + # These are tuned for common reasoning patterns in R1-style models + # Enhanced to catch contractions (I'll, I'd) and more natural phrasings + PATTERNS: Dict[str, str] = field( + default_factory=lambda: { + # Web search intents - extended with contractions and more verbs + "web_search": r"(?:I(?:'ll| will| need to| should| must| have to)|Let me|I'd like to|need to)?\s*(?:search|look up|find|query|google|browse|look for|check online|research|find out|look into)\s+(?:for\s+|about\s+|the\s+|on\s+)?(?:information\s+(?:about|on)\s+)?['\"]?(.+?)(?:['\"]|\\.|,|;|$)", + # Database query intents + "database_query": r"(?:I(?:'ll| will| need to| should)|Let me)\s*(?:query|check|access|look up in|fetch from|retrieve from)\s+(?:the\s+)?(?:database|db|table|records?)\s+(?:for\s+)?['\"]?(.+?)(?:['\"]|\\.|,|$)", + # Calculation intents + "calculate": r"(?:I(?:'ll| will| need to| should)|Let me)\s*(?:calculate|compute|evaluate|work out|figure out|determine)\s+(.+?)(?:\\.|,|$)", + # API call intents + "api_call": r"(?:I(?:'ll| will| need to| should)|Let me)\s*(?:call|invoke|use|hit|query|fetch from)\s+(?:the\s+)?(?:API|endpoint|service|REST)\s+(?:for\s+)?['\"]?(.+?)(?:['\"]|\\.|,|$)", + # File/document lookup intents + "file_lookup": r"(?:I(?:'ll| will| need to| should)|Let me)\s*(?:read|open|check|look at|load|parse)\s+(?:the\s+)?(?:file|document|config|data)\s+['\"]?(.+?)(?:['\"]|\\.|,|$)", + # Generic tool usage + "tool_use": r"(?:I(?:'ll| will| need to| should)|Let me)\s*(?:use|invoke|call|execute)\s+(?:the\s+)?[`'\"]?(\w+)[`'\"]?\s+(?:tool|function)", + } + ) + + # Compiled patterns for performance + _compiled_patterns: Dict[str, Pattern[str]] = field(default_factory=dict) + + # Buffer for accumulating reasoning context + _context_buffer: List[str] = field(default_factory=list) + _max_buffer_size: int = 500 # Characters to keep for context + + def __post_init__(self) -> None: + """Compile regex patterns for maximum performance.""" + self._compiled_patterns = { + tool_type: re.compile(pattern, re.IGNORECASE | re.DOTALL) + for tool_type, pattern in self.PATTERNS.items() + } + + def add_pattern(self, tool_type: str, pattern: str) -> None: + """Add a custom pattern for detecting a specific tool type. + + Args: + tool_type: The name/type of the tool to detect. + pattern: Regex pattern with a capture group for the query hint. + """ + self.PATTERNS[tool_type] = pattern + self._compiled_patterns[tool_type] = re.compile( + pattern, re.IGNORECASE | re.DOTALL + ) + + def sniff(self, text: str) -> Optional[ToolIntent]: + """Analyze a chunk of reasoning text for tool-call intents. + + This method is designed to be called on every streaming chunk. + It maintains an internal buffer to provide context for pattern matching. + + Args: + text: A chunk of reasoning text from the model's thought stream. + + Returns: + ToolIntent if an intent was detected, None otherwise. + """ + if not text or not text.strip(): + return None + + # Add to context buffer + self._context_buffer.append(text) + + # Keep buffer size manageable + full_context = "".join(self._context_buffer) + if len(full_context) > self._max_buffer_size: + # Trim from the beginning + excess = len(full_context) - self._max_buffer_size + self._context_buffer = [full_context[excess:]] + + # Scan the current chunk and recent context + search_text = full_context[-self._max_buffer_size :] + + for tool_type, compiled_pattern in self._compiled_patterns.items(): + match = compiled_pattern.search(search_text) + if match: + query_hint = match.group(1).strip() if match.lastindex else "" + # Calculate confidence based on match quality + confidence = self._calculate_confidence(match, search_text) + + return ToolIntent( + tool_type=tool_type, + query_hint=query_hint, + confidence=confidence, + detected_at_ms=time.perf_counter() * 1000, + reasoning_context=search_text[-200:], # Last 200 chars for context + ) + + return None + + def _calculate_confidence(self, match: re.Match[str], text: str) -> float: + """Calculate confidence score based on match characteristics. + + Args: + match: The regex match object. + text: The full text being searched. + + Returns: + Confidence score between 0.0 and 1.0. + """ + confidence = 0.7 # Base confidence for any match + + # Boost if match is recent (in the last 100 chars) + if match.end() > len(text) - 100: + confidence += 0.15 + + # Boost if query hint is substantial + if match.lastindex and len(match.group(1).strip()) > 5: + confidence += 0.1 + + # Boost if explicit tool mention + if "tool" in text.lower() or "function" in text.lower(): + confidence += 0.05 + + return min(confidence, 1.0) + + def reset(self) -> None: + """Reset the context buffer. Call this between separate reasoning sessions.""" + self._context_buffer.clear() diff --git a/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_speculative_cache.py b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_speculative_cache.py new file mode 100644 index 000000000000..0694f5c33c5b --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/nvidia/_speculative_cache.py @@ -0,0 +1,263 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Thread-safe Speculative Cache for pre-executed tool results. + +This module provides a singleton cache that stores results from speculatively +executed tools. When the model formally requests a tool, the cache is checked +first to provide near-instantaneous results. +""" + +import asyncio +import hashlib +import json +import logging +import time +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class CacheEntry: + """A single cached result with metadata.""" + + tool_name: str + args_hash: str + result: Any + created_at: float + ttl: float + hit_count: int = 0 + + def is_expired(self) -> bool: + """Check if this entry has expired.""" + return time.time() - self.created_at > self.ttl + + +class SpeculativeCache: + """Thread-safe singleton cache for speculative tool execution results. + + This cache stores pre-executed tool results that were triggered by the + ReasoningSniffer during the model's reasoning phase. When a tool is + formally called, the cache is checked first. + + The cache uses argument hashing to match pre-warmed results with actual + tool calls, allowing for partial matches when full arguments aren't known + during speculation. + + Example: + >>> cache = SpeculativeCache.get_instance() + >>> cache.store("web_search", {"query": "python docs"}, "search results") + >>> result = cache.get("web_search", {"query": "python docs"}) + >>> print(result) + search results + """ + + _instance: Optional["SpeculativeCache"] = None + _lock: asyncio.Lock = asyncio.Lock() + + def __new__(cls) -> "SpeculativeCache": + """Ensure singleton pattern.""" + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._cache: Dict[str, CacheEntry] = {} + cls._instance._stats = { + "hits": 0, + "misses": 0, + "stores": 0, + "evictions": 0, + } + return cls._instance + + @classmethod + def get_instance(cls) -> "SpeculativeCache": + """Get the singleton cache instance.""" + return cls() + + @classmethod + def reset_instance(cls) -> None: + """Reset the singleton instance. Mainly for testing.""" + cls._instance = None + + @staticmethod + def _hash_args(args: Dict[str, Any]) -> str: + """Create a consistent hash for tool arguments. + + Args: + args: Dictionary of tool arguments. + + Returns: + A hex digest hash of the arguments. + """ + # Sort keys for consistent ordering + serialized = json.dumps(args, sort_keys=True, default=str) + return hashlib.md5(serialized.encode()).hexdigest() + + def store( + self, + tool_name: str, + args: Dict[str, Any], + result: Any, + ttl: float = 30.0, + ) -> None: + """Store a speculatively executed tool result. + + Args: + tool_name: Name of the tool that was executed. + args: Arguments used for execution. + result: The result of the tool execution. + ttl: Time-to-live in seconds (default: 30s). + """ + args_hash = self._hash_args(args) + cache_key = f"{tool_name}:{args_hash}" + + entry = CacheEntry( + tool_name=tool_name, + args_hash=args_hash, + result=result, + created_at=time.time(), + ttl=ttl, + ) + + self._cache[cache_key] = entry + self._stats["stores"] += 1 + + logger.debug( + f"SpeculativeCache: Stored {tool_name} (hash={args_hash[:8]}..., ttl={ttl}s)" + ) + + def get(self, tool_name: str, args: Dict[str, Any]) -> Optional[Any]: + """Retrieve a cached result if available and not expired. + + Args: + tool_name: Name of the tool being called. + args: Arguments for the tool call. + + Returns: + The cached result if found and valid, None otherwise. + """ + args_hash = self._hash_args(args) + cache_key = f"{tool_name}:{args_hash}" + + entry = self._cache.get(cache_key) + + if entry is None: + self._stats["misses"] += 1 + return None + + if entry.is_expired(): + # Clean up expired entry + del self._cache[cache_key] + self._stats["evictions"] += 1 + self._stats["misses"] += 1 + logger.debug(f"SpeculativeCache: Expired entry for {tool_name}") + return None + + # Cache hit! + entry.hit_count += 1 + self._stats["hits"] += 1 + logger.info( + f"SpeculativeCache: HIT for {tool_name} (latency saved, hit #{entry.hit_count})" + ) + return entry.result + + def get_fuzzy(self, tool_name: str, query_hint: str) -> Optional[Any]: + """Attempt fuzzy matching for speculative results. + + This is used when the exact arguments aren't known but we have a + query hint from the ReasoningSniffer. + + Args: + tool_name: Name of the tool. + query_hint: Partial query string detected during reasoning. + + Returns: + The cached result if a fuzzy match is found, None otherwise. + """ + query_hint_lower = query_hint.lower() + + for cache_key, entry in self._cache.items(): + if not cache_key.startswith(f"{tool_name}:"): + continue + + if entry.is_expired(): + continue + + # Check if the query hint might match this entry + # This is a heuristic - the args might contain the hint + if hasattr(entry, "result") and isinstance(entry.result, str): + if query_hint_lower in entry.result.lower(): + entry.hit_count += 1 + self._stats["hits"] += 1 + logger.info( + f"SpeculativeCache: FUZZY HIT for {tool_name} " + f"(hint='{query_hint[:30]}...')" + ) + return entry.result + + self._stats["misses"] += 1 + return None + + def invalidate(self, tool_name: str) -> int: + """Invalidate all cached entries for a specific tool. + + Args: + tool_name: Name of the tool to invalidate. + + Returns: + Number of entries invalidated. + """ + keys_to_remove = [key for key in self._cache if key.startswith(f"{tool_name}:")] + + for key in keys_to_remove: + del self._cache[key] + + self._stats["evictions"] += len(keys_to_remove) + return len(keys_to_remove) + + def clear(self) -> None: + """Clear all cached entries.""" + count = len(self._cache) + self._cache.clear() + self._stats["evictions"] += count + logger.info(f"SpeculativeCache: Cleared {count} entries") + + def cleanup_expired(self) -> int: + """Remove all expired entries from the cache. + + Returns: + Number of entries removed. + """ + expired_keys = [key for key, entry in self._cache.items() if entry.is_expired()] + + for key in expired_keys: + del self._cache[key] + + self._stats["evictions"] += len(expired_keys) + return len(expired_keys) + + @property + def stats(self) -> Dict[str, int]: + """Get cache statistics. + + Returns: + Dictionary with hits, misses, stores, and evictions counts. + """ + return dict(self._stats) + + @property + def hit_rate(self) -> float: + """Calculate the cache hit rate. + + Returns: + Hit rate as a float between 0.0 and 1.0. + """ + total = self._stats["hits"] + self._stats["misses"] + if total == 0: + return 0.0 + return self._stats["hits"] / total + + def __len__(self) -> int: + """Return the number of cached entries.""" + return len(self._cache) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index a80e912534ab..e70e2c432425 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -5,6 +5,7 @@ import math import os import re +import time import warnings from asyncio import Task from dataclasses import dataclass @@ -93,9 +94,9 @@ openai_init_kwargs = set(inspect.getfullargspec(AsyncOpenAI.__init__).kwonlyargs) aopenai_init_kwargs = set(inspect.getfullargspec(AsyncAzureOpenAI.__init__).kwonlyargs) -create_kwargs = set(completion_create_params.CompletionCreateParamsBase.__annotations__.keys()) | set( - ("timeout", "stream", "extra_body") -) +create_kwargs = set( + completion_create_params.CompletionCreateParamsBase.__annotations__.keys() +) | set(("timeout", "stream", "extra_body")) # Only single choice allowed disallowed_create_args = set(["stream", "messages", "function_call", "functions", "n"]) required_create_args: Set[str] = set(["model"]) @@ -138,9 +139,13 @@ def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]: create_args = {k: v for k, v in config.items() if k in create_kwargs} create_args_keys = set(create_args.keys()) if not required_create_args.issubset(create_args_keys): - raise ValueError(f"Required create args are missing: {required_create_args - create_args_keys}") + raise ValueError( + f"Required create args are missing: {required_create_args - create_args_keys}" + ) if disallowed_create_args.intersection(create_args_keys): - raise ValueError(f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}") + raise ValueError( + f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}" + ) return create_args @@ -175,12 +180,14 @@ def to_oai_type( } transformers = get_transformer("openai", model, model_family) - def raise_value_error(message: LLMMessage, context: Dict[str, Any]) -> Sequence[ChatCompletionMessageParam]: + def raise_value_error( + message: LLMMessage, context: Dict[str, Any] + ) -> Sequence[ChatCompletionMessageParam]: raise ValueError(f"Unknown message type: {type(message)}") - transformer: Callable[[LLMMessage, Dict[str, Any]], Sequence[ChatCompletionMessageParam]] = transformers.get( - type(message), raise_value_error - ) + transformer: Callable[ + [LLMMessage, Dict[str, Any]], Sequence[ChatCompletionMessageParam] + ] = transformers.get(type(message), raise_value_error) result = transformer(message, context) return result @@ -257,11 +264,19 @@ def convert_tools( type="function", function=FunctionDefinition( name=tool_schema["name"], - description=(tool_schema["description"] if "description" in tool_schema else ""), + description=( + tool_schema["description"] + if "description" in tool_schema + else "" + ), parameters=( - cast(FunctionParameters, tool_schema["parameters"]) if "parameters" in tool_schema else {} + cast(FunctionParameters, tool_schema["parameters"]) + if "parameters" in tool_schema + else {} + ), + strict=( + tool_schema["strict"] if "strict" in tool_schema else False ), - strict=(tool_schema["strict"] if "strict" in tool_schema else False), ), ) ) @@ -293,7 +308,9 @@ def convert_tool_choice(tool_choice: Tool | Literal["auto", "required", "none"]) if isinstance(tool_choice, Tool): return {"type": "function", "function": {"name": tool_choice.schema["name"]}} else: - raise ValueError(f"tool_choice must be a Tool object, 'auto', 'required', or 'none', got {type(tool_choice)}") + raise ValueError( + f"tool_choice must be a Tool object, 'auto', 'required', or 'none', got {type(tool_choice)}" + ) def normalize_name(name: str) -> str: @@ -339,14 +356,18 @@ def count_tokens_openai( continue if isinstance(message, UserMessage) and isinstance(value, list): - typed_message_value = cast(List[ChatCompletionContentPartParam], value) + typed_message_value = cast( + List[ChatCompletionContentPartParam], value + ) assert len(typed_message_value) == len( message.content ), "Mismatch in message content and typed message value" # We need image properties that are only in the original message - for part, content_part in zip(typed_message_value, message.content, strict=False): + for part, content_part in zip( + typed_message_value, message.content, strict=False + ): if isinstance(content_part, Image): # TODO: add detail parameter num_tokens += calculate_vision_tokens(content_part) @@ -357,13 +378,17 @@ def count_tokens_openai( serialized_part = json.dumps(part) num_tokens += len(encoding.encode(serialized_part)) except TypeError: - trace_logger.warning(f"Could not convert {part} to string, skipping.") + trace_logger.warning( + f"Could not convert {part} to string, skipping." + ) else: if not isinstance(value, str): try: value = json.dumps(value) except TypeError: - trace_logger.warning(f"Could not convert {value} to string, skipping.") + trace_logger.warning( + f"Could not convert {value} to string, skipping." + ) continue num_tokens += len(encoding.encode(value)) if key == "name": @@ -389,26 +414,38 @@ def count_tokens_openai( for field in v: # pyright: ignore if field == "type": tool_tokens += 2 - tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore + tool_tokens += len( + encoding.encode(v["type"]) + ) # pyright: ignore elif field == "description": tool_tokens += 2 - tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore + tool_tokens += len( + encoding.encode(v["description"]) + ) # pyright: ignore elif field == "anyOf": tool_tokens -= 3 for o in v["anyOf"]: # type: ignore tool_tokens += 3 - tool_tokens += len(encoding.encode(str(o["type"]))) # pyright: ignore + tool_tokens += len( + encoding.encode(str(o["type"])) + ) # pyright: ignore elif field == "default": tool_tokens += 2 - tool_tokens += len(encoding.encode(json.dumps(v["default"]))) + tool_tokens += len( + encoding.encode(json.dumps(v["default"])) + ) elif field == "title": tool_tokens += 2 - tool_tokens += len(encoding.encode(str(v["title"]))) # pyright: ignore + tool_tokens += len( + encoding.encode(str(v["title"])) + ) # pyright: ignore elif field == "enum": tool_tokens -= 3 for o in v["enum"]: # pyright: ignore tool_tokens += 3 - tool_tokens += len(encoding.encode(o)) # pyright: ignore + tool_tokens += len( + encoding.encode(o) + ) # pyright: ignore else: trace_logger.warning(f"Not supported field {field}") tool_tokens += 11 @@ -447,7 +484,9 @@ def __init__( try: self._model_info = _model_info.get_info(create_args["model"]) except KeyError as err: - raise ValueError("model_info is required when model name is not a valid OpenAI model") from err + raise ValueError( + "model_info is required when model name is not a valid OpenAI model" + ) from err elif model_capabilities is not None and model_info is not None: raise ValueError("model_capabilities and model_info are mutually exclusive") elif model_capabilities is not None and model_info is None: @@ -487,7 +526,9 @@ def __init__( def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient: return OpenAIChatCompletionClient(**config) - def _rstrip_last_assistant_message(self, messages: Sequence[LLMMessage]) -> Sequence[LLMMessage]: + def _rstrip_last_assistant_message( + self, messages: Sequence[LLMMessage] + ) -> Sequence[LLMMessage]: """ Remove the last assistant message if it is empty. """ @@ -509,7 +550,9 @@ def _process_create_args( # Make sure all extra_create_args are valid extra_create_args_keys = set(extra_create_args.keys()) if not create_kwargs.issuperset(extra_create_args_keys): - raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}") + raise ValueError( + f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}" + ) # Copy the create args and overwrite anything in extra_create_args create_args = self._create_args.copy() @@ -541,7 +584,9 @@ def _process_create_args( raise ValueError("Model does not support JSON output.") if json_output is True: # JSON mode. - create_args["response_format"] = ResponseFormatJSONObject(type="json_object") + create_args["response_format"] = ResponseFormatJSONObject( + type="json_object" + ) elif json_output is False: # Text mode. create_args["response_format"] = ResponseFormatText(type="text") @@ -555,7 +600,9 @@ def _process_create_args( # Beta client mode with Pydantic model class. response_format_value = json_output else: - raise ValueError(f"json_output must be a boolean or a Pydantic model class, got {type(json_output)}") + raise ValueError( + f"json_output must be a boolean or a Pydantic model class, got {type(json_output)}" + ) if response_format_value is not None and "response_format" in create_args: warnings.warn( @@ -573,8 +620,12 @@ def _process_create_args( if self.model_info["vision"] is False: for message in messages: if isinstance(message, UserMessage): - if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content): - raise ValueError("Model does not support vision and image was provided") + if isinstance(message.content, list) and any( + isinstance(x, Image) for x in message.content + ): + raise ValueError( + "Model does not support vision and image was provided" + ) if self.model_info["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output.") @@ -646,7 +697,9 @@ def _process_create_args( # tool_choice is a single Tool object tool_name = tool_choice.schema["name"] if tool_name not in tool_names_available: - raise ValueError(f"tool_choice references '{tool_name}' but it's not in the provided tools") + raise ValueError( + f"tool_choice references '{tool_name}' but it's not in the provided tools" + ) if len(converted_tools) > 0: # Convert to OpenAI format and add to create_args @@ -678,12 +731,17 @@ async def create( extra_create_args, ) future: Union[Task[ParsedChatCompletion[BaseModel]], Task[ChatCompletion]] + start_time = time.perf_counter() if create_params.response_format is not None: # Use beta client if response_format is not None future = asyncio.ensure_future( self._client.beta.chat.completions.parse( messages=create_params.messages, - tools=(create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN), + tools=( + create_params.tools + if len(create_params.tools) > 0 + else NOT_GIVEN + ), response_format=create_params.response_format, **create_params.create_args, ) @@ -694,7 +752,11 @@ async def create( self._client.chat.completions.create( messages=create_params.messages, stream=False, - tools=(create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN), + tools=( + create_params.tools + if len(create_params.tools) > 0 + else NOT_GIVEN + ), **create_params.create_args, ) ) @@ -702,6 +764,7 @@ async def create( if cancellation_token is not None: cancellation_token.link_future(future) result: Union[ParsedChatCompletion[BaseModel], ChatCompletion] = await future + end_time = time.perf_counter() if create_params.response_format is not None: result = cast(ParsedChatCompletion[Any], result) @@ -709,8 +772,25 @@ async def create( # even when result.usage is not None usage = RequestUsage( # TODO backup token counting - prompt_tokens=getattr(result.usage, "prompt_tokens", 0) if result.usage is not None else 0, - completion_tokens=getattr(result.usage, "completion_tokens", 0) if result.usage is not None else 0, + prompt_tokens=( + getattr(result.usage, "prompt_tokens", 0) + if result.usage is not None + else 0 + ), + completion_tokens=( + getattr(result.usage, "completion_tokens", 0) + if result.usage is not None + else 0 + ), + ) + + # Calculate performance metrics + latency_seconds = end_time - start_time + latency_ms = latency_seconds * 1000 + tokens_per_second = ( + usage.completion_tokens / latency_seconds + if latency_seconds > 0 and usage.completion_tokens > 0 + else None ) logger.info( @@ -719,6 +799,8 @@ async def create( response=result.model_dump(), prompt_tokens=usage.prompt_tokens, completion_tokens=usage.completion_tokens, + latency_ms=latency_ms, + tokens_per_second=tokens_per_second, tools=create_params.tools, ) ) @@ -733,15 +815,21 @@ async def create( ) # Limited to a single choice currently. - choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], Choice] = result.choices[0] + choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], Choice] = ( + result.choices[0] + ) # Detect whether it is a function call or not. # We don't rely on choice.finish_reason as it is not always accurate, depending on the API used. content: Union[str, List[FunctionCall]] thought: str | None = None if choice.message.function_call is not None: - raise ValueError("function_call is deprecated and is not supported by this model client.") - elif choice.message.tool_calls is not None and len(choice.message.tool_calls) > 0: + raise ValueError( + "function_call is deprecated and is not supported by this model client." + ) + elif ( + choice.message.tool_calls is not None and len(choice.message.tool_calls) > 0 + ): if choice.finish_reason != "tool_calls": warnings.warn( f"Finish reason mismatch: {choice.finish_reason} != tool_calls " @@ -763,7 +851,9 @@ async def create( stacklevel=2, ) if isinstance(tool_call.function.arguments, dict): - tool_call.function.arguments = json.dumps(tool_call.function.arguments) + tool_call.function.arguments = json.dumps( + tool_call.function.arguments + ) content.append( FunctionCall( id=tool_call.id, @@ -788,14 +878,21 @@ async def create( ChatCompletionTokenLogprob( token=x.token, logprob=x.logprob, - top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs], + top_logprobs=[ + TopLogprob(logprob=y.logprob, bytes=y.bytes) + for y in x.top_logprobs + ], bytes=x.bytes, ) for x in choice.logprobs.content ] # This is for local R1 models. - if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1 and thought is None: + if ( + isinstance(content, str) + and self._model_info["family"] == ModelFamily.R1 + and thought is None + ): thought, content = parse_r1_content(content) response = CreateResult( @@ -858,7 +955,10 @@ async def create_stream( if include_usage is not None: if "stream_options" in create_params.create_args: stream_options = create_params.create_args["stream_options"] - if "include_usage" in stream_options and stream_options["include_usage"] != include_usage: + if ( + "include_usage" in stream_options + and stream_options["include_usage"] != include_usage + ): raise ValueError( "include_usage and extra_create_args['stream_options']['include_usage'] are both set, but differ in value." ) @@ -904,6 +1004,10 @@ async def create_stream( first_chunk = True is_reasoning = False + # Performance timing variables + start_time = time.perf_counter() + first_token_time: Optional[float] = None + # Process the stream of chunks. async for chunk in chunks: if first_chunk: @@ -922,7 +1026,10 @@ async def create_stream( # https://github.com/microsoft/autogen/issues/4213 if len(chunk.choices) == 0: empty_chunk_count += 1 - if not empty_chunk_warning_has_been_issued and empty_chunk_count >= empty_chunk_warning_threshold: + if ( + not empty_chunk_warning_has_been_issued + and empty_chunk_count >= empty_chunk_warning_threshold + ): empty_chunk_warning_has_been_issued = True warnings.warn( f"Received more than {empty_chunk_warning_threshold} consecutive empty chunks. Empty chunks are being ignored.", @@ -945,11 +1052,18 @@ async def create_stream( # for liteLLM chunk usage, do the following hack keeping the pervious chunk.stop_reason (if set). # set the stop_reason for the usage chunk to the prior stop_reason - stop_reason = choice.finish_reason if chunk.usage is None and stop_reason is None else stop_reason + stop_reason = ( + choice.finish_reason + if chunk.usage is None and stop_reason is None + else stop_reason + ) maybe_model = chunk.model reasoning_content: str | None = None - if choice.delta.model_extra is not None and "reasoning_content" in choice.delta.model_extra: + if ( + choice.delta.model_extra is not None + and "reasoning_content" in choice.delta.model_extra + ): # If there is a reasoning_content field, then we populate the thought field. This is for models such as R1. reasoning_content = choice.delta.model_extra.get("reasoning_content") @@ -969,6 +1083,8 @@ async def create_stream( # First try get content if choice.delta.content: + if first_token_time is None: + first_token_time = time.perf_counter() content_deltas.append(choice.delta.content) if len(choice.delta.content) > 0: yield choice.delta.content @@ -977,11 +1093,15 @@ async def create_stream( continue # Otherwise, get tool calls if choice.delta.tool_calls is not None: + if first_token_time is None: + first_token_time = time.perf_counter() for tool_call_chunk in choice.delta.tool_calls: idx = tool_call_chunk.index if idx not in full_tool_calls: # We ignore the type hint here because we want to fill in type when the delta provides it - full_tool_calls[idx] = FunctionCall(id="", arguments="", name="") + full_tool_calls[idx] = FunctionCall( + id="", arguments="", name="" + ) if tool_call_chunk.id is not None: full_tool_calls[idx].id += tool_call_chunk.id @@ -990,13 +1110,18 @@ async def create_stream( if tool_call_chunk.function.name is not None: full_tool_calls[idx].name += tool_call_chunk.function.name if tool_call_chunk.function.arguments is not None: - full_tool_calls[idx].arguments += tool_call_chunk.function.arguments + full_tool_calls[ + idx + ].arguments += tool_call_chunk.function.arguments if choice.logprobs and choice.logprobs.content: logprobs = [ ChatCompletionTokenLogprob( token=x.token, logprob=x.logprob, - top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs], + top_logprobs=[ + TopLogprob(logprob=y.logprob, bytes=y.bytes) + for y in x.top_logprobs + ], bytes=x.bytes, ) for x in choice.logprobs.content @@ -1050,7 +1175,11 @@ async def create_stream( thought = "".join(thought_deltas).lstrip("").rstrip("") # This is for local R1 models whose reasoning content is within the content string. - if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1 and thought is None: + if ( + isinstance(content, str) + and self._model_info["family"] == ModelFamily.R1 + and thought is None + ): thought, content = parse_r1_content(content) # Create the result. @@ -1063,12 +1192,30 @@ async def create_stream( thought=thought, ) + # Calculate performance metrics + end_time = time.perf_counter() + latency_seconds = end_time - start_time + latency_ms = latency_seconds * 1000 + ttft_ms = ( + (first_token_time - start_time) * 1000 + if first_token_time is not None + else None + ) + tokens_per_second = ( + usage.completion_tokens / latency_seconds + if latency_seconds > 0 and usage.completion_tokens > 0 + else None + ) + # Log the end of the stream. logger.info( LLMStreamEndEvent( response=result.model_dump(), prompt_tokens=usage.prompt_tokens, completion_tokens=usage.completion_tokens, + latency_ms=latency_ms, + tokens_per_second=tokens_per_second, + ttft_ms=ttft_ms, ) ) @@ -1118,7 +1265,9 @@ async def _create_stream_chunks_beta_client( async with self._client.beta.chat.completions.stream( messages=oai_messages, tools=tool_params if len(tool_params) > 0 else NOT_GIVEN, - response_format=(response_format if response_format is not None else NOT_GIVEN), + response_format=( + response_format if response_format is not None else NOT_GIVEN + ), **create_args_no_response_format, ) as stream: while True: @@ -1148,7 +1297,9 @@ def actual_usage(self) -> RequestUsage: def total_usage(self) -> RequestUsage: return self._total_usage - def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: + def count_tokens( + self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = [] + ) -> int: return count_tokens_openai( messages, self._create_args["model"], @@ -1158,7 +1309,9 @@ def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | include_name_in_message=self._include_name_in_message, ) - def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: + def remaining_tokens( + self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = [] + ) -> int: token_limit = _model_info.get_token_limit(self._create_args["model"]) return token_limit - self.count_tokens(messages, tools=tools) @@ -1176,7 +1329,9 @@ def model_info(self) -> ModelInfo: return self._model_info -class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenAIClientConfigurationConfigModel]): +class OpenAIChatCompletionClient( + BaseOpenAIChatCompletionClient, Component[OpenAIClientConfigurationConfigModel] +): """Chat completion client for OpenAI hosted models. To use this client, you must install the `openai` extra: @@ -1675,7 +1830,9 @@ class AzureOpenAIChatCompletionClient( component_type = "model" component_config_schema = AzureOpenAIClientConfigurationConfigModel - component_provider_override = "autogen_ext.models.openai.AzureOpenAIChatCompletionClient" + component_provider_override = ( + "autogen_ext.models.openai.AzureOpenAIChatCompletionClient" + ) def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]): model_capabilities: Optional[ModelCapabilities] = None # type: ignore @@ -1723,11 +1880,17 @@ def _to_config(self) -> AzureOpenAIClientConfigurationConfigModel: copied_config = self._raw_config.copy() if "azure_ad_token_provider" in copied_config: - if not isinstance(copied_config["azure_ad_token_provider"], AzureTokenProvider): - raise ValueError("azure_ad_token_provider must be a AzureTokenProvider to be component serialized") + if not isinstance( + copied_config["azure_ad_token_provider"], AzureTokenProvider + ): + raise ValueError( + "azure_ad_token_provider must be a AzureTokenProvider to be component serialized" + ) copied_config["azure_ad_token_provider"] = ( - copied_config["azure_ad_token_provider"].dump_component().model_dump(exclude_none=True) + copied_config["azure_ad_token_provider"] + .dump_component() + .model_dump(exclude_none=True) ) return AzureOpenAIClientConfigurationConfigModel(**copied_config) @@ -1743,8 +1906,10 @@ def _from_config(cls, config: AzureOpenAIClientConfigurationConfigModel) -> Self copied_config["api_key"] = config.api_key.get_secret_value() if "azure_ad_token_provider" in copied_config: - copied_config["azure_ad_token_provider"] = AzureTokenProvider.load_component( - copied_config["azure_ad_token_provider"] + copied_config["azure_ad_token_provider"] = ( + AzureTokenProvider.load_component( + copied_config["azure_ad_token_provider"] + ) ) return cls(**copied_config) diff --git a/python/packages/autogen-ext/tests/models/test_nvidia_speculative_client.py b/python/packages/autogen-ext/tests/models/test_nvidia_speculative_client.py new file mode 100644 index 000000000000..81b21074c2b2 --- /dev/null +++ b/python/packages/autogen-ext/tests/models/test_nvidia_speculative_client.py @@ -0,0 +1,501 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Unit tests for NVIDIA Speculative Reasoning Execution components. + +Tests the ReasoningSniffer, SpeculativeCache, and NvidiaSpeculativeClient +for detecting tool-call intents in LLM reasoning streams and managing +speculative pre-warming of tools. +""" + +import asyncio +import time +from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from unittest.mock import AsyncMock, MagicMock + +import pytest +from autogen_core import CancellationToken +from autogen_core.models import ( + CreateResult, + RequestUsage, + UserMessage, +) + +from autogen_ext.models.nvidia import ( + NvidiaSpeculativeClient, + ReasoningSniffer, + SpeculativeCache, + ToolIntent, +) + + +class TestReasoningSniffer: + """Tests for the ReasoningSniffer component.""" + + def test_sniffer_initialization(self) -> None: + """Test that sniffer initializes with default patterns.""" + sniffer = ReasoningSniffer() + assert sniffer is not None + assert len(sniffer.PATTERNS) > 0 + assert "web_search" in sniffer.PATTERNS + + def test_sniffer_detects_web_search_intent(self) -> None: + """Test detection of web search intent patterns.""" + sniffer = ReasoningSniffer() + + intent = sniffer.sniff("I will search for Python documentation") + + assert intent is not None + assert intent.tool_type == "web_search" + assert "Python documentation" in intent.query_hint + assert intent.confidence > 0.5 + + def test_sniffer_detects_database_query_intent(self) -> None: + """Test detection of database query intent patterns.""" + sniffer = ReasoningSniffer() + + intent = sniffer.sniff("I need to check the database for user records") + + assert intent is not None + assert intent.tool_type == "database_query" + + def test_sniffer_detects_calculate_intent(self) -> None: + """Test detection of calculation intent patterns.""" + sniffer = ReasoningSniffer() + + intent = sniffer.sniff("I will calculate the total revenue") + + assert intent is not None + assert intent.tool_type == "calculate" + + def test_sniffer_returns_none_for_non_matching_text(self) -> None: + """Test that sniffer returns None when no intent is detected.""" + sniffer = ReasoningSniffer() + + intent = sniffer.sniff("The weather is nice today") + + assert intent is None + + def test_sniffer_accumulates_context(self) -> None: + """Test that sniffer accumulates context across chunks.""" + sniffer = ReasoningSniffer() + + # Send partial text across chunks + sniffer.sniff("I will ") + intent = sniffer.sniff("search for documentation") + + assert intent is not None + assert intent.tool_type == "web_search" + + def test_sniffer_reset_clears_context(self) -> None: + """Test that reset clears the accumulated context.""" + sniffer = ReasoningSniffer() + + sniffer.sniff("I will search for ") + sniffer.reset() + intent = sniffer.sniff("documentation") + + # Should not match after reset since "I will search for" was cleared + assert intent is None + + def test_sniffer_custom_pattern(self) -> None: + """Test adding and using custom patterns.""" + sniffer = ReasoningSniffer() + + sniffer.add_pattern( + "custom_tool", + r"(?:I will|Let me)\s+run\s+custom\s+analysis\s+on\s+(.+)" + ) + + intent = sniffer.sniff("I will run custom analysis on the dataset") + + assert intent is not None + assert intent.tool_type == "custom_tool" + + +class TestSpeculativeCache: + """Tests for the SpeculativeCache component.""" + + def setup_method(self) -> None: + """Reset cache singleton before each test.""" + SpeculativeCache.reset_instance() + + def test_cache_singleton_pattern(self) -> None: + """Test that cache follows singleton pattern.""" + cache1 = SpeculativeCache.get_instance() + cache2 = SpeculativeCache.get_instance() + + assert cache1 is cache2 + + def test_cache_reset_instance(self) -> None: + """Test that reset_instance creates a new singleton.""" + cache1 = SpeculativeCache.get_instance() + cache1.store("web_search", {"query": "test"}, "value") + + SpeculativeCache.reset_instance() + cache2 = SpeculativeCache.get_instance() + + assert cache1 is not cache2 + assert cache2.get("web_search", {"query": "test"}) is None + + def test_cache_store_and_retrieve(self) -> None: + """Test basic store and retrieve functionality.""" + cache = SpeculativeCache.get_instance() + + cache.store("web_search", {"query": "test"}, "test_value") + result = cache.get("web_search", {"query": "test"}) + + assert result == "test_value" + + def test_cache_returns_none_for_missing_key(self) -> None: + """Test that get returns None for missing keys.""" + cache = SpeculativeCache.get_instance() + + result = cache.get("nonexistent_tool", {"query": "missing"}) + + assert result is None + + def test_cache_tracks_hits_and_misses(self) -> None: + """Test that cache tracks hit and miss statistics.""" + cache = SpeculativeCache.get_instance() + + cache.store("web_search", {"query": "test"}, "value") + cache.get("web_search", {"query": "test"}) # Hit + cache.get("web_search", {"query": "other"}) # Miss + + assert cache.stats["hits"] == 1 + assert cache.stats["misses"] == 1 + assert cache.stats["stores"] == 1 + + def test_cache_argument_hashing(self) -> None: + """Test that cache generates consistent hashes from arguments.""" + cache = SpeculativeCache.get_instance() + + # Store two entries with same args - should overwrite + cache.store("web_search", {"query": "test"}, "result1") + cache.store("web_search", {"query": "test"}, "result2") # Overwrites + cache.store("web_search", {"query": "different"}, "result3") + + # Should get latest value for same args + assert cache.get("web_search", {"query": "test"}) == "result2" + assert cache.get("web_search", {"query": "different"}) == "result3" + + def test_cache_len(self) -> None: + """Test that len returns number of cached entries.""" + cache = SpeculativeCache.get_instance() + + assert len(cache) == 0 + + cache.store("tool1", {"arg": "a"}, "value1") + cache.store("tool2", {"arg": "b"}, "value2") + + assert len(cache) == 2 + + +class TestNvidiaSpeculativeClient: + """Tests for the NvidiaSpeculativeClient wrapper.""" + + @pytest.fixture + def mock_inner_client(self) -> MagicMock: + """Create a mock inner client for testing.""" + client = MagicMock() + client.model_info = { + "vision": False, + "function_calling": True, + "json_output": True, + "family": "r1", + } + client.capabilities = MagicMock() + client.capabilities.vision = False + client.capabilities.function_calling = True + client.capabilities.json_output = True + + async def mock_close() -> None: + pass + + client.close = AsyncMock(side_effect=mock_close) + return client + + def test_client_initialization(self, mock_inner_client: MagicMock) -> None: + """Test that client initializes correctly.""" + SpeculativeCache.reset_instance() + + client = NvidiaSpeculativeClient(inner_client=mock_inner_client) + + assert client is not None + assert client._inner_client is mock_inner_client + assert client._enable_speculation is True + + def test_client_initialization_with_custom_sniffer( + self, mock_inner_client: MagicMock + ) -> None: + """Test initialization with custom sniffer.""" + SpeculativeCache.reset_instance() + custom_sniffer = ReasoningSniffer() + + client = NvidiaSpeculativeClient( + inner_client=mock_inner_client, + sniffer=custom_sniffer, + ) + + assert client._sniffer is custom_sniffer + + def test_client_initialization_with_speculation_disabled( + self, mock_inner_client: MagicMock + ) -> None: + """Test that speculation can be disabled.""" + SpeculativeCache.reset_instance() + + client = NvidiaSpeculativeClient( + inner_client=mock_inner_client, + enable_speculation=False, + ) + + assert client._enable_speculation is False + + def test_client_exposes_cache(self, mock_inner_client: MagicMock) -> None: + """Test that cache is accessible via property.""" + SpeculativeCache.reset_instance() + + client = NvidiaSpeculativeClient(inner_client=mock_inner_client) + + assert client.cache is not None + assert isinstance(client.cache, SpeculativeCache) + + @pytest.mark.asyncio + async def test_client_create_delegates_to_inner( + self, mock_inner_client: MagicMock + ) -> None: + """Test that create() delegates to inner client.""" + SpeculativeCache.reset_instance() + + expected_result = CreateResult( + content="Test response", + usage=RequestUsage(prompt_tokens=10, completion_tokens=5), + finish_reason="stop", + cached=False, + ) + mock_inner_client.create = AsyncMock(return_value=expected_result) + + client = NvidiaSpeculativeClient(inner_client=mock_inner_client) + + result = await client.create( + messages=[UserMessage(content="Hello", source="user")] + ) + + assert result == expected_result + mock_inner_client.create.assert_called_once() + + @pytest.mark.asyncio + async def test_client_close_closes_inner_client( + self, mock_inner_client: MagicMock + ) -> None: + """Test that close() closes the inner client.""" + SpeculativeCache.reset_instance() + + client = NvidiaSpeculativeClient(inner_client=mock_inner_client) + + await client.close() + + mock_inner_client.close.assert_called_once() + + +class TestNvidiaSpeculativeClientStreaming: + """Tests for streaming functionality of NvidiaSpeculativeClient.""" + + @pytest.fixture + def mock_inner_client(self) -> MagicMock: + """Create a mock inner client with streaming support.""" + client = MagicMock() + client.model_info = { + "vision": False, + "function_calling": True, + "json_output": True, + "family": "r1", + } + + async def mock_close() -> None: + pass + + client.close = AsyncMock(side_effect=mock_close) + return client + + @pytest.mark.asyncio + async def test_stream_yields_chunks_and_final_result( + self, mock_inner_client: MagicMock + ) -> None: + """Test that create_stream yields chunks and final CreateResult.""" + SpeculativeCache.reset_instance() + + async def mock_stream( + messages: Any, **kwargs: Any + ) -> AsyncGenerator[Union[str, CreateResult], None]: + yield "Hello" + yield " World" + yield CreateResult( + content="Hello World", + usage=RequestUsage(prompt_tokens=10, completion_tokens=5), + finish_reason="stop", + cached=False, + ) + + mock_inner_client.create_stream = mock_stream + + client = NvidiaSpeculativeClient(inner_client=mock_inner_client) + + chunks: List[Any] = [] + async for chunk in client.create_stream( + messages=[UserMessage(content="Hi", source="user")] + ): + chunks.append(chunk) + + assert len(chunks) == 3 + assert chunks[0] == "Hello" + assert chunks[1] == " World" + assert isinstance(chunks[2], CreateResult) + + @pytest.mark.asyncio + async def test_stream_triggers_prewarm_on_intent_detection( + self, mock_inner_client: MagicMock + ) -> None: + """Test that prewarm callback is triggered when intent is detected.""" + SpeculativeCache.reset_instance() + + prewarm_called = False + prewarm_tool_type: Optional[str] = None + + async def prewarm_callback( + tool_type: str, query_hint: str, context: Dict[str, Any] + ) -> Optional[str]: + nonlocal prewarm_called, prewarm_tool_type + prewarm_called = True + prewarm_tool_type = tool_type + return "mock_result" + + async def mock_stream( + messages: Any, **kwargs: Any + ) -> AsyncGenerator[Union[str, CreateResult], None]: + # Simulate reasoning content with tool intent + yield "" + yield "I will search for Python documentation" + yield "" + yield CreateResult( + content="Here's the documentation", + usage=RequestUsage(prompt_tokens=10, completion_tokens=20), + finish_reason="stop", + cached=False, + ) + + mock_inner_client.create_stream = mock_stream + + client = NvidiaSpeculativeClient( + inner_client=mock_inner_client, + prewarm_callback=prewarm_callback, + min_confidence=0.5, + ) + + chunks: List[Any] = [] + async for chunk in client.create_stream( + messages=[UserMessage(content="Find docs", source="user")] + ): + chunks.append(chunk) + + # Allow background tasks to complete + await asyncio.sleep(0.2) + + assert prewarm_called is True + assert prewarm_tool_type == "web_search" + + @pytest.mark.asyncio + async def test_stream_respects_min_confidence( + self, mock_inner_client: MagicMock + ) -> None: + """Test that prewarm is only triggered above min_confidence.""" + SpeculativeCache.reset_instance() + + prewarm_called = False + + async def prewarm_callback( + tool_type: str, query_hint: str, context: Dict[str, Any] + ) -> Optional[str]: + nonlocal prewarm_called + prewarm_called = True + return None + + async def mock_stream( + messages: Any, **kwargs: Any + ) -> AsyncGenerator[Union[str, CreateResult], None]: + yield "Maybe I should search" # Low confidence phrase + yield CreateResult( + content="Done", + usage=RequestUsage(prompt_tokens=5, completion_tokens=2), + finish_reason="stop", + cached=False, + ) + + mock_inner_client.create_stream = mock_stream + + client = NvidiaSpeculativeClient( + inner_client=mock_inner_client, + prewarm_callback=prewarm_callback, + min_confidence=0.99, # Very high threshold + ) + + async for _ in client.create_stream( + messages=[UserMessage(content="Test", source="user")] + ): + pass + + await asyncio.sleep(0.1) + + # Should not trigger because confidence is below threshold + assert prewarm_called is False + + @pytest.mark.asyncio + async def test_stream_deduplicates_intents( + self, mock_inner_client: MagicMock + ) -> None: + """Test that same intent is only fired once per session.""" + SpeculativeCache.reset_instance() + + prewarm_count = 0 + + async def prewarm_callback( + tool_type: str, query_hint: str, context: Dict[str, Any] + ) -> Optional[str]: + nonlocal prewarm_count + prewarm_count += 1 + return "result" + + async def mock_stream( + messages: Any, **kwargs: Any + ) -> AsyncGenerator[Union[str, CreateResult], None]: + # Multiple search intents in same stream + yield "I will search for Python docs. " + yield "Let me search for more info. " + yield "I need to search again. " + yield CreateResult( + content="Done", + usage=RequestUsage(prompt_tokens=10, completion_tokens=5), + finish_reason="stop", + cached=False, + ) + + mock_inner_client.create_stream = mock_stream + + client = NvidiaSpeculativeClient( + inner_client=mock_inner_client, + prewarm_callback=prewarm_callback, + min_confidence=0.5, + sniff_all_content=True, + ) + + async for _ in client.create_stream( + messages=[UserMessage(content="Test", source="user")] + ): + pass + + await asyncio.sleep(0.2) + + # Should only fire once due to deduplication + assert prewarm_count == 1 diff --git a/python/packages/autogen-ext/tests/models/test_sre_integration.py b/python/packages/autogen-ext/tests/models/test_sre_integration.py new file mode 100644 index 000000000000..0b7ba75803b2 --- /dev/null +++ b/python/packages/autogen-ext/tests/models/test_sre_integration.py @@ -0,0 +1,344 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Integration tests for NVIDIA Speculative Reasoning Execution. + +These tests verify the end-to-end flow of speculative execution with +simulated DeepSeek-R1 style reasoning streams. They demonstrate the +latency savings achievable by parallelizing tool execution with +model reasoning. +""" + +import asyncio +import time +from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from unittest.mock import MagicMock + +import pytest +from autogen_core.models import ( + CreateResult, + RequestUsage, + UserMessage, +) + +from autogen_ext.models.nvidia import ( + NvidiaSpeculativeClient, + ReasoningSniffer, + SpeculativeCache, +) + + +# Simulated DeepSeek-R1 reasoning stream with tool intent +MOCK_R1_STREAM_CONTENT = [ + "", + "Let me analyze this problem step by step.", + " First, I need to understand what the user is asking.", + " The user wants information about Python documentation.", + " I will search for Python 3.12 documentation to find the answer.", + " Let me also check for any related tutorials.", + "", + "\n\nBased on my research, here are the key Python 3.12 features:\n", + "1. Pattern matching improvements\n", + "2. Type parameter syntax\n", + "3. Performance optimizations", +] + + +class TestSpeculativeExecutionIntegration: + """Integration tests for the full speculative execution flow.""" + + @pytest.fixture(autouse=True) + def reset_cache(self) -> None: + """Reset the singleton cache before each test.""" + SpeculativeCache.reset_instance() + + @pytest.fixture + def mock_client_with_r1_stream(self) -> MagicMock: + """Create a mock client that simulates R1-style reasoning stream.""" + client = MagicMock() + client.model_info = { + "vision": False, + "function_calling": True, + "json_output": True, + "family": "r1", + } + + async def mock_close() -> None: + pass + + client.close = MagicMock(side_effect=mock_close) + + async def mock_stream( + messages: Any, **kwargs: Any + ) -> AsyncGenerator[Union[str, CreateResult], None]: + for chunk in MOCK_R1_STREAM_CONTENT: + await asyncio.sleep(0.05) # Simulate token generation latency + yield chunk + + yield CreateResult( + content="".join(MOCK_R1_STREAM_CONTENT), + usage=RequestUsage(prompt_tokens=50, completion_tokens=100), + finish_reason="stop", + cached=False, + ) + + client.create_stream = mock_stream + return client + + @pytest.mark.asyncio + async def test_end_to_end_speculative_execution( + self, mock_client_with_r1_stream: MagicMock + ) -> None: + """Test complete speculative execution flow with timing verification.""" + prewarm_events: List[Dict[str, Any]] = [] + stream_start_time: Optional[float] = None + + async def prewarm_callback( + tool_type: str, query_hint: str, context: Dict[str, Any] + ) -> Optional[str]: + prewarm_time = time.perf_counter() + prewarm_events.append({ + "tool_type": tool_type, + "query_hint": query_hint, + "time": prewarm_time, + "offset_ms": (prewarm_time - stream_start_time) * 1000 + if stream_start_time + else 0, + }) + # Simulate tool execution + await asyncio.sleep(0.1) + return f"Mock result for {tool_type}" + + client = NvidiaSpeculativeClient( + inner_client=mock_client_with_r1_stream, + prewarm_callback=prewarm_callback, + min_confidence=0.6, + ) + + stream_start_time = time.perf_counter() + chunks: List[Any] = [] + + async for chunk in client.create_stream( + messages=[UserMessage(content="Tell me about Python 3.12", source="user")] + ): + chunks.append(chunk) + + stream_end_time = time.perf_counter() + + # Allow background tasks to complete + await asyncio.sleep(0.2) + await client.close() + + # Verify prewarm was triggered + assert len(prewarm_events) > 0, "Expected at least one prewarm event" + + # Verify prewarm fired before stream ended + first_prewarm = prewarm_events[0] + assert first_prewarm["tool_type"] == "web_search" + + # Verify stream completed successfully + assert isinstance(chunks[-1], CreateResult) + + @pytest.mark.asyncio + async def test_speculative_delta_timing( + self, mock_client_with_r1_stream: MagicMock + ) -> None: + """Test that speculative execution provides timing advantage.""" + prewarm_time: Optional[float] = None + reasoning_end_time: Optional[float] = None + + async def prewarm_callback( + tool_type: str, query_hint: str, context: Dict[str, Any] + ) -> Optional[str]: + nonlocal prewarm_time + prewarm_time = time.perf_counter() + await asyncio.sleep(0.05) + return "result" + + client = NvidiaSpeculativeClient( + inner_client=mock_client_with_r1_stream, + prewarm_callback=prewarm_callback, + min_confidence=0.6, + ) + + stream_start = time.perf_counter() + + async for chunk in client.create_stream( + messages=[UserMessage(content="Test", source="user")] + ): + chunk_str = str(chunk) + if "" in chunk_str: + reasoning_end_time = time.perf_counter() + + await asyncio.sleep(0.2) + await client.close() + + # Verify timing relationship + assert prewarm_time is not None, "Prewarm should have been triggered" + assert reasoning_end_time is not None, "Should have detected reasoning end" + + # Prewarm should fire before reasoning ends (negative delta = savings) + prewarm_offset = (prewarm_time - stream_start) * 1000 + reasoning_end_offset = (reasoning_end_time - stream_start) * 1000 + speculative_delta = prewarm_offset - reasoning_end_offset + + assert speculative_delta < 0, ( + f"Speculative execution should provide timing advantage. " + f"Delta: {speculative_delta:.0f}ms" + ) + + @pytest.mark.asyncio + async def test_no_prewarm_when_speculation_disabled( + self, mock_client_with_r1_stream: MagicMock + ) -> None: + """Test that prewarming is skipped when speculation is disabled.""" + prewarm_called = False + + async def prewarm_callback( + tool_type: str, query_hint: str, context: Dict[str, Any] + ) -> Optional[str]: + nonlocal prewarm_called + prewarm_called = True + return None + + client = NvidiaSpeculativeClient( + inner_client=mock_client_with_r1_stream, + prewarm_callback=prewarm_callback, + enable_speculation=False, # Disabled + ) + + async for _ in client.create_stream( + messages=[UserMessage(content="Test", source="user")] + ): + pass + + await asyncio.sleep(0.1) + await client.close() + + assert prewarm_called is False + + @pytest.mark.asyncio + async def test_cache_stores_prewarm_results( + self, mock_client_with_r1_stream: MagicMock + ) -> None: + """Test that prewarm results are stored in cache.""" + async def prewarm_callback( + tool_type: str, query_hint: str, context: Dict[str, Any] + ) -> Optional[str]: + return f"cached_result_for_{tool_type}" + + client = NvidiaSpeculativeClient( + inner_client=mock_client_with_r1_stream, + prewarm_callback=prewarm_callback, + min_confidence=0.5, + ) + + async for _ in client.create_stream( + messages=[UserMessage(content="Test", source="user")] + ): + pass + + await asyncio.sleep(0.2) + await client.close() + + # Verify cache has entries + cache = client.cache + assert len(cache) > 0 + assert cache.stats["stores"] > 0 + + +class TestSnifferPatternMatching: + """Integration tests for sniffer pattern matching accuracy.""" + + @pytest.fixture(autouse=True) + def reset_cache(self) -> None: + """Reset cache before each test.""" + SpeculativeCache.reset_instance() + + @pytest.mark.parametrize( + "text,expected_tool,should_match", + [ + ("I will search for Python documentation", "web_search", True), + ("Let me look up the latest news", "web_search", True), + ("I need to check the database for user records", "database_query", True), + ("I'll calculate the total revenue", "calculate", True), + ("The weather is nice", None, False), + ("Just thinking about the problem", None, False), + ], + ) + def test_intent_detection_patterns( + self, text: str, expected_tool: Optional[str], should_match: bool + ) -> None: + """Test various intent detection patterns.""" + sniffer = ReasoningSniffer() + + intent = sniffer.sniff(text) + + if should_match: + assert intent is not None, f"Expected match for: {text}" + assert intent.tool_type == expected_tool + else: + assert intent is None, f"Expected no match for: {text}" + + def test_contraction_patterns(self) -> None: + """Test that contractions like I'll are properly detected.""" + sniffer = ReasoningSniffer() + + # Test I'll patterns + intent = sniffer.sniff("I'll search for the latest updates") + assert intent is not None + assert intent.tool_type == "web_search" + + sniffer.reset() + + intent = sniffer.sniff("I'll need to look up some information") + assert intent is not None + + +class TestPerformanceMetrics: + """Tests for performance metric tracking.""" + + @pytest.fixture(autouse=True) + def reset_cache(self) -> None: + """Reset cache before each test.""" + SpeculativeCache.reset_instance() + + @pytest.mark.asyncio + async def test_metrics_are_tracked(self) -> None: + """Test that performance metrics are properly tracked.""" + client_mock = MagicMock() + client_mock.model_info = {"family": "r1"} + + async def mock_close() -> None: + pass + + client_mock.close = MagicMock(side_effect=mock_close) + + async def mock_stream( + messages: Any, **kwargs: Any + ) -> AsyncGenerator[Union[str, CreateResult], None]: + yield "Test response" + yield CreateResult( + content="Test", + usage=RequestUsage(prompt_tokens=10, completion_tokens=5), + finish_reason="stop", + cached=False, + ) + + client_mock.create_stream = mock_stream + + client = NvidiaSpeculativeClient(inner_client=client_mock) + + async for _ in client.create_stream( + messages=[UserMessage(content="Test", source="user")] + ): + pass + + await client.close() + + # Verify metrics were captured + metrics = client.last_metrics + assert metrics is not None + assert metrics.stream_start_time is not None + assert metrics.stream_end_time is not None + assert metrics.first_token_time is not None