From 85ad2f3729c52182da6116f98b8efc026f0155ee Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Mon, 19 Jan 2026 00:52:26 +0000 Subject: [PATCH 1/2] refactor(chat): extract reusable chat and chat_stream functions for MCP server - Create new vlmrun/client/chat.py module with reusable functions: - chat(): Non-streaming chat completion - chat_stream(): Streaming chat completion - collect_stream(): Helper to collect streaming chunks - ChatResponse and ChatStreamChunk dataclasses - ChatError exception class - Helper functions for file upload, message building, validation - Refactor CLI chat command to use the new reusable functions - CLI now focuses on presentation layer (Rich output, progress spinners) - Core logic delegated to vlmrun.client.chat module - Expose new functions in vlmrun.client.__init__.py for easy import: - chat, chat_stream, collect_stream - ChatResponse, ChatStreamChunk, ChatError - AVAILABLE_MODELS, DEFAULT_MODEL, extract_artifact_refs - Update tests to use new module paths Co-Authored-By: Sudeep Pillai --- tests/cli/test_cli_chat.py | 4 +- vlmrun/cli/_cli/chat.py | 481 +++++++++++++++---------------------- vlmrun/client/__init__.py | 11 + vlmrun/client/chat.py | 415 ++++++++++++++++++++++++++++++++ 4 files changed, 619 insertions(+), 292 deletions(-) create mode 100644 vlmrun/client/chat.py diff --git a/tests/cli/test_cli_chat.py b/tests/cli/test_cli_chat.py index 5279fc8..36197df 100644 --- a/tests/cli/test_cli_chat.py +++ b/tests/cli/test_cli_chat.py @@ -12,6 +12,8 @@ from vlmrun.cli._cli.chat import ( resolve_prompt, build_messages, +) +from vlmrun.client.chat import ( extract_artifact_refs, AVAILABLE_MODELS, DEFAULT_MODEL, @@ -244,7 +246,7 @@ def test_chat_unsupported_file_type( assert result.exit_code == 1 assert "Unsupported file type" in result.stdout - @patch("vlmrun.cli._cli.chat.upload_files") + @patch("vlmrun.client.chat._upload_files") def test_chat_with_file_json_output( self, mock_upload, diff --git a/vlmrun/cli/_cli/chat.py b/vlmrun/cli/_cli/chat.py index 7790a4f..7080c1d 100644 --- a/vlmrun/cli/_cli/chat.py +++ b/vlmrun/cli/_cli/chat.py @@ -4,13 +4,11 @@ import json import os -import re import sys import threading import time -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Optional import typer from openai import APIConnectionError, APIError, AuthenticationError, RateLimitError @@ -21,7 +19,15 @@ from rich.tree import Tree from vlmrun.client import VLMRun -from vlmrun.client.types import FileResponse +from vlmrun.client.chat import ( + AVAILABLE_MODELS, + DEFAULT_MODEL, + ChatError, + ChatResponse, + chat as chat_core, + chat_stream as chat_stream_core, + extract_artifact_refs, +) from vlmrun.constants import ( DEFAULT_BASE_URL, SUPPORTED_INPUT_FILETYPES, @@ -81,6 +87,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): ) ) raise typer.Exit(1) + elif issubclass(exc_type, ChatError): + console.print( + Panel( + str(exc_val), + title="[red]Chat Error[/red]", + title_align="left", + border_style="red", + ) + ) + raise typer.Exit(1) elif issubclass(exc_type, Exception): console.print( Panel( @@ -129,15 +145,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): no_args_is_help=True, ) -# Available models -AVAILABLE_MODELS = [ - "vlmrun-orion-1:fast", - "vlmrun-orion-1:auto", - "vlmrun-orion-1:pro", -] - -DEFAULT_MODEL = "vlmrun-orion-1:auto" - def read_prompt_from_stdin() -> Optional[str]: """Read prompt from stdin if available (piped input).""" @@ -229,89 +236,18 @@ def format_file_size(size_bytes: int) -> str: return f"{size_bytes / (1024 * 1024 * 1024):.1f}GB" -def upload_files( - client: VLMRun, files: List[Path], show_progress: bool = True -) -> List[FileResponse]: - """Upload files concurrently and return their file responses.""" - file_responses: List[FileResponse] = [] - - if not show_progress: - # JSON output: upload without progress display - with ThreadPoolExecutor(max_workers=min(len(files), 4)) as executor: - futures = { - executor.submit(client.files.upload, file=f, purpose="assistants"): f - for f in files - } - for future in as_completed(futures): - file_path = futures[future] - try: - file_response = future.result() - file_responses.append(file_response) - except Exception as e: - console.print(f"[red]Error uploading {file_path.name}:[/] {e}") - raise typer.Exit(1) from e - else: - # Rich output - upload with status spinner - with Status("Uploading...", console=console, spinner="dots") as status: - with ThreadPoolExecutor(max_workers=min(len(files), 4)) as executor: - futures = { - executor.submit( - client.files.upload, file=f, purpose="assistants" - ): f - for f in files - } - for future in as_completed(futures): - file_path = futures[future] - try: - file_response = future.result() - file_responses.append(file_response) - status.update(f"Uploading {file_path.name}...") - except Exception as e: - console.print(f"[red]Error uploading {file_path.name}:[/] {e}") - raise typer.Exit(1) from e - - return file_responses - - def build_messages( - prompt: str, file_responses: Optional[List[FileResponse]] = None + prompt: str, file_responses: Optional[List[Any]] = None ) -> List[Dict[str, Any]]: """Build OpenAI-style messages with optional file attachments.""" - # Add files first using file IDs - content = [ + content: List[Dict[str, Any]] = [ {"type": "input_file", "file_id": file_response.id} for file_response in file_responses or [] ] - # Add text prompt after files content.append({"type": "text", "text": prompt}) return [{"role": "user", "content": content}] -def extract_artifact_refs(response_content: str) -> List[str]: - """Extract artifact reference IDs from response content. - - Looks for patterns like img_XXXXXX, aud_XXXXXX, vid_XXXXXX, doc_XXXXXX, - recon_XXXXXX, arr_XXXXXX, url_XXXXXX in the response text. - """ - # Reference patterns from vlmrun/types/refs.py - patterns = [ - r"\bimg_\w{6}\b", # ImageRef - r"\baud_\w{6}\b", # AudioRef - r"\bvid_\w{6}\b", # VideoRef - r"\bdoc_\w{6}\b", # DocumentRef - r"\brecon_\w{6}\b", # ReconRef - r"\barr_\w{6}\b", # ArrayRef - r"\burl_\w{6}\b", # UrlRef - ] - - refs: Set[str] = set() - for pattern in patterns: - matches = re.findall(pattern, response_content) - refs.update(matches) - - return sorted(list(refs)) - - def download_artifact( client: VLMRun, session_id: str, ref_id: str, output_dir: Path ) -> Path: @@ -371,10 +307,10 @@ def __init__(self, message: str, console: Console, spinner: str = "dots"): self.base_message = message self.console = console self.spinner = spinner - self.start_time = None - self.status = None + self.start_time: Optional[float] = None + self.status: Optional[Status] = None self._stop_event = threading.Event() - self._timer_thread = None + self._timer_thread: Optional[threading.Thread] = None def __enter__(self): self.start_time = time.time() @@ -410,43 +346,39 @@ def _update_timer(self): def print_rich_output( - content: str, - model: str, - latency_s: float, - usage: Optional[Dict[str, Any]] = None, + response: ChatResponse, artifacts: Optional[List[Dict[str, str]]] = None, artifact_dir: Optional[Path] = None, - session_id: Optional[str] = None, ) -> None: """Print rich-formatted output with panels.""" # Build subtitle with stats - stats = [model] - if usage: - prompt_tokens = usage.get("prompt_tokens", 0) - completion_tokens = usage.get("completion_tokens", 0) - total_tokens = usage.get("total_tokens", 0) + stats = [response.model] + if response.usage: + prompt_tokens = response.usage.get("prompt_tokens", 0) + completion_tokens = response.usage.get("completion_tokens", 0) + total_tokens = response.usage.get("total_tokens", 0) if total_tokens and False: # TODO: Add back in when usage is implemented stats.append( f"P:{prompt_tokens} / C:{completion_tokens} / T:{total_tokens} tokens" ) # Add credits if available - credits = usage.get("credits_used") + credits = response.usage.get("credits_used") if credits is not None: stats.append(f"{credits} credit(s)") - stats.append(f"{format_time(latency_s)}") + stats.append(f"{format_time(response.latency_s)}") subtitle = " ยท ".join(stats) # Build title with optional session_id - if session_id: - title = f"[bold]Response[/bold] [dim](id={session_id})[/dim]" + if response.session_id: + title = f"[bold]Response[/bold] [dim](id={response.session_id})[/dim]" else: title = "[bold]Response[/bold]" # Main response panel console.print( Panel( - Markdown(content), + Markdown(response.content), title=title, title_align="left", subtitle=f"[dim][white]{subtitle}[/white][/dim]", @@ -460,6 +392,92 @@ def print_rich_output( console.print(f" [dim]Artifacts:[/dim] {artifact_dir}") +def _handle_artifact_download( + client: VLMRun, + response: ChatResponse, + output_dir: Optional[Path], + output_json: bool, +) -> Optional[List[Path]]: + """Handle artifact download from a chat response. + + Args: + client: VLMRun client instance + response: The chat response containing artifact refs + output_dir: Optional output directory for artifacts + output_json: Whether JSON output mode is enabled + + Returns: + List of downloaded file paths, or None if no artifacts + """ + artifact_refs = response.artifact_refs + if not artifact_refs: + return None + + if not response.session_id: + console.print( + "[yellow]Warning:[/] No session_id available, artifacts download skipped" + ) + return None + + _session_id = response.session_id + + # Set up output directory + if output_dir: + artifact_dir = output_dir + else: + artifact_dir = VLMRUN_ARTIFACTS_CACHE_DIR / _session_id + artifact_dir.mkdir(parents=True, exist_ok=True) + + downloaded_files: List[Path] = [] + + # Download artifacts with status spinner + if not output_json: + with Status("Downloading artifacts...", console=console, spinner="dots"): + for ref_id in artifact_refs: + try: + output_path = download_artifact( + client, + _session_id, + ref_id, + artifact_dir, + ) + downloaded_files.append(output_path) + except Exception as e: + console.print(f"[red]Failed to download {ref_id}: {e}[/]") + else: + # JSON output mode - download without progress + for ref_id in artifact_refs: + try: + output_path = download_artifact( + client, + _session_id, + ref_id, + artifact_dir, + ) + downloaded_files.append(output_path) + except Exception as e: + console.print(f"[red]Failed to download {ref_id}: {e}[/]") + + if not output_json and downloaded_files: + # Create elegant tree view of artifacts + tree = Tree(f"{artifact_dir}", guide_style="dim") + for file_path in sorted(downloaded_files): + # Get file size + size_str = format_file_size(file_path.stat().st_size) + tree.add(f"{file_path.name} [dim]({size_str})[/dim]") + + console.print( + Panel( + tree, + title=f"Downloaded {len(downloaded_files)} artifact(s)", + title_align="left", + border_style="dim", + ) + ) + + return downloaded_files + + @app.command() def chat( ctx: typer.Context, @@ -567,42 +585,24 @@ def chat( raise typer.Exit(1) try: - # Upload input files concurrently if provided - file_responses: List[FileResponse] = [] - if input_files: - file_responses = upload_files( - client, input_files, show_progress=not output_json - ) - - if not output_json: - # Create tree view of uploaded files - tree = Tree("", guide_style="dim", hide_root=True) - for file_path in input_files: - size_str = format_file_size(file_path.stat().st_size) - tree.add(f"{file_path.name} [dim]({size_str})[/dim]") - console.print( - Panel( - tree, - title=f"Uploaded {len(file_responses)} file(s)", - title_align="left", - border_style="dim", - ) + # Show file upload info for rich output + if input_files and not output_json: + # Create tree view of files to be uploaded + tree = Tree("", guide_style="dim", hide_root=True) + for file_path in input_files: + size_str = format_file_size(file_path.stat().st_size) + tree.add(f"{file_path.name} [dim]({size_str})[/dim]") + console.print( + Panel( + tree, + title=f"Uploading {len(input_files)} file(s)", + title_align="left", + border_style="dim", ) - - # Build messages for the chat completion - messages = build_messages( - final_prompt, file_responses if file_responses else None - ) - - # Call the OpenAI-compatible chat completions API - response_content = "" - usage_data: Optional[Dict[str, Any]] = None - response_id: Optional[str] = None - - start_time = time.time() + ) if no_stream: - # Non-streaming mode + # Non-streaming mode - use the reusable chat function if not output_json: with ( TimedStatus( @@ -612,65 +612,39 @@ def chat( ), handle_api_errors(), ): - response = client.agent.completions.create( + response = chat_core( + client=client, + prompt=final_prompt, + inputs=input_files, model=model, - messages=messages, - stream=False, session_id=session_id, ) else: - # JSON output: no status messages, just make the API call + # JSON output: no status messages with handle_api_errors(): - response = client.agent.completions.create( + response = chat_core( + client=client, + prompt=final_prompt, + inputs=input_files, model=model, - messages=messages, - stream=False, session_id=session_id, ) - latency_s = time.time() - start_time - response_content = response.choices[0].message.content or "" - response_id = response.session_id - - if response.usage: - usage_data = { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - } - # Add credits if available (may be in custom fields) - if ( - hasattr(response.usage, "credits_used") - and response.usage.credits_used is not None - ): - usage_data["credits_used"] = response.usage.credits_used - elif ( - hasattr(response.usage, "credits") - and response.usage.credits is not None - ): - usage_data["credits_used"] = response.usage.credits - if output_json: output = { "id": response.id, "session_id": response.session_id, "model": response.model, - "content": response_content, - "latency_s": latency_s, - "usage": usage_data, + "content": response.content, + "latency_s": response.latency_s, + "usage": response.usage, } print(json.dumps(output, indent=2, default=str)) else: - print_rich_output( - response_content, - model, - latency_s, - usage_data, - session_id=response_id, - ) + print_rich_output(response) else: - # Streaming mode + # Streaming mode - use the reusable chat_stream function with ( TimedStatus( f"Processing ([bold]{model}[/bold])...", @@ -679,138 +653,63 @@ def chat( ), handle_api_errors(), ): - stream = client.agent.completions.create( + stream = chat_stream_core( + client=client, + prompt=final_prompt, + inputs=input_files, model=model, - messages=messages, - stream=True, session_id=session_id, ) - # Collect streaming content and usage data - chunks = [] - stream_usage_data = None + # Collect streaming content + chunks: List[str] = [] + response_session_id: Optional[str] = None + stream_usage_data: Optional[Dict[str, Any]] = None + + start_time = time.time() for chunk in stream: - # Capture session_id from first chunk - if not response_id and hasattr(chunk, "session_id"): - response_id = chunk.session_id - if ( - chunk.choices - and chunk.choices[0].delta - and chunk.choices[0].delta.content - ): - chunks.append(chunk.choices[0].delta.content) - # Capture usage data from the final chunk - if hasattr(chunk, "usage") and chunk.usage: - stream_usage_data = { - "prompt_tokens": chunk.usage.prompt_tokens, - "completion_tokens": chunk.usage.completion_tokens, - "total_tokens": chunk.usage.total_tokens, - } - # Add credits if available - if ( - hasattr(chunk.usage, "credits_used") - and chunk.usage.credits_used is not None - ): - stream_usage_data["credits_used"] = chunk.usage.credits_used - elif ( - hasattr(chunk.usage, "credits") - and chunk.usage.credits is not None - ): - stream_usage_data["credits_used"] = chunk.usage.credits - - response_content = "".join(chunks) - - latency_s = time.time() - start_time + if chunk.content: + chunks.append(chunk.content) + if chunk.session_id: + response_session_id = chunk.session_id + if chunk.usage: + stream_usage_data = chunk.usage + + latency_s = time.time() - start_time + + response_content = "".join(chunks) + artifact_refs = extract_artifact_refs(response_content) + + # Create a ChatResponse for consistent handling + response = ChatResponse( + content=response_content, + session_id=response_session_id, + model=model, + latency_s=latency_s, + usage=stream_usage_data, + artifact_refs=artifact_refs, + ) # Display the complete response if output_json: output = { - "content": response_content, - "latency_s": latency_s, + "content": response.content, + "session_id": response.session_id, + "latency_s": response.latency_s, } if stream_usage_data: output["usage"] = stream_usage_data print(json.dumps(output, indent=2, default=str)) else: - print_rich_output( - response_content, - model, - latency_s, - stream_usage_data, - session_id=response_id, - ) + print_rich_output(response) # Extract and download artifacts if present if not no_download: - artifact_refs = extract_artifact_refs(response_content) - if artifact_refs: - # Use response_id as _session_id if available - if not response_id: - console.print( - "[yellow]Warning:[/] No session_id available, artifacts download skipped" - ) - return - else: - _session_id = response_id - - # Set up output directory - if output_dir: - artifact_dir = output_dir - else: - artifact_dir = VLMRUN_ARTIFACTS_CACHE_DIR / _session_id - artifact_dir.mkdir(parents=True, exist_ok=True) - - downloaded_files = [] - - # Download artifacts with status spinner - if not output_json: - with Status( - "Downloading artifacts...", console=console, spinner="dots" - ): - for ref_id in artifact_refs: - try: - output_path = download_artifact( - client, - _session_id, - ref_id, - artifact_dir, - ) - downloaded_files.append(output_path) - except Exception as e: - console.print( - f"[red]Failed to download {ref_id}: {e}[/]" - ) - else: - # JSON output mode - download without progress - for ref_id in artifact_refs: - try: - output_path = download_artifact( - client, - _session_id, - ref_id, - artifact_dir, - ) - downloaded_files.append(output_path) - except Exception as e: - console.print(f"[red]Failed to download {ref_id}: {e}[/]") - - if not output_json and downloaded_files: - # Create elegant tree view of artifacts - tree = Tree(f"{artifact_dir}", guide_style="dim") - for file_path in sorted(downloaded_files): - # Get file size - size_str = format_file_size(file_path.stat().st_size) - tree.add(f"{file_path.name} [dim]({size_str})[/dim]") - - console.print( - Panel( - tree, - title=f"Downloaded {len(downloaded_files)} artifact(s)", - title_align="left", - border_style="dim", - ) - ) + _handle_artifact_download(client, response, output_dir, output_json) + except ChatError as e: + console.print(f"[red]Error:[/] {e}") + raise typer.Exit(1) except Exception as e: console.print(f"[red]Error:[/] {e}") raise typer.Exit(1) diff --git a/vlmrun/client/__init__.py b/vlmrun/client/__init__.py index f69f413..0a7214f 100644 --- a/vlmrun/client/__init__.py +++ b/vlmrun/client/__init__.py @@ -1 +1,12 @@ from .client import VLMRun # noqa: F401 +from .chat import ( # noqa: F401 + chat, + chat_stream, + collect_stream, + ChatResponse, + ChatStreamChunk, + ChatError, + AVAILABLE_MODELS, + DEFAULT_MODEL, + extract_artifact_refs, +) diff --git a/vlmrun/client/chat.py b/vlmrun/client/chat.py new file mode 100644 index 0000000..738bff2 --- /dev/null +++ b/vlmrun/client/chat.py @@ -0,0 +1,415 @@ +"""Reusable chat functions for VLM Run API. + +This module provides `chat` and `chat_stream` functions that can be used +by the CLI, MCP server, or any other client code. +""" + +from __future__ import annotations + +import re +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Set, TYPE_CHECKING + +if TYPE_CHECKING: + from vlmrun.client import VLMRun + +from vlmrun.client.types import FileResponse +from vlmrun.constants import SUPPORTED_INPUT_FILETYPES + + +# Available models +AVAILABLE_MODELS = [ + "vlmrun-orion-1:fast", + "vlmrun-orion-1:auto", + "vlmrun-orion-1:pro", +] + +DEFAULT_MODEL = "vlmrun-orion-1:auto" + + +@dataclass +class ChatResponse: + """Response from a non-streaming chat completion.""" + + content: str + session_id: Optional[str] + model: str + latency_s: float + usage: Optional[Dict[str, Any]] = None + id: Optional[str] = None + artifact_refs: List[str] = field(default_factory=list) + + +@dataclass +class ChatStreamChunk: + """A chunk from a streaming chat completion.""" + + content: str + session_id: Optional[str] = None + usage: Optional[Dict[str, Any]] = None + is_final: bool = False + + +class ChatError(Exception): + """Error raised during chat operations.""" + + pass + + +def extract_artifact_refs(response_content: str) -> List[str]: + """Extract artifact reference IDs from response content. + + Looks for patterns like img_XXXXXX, aud_XXXXXX, vid_XXXXXX, doc_XXXXXX, + recon_XXXXXX, arr_XXXXXX, url_XXXXXX in the response text. + + Args: + response_content: The response text to search for artifact references. + + Returns: + A sorted list of unique artifact reference IDs found in the content. + """ + patterns = [ + r"\bimg_\w{6}\b", # ImageRef + r"\baud_\w{6}\b", # AudioRef + r"\bvid_\w{6}\b", # VideoRef + r"\bdoc_\w{6}\b", # DocumentRef + r"\brecon_\w{6}\b", # ReconRef + r"\barr_\w{6}\b", # ArrayRef + r"\burl_\w{6}\b", # UrlRef + ] + + refs: Set[str] = set() + for pattern in patterns: + matches = re.findall(pattern, response_content) + refs.update(matches) + + return sorted(list(refs)) + + +def _validate_model(model: str) -> None: + """Validate that the model is supported. + + Args: + model: The model name to validate. + + Raises: + ChatError: If the model is not supported. + """ + if model not in AVAILABLE_MODELS: + raise ChatError( + f"Invalid model '{model}'. Available models: {', '.join(AVAILABLE_MODELS)}" + ) + + +def _validate_inputs(inputs: Optional[List[Path]]) -> None: + """Validate that input files have supported file types. + + Args: + inputs: List of input file paths to validate. + + Raises: + ChatError: If any file has an unsupported file type. + """ + if not inputs: + return + + for file_path in inputs: + suffix = file_path.suffix.lower() + if suffix not in SUPPORTED_INPUT_FILETYPES: + raise ChatError( + f"Unsupported file type: {suffix}. " + f"Supported types: {', '.join(SUPPORTED_INPUT_FILETYPES)}" + ) + + +def _upload_files( + client: "VLMRun", files: List[Path], max_workers: int = 4 +) -> List[FileResponse]: + """Upload files concurrently and return their file responses. + + Args: + client: VLMRun client instance. + files: List of file paths to upload. + max_workers: Maximum number of concurrent uploads. + + Returns: + List of FileResponse objects for the uploaded files. + + Raises: + ChatError: If any file upload fails. + """ + file_responses: List[FileResponse] = [] + + with ThreadPoolExecutor(max_workers=min(len(files), max_workers)) as executor: + futures = { + executor.submit(client.files.upload, file=f, purpose="assistants"): f + for f in files + } + for future in as_completed(futures): + file_path = futures[future] + try: + file_response = future.result() + file_responses.append(file_response) + except Exception as e: + raise ChatError(f"Error uploading {file_path.name}: {e}") from e + + return file_responses + + +def _build_messages( + prompt: str, file_responses: Optional[List[FileResponse]] = None +) -> List[Dict[str, Any]]: + """Build OpenAI-style messages with optional file attachments. + + Args: + prompt: The text prompt for the chat. + file_responses: Optional list of FileResponse objects for file attachments. + + Returns: + List of message dictionaries in OpenAI format. + """ + content: List[Dict[str, Any]] = [ + {"type": "input_file", "file_id": file_response.id} + for file_response in file_responses or [] + ] + content.append({"type": "text", "text": prompt}) + return [{"role": "user", "content": content}] + + +def _extract_usage_data(usage: Any) -> Optional[Dict[str, Any]]: + """Extract usage data from a response or chunk. + + Args: + usage: The usage object from the API response. + + Returns: + Dictionary with usage data or None if not available. + """ + if not usage: + return None + + usage_data = { + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens, + "total_tokens": usage.total_tokens, + } + + if hasattr(usage, "credits_used") and usage.credits_used is not None: + usage_data["credits_used"] = usage.credits_used + elif hasattr(usage, "credits") and usage.credits is not None: + usage_data["credits_used"] = usage.credits + + return usage_data + + +def chat( + client: "VLMRun", + prompt: str, + inputs: Optional[List[Path]] = None, + model: str = DEFAULT_MODEL, + session_id: Optional[str] = None, +) -> ChatResponse: + """Perform a non-streaming chat completion. + + This function uploads any input files, sends the prompt to the VLM Run API, + and returns the complete response. + + Args: + client: VLMRun client instance. + prompt: The text prompt for the chat. + inputs: Optional list of input file paths (images, videos, documents). + model: The model to use for the chat. Defaults to "vlmrun-orion-1:auto". + session_id: Optional session ID for persisting chat history. + + Returns: + ChatResponse containing the response content, session ID, usage data, etc. + + Raises: + ChatError: If the model is invalid, file types are unsupported, or API call fails. + + Example: + ```python + from vlmrun.client import VLMRun + from vlmrun.client.chat import chat + from pathlib import Path + + client = VLMRun(api_key="your-api-key") + response = chat( + client=client, + prompt="Describe this image", + inputs=[Path("image.jpg")], + ) + print(response.content) + ``` + """ + _validate_model(model) + _validate_inputs(inputs) + + start_time = time.time() + + file_responses: List[FileResponse] = [] + if inputs: + file_responses = _upload_files(client, inputs) + + messages = _build_messages(prompt, file_responses if file_responses else None) + + response = client.agent.completions.create( + model=model, + messages=messages, + stream=False, + session_id=session_id, + ) + + latency_s = time.time() - start_time + response_content = response.choices[0].message.content or "" + response_id = getattr(response, "session_id", None) + usage_data = _extract_usage_data(response.usage) + artifact_refs = extract_artifact_refs(response_content) + + return ChatResponse( + content=response_content, + session_id=response_id, + model=model, + latency_s=latency_s, + usage=usage_data, + id=response.id, + artifact_refs=artifact_refs, + ) + + +def chat_stream( + client: "VLMRun", + prompt: str, + inputs: Optional[List[Path]] = None, + model: str = DEFAULT_MODEL, + session_id: Optional[str] = None, +) -> Iterator[ChatStreamChunk]: + """Perform a streaming chat completion. + + This function uploads any input files, sends the prompt to the VLM Run API, + and yields response chunks as they arrive. + + Args: + client: VLMRun client instance. + prompt: The text prompt for the chat. + inputs: Optional list of input file paths (images, videos, documents). + model: The model to use for the chat. Defaults to "vlmrun-orion-1:auto". + session_id: Optional session ID for persisting chat history. + + Yields: + ChatStreamChunk objects containing response content chunks. + The final chunk will have is_final=True and may contain usage data. + + Raises: + ChatError: If the model is invalid, file types are unsupported, or API call fails. + + Example: + ```python + from vlmrun.client import VLMRun + from vlmrun.client.chat import chat_stream + from pathlib import Path + + client = VLMRun(api_key="your-api-key") + for chunk in chat_stream( + client=client, + prompt="Describe this image", + inputs=[Path("image.jpg")], + ): + print(chunk.content, end="", flush=True) + if chunk.is_final: + print(f"\\nUsage: {chunk.usage}") + ``` + """ + _validate_model(model) + _validate_inputs(inputs) + + file_responses: List[FileResponse] = [] + if inputs: + file_responses = _upload_files(client, inputs) + + messages = _build_messages(prompt, file_responses if file_responses else None) + + stream = client.agent.completions.create( + model=model, + messages=messages, + stream=True, + session_id=session_id, + ) + + response_session_id: Optional[str] = None + chunks_content: List[str] = [] + + for chunk in stream: + if not response_session_id and hasattr(chunk, "session_id"): + response_session_id = chunk.session_id + + content = "" + if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + chunks_content.append(content) + + usage_data = None + is_final = False + + if hasattr(chunk, "usage") and chunk.usage: + usage_data = _extract_usage_data(chunk.usage) + is_final = True + + yield ChatStreamChunk( + content=content, + session_id=response_session_id, + usage=usage_data, + is_final=is_final, + ) + + +def collect_stream(stream: Iterator[ChatStreamChunk]) -> ChatResponse: + """Collect all chunks from a streaming response into a single ChatResponse. + + This is a convenience function for cases where you want to use streaming + internally but return a complete response. + + Args: + stream: An iterator of ChatStreamChunk objects from chat_stream(). + + Returns: + ChatResponse containing the complete response. + + Example: + ```python + from vlmrun.client.chat import chat_stream, collect_stream + + stream = chat_stream(client, "Hello") + response = collect_stream(stream) + print(response.content) + ``` + """ + chunks: List[str] = [] + session_id: Optional[str] = None + usage: Optional[Dict[str, Any]] = None + + start_time = time.time() + + for chunk in stream: + if chunk.content: + chunks.append(chunk.content) + if chunk.session_id: + session_id = chunk.session_id + if chunk.usage: + usage = chunk.usage + + latency_s = time.time() - start_time + content = "".join(chunks) + artifact_refs = extract_artifact_refs(content) + + return ChatResponse( + content=content, + session_id=session_id, + model=DEFAULT_MODEL, + latency_s=latency_s, + usage=usage, + artifact_refs=artifact_refs, + ) From 4b877b00e349eac7c777aab0445aa2e9cb185b61 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Mon, 19 Jan 2026 00:55:38 +0000 Subject: [PATCH 2/2] fix(test): use sys.modules to access chat module for mocking Fix test failure on Python 3.10 where the patch path 'vlmrun.client.chat._upload_files' was resolving to the 'chat' function instead of the 'chat' module due to the naming collision in vlmrun.client.__init__.py. Use sys.modules to explicitly get the chat module and patch.object to mock the _upload_files function. Co-Authored-By: Sudeep Pillai --- tests/cli/test_cli_chat.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/cli/test_cli_chat.py b/tests/cli/test_cli_chat.py index 36197df..8b89c88 100644 --- a/tests/cli/test_cli_chat.py +++ b/tests/cli/test_cli_chat.py @@ -246,10 +246,8 @@ def test_chat_unsupported_file_type( assert result.exit_code == 1 assert "Unsupported file type" in result.stdout - @patch("vlmrun.client.chat._upload_files") def test_chat_with_file_json_output( self, - mock_upload, runner, config_file, mock_client, @@ -257,9 +255,13 @@ def test_chat_with_file_json_output( tmp_path, ): """Test chat with file and JSON output.""" + import sys + + # Get the actual chat module (not the function) + chat_module = sys.modules["vlmrun.client.chat"] # Setup mocks - return FileResponse - mock_upload.return_value = [ + mock_file_response = [ FileResponse( id="file-123", filename="test.jpg", @@ -276,7 +278,10 @@ def test_chat_with_file_json_output( test_file = tmp_path / "test.jpg" test_file.write_bytes(b"fake image data") - with patch.object(mock_client, "openai", mock_client.openai): + with ( + patch.object(chat_module, "_upload_files", return_value=mock_file_response), + patch.object(mock_client, "openai", mock_client.openai), + ): _result = runner.invoke( app, [