diff --git a/.env.template b/.env.template index df704083..d30bbab5 100644 --- a/.env.template +++ b/.env.template @@ -16,8 +16,3 @@ OPEN_ROUTER_API_KEY= # Telemetry ASKUI__VA__TELEMETRY__ENABLED=True # Set to "False" to disable telemetry -# OpenTelemetry Tracing Configuration -#ASKUI__CHAT_API__OTEL__ENABLED=False -#ASKUI__CHAT_API__OTEL__ENDPOINT=http://localhost/v1/traces -#ASKUI__CHAT_API__OTEL__SECRET= -#ASKUI__CHAT_API__OTEL__SERVICE_NAME=chat-api diff --git a/README.md b/README.md index e644c4b9..1321e7af 100644 --- a/README.md +++ b/README.md @@ -234,7 +234,6 @@ Tools are organized by category: `universal/` (work with any agent), `computer/` Aside from our [official documentation](https://docs.askui.com), we also have some additional guides and examples under the [docs](docs) folder that you may find useful, for example: -- **[Chat](docs/chat.md)** - How to interact with agents through a chat - **[Direct Tool Use](docs/direct-tool-use.md)** - How to use the tools, e.g., clipboard, the Agent OS etc. - **[Extracting Data](docs/extracting-data.md)** - How to extract data from the screen and documents - **[MCP](docs/mcp.md)** - How to use MCP servers to extend the capabilities of an agent diff --git a/docs/chat.md b/docs/chat.md deleted file mode 100644 index c1190834..00000000 --- a/docs/chat.md +++ /dev/null @@ -1,393 +0,0 @@ -# AskUI Chat - -**⚠️ Warning:** AskUI Chat is currently in an experimental stage and has several limitations (see below). - -AskUI Chat is a web application that allows interacting with an AskUI Vision Agent similar how it can be -done with `VisionAgent.act()` or `AndroidVisionAgent.act()` but in a more interactive manner that involves less code. - -## Table of Contents - -- [Installation](#installation) -- [Configuration](#configuration) -- [Usage](#usage) -- [Architecture](#architecture) - - [Resources](#resources) - - [Chat Conversation Flow](#chat-conversation-flow) -- [API Reference](#api-reference) -- [API Usage Examples](#api-usage-examples) - - [0. Start Chat API server and prepare environment](#0-start-chat-api-server-and-prepare-environment) - - [1. List available assistants](#1-list-available-assistants) - - [2. Start conversation](#2-start-conversation) - - [3. Continue conversation](#3-continue-conversation) - - [4. Retrieve the whole conversation](#4-retrieve-the-whole-conversation) - -## Installation - -Please follow the [installation instructions](../README.md#installation). - -Instead of installing the `askui[all]` package with all features, you can install the `askui[chat]` package with only the chat features to save some disk space and speed up the installation: - -```bash -pip install askui[chat] -``` - -## Configuration - -To use the chat, configure the following environment variables: - -- `ASKUI_TOKEN`: AskUI Vision Agent behind chat uses currently the AskUI API -- `ASKUI_WORKSPACE_ID`: AskUI Vision Agent behind chat uses currently the AskUI API -- `ASKUI__CHAT_API__DATA_DIR` (optional, defaults to `$(pwd)/chat`): Currently, the AskUI chat stores all data in a directory locally. You can change the default directory by setting this environment variable. -- `ASKUI__CHAT_API__HOST` (optional, defaults to `127.0.0.1`): The host to bind the chat API to. -- `ASKUI__CHAT_API__PORT` (optional, defaults to `9261`): The port to bind the chat API to. -- `ASKUI__CHAT_API__LOG_LEVEL` (optional, defaults to `info`): The log level to use for the chat API. - - -## Usage - -Start the chat API server within a shell: - -```bash -python -m askui.chat -``` - -After the server has started, navigate to the chat in the [AskUI Hub](https://hub.askui.com/). - -## Architecture - -This repository only includes the AskUI Chat API (`src/askui/chat`). The AskUI Chat UI can be accessed through the [AskUI Hub](https://hub.askui.com/) and connects to the local Chat API after it has been started. - -The AskUI Chat provides a comprehensive chat system with assistants, threads, messages, runs, and file management capabilities. -The underlying API is roughly modeled after the [OpenAI Assistants API](https://platform.openai.com/docs/assistants/migration) but also -integrates different concepts and extends it in various ways, e.g., - -- MCP configs for retrieving tools from MCP servers -- messages modeled after [Anthropic's Message API](https://docs.anthropic.com/en/api/messages#body-messages) -- runs enabling the execution of multiple iterations of tool calling loops instead of passing control back to user after each iteration. - -### Resources - -The API is organized around the following core resources: -- **Assistants**: AI agents that take are passed an ongoing conversation (thread) including configuration (tools, limits etc.) and continue the conversation -- **Threads**: Conversation sessions that contain messages -- **Messages**: Individual messages by user or assistants in a thread -- **Runs**: Calling the agent with the thread to continue conversation which results in tool calls and calls of other assistants and messages being added to the thread -- **Files**: Attachments and resources that can be referenced in messages -- **MCP Configs**: Model Context Protocol configurations for AI models to retrieve tools from MCP servers enabling to pluging in custom tools - -```mermaid -classDiagram - class Assistant { - +id: AssistantId - +name: str - +description: str - +avatar: str - +created_at: UnixDatetime - } - - class Thread { - +id: ThreadId - +name: str - +created_at: UnixDatetime - } - - class Message { - +id: MessageId - +role: "user" | "assistant" - +content: str | ContentBlockParam[] - +assistant_id: AssistantId - +run_id: RunId - +thread_id: ThreadId - +created_at: UnixDatetime - } - - class Run { - +id: RunId - +assistant_id: AssistantId - +thread_id: ThreadId - +status: RunStatus - +created_at: UnixDatetime - +expires_at: UnixDatetime - +started_at: UnixDatetime - +completed_at: UnixDatetime - } - - class File { - +id: FileId - +filename: str - +size: int - +media_type: str - +created_at: UnixDatetime - +create(params) - } - - class MCPConfig { - +id: MCPConfigId - +name: str - +created_at: UnixDatetime - +mcp_server: McpServer - } - - Message --* Thread : contained in - Run --* Thread : continues conversation - Run --> MCPConfig : retrieves tools with - Run --> Assistant : executes - Assistant --> Message : generates - File --o Message : referenced in -``` - -### Chat Conversation Flow - -```mermaid -sequenceDiagram - participant Client - participant API - participant ThreadService - participant MessageService - participant FileService - participant RunService - participant AssistantService - participant MCPConfigService - participant MCPServers - participant Model APIs - - Note over Client, Model APIs: Conversation Start/Continue Flow - - Client->>API: POST /runs or POST /threads/{thread_id}/runs - alt New conversation - API->>ThreadService: Create thread - ThreadService-->>API: Thread created - else Existing conversation - API->>ThreadService: Validate thread exists - ThreadService-->>API: Thread validated - end - - API->>MessageService: Get thread messages - MessageService-->>API: Messages list - API->>FileService: Get files for messages (where referenced) - FileService-->>API: Files data - API->>AssistantService: Get assistant details - AssistantService-->>API: Assistant config - API->>RunService: Create run - RunService-->>API: Run created - - Note over RunService, Model APIs: Run execution starts - - RunService->>MCPConfigService: Get MCP configurations - MCPConfigService-->>RunService: MCP configs - RunService->>MCPServers: Build MCP client with tools - MCPServers-->>RunService: MCP client ready - - RunService->>Model APIs: Start agent execution with tools - Model APIs-->>RunService: Agent response - - alt Tool execution required - loop Tool execution loop - alt MCP tool - RunService->>MCPServers: Execute tool via MCP client - MCPServers-->>RunService: Tool result - else In-memory tool - RunService->>RunService: Execute local tool directly - RunService-->>RunService: Tool result - end - RunService->>Model APIs: Continue agent execution with tool result - Model APIs-->>RunService: Next agent response - - Note over RunService, Client: Messages streamed in real-time - RunService->>MessageService: Store message - MessageService-->>RunService: Message stored - RunService->>API: Stream message event - API->>Client: Stream: thread.message.created - end - end - - RunService->>RunService: Update run status to completed - RunService->>API: Stream run completion event - API->>Client: Stream: thread.run.completed - API->>Client: Stream: [DONE] -``` - -## API Reference - -To see the API reference, start the AskUI Chat API server and open the Swagger UI. - -```bash -python -m askui.chat -``` - -Navigate to `http://localhost:9261/docs` in your favorite browser. - -The API reference is interactive and allows you to try out the API with the Swagger UI. - -For most endpoints, you need to specify a `AskUI-Workspace` header which is the workspace id as they are scoped to a workspace. -Navigate to `https://hub.askui.com` and select a workspace through the UI and copy the workspace id from the URL you are directed to. - -## API Usage Examples - -### 0. Start Chat API server and prepare environment - -Start the AskUI Chat API server - -```bash -python -m askui.chat -``` - -In another shell, prepare the environment variables with your workspace id and askui access token (from the [AskUI Hub](https://hub.askui.com)) as well as the base url, e.g., `http://localhost:9261` from the initial logs output by `python -m askui.chat`. - -```bash -export BASE_URL="" -export ASKUI_WORKSPACE="" -export ASKUI_TOKEN="" -export AUTHORIZATION="Basic $(echo -n $ASKUI_TOKEN | base64)" -``` - - -### 1. List available assistants - -```bash -curl -X GET "$BASE_URL/v1/assistants" \ - -H "Authorization: $AUTHORIZATION" \ - -H "AskUI-Workspace: $ASKUI_WORKSPACE" -``` - -**Example Response:** -```json - -{ - "object": "list", - "data": [ - { - "id": "asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmcd", - "name": "AskUI Web Testing Agent", - "description": null, - "avatar": "", - "object": "assistant", - "created_at": 1755848144 - }, - { - "id": "asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmcc", - "name": "AskUI Web Vision Agent", - "description": null, - "avatar": "", - "object": "assistant", - "created_at": 1755848144 - }, - { - "id": "asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmca", - "name": "AskUI Vision Agent", - "description": null, - "avatar": "", - "object": "assistant", - "created_at": 1755848144 - }, - { - "id": "asst_78da09fbf1ed43c7826fb1686f89f541", - "name": "AskUI Android Vision Agent", - "description": null, - "avatar": "", - "object": "assistant", - "created_at": 1755848144 - } - ], - "first_id": "asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmcd", - "last_id": "asst_78da09fbf1ed43c7826fb1686f89f541", - "has_more": false -} -``` - -Choose an assistant and copy the assistants ids to clipboard for the next call to create the conversation. - - -### 2. Start conversation - -Create a conversation (thread) and immediately run the assistant with the thread to continue the conversation with the assistant: - -Make sure to replace `` beforehand. - -```bash -export ASSISTANT_ID="" - -curl -X POST "$BASE_URL/v1/runs" \ - -H "Authorization: $AUTHORIZATION" \ - -H "AskUI-Workspace: $ASKUI_WORKSPACE" \ - -H "Content-Type: application/json" \ - -d "{ - \"assistant_id\": \"$ASSISTANT_ID\", - \"stream\": true, - \"thread\": { - \"name\": \"Quick Chat\", - \"messages\": [ - { - \"role\": \"user\", - \"content\": \"What kind of assistant are you? What can you do for me?\" - } - ] - } - }" -``` - -This will return a Server-Sent Events (SSE) stream with real-time updates. -The `thread.message.created` events contain the assistant's responses. - -**Example Response:** -```bash -event: thread.run.created -data: {"id":"run_68a83b0edc4b2f83ddacba1b","assistant_id":"asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmca","object":"thread.run","thread_id":"thread_68a83b0edc4b2f83ddacba19","created_at":1755855630,"expires_at":1755856230,"started_at":null,"completed_at":null,"failed_at":null,"cancelled_at":null,"tried_cancelling_at":null,"last_error":null,"status":"queued"} - -event: thread.run.queued -data: {"id":"run_68a83b0edc4b2f83ddacba1b","assistant_id":"asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmca","object":"thread.run","thread_id":"thread_68a83b0edc4b2f83ddacba19","created_at":1755855630,"expires_at":1755856230,"started_at":null,"completed_at":null,"failed_at":null,"cancelled_at":null,"tried_cancelling_at":null,"last_error":null,"status":"queued"} - -event: thread.run.in_progress -data: {"id":"run_68a83b0edc4b2f83ddacba1b","assistant_id":"asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmca","object":"thread.run","thread_id":"thread_68a83b0edc4b2f83ddacba19","created_at":1755855630,"expires_at":1755856230,"started_at":1755855630,"completed_at":null,"failed_at":null,"cancelled_at":null,"tried_cancelling_at":null,"last_error":null,"status":"in_progress"} - -event: thread.message.created -data: {"id":"msg_68a83b1fdc4b2f83ddacba1c","role":"assistant","content":[{"thinking":"The user is asking about what kind of assistant I am and what I can do for them. Looking at the functions available to me, I can see I have access to:\n\n1. Computer interaction tools - I can take screenshots, click, type, scroll, and interact with a desktop GUI environment\n2. Display management tools - I can list displays, set active displays, etc.\n3. Exception tool - for error handling\n\nBased on the system capability description, I'm running on a darwin (macOS) machine with arm64 architecture and have internet access. I can help with web tasks by opening browsers, and I have various capabilities for interacting with applications and the desktop environment.\n\nI should explain my capabilities clearly to the user without going into too much technical detail about the specific function names.","type":"thinking"},{"text":"I'm an AI assistant that can help you interact with and control a computer desktop environment. Here's what I can do for you:\n\n## Computer Control & Automation\n- **Take screenshots** to see what's on the screen\n- **Click, type, and navigate** using mouse and keyboard\n- **Open and interact with applications** like browsers, text editors, etc.\n- **Perform web browsing tasks** - search for information, fill out forms, navigate websites\n- **Automate repetitive tasks** on the desktop\n\n## Specific Capabilities\n- **Web research and browsing** - I can open browsers (Firefox, Chrome, Safari) and search for information online\n- **Application interaction** - Navigate menus, click buttons, fill forms\n- **File and document handling** - Open, read, and work with various file types\n- **Multi-display support** - Work across multiple screens if available\n\n## Current Environment\nI'm running on a macOS system with internet access, so I can help with both local desktop tasks and online activities.\n\n## How I Work\nI interact with the computer visually - I take screenshots to see what's happening, then use mouse clicks and keyboard input to perform actions, just like a human would.\n\n**What would you like me to help you with?** I can assist with tasks like:\n- Researching topics online\n- Opening and using specific applications\n- Automating workflows\n- Finding and organizing information\n- And much more!\n\nJust let me know what you need, and I'll get started!","type":"text","cache_control":null,"citations":null}],"stop_reason":null,"assistant_id":"asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmca","run_id":"run_68a83b0edc4b2f83ddacba1b","object":"thread.message","created_at":1755855647,"thread_id":"thread_68a83b0edc4b2f83ddacba19"} - -event: thread.run.completed -data: {"id":"run_68a83b0edc4b2f83ddacba1b","assistant_id":"asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmca","object":"thread.run","thread_id":"thread_68a83b0edc4b2f83ddacba19","created_at":1755855630,"expires_at":1755856230,"started_at":1755855630,"completed_at":1755855647,"failed_at":null,"cancelled_at":null,"tried_cancelling_at":null,"last_error":null,"status":"completed"} - -event: done -data: [DONE] -``` - -### 3. Continue conversation - -To continue the conversation, just add a new message to the thread and run the assistant again: - -Make sure to replace `` beforehand with the thread id from the previous response. - -```bash -export THREAD_ID="" - -curl -X POST "$BASE_URL/v1/threads/$THREAD_ID/messages" \ - -H "Authorization: $AUTHORIZATION" \ - -H "AskUI-Workspace: $ASKUI_WORKSPACE" \ - -H "Content-Type: application/json" \ - -d '{ - "role": "user", - "content": "Can you explain that in more detail?" - }' -``` - -```bash -curl -X POST "$BASE_URL/v1/threads/$THREAD_ID/runs" \ - -H "Authorization: $AUTHORIZATION" \ - -H "AskUI-Workspace: $ASKUI_WORKSPACE" \ - -H "Content-Type: application/json" \ - -d "{ - \"assistant_id\": \"$ASSISTANT_ID\", - \"stream\": true - }" -``` - -This pattern continues for the entire conversation - add messages and create runs to process them. - -### 4. Retrieve the whole conversation - -*Important:* The `order` parameter is required to retrieve the messages in the chronological order. Only the last 20 messages are returned by default. To go through all messages, check the other parameters of the endpoint in the [API reference](http://localhost:9261/docs). - -```bash -curl -X GET "$BASE_URL/v1/threads/$THREAD_ID/messages?order=asc" \ - -H "Authorization: $AUTHORIZATION" \ - -H "AskUI-Workspace: $ASKUI_WORKSPACE" > conversation.json -``` diff --git a/docs/file-support.md b/docs/file-support.md index 38ec40ce..d8bfab5b 100644 --- a/docs/file-support.md +++ b/docs/file-support.md @@ -12,13 +12,6 @@ The AskUI Vision Agent supports the following file formats for data extraction a - **Maximum File Size**: 20MB - **Processing Method**: **Depends on Usage Context** -#### Chat API Usage (File Stored, Extraction Per-Run) - -- **File Storage**: PDF files are stored as PDF files, not as extracted text -- **Per-Run Extraction**: Data extraction happens on every chat run using AskUI’s Gemini models (see [translator.py#L67](https://github.com/askui/vision-agent/blob/7993387387c7a5b9e2e813ca4124f8f613d8107b/src/askui/chat/api/messages/translator.py#L67)) -- **No Caching**: Currently, no caching mechanism exists - PDF content is re-extracted for each chat run -- **Multiple PDFs**: If multiple PDF files are present, all will be re-extracted on every run -- **Architecture**: PDF file stored → Extract on each chat run → Process with Gemini → LLM response stored in chat history (extracted content not stored) #### VisionAgent.get() Usage (No History) @@ -27,23 +20,10 @@ The AskUI Vision Agent supports the following file formats for data extraction a - **No extraction caching**: File is processed fresh for each separate `get()` call - **Architecture**: PDF → Gemini (direct processing) → Return results → No storage - **Model Support**: - - ✅ **Chat API**: AskUI Gemini models extract PDF content on every chat run (no caching) - ✅ **VisionAgent.get()**: AskUI Gemini models process PDF directly for each query #### Processing Workflow for PDF Files -**Chat API Workflow (Per-run extraction):** - -```mermaid -graph TD - A[Upload PDF] --> B[Store PDF file in File Service] - B --> C[New Chat Run Started] - C --> D[Load PDF as PdfSource] - D --> E[Send directly as binary to Gemini] - E --> F[Gemini extracts content for this run] - F --> G[LLM response stored in chat history] - G --> H[Next chat run repeats extraction - no cached content] -``` **VisionAgent.get() Workflow (Per-query processing):** @@ -60,11 +40,10 @@ graph TD - **20MB file size limit** for PDF files - **Processing model restriction**: Only AskUI-hosted Gemini models can process PDFs -- **No caching mechanism**: PDF content is re-extracted on every chat run (both Chat API and VisionAgent.get()) +- **No caching mechanism**: PDF content is re-extracted on every run - **Performance impact**: - - Chat API: PDF re-extracted for each chat run, slower with multiple PDFs - VisionAgent.get(): PDF processed for each individual query -- **Multiple PDF overhead**: All PDF files are re-processed on every chat run +- **Multiple PDF overhead**: All PDF files are re-processed on every run - **Future enhancement**: Caching mechanism may be implemented to avoid repeated extraction ### 📊 Excel Files (.xlsx, .xls) @@ -74,14 +53,6 @@ graph TD - `application/vnd.ms-excel` (.xls) - **Processing Method**: **Depends on Usage Context** -#### Chat API Usage (File Stored, Conversion Per-Run) - -- **File Storage**: Excel files are stored as Excel files, not as extracted markdown text -- **Per-Run Conversion**: Files are converted to markdown on every chat run using [`markitdown`](https://github.com/microsoft/markitdown) library (no AI involved) -- **No Caching**: Currently, no caching mechanism exists - Excel content is re-converted for each chat run -- **Multiple Files**: If multiple Excel files are present, all will be re-converted on every run -- **Architecture**: Excel file stored → Convert to markdown on each chat run → Pass directly to target LLM (e.g., Anthropic for computer use) → LLM response stored in chat history (converted content not stored) - #### VisionAgent.get() Usage (No History) - **Per-Query Processing**: Each `get()` command converts the Excel file to markdown fresh, no history is maintained @@ -95,23 +66,10 @@ graph TD - Deterministic conversion process (same input = same output) - **No AI in conversion**: `markitdown` performs rule-based conversion - **Model Support**: - - ✅ **Chat API**: Converted markdown passed directly to target LLM (e.g., Anthropic for computer use) - no Gemini processing - ✅ **VisionAgent.get()**: Only Gemini models can process converted markdown for each query #### Processing Workflow for Excel Files -**Chat API Workflow (Per-run conversion):** - -```mermaid -graph TD - A[Upload Excel] --> B[Store Excel file in File Service] - B --> C[New Chat Run Started] - C --> D[Load Excel as OfficeDocumentSource] - D --> E[Convert to Markdown using markitdown - NO AI] - E --> F[Pass markdown directly to target LLM - Anthropic for computer use] - F --> G[LLM response stored in chat history] - G --> H[Next chat run repeats conversion - no cached content] -``` **VisionAgent.get() Workflow (Per-query conversion):** @@ -127,18 +85,16 @@ graph TD #### Excel-Specific Limitations - **No specific file size limit** mentioned (limited by general upload constraints) -- **No caching mechanism**: Excel content is re-converted on every chat run (both Chat API and VisionAgent.get()) +- **No caching mechanism**: Excel content is re-converted on every run - Conversion quality depends on [`markitdown`](https://github.com/microsoft/markitdown) library capabilities - Complex formatting may be simplified during markdown conversion - Embedded objects (charts, complex tables) may not preserve all details - **Processing model differences**: - - Chat API: Converted markdown passed directly to target LLM (e.g., Anthropic) - no Gemini processing - VisionAgent.get(): Only Gemini models can process converted content - **No AI in conversion**: Conversion is deterministic and rule-based, not AI-powered - **Performance impact**: - - Chat API: Excel re-converted for each chat run, but no additional AI processing overhead - VisionAgent.get(): Excel converted for each individual query -- **Multiple file overhead**: All Excel files are re-processed on every chat run +- **Multiple file overhead**: All Excel files are re-processed on every run - **Future enhancement**: Caching mechanism may be implemented to avoid repeated conversion ### 📝 Word Documents (.doc, .docx) @@ -148,13 +104,6 @@ graph TD - `application/msword` (.doc) - **Processing Method**: **Depends on Usage Context** -#### Chat API Usage (File Stored, Conversion Per-Run) - -- **File Storage**: Word documents are stored as Word files, not as extracted markdown text -- **Per-Run Conversion**: Files are converted to markdown on every chat run using [`markitdown`](https://github.com/microsoft/markitdown) library (no AI involved) -- **No Caching**: Currently, no caching mechanism exists - Word content is re-converted for each chat run -- **Multiple Files**: If multiple Word files are present, all will be re-converted on every run -- **Architecture**: Word file stored → Convert to markdown on each chat run → Pass directly to target LLM (e.g., Anthropic for computer use) → LLM response stored in chat history (converted content not stored) #### VisionAgent.get() Usage (No History) @@ -169,23 +118,10 @@ graph TD - No AI-generated image descriptions during conversion (handled by `markitdown`) - **No AI in conversion**: `markitdown` performs rule-based conversion - **Model Support**: - - ✅ **Chat API**: Converted markdown passed directly to target LLM (e.g., Anthropic for computer use) - no Gemini processing - ✅ **VisionAgent.get()**: Only Gemini models can process converted markdown for each query #### Processing Workflow for Word Documents -**Chat API Workflow (Per-run conversion):** - -```mermaid -graph TD - A[Upload Word] --> B[Store Word file in File Service] - B --> C[New Chat Run Started] - C --> D[Load Word as OfficeDocumentSource] - D --> E[Convert to Markdown using markitdown - NO AI] - E --> F[Pass markdown directly to target LLM - Anthropic for computer use] - F --> G[LLM response stored in chat history] - G --> H[Next chat run repeats conversion - no cached content] -``` **VisionAgent.get() Workflow (Per-query conversion):** @@ -201,24 +137,21 @@ graph TD #### Word Document-Specific Limitations - **No specific file size limit** mentioned (limited by general upload constraints) -- **No caching mechanism**: Word content is re-converted on every chat run (both Chat API and VisionAgent.get()) +- **No caching mechanism**: Word content is re-converted on every run - Conversion quality depends on [`markitdown`](https://github.com/microsoft/markitdown) library capabilities - Complex formatting may be simplified during markdown conversion - Embedded objects (charts, complex tables) may not preserve all details - **Processing model differences**: - - Chat API: Converted markdown passed directly to target LLM (e.g., Anthropic) - no Gemini processing - VisionAgent.get(): Only Gemini models can process converted content - **No AI in conversion**: Conversion is deterministic and rule-based, not AI-powered - **Performance impact**: - - Chat API: Word re-converted for each chat run, but no additional AI processing overhead - VisionAgent.get(): Word converted for each individual query -- **Multiple file overhead**: All Word files are re-processed on every chat run +- **Multiple file overhead**: All Word files are re-processed on every run - **Future enhancement**: Caching mechanism may be implemented to avoid repeated conversion ### 📈 CSV Files (.csv) - **Status**: **Not directly supported by the backend** -- **Likely Processing**: CSV files are most probably converted to text on the frontend and sent to the chat API as plain text content - **Note**: No specific CSV processing logic was found in the backend codebase, suggesting frontend preprocessing #### Processing Workflow for CSV Files @@ -238,16 +171,6 @@ graph TD - **Text-only processing**: Treated as regular text content by the LLM - **No file size limits**: Since processing happens on frontend, backend file limits don’t apply -## Processing Architecture - -### File Upload and Storage - -1. **File Upload**: Files are uploaded via the Chat API (`/v1/files` endpoint) -2. **Storage**: Files are stored locally in the chat data directory with metadata -3. **Size Limits**: - - **PDF files**: 20MB maximum file size - - **Office documents**: No specific size limit mentioned (limited by general upload constraints) -4. **MIME Type Detection**: Automatic MIME type detection using the `filetype` library ## Technical Implementation Details @@ -271,14 +194,8 @@ graph TD - **Google Gemini API** (`src/askui/models/askui/google_genai_api.py`): - **Initial Processing Only**: Direct PDF processing (binary data) and Office document processing during upload - Currently the only models that support initial document processing during upload - - Once extracted, the text is stored in chat history for all models to use - **Anthropic Models**: Can process extracted document text from chat history, primary LLM for computer use tasks -#### 4. Chat API Integration - -- **File Service** (`src/askui/chat/api/files/service.py`): Handles file upload and storage -- **Message Translator** (`src/askui/chat/api/messages/translator.py`): Converts files to LLM-compatible format during upload -- **Chat Message History**: Stores extracted document text for future reference by any model ### Dependencies @@ -290,10 +207,6 @@ The following key dependencies enable file format support: ## Usage Examples -### Chat API Frontend Implementation - -The Chat API provides comprehensive document processing capabilities on the frontend, including file upload, processing, and interaction flows for a complete user experience from uploading documents to receiving processed responses. - ### Processing Excel Files #### Using VisionAgent.get() (Per-query conversion) @@ -311,15 +224,6 @@ with VisionAgent() as agent: ) ``` -#### Using Chat API (Per-run conversion, no caching) - -```python -# When using the Chat API, Excel files are stored but content is re-converted# on every chat run - no caching mechanism currently exists## Example: If you have 2 Excel files uploaded and start a new chat:# - Both Excel files will be converted to markdown by markitdown again# - This happens for every new chat run# - Future enhancement: caching may be implemented to avoid repeated conversion -``` - -**Frontend Implementation Reference:** -The Chat API provides a complete flow from file upload to processing with comprehensive document handling capabilities. - ### Processing PDF Files #### Using VisionAgent.get() (Per-query processing) @@ -337,12 +241,6 @@ with VisionAgent() as agent: ) ``` -#### Using Chat API (Per-run extraction, no caching) - -```python -# When using the Chat API, PDF files are stored but content is re-extracted# on every chat run - no caching mechanism currently exists## Example: If you have 3 PDF files uploaded and start a new chat:# - All 3 PDFs will be processed by Gemini again# - This happens for every new chat run# - Future enhancement: caching may be implemented to avoid repeated extraction -``` - ### Processing Word Documents #### Using VisionAgent.get() (Per-query conversion) @@ -360,39 +258,31 @@ with VisionAgent() as agent: ) ``` -#### Using Chat API (Per-run conversion, no caching) - -```python -# When using the Chat API, Word files are stored but content is re-converted# on every chat run - no caching mechanism currently exists## Example: If you have 3 Word files uploaded and start a new chat:# - All 3 Word files will be converted to markdown by markitdown again# - This happens for every new chat run# - Future enhancement: caching may be implemented to avoid repeated conversion -``` ## General Limitations and Considerations - **Processing Model Restriction**: Currently, only Gemini models support document processing -- **No Caching Mechanism**: All document files (PDF, Excel, Word) are re-processed on every chat run and VisionAgent.get() call +- **No Caching Mechanism**: All document files (PDF, Excel, Word) are re-processed on every VisionAgent.get() call - **File Storage Only**: Files are stored as original files, not as extracted/converted content - **Runtime Conversion**: Document processing happens on runtime to create LLM-compatible messages -- **Chat History Storage**: LLM responses (e.g., from Anthropic computer use) are stored in chat history, but processed document content is not - **Performance Impact**: Multiple documents mean multiple processing operations on every run ### Performance Considerations -- **No caching**: All document files are re-processed on every chat run and VisionAgent.get() call +- **No caching**: All document files are re-processed on every VisionAgent.get() call - **Multiple file overhead**: Having multiple documents significantly impacts performance as all are re-processed - **Processing types**: - PDFs: Re-extracted by Gemini on every run (AI processing overhead) - Office documents: Re-converted by `markitdown` on every run (fast, deterministic, no AI overhead) -- **Chat API vs VisionAgent.get()**: Both follow the same pattern - no caching, always re-process - Currently limited to Gemini models for document processing, which may create processing bottlenecks ## Model Compatibility Matrix | File Format | AskUI Gemini | Anthropic Claude | OpenRouter | AskUI Inference API | | ------------------- | ---------------------------------------- | ----------------------------------------- | ----------------------------------------- | ----------------------------------------- | -| PDF (.pdf) | ✅ Direct Processing (per-run/per-query) | ❌ Not Supported | ❌ Not Supported | ❌ Not Supported | -| Excel (.xlsx, .xls) | ✅ VisionAgent.get() only | ✅ Chat API (receives converted markdown) | ✅ Chat API (receives converted markdown) | ✅ Chat API (receives converted markdown) | -| Word (.docx, .doc) | ✅ VisionAgent.get() only | ✅ Chat API (receives converted markdown) | ✅ Chat API (receives converted markdown) | ✅ Chat API (receives converted markdown) | -| CSV (.csv) | ⚠️ Frontend Only | ⚠️ Frontend Only | ⚠️ Frontend Only | ⚠️ Frontend Only | +| PDF (.pdf) | ✅ | ❌ | ❌ | ❌ | +| Excel (.xlsx, .xls) | ✅ | ✅ | ✅ | ✅ | +| Word (.docx, .doc) | ✅ | ✅ | ✅ | ✅ | **Legend:** @@ -403,11 +293,9 @@ with VisionAgent() as agent: ## Best Practices 1. **Understand Model Usage**: - - **PDFs**: Only Gemini models can process PDFs (both Chat API and VisionAgent.get()) - - **Office docs (Chat API)**: Converted markdown passed directly to target LLM (e.g., Anthropic) + - **PDFs**: Only Gemini models can process PDFs - **Office docs (VisionAgent.get())**: Only Gemini models can process converted content 2. **Understand Processing Flow**: - - Office docs (Chat API): `markitdown` (non-AI) conversion → Direct to target LLM (e.g., Anthropic) - Office docs (VisionAgent.get()): `markitdown` (non-AI) conversion → Gemini processes converted content - PDFs: Direct binary processing by Gemini on every run/query → No caching 3. **Optimize File Size**: Keep PDF files under 20MB (required limit); Office documents have no specific size limit @@ -418,7 +306,7 @@ with VisionAgent() as agent: Potential areas for improvement: -- **PDF Caching Mechanism**: Implement caching to avoid re-extracting PDF content on every chat run +- **PDF Caching Mechanism**: Implement caching to avoid re-extracting PDF content on every run - **Expand Model Support**: Enable document processing for other models (Anthropic, OpenRouter, etc.) - **Native CSV Support**: Add backend CSV processing capabilities - **Additional Formats**: Support for PowerPoint, RTF, and other document formats diff --git a/docs/migrations.md b/docs/migrations.md deleted file mode 100644 index ffc76df4..00000000 --- a/docs/migrations.md +++ /dev/null @@ -1,213 +0,0 @@ -# Database Migrations - -This document explains how database migrations work in the AskUI Chat system. - -## Overview - -Database migrations are used to manage changes to the database schema and data over time. They ensure that your database structure stays in sync with the application code and handle data transformations when the schema changes. - -## What Are Migrations Used For? - -Migrations in the AskUI Chat system are primarily used for: - -- **Schema Changes**: Creating, modifying, or dropping database tables and columns -- **Data Migrations**: Transforming existing data when the schema changes -- **Persistence Layer Evolution**: Migrating from one persistence format to another (e.g., JSON files to SQLite database) -- **Seed Data**: Populating the database with default data - -### Example Use Cases - -The current migration history shows several real-world examples: - -1. **`4d1e043b4254_create_assistants_table.py`**: Creates the initial `assistants` table with columns for ID, workspace, timestamps, and assistant configuration -2. **`057f82313448_import_json_assistants.py`**: Migrates existing assistant data from JSON files to the new SQLite database -3. **`c35e88ea9595_seed_default_assistants.py`**: Seeds the database with default assistant configurations -4. **`37007a499ca7_remove_assistants_dir.py`**: Cleans up the old JSON-based persistence by removing the assistants directory - -### Our current migration strategy - -#### Until `5e6f7a8b9c0d_import_json_messages.py` - -On Upgrade: -- We migrate from file system persistence to SQLite database persistence. We don't delete any of the files from the file system so rolling back is as easy as just installing an older version of the `askui` library. - -On Downgrade: -- This is mainly to be used by us for debugging and testing new migrations but not a user. -- We export data from database but already existing files take precedence so you may loose some data that was upgraded or deleted between the upgrade and downgrade. Also you may loose some of the data that was not originally available in the schema, e.g., global files (not scoped to workspace). - -## Automatic Migrations on Startup - -By default, migrations are automatically run when the chat API starts up. This ensures that users are always upgraded to the newest database schema version without manual intervention. - -### Configuration - -The automatic migration behavior is controlled by the `auto_migrate` setting in the database configuration: - -```python -class DbSettings(BaseModel): - auto_migrate: bool = Field( - default=True, - description="Whether to run migrations automatically on startup", - ) -``` - -### Environment Variable Override - -You can disable automatic migrations for debugging purposes using the environment variable: - -```bash -export ASKUI__CHAT_API__DB__AUTO_MIGRATE=false -``` - -When disabled, the application will log: -``` -Automatic migrations are disabled. Skipping migrations... -``` - -## Manual Migration Commands - -You can run migrations manually using the Alembic command-line interface: - -```bash -# Run all pending migrations -pdm run alembic upgrade head - -# Run migrations to a specific revision -pdm run alembic upgrade - -# Downgrade to a previous revision -pdm run alembic downgrade - -# Show current migration status -pdm run alembic current - -# Show migration history -pdm run alembic history - -# Generate a new migration -pdm run alembic revision --autogenerate -m "description of changes" -``` - -## Migration Structure - -### Directory Layout - -``` -src/askui/chat/migrations/ -├── alembic.ini # Alembic configuration -├── env.py # Migration environment setup -├── runner.py # Migration runner for programmatic execution -├── script.py.mako # Template for new migration files -├── shared/ # Shared utilities and models for migrations -│ ├── assistants/ # Assistant-related migration utilities -│ ├── models.py # Shared data models -│ └── settings.py # Settings for migrations -└── versions/ # Individual migration files - ├── 4d1e043b4254_create_assistants_table.py - ├── 057f82313448_import_json_assistants.py - ├── c35e88ea9595_seed_default_assistants.py - └── 37007a499ca7_remove_assistants_dir.py -``` - -### Migration File Structure - -Each migration file follows this structure: - -```python -"""migration_description - -Revision ID: -Revises: -Create Date: - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - -# revision identifiers, used by Alembic. -revision: str = "" -down_revision: Union[str, None] = "" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - """Apply the migration changes.""" - # Migration logic here - pass - - -def downgrade() -> None: - """Revert the migration changes.""" - # Rollback logic here - pass -``` - -## Migration Execution Flow - -1. **Startup Check**: When the chat API starts, it checks the `auto_migrate` setting -2. **Migration Runner**: If enabled, calls `run_migrations()` from `runner.py` -3. **Alembic Execution**: Uses Alembic's `upgrade` command to apply all pending migrations -4. **Database Connection**: Connects to the database using settings from `env.py` -5. **Schema Application**: Applies each migration in sequence until reaching the "head" revision - -## Database Configuration - -The migration system uses the same database configuration as the main application: - -- **Database URL**: Configured via `ASKUI__CHAT_API__DB__URL` (defaults to SQLite) -- **Connection**: Uses the same SQLAlchemy engine as the main application -- **Metadata**: Automatically detects schema changes from SQLAlchemy models - -## Best Practices - -### Creating New Migrations - -1. **Use Autogenerate**: Let Alembic detect schema changes automatically: - ```bash - pdm run alembic revision --autogenerate -m "add new column to table" - ``` - -2. **Review Generated Code**: Always review and test autogenerated migrations before applying - -3. **Handle Data Migrations**: For complex data transformations, write custom migration logic - -4. **Test Both Directions**: Ensure both `upgrade()` and `downgrade()` functions work correctly - -### Migration Safety - -1. **Backup First**: Always backup database before running migrations so that it can be easily rolled back if something goes wrong -2. **Test Locally**: Test migrations on a copy of production data -3. **Rollback Plan**: Have a rollback strategy for critical migrations -4. **Batch Operations**: For large data migrations, process data in batches to avoid memory issues -5. **Keep Old Code Around**: Keep old code versioned around so that migrations are independent of the version of AskUI chat - -## Troubleshooting - -### Common Issues - -1. **Migration Conflicts**: If multiple developers create migrations simultaneously, you may need to resolve conflicts manually -2. **Data Loss**: Some migrations (like dropping columns) can cause data loss - always review carefully -3. **Performance**: Large data migrations can be slow - consider running them not during startup but in the background maintaining compatibility with old code for as long as it runs or just disabling certain apis for that period of time - -### Debugging - -1. **Check Migration Status**: - ```bash - pdm run alembic current - ``` - -2. **View Migration History**: - ```bash - pdm run alembic history --verbose - ``` - -3. **Disable Auto-Migration**: Use the environment variable to disable automatic migrations during debugging - -## Related Documentation - -- [Alembic Documentation](https://alembic.sqlalchemy.org/) - Official Alembic migration tool documentation -- [SQLAlchemy Documentation](https://docs.sqlalchemy.org/) - SQLAlchemy ORM and database toolkit -- [Database Models](../src/askui/chat/api/) - Current database schema and models diff --git a/pdm.lock b/pdm.lock index 315f6d85..f4f14cb6 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,10 +2,10 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "all", "android", "bedrock", "chat", "dev", "pynput", "vertex", "web"] +groups = ["default", "all", "android", "bedrock", "dev", "pynput", "vertex", "web"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:ea8a886d3282bd07778ce1bbdc5782490f78deca9c4a50e22fbebb6a8742eabd" +content_hash = "sha256:0af1559aa395b49a67778024f8c05e96065d40d8c1c8dd0f3aa11d2542ca176a" [[metadata.targets]] requires_python = ">=3.10,<3.14" @@ -21,23 +21,6 @@ files = [ {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, ] -[[package]] -name = "alembic" -version = "1.16.5" -requires_python = ">=3.9" -summary = "A database migration tool for SQLAlchemy." -groups = ["all", "chat"] -dependencies = [ - "Mako", - "SQLAlchemy>=1.4.0", - "tomli; python_version < \"3.11\"", - "typing-extensions>=4.12", -] -files = [ - {file = "alembic-1.16.5-py3-none-any.whl", hash = "sha256:e845dfe090c5ffa7b92593ae6687c5cb1a101e91fa53868497dbd79847f9dbe3"}, - {file = "alembic-1.16.5.tar.gz", hash = "sha256:a88bb7f6e513bd4301ecf4c7f2206fe93f9913f9b48dac3b78babde2d6fe765e"}, -] - [[package]] name = "annotated-types" version = "0.7.0" @@ -111,7 +94,7 @@ name = "anyio" version = "4.10.0" requires_python = ">=3.9" summary = "High-level concurrency and networking framework on top of asyncio or Trio" -groups = ["default", "all", "bedrock", "chat", "vertex"] +groups = ["default", "all", "bedrock", "vertex"] dependencies = [ "exceptiongroup>=1.0.2; python_version < \"3.11\"", "idna>=2.8", @@ -168,35 +151,6 @@ files = [ {file = "arrow-1.4.0.tar.gz", hash = "sha256:ed0cc050e98001b8779e84d461b0098c4ac597e88704a655582b21d116e526d7"}, ] -[[package]] -name = "asgi-correlation-id" -version = "4.3.4" -requires_python = "<4.0,>=3.8" -summary = "Middleware correlating project logs to individual requests" -groups = ["all", "chat"] -dependencies = [ - "packaging", - "starlette>=0.18", -] -files = [ - {file = "asgi_correlation_id-4.3.4-py3-none-any.whl", hash = "sha256:36ce69b06c7d96b4acb89c7556a4c4f01a972463d3d49c675026cbbd08e9a0a2"}, - {file = "asgi_correlation_id-4.3.4.tar.gz", hash = "sha256:ea6bc310380373cb9f731dc2e8b2b6fb978a76afe33f7a2384f697b8d6cd811d"}, -] - -[[package]] -name = "asgiref" -version = "3.10.0" -requires_python = ">=3.9" -summary = "ASGI specs, helper code, and adapters" -groups = ["all", "chat"] -dependencies = [ - "typing-extensions>=4; python_version < \"3.11\"", -] -files = [ - {file = "asgiref-3.10.0-py3-none-any.whl", hash = "sha256:aef8a81283a34d0ab31630c9b7dfe70c812c95eba78171367ca8745e88124734"}, - {file = "asgiref-3.10.0.tar.gz", hash = "sha256:d89f2d8cd8b56dada7d52fa7dc8075baa08fb836560710d38c292a7a3f78c04e"}, -] - [[package]] name = "asyncer" version = "0.0.8" @@ -384,7 +338,7 @@ name = "certifi" version = "2025.8.3" requires_python = ">=3.7" summary = "Python package for providing Mozilla's CA Bundle." -groups = ["default", "all", "bedrock", "chat", "vertex"] +groups = ["default", "all", "bedrock", "vertex"] files = [ {file = "certifi-2025.8.3-py3-none-any.whl", hash = "sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5"}, {file = "certifi-2025.8.3.tar.gz", hash = "sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407"}, @@ -491,7 +445,7 @@ name = "charset-normalizer" version = "3.4.3" requires_python = ">=3.7" summary = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -groups = ["default", "all", "chat", "vertex"] +groups = ["default", "all", "vertex"] files = [ {file = "charset_normalizer-3.4.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fb7f67a1bfa6e40b438170ebdc8158b78dc465a5a67b6dde178a46987b244a72"}, {file = "charset_normalizer-3.4.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cc9370a2da1ac13f0153780040f465839e6cccb4a1e44810124b4e22483c93fe"}, @@ -557,7 +511,7 @@ name = "click" version = "8.3.0" requires_python = ">=3.10" summary = "Composable command line interface toolkit" -groups = ["default", "all", "chat", "dev"] +groups = ["default", "dev"] dependencies = [ "colorama; platform_system == \"Windows\"", ] @@ -582,7 +536,7 @@ name = "colorama" version = "0.4.6" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" summary = "Cross-platform colored terminal text." -groups = ["default", "all", "chat", "dev"] +groups = ["default", "dev"] marker = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, @@ -1072,7 +1026,7 @@ name = "exceptiongroup" version = "1.3.0" requires_python = ">=3.7" summary = "Backport of PEP 654 (exception groups)" -groups = ["default", "all", "bedrock", "chat", "dev", "vertex"] +groups = ["default", "all", "bedrock", "dev", "vertex"] dependencies = [ "typing-extensions>=4.6.0; python_version < \"3.13\"", ] @@ -1455,7 +1409,7 @@ name = "googleapis-common-protos" version = "1.71.0" requires_python = ">=3.7" summary = "Common protobufs used in Google APIs" -groups = ["all", "chat", "vertex"] +groups = ["all", "vertex"] dependencies = [ "protobuf!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.20.2", ] @@ -1504,7 +1458,7 @@ name = "greenlet" version = "3.2.4" requires_python = ">=3.9" summary = "Lightweight in-process concurrent programming" -groups = ["default", "all", "chat", "dev", "web"] +groups = ["default", "all", "dev", "web"] files = [ {file = "greenlet-3.2.4-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:8c68325b0d0acf8d91dde4e6f930967dd52a5302cd4062932a6b2e7c2969f47c"}, {file = "greenlet-3.2.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:94385f101946790ae13da500603491f04a76b6e4c059dab271b3ce2e283b2590"}, @@ -1721,7 +1675,7 @@ name = "h11" version = "0.16.0" requires_python = ">=3.8" summary = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -groups = ["default", "all", "bedrock", "chat", "vertex"] +groups = ["default", "all", "bedrock", "vertex"] files = [ {file = "h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86"}, {file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"}, @@ -1831,27 +1785,12 @@ name = "idna" version = "3.10" requires_python = ">=3.6" summary = "Internationalized Domain Names in Applications (IDNA)" -groups = ["default", "all", "bedrock", "chat", "dev", "vertex"] +groups = ["default", "all", "bedrock", "dev", "vertex"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, ] -[[package]] -name = "importlib-metadata" -version = "8.7.0" -requires_python = ">=3.9" -summary = "Read metadata from Python packages" -groups = ["all", "chat"] -dependencies = [ - "typing-extensions>=3.6.4; python_version < \"3.8\"", - "zipp>=3.20", -] -files = [ - {file = "importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd"}, - {file = "importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000"}, -] - [[package]] name = "inflect" version = "7.5.0" @@ -2329,20 +2268,6 @@ files = [ {file = "magika-0.6.2.tar.gz", hash = "sha256:37eb6ae8020f6e68f231bc06052c0a0cbe8e6fa27492db345e8dc867dbceb067"}, ] -[[package]] -name = "mako" -version = "1.3.10" -requires_python = ">=3.8" -summary = "A super-fast templating language that borrows the best ideas from the existing templating languages." -groups = ["all", "chat"] -dependencies = [ - "MarkupSafe>=0.9.2", -] -files = [ - {file = "mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59"}, - {file = "mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28"}, -] - [[package]] name = "mammoth" version = "1.11.0" @@ -2431,7 +2356,7 @@ name = "markupsafe" version = "3.0.2" requires_python = ">=3.9" summary = "Safely add untrusted strings to HTML/XML markup." -groups = ["default", "all", "chat", "dev"] +groups = ["default", "dev"] files = [ {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8"}, {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158"}, @@ -2819,200 +2744,6 @@ files = [ {file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"}, ] -[[package]] -name = "opentelemetry-api" -version = "1.38.0" -requires_python = ">=3.9" -summary = "OpenTelemetry Python API" -groups = ["all", "chat"] -dependencies = [ - "importlib-metadata<8.8.0,>=6.0", - "typing-extensions>=4.5.0", -] -files = [ - {file = "opentelemetry_api-1.38.0-py3-none-any.whl", hash = "sha256:2891b0197f47124454ab9f0cf58f3be33faca394457ac3e09daba13ff50aa582"}, - {file = "opentelemetry_api-1.38.0.tar.gz", hash = "sha256:f4c193b5e8acb0912b06ac5b16321908dd0843d75049c091487322284a3eea12"}, -] - -[[package]] -name = "opentelemetry-exporter-otlp-proto-common" -version = "1.38.0" -requires_python = ">=3.9" -summary = "OpenTelemetry Protobuf encoding" -groups = ["all", "chat"] -dependencies = [ - "opentelemetry-proto==1.38.0", -] -files = [ - {file = "opentelemetry_exporter_otlp_proto_common-1.38.0-py3-none-any.whl", hash = "sha256:03cb76ab213300fe4f4c62b7d8f17d97fcfd21b89f0b5ce38ea156327ddda74a"}, - {file = "opentelemetry_exporter_otlp_proto_common-1.38.0.tar.gz", hash = "sha256:e333278afab4695aa8114eeb7bf4e44e65c6607d54968271a249c180b2cb605c"}, -] - -[[package]] -name = "opentelemetry-exporter-otlp-proto-http" -version = "1.38.0" -requires_python = ">=3.9" -summary = "OpenTelemetry Collector Protobuf over HTTP Exporter" -groups = ["all", "chat"] -dependencies = [ - "googleapis-common-protos~=1.52", - "opentelemetry-api~=1.15", - "opentelemetry-exporter-otlp-proto-common==1.38.0", - "opentelemetry-proto==1.38.0", - "opentelemetry-sdk~=1.38.0", - "requests~=2.7", - "typing-extensions>=4.5.0", -] -files = [ - {file = "opentelemetry_exporter_otlp_proto_http-1.38.0-py3-none-any.whl", hash = "sha256:84b937305edfc563f08ec69b9cb2298be8188371217e867c1854d77198d0825b"}, - {file = "opentelemetry_exporter_otlp_proto_http-1.38.0.tar.gz", hash = "sha256:f16bd44baf15cbe07633c5112ffc68229d0edbeac7b37610be0b2def4e21e90b"}, -] - -[[package]] -name = "opentelemetry-instrumentation" -version = "0.59b0" -requires_python = ">=3.9" -summary = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" -groups = ["all", "chat"] -dependencies = [ - "opentelemetry-api~=1.4", - "opentelemetry-semantic-conventions==0.59b0", - "packaging>=18.0", - "wrapt<2.0.0,>=1.0.0", -] -files = [ - {file = "opentelemetry_instrumentation-0.59b0-py3-none-any.whl", hash = "sha256:44082cc8fe56b0186e87ee8f7c17c327c4c2ce93bdbe86496e600985d74368ee"}, - {file = "opentelemetry_instrumentation-0.59b0.tar.gz", hash = "sha256:6010f0faaacdaf7c4dff8aac84e226d23437b331dcda7e70367f6d73a7db1adc"}, -] - -[[package]] -name = "opentelemetry-instrumentation-asgi" -version = "0.59b0" -requires_python = ">=3.9" -summary = "ASGI instrumentation for OpenTelemetry" -groups = ["all", "chat"] -dependencies = [ - "asgiref~=3.0", - "opentelemetry-api~=1.12", - "opentelemetry-instrumentation==0.59b0", - "opentelemetry-semantic-conventions==0.59b0", - "opentelemetry-util-http==0.59b0", -] -files = [ - {file = "opentelemetry_instrumentation_asgi-0.59b0-py3-none-any.whl", hash = "sha256:ba9703e09d2c33c52fa798171f344c8123488fcd45017887981df088452d3c53"}, - {file = "opentelemetry_instrumentation_asgi-0.59b0.tar.gz", hash = "sha256:2509d6fe9fd829399ce3536e3a00426c7e3aa359fc1ed9ceee1628b56da40e7a"}, -] - -[[package]] -name = "opentelemetry-instrumentation-fastapi" -version = "0.59b0" -requires_python = ">=3.9" -summary = "OpenTelemetry FastAPI Instrumentation" -groups = ["all", "chat"] -dependencies = [ - "opentelemetry-api~=1.12", - "opentelemetry-instrumentation-asgi==0.59b0", - "opentelemetry-instrumentation==0.59b0", - "opentelemetry-semantic-conventions==0.59b0", - "opentelemetry-util-http==0.59b0", -] -files = [ - {file = "opentelemetry_instrumentation_fastapi-0.59b0-py3-none-any.whl", hash = "sha256:0d8d00ff7d25cca40a4b2356d1d40a8f001e0668f60c102f5aa6bb721d660c4f"}, - {file = "opentelemetry_instrumentation_fastapi-0.59b0.tar.gz", hash = "sha256:e8fe620cfcca96a7d634003df1bc36a42369dedcdd6893e13fb5903aeeb89b2b"}, -] - -[[package]] -name = "opentelemetry-instrumentation-httpx" -version = "0.59b0" -requires_python = ">=3.9" -summary = "OpenTelemetry HTTPX Instrumentation" -groups = ["all", "chat"] -dependencies = [ - "opentelemetry-api~=1.12", - "opentelemetry-instrumentation==0.59b0", - "opentelemetry-semantic-conventions==0.59b0", - "opentelemetry-util-http==0.59b0", - "wrapt<2.0.0,>=1.0.0", -] -files = [ - {file = "opentelemetry_instrumentation_httpx-0.59b0-py3-none-any.whl", hash = "sha256:7dc9f66aef4ca3904d877f459a70c78eafd06131dc64d713b9b1b5a7d0a48f05"}, - {file = "opentelemetry_instrumentation_httpx-0.59b0.tar.gz", hash = "sha256:a1cb9b89d9f05a82701cc9ab9cfa3db54fd76932489449778b350bc1b9f0e872"}, -] - -[[package]] -name = "opentelemetry-instrumentation-sqlalchemy" -version = "0.59b0" -requires_python = ">=3.9" -summary = "OpenTelemetry SQLAlchemy instrumentation" -groups = ["all", "chat"] -dependencies = [ - "opentelemetry-api~=1.12", - "opentelemetry-instrumentation==0.59b0", - "opentelemetry-semantic-conventions==0.59b0", - "packaging>=21.0", - "wrapt>=1.11.2", -] -files = [ - {file = "opentelemetry_instrumentation_sqlalchemy-0.59b0-py3-none-any.whl", hash = "sha256:4ef150c49b6d1a8a7328f9d23ff40c285a245b88b0875ed2e5d277a40aa921c8"}, - {file = "opentelemetry_instrumentation_sqlalchemy-0.59b0.tar.gz", hash = "sha256:7647b1e63497deebd41f9525c414699e0d49f19efcadc8a0642b715897f62d32"}, -] - -[[package]] -name = "opentelemetry-proto" -version = "1.38.0" -requires_python = ">=3.9" -summary = "OpenTelemetry Python Proto" -groups = ["all", "chat"] -dependencies = [ - "protobuf<7.0,>=5.0", -] -files = [ - {file = "opentelemetry_proto-1.38.0-py3-none-any.whl", hash = "sha256:b6ebe54d3217c42e45462e2a1ae28c3e2bf2ec5a5645236a490f55f45f1a0a18"}, - {file = "opentelemetry_proto-1.38.0.tar.gz", hash = "sha256:88b161e89d9d372ce723da289b7da74c3a8354a8e5359992be813942969ed468"}, -] - -[[package]] -name = "opentelemetry-sdk" -version = "1.38.0" -requires_python = ">=3.9" -summary = "OpenTelemetry Python SDK" -groups = ["all", "chat"] -dependencies = [ - "opentelemetry-api==1.38.0", - "opentelemetry-semantic-conventions==0.59b0", - "typing-extensions>=4.5.0", -] -files = [ - {file = "opentelemetry_sdk-1.38.0-py3-none-any.whl", hash = "sha256:1c66af6564ecc1553d72d811a01df063ff097cdc82ce188da9951f93b8d10f6b"}, - {file = "opentelemetry_sdk-1.38.0.tar.gz", hash = "sha256:93df5d4d871ed09cb4272305be4d996236eedb232253e3ab864c8620f051cebe"}, -] - -[[package]] -name = "opentelemetry-semantic-conventions" -version = "0.59b0" -requires_python = ">=3.9" -summary = "OpenTelemetry Semantic Conventions" -groups = ["all", "chat"] -dependencies = [ - "opentelemetry-api==1.38.0", - "typing-extensions>=4.5.0", -] -files = [ - {file = "opentelemetry_semantic_conventions-0.59b0-py3-none-any.whl", hash = "sha256:35d3b8833ef97d614136e253c1da9342b4c3c083bbaf29ce31d572a1c3825eed"}, - {file = "opentelemetry_semantic_conventions-0.59b0.tar.gz", hash = "sha256:7a6db3f30d70202d5bf9fa4b69bc866ca6a30437287de6c510fb594878aed6b0"}, -] - -[[package]] -name = "opentelemetry-util-http" -version = "0.59b0" -requires_python = ">=3.9" -summary = "Web util for OpenTelemetry" -groups = ["all", "chat"] -files = [ - {file = "opentelemetry_util_http-0.59b0-py3-none-any.whl", hash = "sha256:6d036a07563bce87bf521839c0671b507a02a0d39d7ea61b88efa14c6e25355d"}, - {file = "opentelemetry_util_http-0.59b0.tar.gz", hash = "sha256:ae66ee91be31938d832f3b4bc4eb8a911f6eddd38969c4a871b1230db2a0a560"}, -] - [[package]] name = "packageurl-python" version = "0.17.6" @@ -3029,7 +2760,7 @@ name = "packaging" version = "25.0" requires_python = ">=3.8" summary = "Core utilities for Python packages" -groups = ["default", "all", "chat", "dev", "vertex"] +groups = ["default", "all", "dev", "vertex"] files = [ {file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"}, {file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"}, @@ -3254,7 +2985,7 @@ name = "playwright" version = "1.55.0" requires_python = ">=3.9" summary = "A high-level API to automate web browsers" -groups = ["all", "chat", "dev", "web"] +groups = ["all", "dev", "web"] dependencies = [ "greenlet<4.0.0,>=3.1.1", "pyee<14,>=13", @@ -3281,32 +3012,6 @@ files = [ {file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"}, ] -[[package]] -name = "prometheus-client" -version = "0.23.1" -requires_python = ">=3.9" -summary = "Python client for the Prometheus monitoring system." -groups = ["all", "chat"] -files = [ - {file = "prometheus_client-0.23.1-py3-none-any.whl", hash = "sha256:dd1913e6e76b59cfe44e7a4b83e01afc9873c1bdfd2ed8739f1e76aeca115f99"}, - {file = "prometheus_client-0.23.1.tar.gz", hash = "sha256:6ae8f9081eaaaf153a2e959d2e6c4f4fb57b12ef76c8c7980202f1e57b48b2ce"}, -] - -[[package]] -name = "prometheus-fastapi-instrumentator" -version = "7.1.0" -requires_python = ">=3.8" -summary = "Instrument your FastAPI app with Prometheus metrics" -groups = ["all", "chat"] -dependencies = [ - "prometheus-client<1.0.0,>=0.8.0", - "starlette<1.0.0,>=0.30.0", -] -files = [ - {file = "prometheus_fastapi_instrumentator-7.1.0-py3-none-any.whl", hash = "sha256:978130f3c0bb7b8ebcc90d35516a6fe13e02d2eb358c8f83887cdef7020c31e9"}, - {file = "prometheus_fastapi_instrumentator-7.1.0.tar.gz", hash = "sha256:be7cd61eeea4e5912aeccb4261c6631b3f227d8924542d79eaf5af3f439cbe5e"}, -] - [[package]] name = "proto-plus" version = "1.26.1" @@ -3326,7 +3031,7 @@ name = "protobuf" version = "6.32.1" requires_python = ">=3.9" summary = "" -groups = ["default", "all", "chat", "dev", "vertex"] +groups = ["default", "all", "dev", "vertex"] files = [ {file = "protobuf-6.32.1-cp310-abi3-win32.whl", hash = "sha256:a8a32a84bc9f2aad712041b8b366190f71dde248926da517bde9e832e4412085"}, {file = "protobuf-6.32.1-cp310-abi3-win_amd64.whl", hash = "sha256:b00a7d8c25fa471f16bc8153d0e53d6c9e827f0953f3c09aaa4331c718cae5e1"}, @@ -3341,7 +3046,7 @@ files = [ name = "pure-python-adb" version = "0.3.0.dev0" summary = "Pure python implementation of the adb client" -groups = ["all", "android", "chat"] +groups = ["all", "android"] files = [ {file = "pure-python-adb-0.3.0.dev0.tar.gz", hash = "sha256:0ecc89d780160cfe03260ba26df2c471a05263b2cad0318363573ee8043fb94d"}, ] @@ -3553,7 +3258,7 @@ name = "pyee" version = "13.0.0" requires_python = ">=3.8" summary = "A rough port of Node.js's EventEmitter to Python with a few tricks of its own" -groups = ["all", "chat", "dev", "web"] +groups = ["all", "dev", "web"] dependencies = [ "typing-extensions", ] @@ -4000,7 +3705,7 @@ name = "requests" version = "2.32.5" requires_python = ">=3.9" summary = "Python HTTP for Humans." -groups = ["default", "all", "chat", "vertex"] +groups = ["default", "all", "vertex"] dependencies = [ "certifi>=2017.4.17", "charset-normalizer<4,>=2", @@ -4387,7 +4092,7 @@ name = "sniffio" version = "1.3.1" requires_python = ">=3.7" summary = "Sniff out which async library your code is running under" -groups = ["default", "all", "bedrock", "chat", "vertex"] +groups = ["default", "all", "bedrock", "vertex"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, @@ -4419,7 +4124,7 @@ name = "sqlalchemy" version = "2.0.44" requires_python = ">=3.7" summary = "Database Abstraction Library" -groups = ["default", "all", "chat"] +groups = ["default"] dependencies = [ "greenlet>=1; platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\"", "importlib-metadata; python_version < \"3.8\"", @@ -4529,7 +4234,7 @@ name = "starlette" version = "0.48.0" requires_python = ">=3.9" summary = "The little ASGI library that shines." -groups = ["default", "all", "chat"] +groups = ["default"] dependencies = [ "anyio<5,>=3.6.2", "typing-extensions>=4.10.0; python_version < \"3.13\"", @@ -4539,34 +4244,6 @@ files = [ {file = "starlette-0.48.0.tar.gz", hash = "sha256:7e8cee469a8ab2352911528110ce9088fdc6a37d9876926e73da7ce4aa4c7a46"}, ] -[[package]] -name = "starlette-context" -version = "0.4.0" -requires_python = "<4.0,>=3.9" -summary = "Middleware for Starlette that allows you to store and access the context data of a request. Can be used with logging so logs automatically use request headers such as x-request-id or x-correlation-id." -groups = ["all", "chat"] -dependencies = [ - "starlette>=0.27.0", -] -files = [ - {file = "starlette_context-0.4.0-py3-none-any.whl", hash = "sha256:dbcc11006587f901edd3d0a989a69a628fccf9d00c1ca3c28fab23ab88bd0093"}, - {file = "starlette_context-0.4.0.tar.gz", hash = "sha256:3242417c9354c067a4ac5009aff762dc0b322074216f664825d5d127108553be"}, -] - -[[package]] -name = "structlog" -version = "25.4.0" -requires_python = ">=3.8" -summary = "Structured Logging for Python" -groups = ["all", "chat"] -dependencies = [ - "typing-extensions; python_version < \"3.11\"", -] -files = [ - {file = "structlog-25.4.0-py3-none-any.whl", hash = "sha256:fe809ff5c27e557d14e613f45ca441aabda051d119ee5a0102aaba6ce40eed2c"}, - {file = "structlog-25.4.0.tar.gz", hash = "sha256:186cd1b0a8ae762e29417095664adf1d6a31702160a46dacb7796ea82f7409e4"}, -] - [[package]] name = "sympy" version = "1.14.0" @@ -4598,7 +4275,7 @@ name = "tomli" version = "2.2.1" requires_python = ">=3.8" summary = "A lil' TOML parser" -groups = ["default", "all", "chat", "dev"] +groups = ["default", "dev"] marker = "python_version <= \"3.11\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, @@ -4749,7 +4426,7 @@ name = "typing-extensions" version = "4.15.0" requires_python = ">=3.9" summary = "Backported and Experimental Type Hints for Python 3.9+" -groups = ["default", "all", "bedrock", "chat", "dev", "vertex", "web"] +groups = ["default", "all", "bedrock", "dev", "vertex", "web"] files = [ {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, @@ -4810,7 +4487,7 @@ name = "urllib3" version = "2.5.0" requires_python = ">=3.9" summary = "HTTP library with thread-safe connection pooling, file post, and more." -groups = ["default", "all", "bedrock", "chat", "dev", "vertex"] +groups = ["default", "all", "bedrock", "dev", "vertex"] files = [ {file = "urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc"}, {file = "urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760"}, @@ -4821,7 +4498,8 @@ name = "uvicorn" version = "0.37.0" requires_python = ">=3.9" summary = "The lightning-fast ASGI server." -groups = ["default", "all", "chat"] +groups = ["default"] +marker = "sys_platform != \"emscripten\"" dependencies = [ "click>=7.0", "h11>=0.8", @@ -4930,57 +4608,6 @@ files = [ {file = "winregistry-2.1.1.tar.gz", hash = "sha256:8233c4261a9d937cd8f0670da0d1e61fd7b86712c39b1af08cb83e91316195a7"}, ] -[[package]] -name = "wrapt" -version = "1.17.3" -requires_python = ">=3.8" -summary = "Module for decorators, wrappers and monkey patching." -groups = ["all", "chat"] -files = [ - {file = "wrapt-1.17.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:88bbae4d40d5a46142e70d58bf664a89b6b4befaea7b2ecc14e03cedb8e06c04"}, - {file = "wrapt-1.17.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6b13af258d6a9ad602d57d889f83b9d5543acd471eee12eb51f5b01f8eb1bc2"}, - {file = "wrapt-1.17.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd341868a4b6714a5962c1af0bd44f7c404ef78720c7de4892901e540417111c"}, - {file = "wrapt-1.17.3-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f9b2601381be482f70e5d1051a5965c25fb3625455a2bf520b5a077b22afb775"}, - {file = "wrapt-1.17.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343e44b2a8e60e06a7e0d29c1671a0d9951f59174f3709962b5143f60a2a98bd"}, - {file = "wrapt-1.17.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:33486899acd2d7d3066156b03465b949da3fd41a5da6e394ec49d271baefcf05"}, - {file = "wrapt-1.17.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e6f40a8aa5a92f150bdb3e1c44b7e98fb7113955b2e5394122fa5532fec4b418"}, - {file = "wrapt-1.17.3-cp310-cp310-win32.whl", hash = "sha256:a36692b8491d30a8c75f1dfee65bef119d6f39ea84ee04d9f9311f83c5ad9390"}, - {file = "wrapt-1.17.3-cp310-cp310-win_amd64.whl", hash = "sha256:afd964fd43b10c12213574db492cb8f73b2f0826c8df07a68288f8f19af2ebe6"}, - {file = "wrapt-1.17.3-cp310-cp310-win_arm64.whl", hash = "sha256:af338aa93554be859173c39c85243970dc6a289fa907402289eeae7543e1ae18"}, - {file = "wrapt-1.17.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:273a736c4645e63ac582c60a56b0acb529ef07f78e08dc6bfadf6a46b19c0da7"}, - {file = "wrapt-1.17.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5531d911795e3f935a9c23eb1c8c03c211661a5060aab167065896bbf62a5f85"}, - {file = "wrapt-1.17.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0610b46293c59a3adbae3dee552b648b984176f8562ee0dba099a56cfbe4df1f"}, - {file = "wrapt-1.17.3-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b32888aad8b6e68f83a8fdccbf3165f5469702a7544472bdf41f582970ed3311"}, - {file = "wrapt-1.17.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cccf4f81371f257440c88faed6b74f1053eef90807b77e31ca057b2db74edb1"}, - {file = "wrapt-1.17.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8a210b158a34164de8bb68b0e7780041a903d7b00c87e906fb69928bf7890d5"}, - {file = "wrapt-1.17.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:79573c24a46ce11aab457b472efd8d125e5a51da2d1d24387666cd85f54c05b2"}, - {file = "wrapt-1.17.3-cp311-cp311-win32.whl", hash = "sha256:c31eebe420a9a5d2887b13000b043ff6ca27c452a9a22fa71f35f118e8d4bf89"}, - {file = "wrapt-1.17.3-cp311-cp311-win_amd64.whl", hash = "sha256:0b1831115c97f0663cb77aa27d381237e73ad4f721391a9bfb2fe8bc25fa6e77"}, - {file = "wrapt-1.17.3-cp311-cp311-win_arm64.whl", hash = "sha256:5a7b3c1ee8265eb4c8f1b7d29943f195c00673f5ab60c192eba2d4a7eae5f46a"}, - {file = "wrapt-1.17.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ab232e7fdb44cdfbf55fc3afa31bcdb0d8980b9b95c38b6405df2acb672af0e0"}, - {file = "wrapt-1.17.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9baa544e6acc91130e926e8c802a17f3b16fbea0fd441b5a60f5cf2cc5c3deba"}, - {file = "wrapt-1.17.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b538e31eca1a7ea4605e44f81a48aa24c4632a277431a6ed3f328835901f4fd"}, - {file = "wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:042ec3bb8f319c147b1301f2393bc19dba6e176b7da446853406d041c36c7828"}, - {file = "wrapt-1.17.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3af60380ba0b7b5aeb329bc4e402acd25bd877e98b3727b0135cb5c2efdaefe9"}, - {file = "wrapt-1.17.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b02e424deef65c9f7326d8c19220a2c9040c51dc165cddb732f16198c168396"}, - {file = "wrapt-1.17.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:74afa28374a3c3a11b3b5e5fca0ae03bef8450d6aa3ab3a1e2c30e3a75d023dc"}, - {file = "wrapt-1.17.3-cp312-cp312-win32.whl", hash = "sha256:4da9f45279fff3543c371d5ababc57a0384f70be244de7759c85a7f989cb4ebe"}, - {file = "wrapt-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:e71d5c6ebac14875668a1e90baf2ea0ef5b7ac7918355850c0908ae82bcb297c"}, - {file = "wrapt-1.17.3-cp312-cp312-win_arm64.whl", hash = "sha256:604d076c55e2fdd4c1c03d06dc1a31b95130010517b5019db15365ec4a405fc6"}, - {file = "wrapt-1.17.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a47681378a0439215912ef542c45a783484d4dd82bac412b71e59cf9c0e1cea0"}, - {file = "wrapt-1.17.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:54a30837587c6ee3cd1a4d1c2ec5d24e77984d44e2f34547e2323ddb4e22eb77"}, - {file = "wrapt-1.17.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ecf15d6af39246fe33e507105d67e4b81d8f8d2c6598ff7e3ca1b8a37213f7"}, - {file = "wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6fd1ad24dc235e4ab88cda009e19bf347aabb975e44fd5c2fb22a3f6e4141277"}, - {file = "wrapt-1.17.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ed61b7c2d49cee3c027372df5809a59d60cf1b6c2f81ee980a091f3afed6a2d"}, - {file = "wrapt-1.17.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:423ed5420ad5f5529db9ce89eac09c8a2f97da18eb1c870237e84c5a5c2d60aa"}, - {file = "wrapt-1.17.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e01375f275f010fcbf7f643b4279896d04e571889b8a5b3f848423d91bf07050"}, - {file = "wrapt-1.17.3-cp313-cp313-win32.whl", hash = "sha256:53e5e39ff71b3fc484df8a522c933ea2b7cdd0d5d15ae82e5b23fde87d44cbd8"}, - {file = "wrapt-1.17.3-cp313-cp313-win_amd64.whl", hash = "sha256:1f0b2f40cf341ee8cc1a97d51ff50dddb9fcc73241b9143ec74b30fc4f44f6cb"}, - {file = "wrapt-1.17.3-cp313-cp313-win_arm64.whl", hash = "sha256:7425ac3c54430f5fc5e7b6f41d41e704db073309acfc09305816bc6a0b26bb16"}, - {file = "wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22"}, - {file = "wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0"}, -] - [[package]] name = "xlrd" version = "2.0.2" @@ -4991,14 +4618,3 @@ files = [ {file = "xlrd-2.0.2-py2.py3-none-any.whl", hash = "sha256:ea762c3d29f4cca48d82df517b6d89fbce4db3107f9d78713e48cd321d5c9aa9"}, {file = "xlrd-2.0.2.tar.gz", hash = "sha256:08b5e25de58f21ce71dc7db3b3b8106c1fa776f3024c54e45b45b374e89234c9"}, ] - -[[package]] -name = "zipp" -version = "3.23.0" -requires_python = ">=3.9" -summary = "Backport of pathlib-compatible object wrapper for zip files" -groups = ["all", "chat"] -files = [ - {file = "zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e"}, - {file = "zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166"}, -] diff --git a/pyproject.toml b/pyproject.toml index 137c229d..ca293d57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ path = "src/askui/__init__.py" distribution = true [tool.pdm.scripts] -alembic = "alembic -c src/askui/chat/migrations/alembic.ini" test = "pytest -n auto" "test:cov" = "pytest -n auto --cov=src/askui --cov-report=html" "test:cov:view" = "python -m http.server --directory htmlcov" @@ -66,7 +65,6 @@ lint = "ruff check src tests" "lint:fix" = "ruff check --fix src tests" typecheck = "mypy" "typecheck:all" = "mypy ." -"chat:api" = "python -m askui.chat" "generate:SBOM" = "cyclonedx-py environment --pyproject ./pyproject.toml --output-format JSON --output-file bom.json --spec-version 1.6 --gather-license-texts " "qa:fix" = { composite = [ "typecheck:all", @@ -184,7 +182,6 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" [tool.ruff.lint.per-file-ignores] "src/askui/agent.py" = ["E501"] "src/askui/android_agent.py" = ["E501"] -"src/askui/chat/*" = ["E501", "F401", "F403"] "src/askui/locators/locators.py" = ["E501"] "src/askui/locators/relatable.py" = ["E501", "SLF001"] "src/askui/locators/serializers.py" = ["E501", "SLF001"] @@ -217,30 +214,13 @@ known-first-party = ["askui"] known-third-party = ["pytest", "mypy"] [project.optional-dependencies] -all = ["askui[android,bedrock,chat,pynput,vertex,web]"] +all = ["askui[android,bedrock,pynput,vertex,web]"] android = [ "pure-python-adb>=0.3.0.dev0" ] bedrock = [ "anthropic[bedrock]>=0.72.0" ] -chat = [ - "askui[android,web]", - "uvicorn>=0.34.3", - "anyio>=4.10.0", - "structlog>=25.4.0", - "asgi-correlation-id>=4.3.4", - "prometheus-fastapi-instrumentator>=7.1.0", - "starlette-context>=0.4.0", - "sqlalchemy>=2.0.43", - "alembic>=1.16.5", - "opentelemetry-api>=1.38.0", - "opentelemetry-sdk>=1.38.0", - "opentelemetry-instrumentation-fastapi>=0.59b0", - "opentelemetry-exporter-otlp-proto-http>=1.38.0", - "opentelemetry-instrumentation-httpx>=0.59b0", - "opentelemetry-instrumentation-sqlalchemy>=0.59b0", -] pynput = [ "mss>=10.0.0", "pynput>=1.8.1", diff --git a/src/askui/chat/__init__.py b/src/askui/chat/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py deleted file mode 100644 index 9107420b..00000000 --- a/src/askui/chat/__main__.py +++ /dev/null @@ -1,22 +0,0 @@ -import uvicorn - -from askui.chat.api.app import app -from askui.chat.api.dependencies import get_settings -from askui.chat.api.telemetry.integrations.fastapi import instrument -from askui.telemetry.otel import setup_opentelemetry_tracing - -if __name__ == "__main__": - settings = get_settings() - instrument(app, settings.telemetry) - if settings.otel.enabled: - setup_opentelemetry_tracing(app, settings.otel) - - uvicorn.run( - app, - host=settings.host, - port=settings.port, - reload=False, - workers=1, - log_config=None, - timeout_graceful_shutdown=5, - ) diff --git a/src/askui/chat/api/__init__.py b/src/askui/chat/api/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/app.py b/src/askui/chat/api/app.py deleted file mode 100644 index 82711ab5..00000000 --- a/src/askui/chat/api/app.py +++ /dev/null @@ -1,210 +0,0 @@ -import logging -from contextlib import asynccontextmanager -from typing import AsyncGenerator - -from fastapi import APIRouter, FastAPI, HTTPException, Request, status -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse -from fastmcp import FastMCP - -from askui.chat.api.assistants.router import router as assistants_router -from askui.chat.api.db.session import get_session -from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_settings -from askui.chat.api.files.router import router as files_router -from askui.chat.api.health.router import router as health_router -from askui.chat.api.mcp_clients.dependencies import get_mcp_client_manager_manager -from askui.chat.api.mcp_clients.manager import McpServerConnectionError -from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service -from askui.chat.api.mcp_configs.router import router as mcp_configs_router -from askui.chat.api.mcp_servers.android import mcp as android_mcp -from askui.chat.api.mcp_servers.computer import mcp as computer_mcp -from askui.chat.api.mcp_servers.testing import mcp as testing_mcp -from askui.chat.api.mcp_servers.utility import mcp as utility_mcp -from askui.chat.api.messages.router import router as messages_router -from askui.chat.api.runs.router import router as runs_router -from askui.chat.api.scheduled_jobs.router import router as scheduled_jobs_router -from askui.chat.api.scheduled_jobs.scheduler import shutdown_scheduler, start_scheduler -from askui.chat.api.threads.router import router as threads_router -from askui.chat.api.workflows.router import router as workflows_router -from askui.chat.migrations.runner import run_migrations -from askui.utils.api_utils import ( - ConflictError, - FileTooLargeError, - ForbiddenError, - LimitReachedError, - NotFoundError, -) - -logger = logging.getLogger(__name__) - - -settings = get_settings() - - -@asynccontextmanager -async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 - if settings.db.auto_migrate: - run_migrations() - else: - logger.info("Automatic migrations are disabled. Skipping migrations...") - logger.info("Seeding default MCP configurations...") - session = next(get_session()) - mcp_config_service = get_mcp_config_service(session=session, settings=settings) - mcp_config_service.seed() - - # Start the scheduler for scheduled jobs - logger.info("Starting scheduled job scheduler...") - await start_scheduler() - - yield - - # Shutdown scheduler - logger.info("Shutting down scheduled job scheduler...") - await shutdown_scheduler() - - logger.info("Disconnecting all MCP clients...") - await get_mcp_client_manager_manager(mcp_config_service).disconnect_all(force=True) - - -app = FastAPI( - title="AskUI Chat API", - version="0.1.0", - lifespan=lifespan, -) - - -# Include routers -v1_router = APIRouter(prefix="/v1") -v1_router.include_router(assistants_router) -v1_router.include_router(threads_router) -v1_router.include_router(messages_router) -v1_router.include_router(runs_router) -v1_router.include_router(mcp_configs_router) -v1_router.include_router(files_router) -v1_router.include_router(workflows_router) -v1_router.include_router(scheduled_jobs_router) -v1_router.include_router(health_router) -app.include_router(v1_router) - - -mcp = FastMCP.from_fastapi(app=app, name="AskUI Chat MCP") -mcp.mount(computer_mcp) -mcp.mount(android_mcp) -mcp.mount(testing_mcp) -mcp.mount(utility_mcp) - -mcp_app = mcp.http_app("/sse", transport="sse") - - -@asynccontextmanager -async def combined_lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - async with lifespan(app): - async with mcp_app.lifespan(app): - yield - - -app = FastAPI( - title=app.title, - version=app.version, - lifespan=combined_lifespan, - dependencies=[SetEnvFromHeadersDep], -) -app.mount("/mcp", mcp_app) -app.include_router(v1_router) - - -@app.exception_handler(NotFoundError) -def not_found_error_handler( - request: Request, # noqa: ARG001 - exc: NotFoundError, -) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, content={"detail": str(exc)} - ) - - -@app.exception_handler(ConflictError) -def conflict_error_handler( - request: Request, # noqa: ARG001 - exc: ConflictError, -) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_409_CONFLICT, content={"detail": str(exc)} - ) - - -@app.exception_handler(LimitReachedError) -def limit_reached_error_handler( - request: Request, # noqa: ARG001 - exc: LimitReachedError, -) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(exc)} - ) - - -@app.exception_handler(FileTooLargeError) -def file_too_large_error_handler( - request: Request, # noqa: ARG001 - exc: FileTooLargeError, -) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, - content={"detail": str(exc)}, - ) - - -@app.exception_handler(ForbiddenError) -def forbidden_error_handler( - request: Request, # noqa: ARG001 - exc: ForbiddenError, -) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_403_FORBIDDEN, - content={"detail": str(exc)}, - ) - - -@app.exception_handler(ValueError) -def value_error_handler( - request: Request, # noqa: ARG001 - exc: ValueError, -) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - content={"detail": str(exc)}, - ) - - -@app.exception_handler(Exception) -def catch_all_exception_handler( - request: Request, # noqa: ARG001 - exc: Exception, -) -> JSONResponse: - if isinstance(exc, HTTPException): - raise exc - - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"detail": "Internal server error"}, - ) - - -@app.exception_handler(McpServerConnectionError) -def mcp_server_connection_error_handler( - request: Request, # noqa: ARG001 - exc: McpServerConnectionError, -) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"detail": str(exc)}, - ) - - -app.add_middleware( - CORSMiddleware, - allow_origins=settings.allow_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) diff --git a/src/askui/chat/api/assistants/__init__.py b/src/askui/chat/api/assistants/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/assistants/dependencies.py b/src/askui/chat/api/assistants/dependencies.py deleted file mode 100644 index 3211094b..00000000 --- a/src/askui/chat/api/assistants/dependencies.py +++ /dev/null @@ -1,14 +0,0 @@ -from fastapi import Depends - -from askui.chat.api.assistants.service import AssistantService -from askui.chat.api.db.session import SessionDep - - -def get_assistant_service( - session: SessionDep, -) -> AssistantService: - """Get AssistantService instance.""" - return AssistantService(session) - - -AssistantServiceDep = Depends(get_assistant_service) diff --git a/src/askui/chat/api/assistants/models.py b/src/askui/chat/api/assistants/models.py deleted file mode 100644 index da18a7e3..00000000 --- a/src/askui/chat/api/assistants/models.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel, Field - -from askui.chat.api.models import AssistantId, WorkspaceId, WorkspaceResource -from askui.utils.datetime_utils import UnixDatetime, now -from askui.utils.id_utils import generate_time_ordered_id -from askui.utils.not_given import NOT_GIVEN, BaseModelWithNotGiven, NotGiven - - -class AssistantBase(BaseModel): - """Base assistant model.""" - - name: str | None = None - description: str | None = None - avatar: str | None = None - tools: list[str] = Field(default_factory=list) - system: str | None = None - - -class AssistantCreate(AssistantBase): - """Parameters for creating an assistant.""" - - -class AssistantModify(BaseModelWithNotGiven): - """Parameters for modifying an assistant.""" - - name: str | NotGiven = NOT_GIVEN - description: str | NotGiven = NOT_GIVEN - avatar: str | NotGiven = NOT_GIVEN - tools: list[str] | NotGiven = NOT_GIVEN - system: str | NotGiven = NOT_GIVEN - - -class Assistant(AssistantBase, WorkspaceResource): - """An assistant that can be used in a thread.""" - - id: AssistantId - object: Literal["assistant"] = "assistant" - created_at: UnixDatetime - - @classmethod - def create( - cls, workspace_id: WorkspaceId | None, params: AssistantCreate - ) -> "Assistant": - return cls( - id=generate_time_ordered_id("asst"), - created_at=now(), - workspace_id=workspace_id, - **params.model_dump(), - ) diff --git a/src/askui/chat/api/assistants/orms.py b/src/askui/chat/api/assistants/orms.py deleted file mode 100644 index c59d4ea9..00000000 --- a/src/askui/chat/api/assistants/orms.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Assistant database model.""" - -from datetime import datetime -from uuid import UUID - -from sqlalchemy import JSON, String, Text, Uuid -from sqlalchemy.orm import Mapped, mapped_column - -from askui.chat.api.assistants.models import Assistant -from askui.chat.api.db.orm.base import Base -from askui.chat.api.db.orm.types import UnixDatetime, create_prefixed_id_type - -AssistantId = create_prefixed_id_type("asst") - - -class AssistantOrm(Base): - """Assistant database model.""" - - __tablename__ = "assistants" - - id: Mapped[str] = mapped_column(AssistantId, primary_key=True) - workspace_id: Mapped[UUID | None] = mapped_column(Uuid, nullable=True, index=True) - created_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) - name: Mapped[str | None] = mapped_column(String, nullable=True) - description: Mapped[str | None] = mapped_column(String, nullable=True) - avatar: Mapped[str | None] = mapped_column(Text, nullable=True) - tools: Mapped[list[str]] = mapped_column(JSON, nullable=False) - system: Mapped[str | None] = mapped_column(Text, nullable=True) - - @classmethod - def from_model(cls, model: Assistant) -> "AssistantOrm": - return cls( - **model.model_dump(exclude={"object"}), - ) - - def to_model(self) -> Assistant: - return Assistant.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/assistants/router.py b/src/askui/chat/api/assistants/router.py deleted file mode 100644 index 15888257..00000000 --- a/src/askui/chat/api/assistants/router.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Annotated - -from fastapi import APIRouter, Header, status - -from askui.chat.api.assistants.dependencies import AssistantServiceDep -from askui.chat.api.assistants.models import Assistant, AssistantCreate, AssistantModify -from askui.chat.api.assistants.service import AssistantService -from askui.chat.api.dependencies import ListQueryDep -from askui.chat.api.models import AssistantId, WorkspaceId -from askui.utils.api_utils import ListQuery, ListResponse - -router = APIRouter(prefix="/assistants", tags=["assistants"]) - - -@router.get("") -def list_assistants( - askui_workspace: Annotated[WorkspaceId | None, Header()] = None, - query: ListQuery = ListQueryDep, - assistant_service: AssistantService = AssistantServiceDep, -) -> ListResponse[Assistant]: - return assistant_service.list_(workspace_id=askui_workspace, query=query) - - -@router.post("", status_code=status.HTTP_201_CREATED) -def create_assistant( - params: AssistantCreate, - askui_workspace: Annotated[WorkspaceId, Header()], - assistant_service: AssistantService = AssistantServiceDep, -) -> Assistant: - return assistant_service.create(workspace_id=askui_workspace, params=params) - - -@router.get("/{assistant_id}") -def retrieve_assistant( - assistant_id: AssistantId, - askui_workspace: Annotated[WorkspaceId | None, Header()] = None, - assistant_service: AssistantService = AssistantServiceDep, -) -> Assistant: - return assistant_service.retrieve( - workspace_id=askui_workspace, assistant_id=assistant_id - ) - - -@router.post("/{assistant_id}") -def modify_assistant( - assistant_id: AssistantId, - askui_workspace: Annotated[WorkspaceId, Header()], - params: AssistantModify, - assistant_service: AssistantService = AssistantServiceDep, -) -> Assistant: - return assistant_service.modify( - workspace_id=askui_workspace, assistant_id=assistant_id, params=params - ) - - -@router.delete("/{assistant_id}", status_code=status.HTTP_204_NO_CONTENT) -def delete_assistant( - assistant_id: AssistantId, - askui_workspace: Annotated[WorkspaceId, Header()], - assistant_service: AssistantService = AssistantServiceDep, -) -> None: - assistant_service.delete(workspace_id=askui_workspace, assistant_id=assistant_id) diff --git a/src/askui/chat/api/assistants/service.py b/src/askui/chat/api/assistants/service.py deleted file mode 100644 index e963503d..00000000 --- a/src/askui/chat/api/assistants/service.py +++ /dev/null @@ -1,95 +0,0 @@ -from sqlalchemy import or_ -from sqlalchemy.orm import Session - -from askui.chat.api.assistants.models import Assistant, AssistantCreate, AssistantModify -from askui.chat.api.assistants.orms import AssistantOrm -from askui.chat.api.db.queries import list_all -from askui.chat.api.models import AssistantId, WorkspaceId -from askui.utils.api_utils import ForbiddenError, ListQuery, ListResponse, NotFoundError - - -class AssistantService: - def __init__(self, session: Session) -> None: - self._session = session - - def list_( - self, workspace_id: WorkspaceId | None, query: ListQuery - ) -> ListResponse[Assistant]: - q = self._session.query(AssistantOrm).filter( - or_( - AssistantOrm.workspace_id == workspace_id, - AssistantOrm.workspace_id.is_(None), - ), - ) - orms: list[AssistantOrm] - orms, has_more = list_all(q, query, AssistantOrm.id) - data = [orm.to_model() for orm in orms] - return ListResponse( - data=data, - has_more=has_more, - first_id=data[0].id if data else None, - last_id=data[-1].id if data else None, - ) - - def _find_by_id( - self, workspace_id: WorkspaceId | None, assistant_id: AssistantId - ) -> AssistantOrm: - assistant_orm: AssistantOrm | None = ( - self._session.query(AssistantOrm) - .filter( - AssistantOrm.id == assistant_id, - or_( - AssistantOrm.workspace_id == workspace_id, - AssistantOrm.workspace_id.is_(None), - ), - ) - .first() - ) - if assistant_orm is None: - error_msg = f"Assistant {assistant_id} not found" - raise NotFoundError(error_msg) - return assistant_orm - - def retrieve( - self, workspace_id: WorkspaceId | None, assistant_id: AssistantId - ) -> Assistant: - assistant_orm = self._find_by_id(workspace_id, assistant_id) - return assistant_orm.to_model() - - def create( - self, workspace_id: WorkspaceId | None, params: AssistantCreate - ) -> Assistant: - assistant = Assistant.create(workspace_id, params) - assistant_orm = AssistantOrm.from_model(assistant) - self._session.add(assistant_orm) - self._session.commit() - return assistant - - def modify( - self, - workspace_id: WorkspaceId | None, - assistant_id: AssistantId, - params: AssistantModify, - force: bool = False, - ) -> Assistant: - assistant_orm = self._find_by_id(workspace_id, assistant_id) - if assistant_orm.workspace_id is None and not force: - error_msg = f"Default assistant {assistant_id} cannot be modified" - raise ForbiddenError(error_msg) - assistant_orm.update(params.model_dump()) - self._session.commit() - self._session.refresh(assistant_orm) - return assistant_orm.to_model() - - def delete( - self, - workspace_id: WorkspaceId | None, - assistant_id: AssistantId, - force: bool = False, - ) -> None: - assistant_orm = self._find_by_id(workspace_id, assistant_id) - if assistant_orm.workspace_id is None and not force: - error_msg = f"Default assistant {assistant_id} cannot be deleted" - raise ForbiddenError(error_msg) - self._session.delete(assistant_orm) - self._session.commit() diff --git a/src/askui/chat/api/db/__init__.py b/src/askui/chat/api/db/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/db/engine.py b/src/askui/chat/api/db/engine.py deleted file mode 100644 index 53d7d3d8..00000000 --- a/src/askui/chat/api/db/engine.py +++ /dev/null @@ -1,41 +0,0 @@ -import logging -from sqlite3 import Connection as SQLite3Connection -from typing import Any - -from sqlalchemy import create_engine, event - -from askui.chat.api.dependencies import get_settings - -_logger = logging.getLogger(__name__) - -_settings = get_settings() -_connect_args = {"check_same_thread": False} -_echo = _logger.isEnabledFor(logging.DEBUG) - -# Create engine with optimized settings -engine = create_engine( - _settings.db.url, - connect_args=_connect_args, - echo=_echo, -) - - -@event.listens_for(engine, "connect") -def _set_sqlite_pragma(dbapi_conn: SQLite3Connection, connection_record: Any) -> None: # noqa: ARG001 - """ - Configure SQLite pragmas for optimal web application performance. - - Applied on each new connection: - - foreign_keys=ON: Enable foreign key constraint enforcement - - journal_mode=WAL: Write-Ahead Logging for better concurrency (readers don't block writers) - - synchronous=NORMAL: Sync every 1000 pages instead of every write (faster, still durable with WAL) - - busy_timeout=30000: Wait up to 30 seconds for locks instead of failing immediately - """ - cursor = dbapi_conn.cursor() - - cursor.execute("PRAGMA foreign_keys = ON") - cursor.execute("PRAGMA journal_mode = WAL") - cursor.execute("PRAGMA synchronous = NORMAL") - cursor.execute("PRAGMA busy_timeout = 30000") - - cursor.close() diff --git a/src/askui/chat/api/db/orm/__init__.py b/src/askui/chat/api/db/orm/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/db/orm/base.py b/src/askui/chat/api/db/orm/base.py deleted file mode 100644 index 6e5ca243..00000000 --- a/src/askui/chat/api/db/orm/base.py +++ /dev/null @@ -1,13 +0,0 @@ -"""SQLAlchemy declarative base.""" - -from typing import Any - -from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass -from typing_extensions import Self - - -class Base(MappedAsDataclass, DeclarativeBase): - def update(self, values: dict[str, Any]) -> Self: - for key, value in values.items(): - setattr(self, key, value) - return self diff --git a/src/askui/chat/api/db/orm/types.py b/src/askui/chat/api/db/orm/types.py deleted file mode 100644 index 0501ec8f..00000000 --- a/src/askui/chat/api/db/orm/types.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Custom SQLAlchemy types for chat API.""" - -from datetime import datetime, timezone -from typing import Any - -from sqlalchemy import Integer, String, TypeDecorator - - -def create_prefixed_id_type(prefix: str) -> type[TypeDecorator[str]]: - class PrefixedObjectId(TypeDecorator[str]): - impl = String(24) - cache_ok = True - - def process_bind_param(self, value: str | None, dialect: Any) -> str | None: # noqa: ARG002 - if value is None: - return value - return value.removeprefix(f"{prefix}_") - - def process_result_value(self, value: str | None, dialect: Any) -> str | None: # noqa: ARG002 - if value is None: - return value - return f"{prefix}_{value}" - - return PrefixedObjectId - - -# Specialized types for each resource -ThreadId = create_prefixed_id_type("thread") -MessageId = create_prefixed_id_type("msg") -RunId = create_prefixed_id_type("run") -WorkflowId = create_prefixed_id_type("workflow") - - -class UnixDatetime(TypeDecorator[datetime]): - impl = Integer - LOCAL_TIMEZONE = datetime.now().astimezone().tzinfo - - def process_bind_param( - self, - value: datetime | int | None, - dialect: Any, # noqa: ARG002 - ) -> int | None: - if value is None: - return value - if isinstance(value, int): - return value - if value.tzinfo is None: - value = value.astimezone(self.LOCAL_TIMEZONE) - return int(value.astimezone(timezone.utc).timestamp()) - - def process_result_value( - self, - value: int | None, - dialect: Any, # noqa: ARG002 - ) -> datetime | None: - if value is None: - return value - return datetime.fromtimestamp(value, timezone.utc) - - -def create_sentinel_id_type( - prefix: str, sentinel_value: str -) -> type[TypeDecorator[str]]: - """Create a type decorator that converts between a sentinel value and NULL. - - This is useful for self-referential nullable foreign keys where NULL in the database - is represented by a sentinel value in the API (e.g., root nodes in a tree structure). - - Args: - prefix (str): The prefix for the ID (e.g., "msg"). - sentinel_value (str): The sentinel value representing NULL (e.g., "msg_000000000000000000000000"). - - Returns: - type[TypeDecorator[str]]: A TypeDecorator class that handles the transformation. - - Example: - ```python - ParentMessageId = create_sentinel_id_type("msg", ROOT_MESSAGE_PARENT_ID) - parent_id: Mapped[str] = mapped_column(ParentMessageId, nullable=True) - ``` - """ - - class SentinelId(TypeDecorator[str]): - """Type decorator that converts between sentinel value (API) and NULL (database). - - - When writing to DB: sentinel_value → NULL - - When reading from DB: NULL → sentinel_value - """ - - impl = String(24) - cache_ok = ( - False # Disable caching due to closure over prefix and sentinel_value - ) - - def process_bind_param( - self, - value: str | None, - dialect: Any, # noqa: ARG002 - ) -> str | None: - """Convert from API model to database storage.""" - if value is None or value == sentinel_value: - # Both None and sentinel value become NULL in database - return None - # Remove prefix for storage (like regular PrefixedObjectId) - return value.removeprefix(f"{prefix}_") - - def process_result_value( - self, - value: str | None, - dialect: Any, # noqa: ARG002 - ) -> str: - """Convert from database storage to API model.""" - if value is None: - # NULL in database becomes sentinel value in API - return sentinel_value - # Add prefix (like regular PrefixedObjectId) - return f"{prefix}_{value}" - - return SentinelId diff --git a/src/askui/chat/api/db/queries.py b/src/askui/chat/api/db/queries.py deleted file mode 100644 index ffb05550..00000000 --- a/src/askui/chat/api/db/queries.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Shared query building utilities for database operations.""" - -from typing import Any, TypeVar - -from sqlalchemy import desc -from sqlalchemy.orm import InstrumentedAttribute, Query - -from askui.chat.api.db.orm.base import Base -from askui.utils.api_utils import ListQuery - -OrmT = TypeVar("OrmT", bound=Base) - - -def list_all( - db_query: Query[OrmT], - list_query: ListQuery, - id_column: InstrumentedAttribute[Any], -) -> tuple[list[OrmT], bool]: - if list_query.order == "asc": - if list_query.after: - db_query = db_query.filter(id_column > list_query.after) - if list_query.before: - db_query = db_query.filter(id_column < list_query.before) - db_query = db_query.order_by(id_column) - else: - if list_query.after: - db_query = db_query.filter(id_column < list_query.after) - if list_query.before: - db_query = db_query.filter(id_column > list_query.before) - db_query = db_query.order_by(desc(id_column)) - db_query = db_query.limit(list_query.limit + 1) - orms = db_query.all() - has_more = len(orms) > list_query.limit - return orms[: list_query.limit], has_more diff --git a/src/askui/chat/api/db/session.py b/src/askui/chat/api/db/session.py deleted file mode 100644 index 47118718..00000000 --- a/src/askui/chat/api/db/session.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Annotated, Generator - -from fastapi import Depends -from sqlalchemy.orm import Session - -from askui.chat.api.db.engine import engine - - -def get_session() -> Generator[Session, None, None]: - with Session(engine) as session: - yield session - - -SessionDep = Annotated[Session, Depends(get_session)] diff --git a/src/askui/chat/api/dependencies.py b/src/askui/chat/api/dependencies.py deleted file mode 100644 index ca42ece1..00000000 --- a/src/askui/chat/api/dependencies.py +++ /dev/null @@ -1,83 +0,0 @@ -import os -from pathlib import Path -from typing import Annotated, Optional - -from fastapi import Depends, Header -from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer -from pydantic import UUID4 - -from askui.chat.api.models import WorkspaceId -from askui.chat.api.settings import Settings -from askui.utils.api_utils import ListQuery - - -def get_settings() -> Settings: - """Get ChatApiSettings instance.""" - return Settings() - - -SettingsDep = Depends(get_settings) - - -http_bearer = HTTPBearer(scheme_name="Bearer", auto_error=False) -api_key_header = APIKeyHeader( - name="Authorization", auto_error=False, scheme_name="Basic" -) - - -def get_authorization( - bearer_auth: Annotated[ - Optional[HTTPAuthorizationCredentials], Depends(http_bearer) - ] = None, - api_key_auth: Annotated[Optional[str], Depends(api_key_header)] = None, -) -> Optional[str]: - if bearer_auth: - return f"{bearer_auth.scheme} {bearer_auth.credentials}" - if api_key_auth: - return api_key_auth - return None - - -def set_env_from_headers( - authorization: Annotated[Optional[str], Depends(get_authorization)] = None, - askui_workspace: Annotated[UUID4 | None, Header()] = None, -) -> None: - """ - Set environment variables from Authorization and AskUI-Workspace headers. - - Args: - authorization (str | None, optional): Authorization header. - Defaults to `None`. - askui_workspace (UUID4 | None, optional): Workspace ID from AskUI-Workspace header. - Defaults to `None`. - """ - if authorization: - os.environ["ASKUI__AUTHORIZATION"] = authorization - if askui_workspace: - os.environ["ASKUI_WORKSPACE_ID"] = str(askui_workspace) - - -SetEnvFromHeadersDep = Depends(set_env_from_headers) - - -def get_workspace_id( - askui_workspace: Annotated[WorkspaceId | None, Header()] = None, -) -> WorkspaceId | None: - """Get workspace ID from header.""" - return askui_workspace - - -WorkspaceIdDep = Depends(get_workspace_id) - - -def get_workspace_dir( - askui_workspace: Annotated[WorkspaceId, Header()], - settings: Settings = SettingsDep, -) -> Path: - return settings.data_dir / "workspaces" / str(askui_workspace) - - -WorkspaceDirDep = Depends(get_workspace_dir) - - -ListQueryDep = Depends(ListQuery) diff --git a/src/askui/chat/api/files/__init__.py b/src/askui/chat/api/files/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/files/dependencies.py b/src/askui/chat/api/files/dependencies.py deleted file mode 100644 index babeb7e2..00000000 --- a/src/askui/chat/api/files/dependencies.py +++ /dev/null @@ -1,17 +0,0 @@ -from fastapi import Depends - -from askui.chat.api.db.session import SessionDep -from askui.chat.api.dependencies import SettingsDep -from askui.chat.api.files.service import FileService -from askui.chat.api.settings import Settings - - -def get_file_service( - session: SessionDep, - settings: Settings = SettingsDep, -) -> FileService: - """Get FileService instance.""" - return FileService(session, settings.data_dir) - - -FileServiceDep = Depends(get_file_service) diff --git a/src/askui/chat/api/files/models.py b/src/askui/chat/api/files/models.py deleted file mode 100644 index f542f6ae..00000000 --- a/src/askui/chat/api/files/models.py +++ /dev/null @@ -1,42 +0,0 @@ -import mimetypes -from typing import Literal - -from pydantic import BaseModel, Field - -from askui.chat.api.models import FileId, WorkspaceId, WorkspaceResource -from askui.utils.datetime_utils import UnixDatetime, now -from askui.utils.id_utils import generate_time_ordered_id - - -class FileBase(BaseModel): - """Base file model.""" - - size: int = Field(description="In bytes", ge=0) - media_type: str - - -class FileCreate(FileBase): - filename: str | None = None - - -class File(FileBase, WorkspaceResource): - """A file that can be stored and managed.""" - - id: FileId - object: Literal["file"] = "file" - created_at: UnixDatetime - filename: str = Field(min_length=1) - - @classmethod - def create(cls, workspace_id: WorkspaceId | None, params: FileCreate) -> "File": - id_ = generate_time_ordered_id("file") - filename = ( - params.filename or f"{id_}{mimetypes.guess_extension(params.media_type)}" - ) - return cls( - id=id_, - created_at=now(), - workspace_id=workspace_id, - filename=filename, - **params.model_dump(exclude={"filename"}), - ) diff --git a/src/askui/chat/api/files/orms.py b/src/askui/chat/api/files/orms.py deleted file mode 100644 index ab0ffd94..00000000 --- a/src/askui/chat/api/files/orms.py +++ /dev/null @@ -1,35 +0,0 @@ -"""File database model.""" - -from datetime import datetime -from uuid import UUID - -from sqlalchemy import Integer, String, Uuid -from sqlalchemy.orm import Mapped, mapped_column - -from askui.chat.api.db.orm.base import Base -from askui.chat.api.db.orm.types import UnixDatetime, create_prefixed_id_type -from askui.chat.api.files.models import File - -FileId = create_prefixed_id_type("file") - - -class FileOrm(Base): - """File database model.""" - - __tablename__ = "files" - - id: Mapped[str] = mapped_column(FileId, primary_key=True) - workspace_id: Mapped[UUID | None] = mapped_column(Uuid, nullable=True, index=True) - created_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) - filename: Mapped[str] = mapped_column(String, nullable=False) - size: Mapped[int] = mapped_column(Integer, nullable=False) - media_type: Mapped[str] = mapped_column(String, nullable=False) - - @classmethod - def from_model(cls, model: File) -> "FileOrm": - return cls( - **model.model_dump(exclude={"object"}), - ) - - def to_model(self) -> File: - return File.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/files/router.py b/src/askui/chat/api/files/router.py deleted file mode 100644 index 7356c85a..00000000 --- a/src/askui/chat/api/files/router.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Annotated - -from fastapi import APIRouter, Header, UploadFile, status -from fastapi.responses import FileResponse - -from askui.chat.api.dependencies import ListQueryDep -from askui.chat.api.files.dependencies import FileServiceDep -from askui.chat.api.files.models import File as FileModel -from askui.chat.api.files.models import FileId -from askui.chat.api.files.service import FileService -from askui.chat.api.models import WorkspaceId -from askui.utils.api_utils import ListQuery, ListResponse - -router = APIRouter(prefix="/files", tags=["files"]) - - -@router.get("") -def list_files( - askui_workspace: Annotated[WorkspaceId | None, Header()] = None, - query: ListQuery = ListQueryDep, - file_service: FileService = FileServiceDep, -) -> ListResponse[FileModel]: - """List all files.""" - return file_service.list_(workspace_id=askui_workspace, query=query) - - -@router.post("", status_code=status.HTTP_201_CREATED) -async def upload_file( - file: UploadFile, - askui_workspace: Annotated[WorkspaceId, Header()], - file_service: FileService = FileServiceDep, -) -> FileModel: - """Upload a new file.""" - return await file_service.upload_file(workspace_id=askui_workspace, file=file) - - -@router.get("/{file_id}") -def retrieve_file( - file_id: FileId, - askui_workspace: Annotated[WorkspaceId | None, Header()] = None, - file_service: FileService = FileServiceDep, -) -> FileModel: - """Get file metadata by ID.""" - return file_service.retrieve(workspace_id=askui_workspace, file_id=file_id) - - -@router.get("/{file_id}/content") -def download_file( - file_id: FileId, - askui_workspace: Annotated[WorkspaceId | None, Header()] = None, - file_service: FileService = FileServiceDep, -) -> FileResponse: - """Retrieve a file by ID.""" - file, file_path = file_service.retrieve_file_content( - workspace_id=askui_workspace, file_id=file_id - ) - return FileResponse(file_path, media_type=file.media_type, filename=file.filename) - - -@router.delete("/{file_id}", status_code=status.HTTP_204_NO_CONTENT) -def delete_file( - file_id: FileId, - askui_workspace: Annotated[WorkspaceId, Header()], - file_service: FileService = FileServiceDep, -) -> None: - """Delete a file by ID.""" - file_service.delete(workspace_id=askui_workspace, file_id=file_id) diff --git a/src/askui/chat/api/files/service.py b/src/askui/chat/api/files/service.py deleted file mode 100644 index e23a7514..00000000 --- a/src/askui/chat/api/files/service.py +++ /dev/null @@ -1,186 +0,0 @@ -import logging -import mimetypes -import shutil -import tempfile -from pathlib import Path - -from fastapi import UploadFile -from sqlalchemy import or_ -from sqlalchemy.orm import Session - -from askui.chat.api.db.queries import list_all -from askui.chat.api.files.models import File, FileCreate -from askui.chat.api.files.orms import FileOrm -from askui.chat.api.models import FileId, WorkspaceId -from askui.utils.api_utils import ( - FileTooLargeError, - ForbiddenError, - ListQuery, - ListResponse, - NotFoundError, -) - -logger = logging.getLogger(__name__) - -# Constants -MAX_FILE_SIZE = 20 * 1024 * 1024 # 20MB supported -CHUNK_SIZE = 1024 * 1024 # 1MB for uploading and downloading - - -class FileService: - """Service for managing File resources with database persistence.""" - - def __init__(self, session: Session, data_dir: Path) -> None: - self._session = session - self._data_dir = data_dir - - def _find_by_id(self, workspace_id: WorkspaceId | None, file_id: FileId) -> FileOrm: - """Find file by ID.""" - file_orm: FileOrm | None = ( - self._session.query(FileOrm) - .filter( - FileOrm.id == file_id, - or_( - FileOrm.workspace_id == workspace_id, - FileOrm.workspace_id.is_(None), - ), - ) - .first() - ) - if file_orm is None: - error_msg = f"File {file_id} not found" - raise NotFoundError(error_msg) - return file_orm - - def _get_static_file_path(self, file: File) -> Path: - """Get the path for the static file based on extension.""" - # For application/octet-stream, don't add .bin extension - extension = "" - if file.media_type != "application/octet-stream": - extension = mimetypes.guess_extension(file.media_type) or "" - base_name = f"{file.id}{extension}" - path = self._data_dir / "static" / base_name - if file.workspace_id is not None: - path = ( - self._data_dir - / "workspaces" - / str(file.workspace_id) - / "static" - / base_name - ) - path.parent.mkdir(parents=True, exist_ok=True) - return path - - def list_( - self, workspace_id: WorkspaceId | None, query: ListQuery - ) -> ListResponse[File]: - """List files with pagination and filtering.""" - q = self._session.query(FileOrm).filter( - or_( - FileOrm.workspace_id == workspace_id, - FileOrm.workspace_id.is_(None), - ), - ) - orms: list[FileOrm] - orms, has_more = list_all(q, query, FileOrm.id) - data = [orm.to_model() for orm in orms] - return ListResponse( - data=data, - has_more=has_more, - first_id=data[0].id if data else None, - last_id=data[-1].id if data else None, - ) - - def retrieve(self, workspace_id: WorkspaceId | None, file_id: FileId) -> File: - """Retrieve file metadata by ID.""" - file_orm = self._find_by_id(workspace_id, file_id) - return file_orm.to_model() - - def delete( - self, workspace_id: WorkspaceId | None, file_id: FileId, force: bool = False - ) -> None: - """Delete a file and its content. - - *Important*: We may be left with a static file that is not associated with any - file metadata if this fails. - """ - file_orm = self._find_by_id(workspace_id, file_id) - file = file_orm.to_model() - if file.workspace_id is None and not force: - error_msg = f"Default file {file_id} cannot be deleted" - raise ForbiddenError(error_msg) - self._session.delete(file_orm) - self._session.commit() - static_path = self._get_static_file_path(file) - static_path.unlink() - - def retrieve_file_content( - self, workspace_id: WorkspaceId | None, file_id: FileId - ) -> tuple[File, Path]: - """Get file metadata and path for downloading.""" - file = self.retrieve(workspace_id, file_id) - static_path = self._get_static_file_path(file) - return file, static_path - - async def _write_to_temp_file( - self, - file: UploadFile, - ) -> tuple[FileCreate, Path]: - size = 0 - temp_file = tempfile.NamedTemporaryFile( - delete=False, - suffix=".temp", - ) - temp_path = Path(temp_file.name) - with temp_file: - while chunk := await file.read(CHUNK_SIZE): - temp_file.write(chunk) - size += len(chunk) - if size > MAX_FILE_SIZE: - raise FileTooLargeError(MAX_FILE_SIZE) - mime_type = file.content_type or "application/octet-stream" - params = FileCreate( - filename=file.filename, - size=size, - media_type=mime_type, - ) - return params, temp_path - - def create( - self, workspace_id: WorkspaceId | None, params: FileCreate, path: Path - ) -> File: - """Create a file and its content. - - *Important*: We may be left with a static file that is not associated with any - file metadata if this fails. - """ - file_model = File.create(workspace_id, params) - static_path = self._get_static_file_path(file_model) - shutil.move(path, static_path) - file_orm = FileOrm.from_model(file_model) - self._session.add(file_orm) - self._session.commit() - return file_model - - async def upload_file( - self, - workspace_id: WorkspaceId | None, - file: UploadFile, - ) -> File: - """Upload a file. - - *Important*: We may be left with a static file that is not associated with any - file metadata if this fails. - """ - temp_path: Path | None = None - try: - params, temp_path = await self._write_to_temp_file(file) - file_model = self.create(workspace_id, params, temp_path) - except Exception: - logger.exception("Failed to upload file") - raise - else: - return file_model - finally: - if temp_path: - temp_path.unlink(missing_ok=True) diff --git a/src/askui/chat/api/health/__init__.py b/src/askui/chat/api/health/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/health/router.py b/src/askui/chat/api/health/router.py deleted file mode 100644 index b6a17580..00000000 --- a/src/askui/chat/api/health/router.py +++ /dev/null @@ -1,18 +0,0 @@ -from fastapi import APIRouter, status -from pydantic import BaseModel - -router = APIRouter(prefix="/health", tags=["healthcheck"]) - - -class HealthCheck(BaseModel): - status: str = "OK" - - -@router.get( - "", - summary="Perform a Health Check", - response_description="Return HTTP Status Code 200 (OK)", - status_code=status.HTTP_200_OK, -) -def get_health() -> HealthCheck: - return HealthCheck(status="OK") diff --git a/src/askui/chat/api/mcp_clients/__init__.py b/src/askui/chat/api/mcp_clients/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/mcp_clients/dependencies.py b/src/askui/chat/api/mcp_clients/dependencies.py deleted file mode 100644 index 3a2744fd..00000000 --- a/src/askui/chat/api/mcp_clients/dependencies.py +++ /dev/null @@ -1,14 +0,0 @@ -from fastapi import Depends - -from askui.chat.api.mcp_clients.manager import McpClientManagerManager -from askui.chat.api.mcp_configs.dependencies import McpConfigServiceDep -from askui.chat.api.mcp_configs.service import McpConfigService - - -def get_mcp_client_manager_manager( - mcp_config_service: McpConfigService = McpConfigServiceDep, -) -> McpClientManagerManager: - return McpClientManagerManager(mcp_config_service) - - -McpClientManagerManagerDep = Depends(get_mcp_client_manager_manager) diff --git a/src/askui/chat/api/mcp_clients/manager.py b/src/askui/chat/api/mcp_clients/manager.py deleted file mode 100644 index 3f7c8ef5..00000000 --- a/src/askui/chat/api/mcp_clients/manager.py +++ /dev/null @@ -1,156 +0,0 @@ -import types -from datetime import timedelta -from typing import Any, Type - -import anyio -import mcp -from fastmcp import Client -from fastmcp.client.client import CallToolResult, ProgressHandler -from fastmcp.exceptions import ToolError -from fastmcp.mcp_config import MCPConfig - -from askui.chat.api.mcp_configs.service import McpConfigService -from askui.chat.api.models import WorkspaceId - -McpServerName = str - - -class McpServerConnectionError(Exception): - """Exception raised when a MCP server connection fails.""" - - def __init__(self, mcp_server_name: McpServerName, error: Exception): - super().__init__(f"Failed to connect to MCP server: {mcp_server_name}: {error}") - self.mcp_server_name = mcp_server_name - self.error = error - - -class McpClientManager: - def __init__( - self, mcp_clients: dict[McpServerName, Client[Any]] | None = None - ) -> None: - self._mcp_clients = mcp_clients or {} - self._tools: dict[McpServerName, list[mcp.types.Tool]] = {} - - @classmethod - def from_config(cls, mcp_config: MCPConfig) -> "McpClientManager": - mcp_clients: dict[McpServerName, Client[Any]] = { - mcp_server_name: Client(mcp_server_config.to_transport()) - for mcp_server_name, mcp_server_config in mcp_config.mcpServers.items() - } - return cls(mcp_clients) - - async def connect(self) -> "McpClientManager": - for mcp_server_name, mcp_client in self._mcp_clients.items(): - try: - await mcp_client._connect() # noqa: SLF001 - except Exception as e: # noqa: PERF203 - raise McpServerConnectionError(mcp_server_name, e) from e - return self - - async def disconnect(self, force: bool = False) -> None: - for mcp_client in self._mcp_clients.values(): - if mcp_client.is_connected(): - await mcp_client._disconnect(force) # noqa: SLF001 - - async def list_tools( - self, - ) -> list[mcp.types.Tool]: - tools: list[mcp.types.Tool] = [] - for mcp_server_name, mcp_client in self._mcp_clients.items(): - if mcp_server_name not in self._tools: - self._tools[mcp_server_name] = await mcp_client.list_tools() - tools.extend(self._tools[mcp_server_name]) - return tools - - async def call_tool( - self, - name: str, - arguments: dict[str, Any] | None = None, - timeout: timedelta | float | None = None, # noqa: ASYNC109 - progress_handler: ProgressHandler | None = None, - raise_on_error: bool = True, - ) -> CallToolResult: - for mcp_server_name, tools in self._tools.items(): # Make lookup faster - for tool in tools: - if tool.name == name: - return await self._mcp_clients[mcp_server_name].call_tool( - name, - arguments, - timeout=timeout, - progress_handler=progress_handler, - raise_on_error=raise_on_error, - ) - error_msg = f"Unknown tool: {name}" - if raise_on_error: - raise ToolError(error_msg) - return CallToolResult( - content=[mcp.types.TextContent(type="text", text=error_msg)], - structured_content=None, - data=None, - is_error=True, - ) - - async def __aenter__(self) -> "McpClientManager": - return await self.connect() - - async def __aexit__( - self, - exc_type: Type[BaseException] | None, - exc_value: BaseException | None, - traceback: types.TracebackType | None, - ) -> None: - await self.disconnect() - - -McpClientManagerKey = str - - -class McpClientManagerManager: - _mcp_client_managers: dict[McpClientManagerKey, McpClientManager | None] = {} - _lock: anyio.Lock = anyio.Lock() - - def __init__(self, mcp_config_service: McpConfigService) -> None: - self._mcp_config_service = mcp_config_service - - async def get_mcp_client_manager( - self, workspace_id: WorkspaceId | None - ) -> McpClientManager | None: - key: McpClientManagerKey = ( - f"workspace_{workspace_id}" if workspace_id else "global" - ) - if key in McpClientManagerManager._mcp_client_managers: - return McpClientManagerManager._mcp_client_managers[key] - - fast_mcp_config = self._mcp_config_service.retrieve_fast_mcp_config( - workspace_id - ) - if not fast_mcp_config: - McpClientManagerManager._mcp_client_managers[key] = None - return None - - async with McpClientManagerManager._lock: - if key not in McpClientManagerManager._mcp_client_managers: - try: - mcp_client_manager = McpClientManager.from_config(fast_mcp_config) - McpClientManagerManager._mcp_client_managers[key] = ( - mcp_client_manager - ) - await mcp_client_manager.connect() - except Exception: - if key in McpClientManagerManager._mcp_client_managers: - if ( - _mcp_client_manager - := McpClientManagerManager._mcp_client_managers[key] - ): - await _mcp_client_manager.disconnect(force=True) - del McpClientManagerManager._mcp_client_managers[key] - raise - return McpClientManagerManager._mcp_client_managers[key] - - async def disconnect_all(self, force: bool = False) -> None: - async with McpClientManagerManager._lock: - for ( - mcp_client_manager - ) in McpClientManagerManager._mcp_client_managers.values(): - if mcp_client_manager: - await mcp_client_manager.disconnect(force) diff --git a/src/askui/chat/api/mcp_configs/__init__.py b/src/askui/chat/api/mcp_configs/__init__.py deleted file mode 100644 index 8b137891..00000000 --- a/src/askui/chat/api/mcp_configs/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/askui/chat/api/mcp_configs/dependencies.py b/src/askui/chat/api/mcp_configs/dependencies.py deleted file mode 100644 index fc807081..00000000 --- a/src/askui/chat/api/mcp_configs/dependencies.py +++ /dev/null @@ -1,16 +0,0 @@ -from fastapi import Depends - -from askui.chat.api.db.session import SessionDep -from askui.chat.api.dependencies import SettingsDep -from askui.chat.api.mcp_configs.service import McpConfigService -from askui.chat.api.settings import Settings - - -def get_mcp_config_service( - session: SessionDep, settings: Settings = SettingsDep -) -> McpConfigService: - """Get McpConfigService instance.""" - return McpConfigService(session, settings.mcp_configs) - - -McpConfigServiceDep = Depends(get_mcp_config_service) diff --git a/src/askui/chat/api/mcp_configs/models.py b/src/askui/chat/api/mcp_configs/models.py deleted file mode 100644 index a98fcfbe..00000000 --- a/src/askui/chat/api/mcp_configs/models.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Annotated, Literal - -from fastmcp.mcp_config import RemoteMCPServer as _RemoteMCPServer -from fastmcp.mcp_config import StdioMCPServer -from pydantic import BaseModel, Field - -from askui.chat.api.models import McpConfigId, WorkspaceId, WorkspaceResource -from askui.utils.datetime_utils import UnixDatetime, now -from askui.utils.id_utils import generate_time_ordered_id -from askui.utils.not_given import NOT_GIVEN, BaseModelWithNotGiven, NotGiven - - -class RemoteMCPServer(_RemoteMCPServer): - auth: Annotated[ - str | Literal["oauth"] | None, # noqa: PYI051 - Field( - description='Either a string representing a Bearer token or the literal "oauth" to use OAuth authentication.', - ), - ] = None - - -McpServer = StdioMCPServer | RemoteMCPServer - - -class McpConfigBase(BaseModel): - """Base MCP configuration model.""" - - name: str - mcp_server: McpServer - - -class McpConfigCreate(McpConfigBase): - """Parameters for creating an MCP configuration.""" - - -class McpConfigModify(BaseModelWithNotGiven): - """Parameters for modifying an MCP configuration.""" - - name: str | NotGiven = NOT_GIVEN - mcp_server: McpServer | NotGiven = NOT_GIVEN - - -class McpConfig(McpConfigBase, WorkspaceResource): - """An MCP configuration that can be stored and managed.""" - - id: McpConfigId - object: Literal["mcp_config"] = "mcp_config" - created_at: UnixDatetime - - @classmethod - def create( - cls, workspace_id: WorkspaceId | None, params: McpConfigCreate - ) -> "McpConfig": - return cls( - id=generate_time_ordered_id("mcpcnf"), - created_at=now(), - workspace_id=workspace_id, - **params.model_dump(), - ) - - def modify(self, params: McpConfigModify) -> "McpConfig": - return McpConfig.model_validate( - { - **self.model_dump(), - **params.model_dump(), - } - ) diff --git a/src/askui/chat/api/mcp_configs/orms.py b/src/askui/chat/api/mcp_configs/orms.py deleted file mode 100644 index 757d5004..00000000 --- a/src/askui/chat/api/mcp_configs/orms.py +++ /dev/null @@ -1,33 +0,0 @@ -"""MCP configuration database model.""" - -from datetime import datetime -from typing import Any -from uuid import UUID - -from sqlalchemy import JSON, String, Uuid -from sqlalchemy.orm import Mapped, mapped_column - -from askui.chat.api.db.orm.base import Base -from askui.chat.api.db.orm.types import UnixDatetime, create_prefixed_id_type -from askui.chat.api.mcp_configs.models import McpConfig - -McpConfigId = create_prefixed_id_type("mcpcnf") - - -class McpConfigOrm(Base): - """MCP configuration database model.""" - - __tablename__ = "mcp_configs" - - id: Mapped[str] = mapped_column(McpConfigId, primary_key=True) - workspace_id: Mapped[UUID | None] = mapped_column(Uuid, nullable=True, index=True) - created_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) - name: Mapped[str] = mapped_column(String, nullable=False) - mcp_server: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False) - - @classmethod - def from_model(cls, model: McpConfig) -> "McpConfigOrm": - return cls(**model.model_dump(exclude={"object"})) - - def to_model(self) -> McpConfig: - return McpConfig.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/mcp_configs/router.py b/src/askui/chat/api/mcp_configs/router.py deleted file mode 100644 index 41b9e3c5..00000000 --- a/src/askui/chat/api/mcp_configs/router.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import Annotated - -from fastapi import APIRouter, Header, status - -from askui.chat.api.dependencies import ListQueryDep -from askui.chat.api.mcp_configs.dependencies import McpConfigServiceDep -from askui.chat.api.mcp_configs.models import ( - McpConfig, - McpConfigCreate, - McpConfigModify, -) -from askui.chat.api.mcp_configs.service import McpConfigService -from askui.chat.api.models import McpConfigId, WorkspaceId -from askui.utils.api_utils import ListQuery, ListResponse - -router = APIRouter(prefix="/mcp-configs", tags=["mcp-configs"]) - - -@router.get("", response_model_exclude_none=True) -def list_mcp_configs( - askui_workspace: Annotated[WorkspaceId | None, Header()], - query: ListQuery = ListQueryDep, - mcp_config_service: McpConfigService = McpConfigServiceDep, -) -> ListResponse[McpConfig]: - return mcp_config_service.list_(workspace_id=askui_workspace, query=query) - - -@router.post("", status_code=status.HTTP_201_CREATED, response_model_exclude_none=True) -def create_mcp_config( - params: McpConfigCreate, - askui_workspace: Annotated[WorkspaceId, Header()], - mcp_config_service: McpConfigService = McpConfigServiceDep, -) -> McpConfig: - """Create a new MCP configuration.""" - return mcp_config_service.create(workspace_id=askui_workspace, params=params) - - -@router.get("/{mcp_config_id}", response_model_exclude_none=True) -def retrieve_mcp_config( - mcp_config_id: McpConfigId, - askui_workspace: Annotated[WorkspaceId | None, Header()], - mcp_config_service: McpConfigService = McpConfigServiceDep, -) -> McpConfig: - """Get an MCP configuration by ID.""" - return mcp_config_service.retrieve( - workspace_id=askui_workspace, mcp_config_id=mcp_config_id - ) - - -@router.post("/{mcp_config_id}", response_model_exclude_none=True) -def modify_mcp_config( - mcp_config_id: McpConfigId, - params: McpConfigModify, - askui_workspace: Annotated[WorkspaceId, Header()], - mcp_config_service: McpConfigService = McpConfigServiceDep, -) -> McpConfig: - """Update an MCP configuration.""" - return mcp_config_service.modify( - workspace_id=askui_workspace, mcp_config_id=mcp_config_id, params=params - ) - - -@router.delete("/{mcp_config_id}", status_code=status.HTTP_204_NO_CONTENT) -def delete_mcp_config( - mcp_config_id: McpConfigId, - askui_workspace: Annotated[WorkspaceId | None, Header()], - mcp_config_service: McpConfigService = McpConfigServiceDep, -) -> None: - """Delete an MCP configuration.""" - mcp_config_service.delete(workspace_id=askui_workspace, mcp_config_id=mcp_config_id) diff --git a/src/askui/chat/api/mcp_configs/service.py b/src/askui/chat/api/mcp_configs/service.py deleted file mode 100644 index d540ee93..00000000 --- a/src/askui/chat/api/mcp_configs/service.py +++ /dev/null @@ -1,156 +0,0 @@ -from fastmcp.mcp_config import MCPConfig -from sqlalchemy import or_ -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session - -from askui.chat.api.db.queries import list_all -from askui.chat.api.mcp_configs.models import ( - McpConfig, - McpConfigCreate, - McpConfigId, - McpConfigModify, -) -from askui.chat.api.mcp_configs.orms import McpConfigOrm -from askui.chat.api.models import WorkspaceId -from askui.utils.api_utils import ( - LIST_LIMIT_MAX, - ForbiddenError, - LimitReachedError, - ListQuery, - ListResponse, - NotFoundError, -) - - -class McpConfigService: - """Service for managing McpConfig resources with database persistence.""" - - def __init__(self, session: Session, seeds: list[McpConfig]) -> None: - self._session = session - self._seeds = seeds - - def list_( - self, workspace_id: WorkspaceId | None, query: ListQuery - ) -> ListResponse[McpConfig]: - q = self._session.query(McpConfigOrm).filter( - or_( - McpConfigOrm.workspace_id == workspace_id, - McpConfigOrm.workspace_id.is_(None), - ), - ) - orms: list[McpConfigOrm] - orms, has_more = list_all(q, query, McpConfigOrm.id) - data = [orm.to_model() for orm in orms] - return ListResponse( - data=data, - has_more=has_more, - first_id=data[0].id if data else None, - last_id=data[-1].id if data else None, - ) - - def _find_by_id( - self, workspace_id: WorkspaceId | None, mcp_config_id: McpConfigId - ) -> McpConfigOrm: - mcp_config_orm: McpConfigOrm | None = ( - self._session.query(McpConfigOrm) - .filter( - McpConfigOrm.id == mcp_config_id, - or_( - McpConfigOrm.workspace_id == workspace_id, - McpConfigOrm.workspace_id.is_(None), - ), - ) - .first() - ) - if mcp_config_orm is None: - error_msg = f"MCP configuration {mcp_config_id} not found" - raise NotFoundError(error_msg) - return mcp_config_orm - - def retrieve( - self, workspace_id: WorkspaceId | None, mcp_config_id: McpConfigId - ) -> McpConfig: - mcp_config_model = self._find_by_id(workspace_id, mcp_config_id) - return mcp_config_model.to_model() - - def retrieve_fast_mcp_config( - self, workspace_id: WorkspaceId | None - ) -> MCPConfig | None: - list_response = self.list_( - workspace_id=workspace_id, - query=ListQuery(limit=LIST_LIMIT_MAX, order="asc"), - ) - mcp_servers_dict = { - mcp_config.name: mcp_config.mcp_server for mcp_config in list_response.data - } - return MCPConfig(mcpServers=mcp_servers_dict) if mcp_servers_dict else None - - def create( - self, workspace_id: WorkspaceId | None, params: McpConfigCreate - ) -> McpConfig: - try: - mcp_config = McpConfig.create(workspace_id, params) - mcp_config_model = McpConfigOrm.from_model(mcp_config) - self._session.add(mcp_config_model) - self._session.commit() - except IntegrityError as e: - if "MCP configuration limit reached" in str(e): - raise LimitReachedError(str(e)) from e - raise - else: - return mcp_config - - def modify( - self, - workspace_id: WorkspaceId | None, - mcp_config_id: McpConfigId, - params: McpConfigModify, - force: bool = False, - ) -> McpConfig: - mcp_config_model = self._find_by_id(workspace_id, mcp_config_id) - if mcp_config_model.workspace_id is None and not force: - error_msg = f"Default MCP configuration {mcp_config_id} cannot be modified" - raise ForbiddenError(error_msg) - mcp_config_model.update(params.model_dump()) - self._session.commit() - self._session.refresh(mcp_config_model) - return mcp_config_model.to_model() - - def delete( - self, - workspace_id: WorkspaceId | None, - mcp_config_id: McpConfigId, - force: bool = False, - ) -> None: - # Use a single query to find and delete atomically - mcp_config_model = ( - self._session.query(McpConfigOrm) - .filter( - McpConfigOrm.id == mcp_config_id, - or_( - McpConfigOrm.workspace_id == workspace_id, - McpConfigOrm.workspace_id.is_(None), - ), - ) - .first() - ) - - if mcp_config_model is None: - error_msg = f"MCP configuration {mcp_config_id} not found" - raise NotFoundError(error_msg) - - if mcp_config_model.workspace_id is None and not force: - error_msg = f"Default MCP configuration {mcp_config_id} cannot be deleted" - raise ForbiddenError(error_msg) - - self._session.delete(mcp_config_model) - self._session.commit() - - def seed(self) -> None: - """Seed the MCP configuration service with default MCP configurations.""" - for seed in self._seeds: - self._session.query(McpConfigOrm).filter( - McpConfigOrm.id == seed.id - ).delete() - self._session.add(McpConfigOrm.from_model(seed)) - self._session.commit() diff --git a/src/askui/chat/api/mcp_servers/__init__.py b/src/askui/chat/api/mcp_servers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/mcp_servers/android.py b/src/askui/chat/api/mcp_servers/android.py deleted file mode 100644 index 1ffe5da1..00000000 --- a/src/askui/chat/api/mcp_servers/android.py +++ /dev/null @@ -1,55 +0,0 @@ -from fastmcp import FastMCP - -from askui.chat.api.mcp_servers.android_setup_doc import ANDROID_SETUP_GUIDE -from askui.tools.android.agent_os_facade import AndroidAgentOsFacade -from askui.tools.android.ppadb_agent_os import PpadbAgentOs -from askui.tools.android.tools import ( - AndroidConnectTool, - AndroidDragAndDropTool, - AndroidGetConnectedDevicesSerialNumbersTool, - AndroidGetConnectedDisplaysInfosTool, - AndroidGetCurrentConnectedDeviceInfosTool, - AndroidKeyCombinationTool, - AndroidKeyTapEventTool, - AndroidScreenshotTool, - AndroidSelectDeviceBySerialNumberTool, - AndroidSelectDisplayByUniqueIDTool, - AndroidShellTool, - AndroidSwipeTool, - AndroidTapTool, - AndroidTypeTool, -) - -mcp = FastMCP(name="AskUI Android MCP") - -# Initialize the AndroidAgentOsFacade -ANDROID_AGENT_OS = PpadbAgentOs() -ANDROID_AGENT_OS_FACADE = AndroidAgentOsFacade(ANDROID_AGENT_OS) -TOOLS = [ - AndroidSelectDeviceBySerialNumberTool(ANDROID_AGENT_OS_FACADE), - AndroidSelectDisplayByUniqueIDTool(ANDROID_AGENT_OS_FACADE), - AndroidGetConnectedDevicesSerialNumbersTool(ANDROID_AGENT_OS_FACADE), - AndroidGetConnectedDisplaysInfosTool(ANDROID_AGENT_OS_FACADE), - AndroidGetCurrentConnectedDeviceInfosTool(ANDROID_AGENT_OS_FACADE), - AndroidConnectTool(ANDROID_AGENT_OS_FACADE), - AndroidScreenshotTool(ANDROID_AGENT_OS_FACADE), - AndroidTapTool(ANDROID_AGENT_OS_FACADE), - AndroidTypeTool(ANDROID_AGENT_OS_FACADE), - AndroidDragAndDropTool(ANDROID_AGENT_OS_FACADE), - AndroidKeyTapEventTool(ANDROID_AGENT_OS_FACADE), - AndroidSwipeTool(ANDROID_AGENT_OS_FACADE), - AndroidKeyCombinationTool(ANDROID_AGENT_OS_FACADE), - AndroidShellTool(ANDROID_AGENT_OS_FACADE), -] - -for tool in TOOLS: - mcp.add_tool(tool.to_mcp_tool({"android"})) - - -@mcp.tool( - description="""Provides step-by-step instructions for setting up Android emulators or real devices. - Use this tool when no device is connected or the ADB server cannot detect any devices.""", - tags={"android"}, -) -def android_setup_helper() -> str: - return ANDROID_SETUP_GUIDE diff --git a/src/askui/chat/api/mcp_servers/android_setup_doc.py b/src/askui/chat/api/mcp_servers/android_setup_doc.py deleted file mode 100644 index d5e9cab3..00000000 --- a/src/askui/chat/api/mcp_servers/android_setup_doc.py +++ /dev/null @@ -1,79 +0,0 @@ -ANDROID_SETUP_GUIDE = """ -# Guide: Setting Up Android Devices - -This guide explains how to prepare **Android Emulators** and **real Android devices** for automation with AskUI. - ---- - -## Android Emulator - -Automating an emulator with AskUI is straightforward once the emulator is installed and running. - -### 1. Install Android Studio -- Download and install **Android Studio** from the [official website](https://developer.android.com/studio). - -### 2. Create an Emulator with AVD Manager -1. Open **Android Studio**. -2. Go to **More Actions → Virtual Device Manager**. -3. Click **Create Virtual Device…**. -4. Choose a hardware profile (e.g., Pixel 5) → **Next**. -5. Select a system image (preferably one with the Play Store). Download may take a few minutes → **Next**. -6. Configure options if needed → **Finish**. - -📖 Reference: [Create and manage virtual devices](https://developer.android.com/studio/run/managing-avds) - -### 3. Start the Emulator -1. In **AVD Manager**, click the **Play** button next to your emulator. -2. Wait until the emulator boots fully. - ---- - -## Real Android Devices - -AskUI can also automate **physical Android devices** once ADB is installed. - -### 1. Enable Developer Options & USB Debugging -1. On your device, go to **Settings → About phone**. -2. Tap **Build number** seven times to enable Developer Options. -3. Go back to **Settings → Developer Options**. -4. Enable **USB Debugging**. - -📖 Reference: [Enable adb debugging on your device](https://developer.android.com/studio/command-line/adb#Enabling) - -### 2. Install ADB (Platform-Tools) -1. Download **Platform-Tools** from the [official ADB source](https://developer.android.com/studio/releases/platform-tools). -2. Extract the ZIP to a folder, e.g.: - - Windows: `C:\\platform-tools` - - macOS/Linux: `~/platform-tools` - -📖 Reference: [Android Debug Bridge](https://developer.android.com/studio/command-line/adb) - -### 3. Add ADB to PATH -- **Windows** - 1. Press `Win + S`, search for **Environment Variables**, and open it. - 2. Click **Environment Variables…**. - 3. Under **System variables → Path**, click **Edit…**. - 4. Add the path to your `platform-tools` folder. - 5. Save with **OK**. - -- **macOS/Linux** - Add this line to your shell config (`~/.bash_profile`, `~/.zshrc`, or `~/.bashrc`): - ```bash - export PATH="$PATH:/path/to/platform-tools" - ``` - Then save and reload your shell. - -### 4. Verify Device Connection -1. Connect your device via USB. - - On first connection, confirm the **USB Debugging Allow prompt** on your device. -2. Open a terminal and run: - ```bash - adb devices - ``` - Expected output: - ``` - List of devices attached - 1234567890abcdef device - ``` ---- -""" diff --git a/src/askui/chat/api/mcp_servers/computer.py b/src/askui/chat/api/mcp_servers/computer.py deleted file mode 100644 index 488b7404..00000000 --- a/src/askui/chat/api/mcp_servers/computer.py +++ /dev/null @@ -1,49 +0,0 @@ -from fastmcp import FastMCP - -from askui.tools.askui.askui_controller import AskUiControllerClient -from askui.tools.computer import ( - ComputerConnectTool, - ComputerDisconnectTool, - ComputerGetMousePositionTool, - ComputerKeyboardPressedTool, - ComputerKeyboardReleaseTool, - ComputerKeyboardTapTool, - ComputerListDisplaysTool, - ComputerMouseClickTool, - ComputerMouseHoldDownTool, - ComputerMouseReleaseTool, - ComputerMouseScrollTool, - ComputerMoveMouseTool, - ComputerRetrieveActiveDisplayTool, - ComputerScreenshotTool, - ComputerSetActiveDisplayTool, - ComputerTypeTool, -) -from askui.tools.computer_agent_os_facade import ComputerAgentOsFacade - -mcp = FastMCP(name="AskUI Computer MCP") - -COMPUTER_AGENT_OS = AskUiControllerClient() -COMPUTER_AGENT_OS_FACADE = ComputerAgentOsFacade(COMPUTER_AGENT_OS) - -TOOLS = [ - ComputerGetMousePositionTool(COMPUTER_AGENT_OS_FACADE), - ComputerKeyboardPressedTool(COMPUTER_AGENT_OS_FACADE), - ComputerKeyboardReleaseTool(COMPUTER_AGENT_OS_FACADE), - ComputerKeyboardTapTool(COMPUTER_AGENT_OS_FACADE), - ComputerListDisplaysTool(COMPUTER_AGENT_OS_FACADE), - ComputerMouseClickTool(COMPUTER_AGENT_OS_FACADE), - ComputerMouseHoldDownTool(COMPUTER_AGENT_OS_FACADE), - ComputerMouseReleaseTool(COMPUTER_AGENT_OS_FACADE), - ComputerMouseScrollTool(COMPUTER_AGENT_OS_FACADE), - ComputerMoveMouseTool(COMPUTER_AGENT_OS_FACADE), - ComputerRetrieveActiveDisplayTool(COMPUTER_AGENT_OS_FACADE), - ComputerScreenshotTool(COMPUTER_AGENT_OS_FACADE), - ComputerSetActiveDisplayTool(COMPUTER_AGENT_OS_FACADE), - ComputerTypeTool(COMPUTER_AGENT_OS_FACADE), - ComputerConnectTool(COMPUTER_AGENT_OS_FACADE), - ComputerDisconnectTool(COMPUTER_AGENT_OS_FACADE), -] - -for tool in TOOLS: - mcp.add_tool(tool.to_mcp_tool({"computer"})) diff --git a/src/askui/chat/api/mcp_servers/testing.py b/src/askui/chat/api/mcp_servers/testing.py deleted file mode 100644 index b76c8165..00000000 --- a/src/askui/chat/api/mcp_servers/testing.py +++ /dev/null @@ -1,79 +0,0 @@ -from fastmcp import FastMCP -from fastmcp.tools import Tool - -from askui.chat.api.dependencies import get_settings -from askui.tools.testing.execution_tools import ( - CreateExecutionTool, - DeleteExecutionTool, - ListExecutionTool, - ModifyExecutionTool, - RetrieveExecutionTool, -) -from askui.tools.testing.feature_tools import ( - CreateFeatureTool, - DeleteFeatureTool, - ListFeatureTool, - ModifyFeatureTool, - RetrieveFeatureTool, -) -from askui.tools.testing.scenario_tools import ( - CreateScenarioTool, - DeleteScenarioTool, - ListScenarioTool, - ModifyScenarioTool, - RetrieveScenarioTool, -) - -mcp = FastMCP(name="AskUI Testing MCP") - -settings = get_settings() -base_dir = settings.data_dir / "testing" - -FEATURE_TOOLS = [ - CreateFeatureTool(base_dir), - RetrieveFeatureTool(base_dir), - ListFeatureTool(base_dir), - ModifyFeatureTool(base_dir), - DeleteFeatureTool(base_dir), -] - -SCENARIO_TOOLS = [ - CreateScenarioTool(base_dir), - RetrieveScenarioTool(base_dir), - ListScenarioTool(base_dir), - ModifyScenarioTool(base_dir), - DeleteScenarioTool(base_dir), -] - -EXECUTION_TOOLS = [ - CreateExecutionTool(base_dir), - RetrieveExecutionTool(base_dir), - ListExecutionTool(base_dir), - ModifyExecutionTool(base_dir), - DeleteExecutionTool(base_dir), -] - - -TOOLS = [ - *FEATURE_TOOLS, - *SCENARIO_TOOLS, - *EXECUTION_TOOLS, -] - - -for tool in TOOLS: - tags = {"testing"} - if tool in FEATURE_TOOLS: - tags.add("feature") - if tool in SCENARIO_TOOLS: - tags.add("scenario") - if tool in EXECUTION_TOOLS: - tags.add("execution") - mcp.add_tool( - Tool.from_function( - tool.__call__, - name=tool.name, - description=tool.description, - tags=tags, - ), - ) diff --git a/src/askui/chat/api/mcp_servers/utility.py b/src/askui/chat/api/mcp_servers/utility.py deleted file mode 100644 index afa4d50f..00000000 --- a/src/askui/chat/api/mcp_servers/utility.py +++ /dev/null @@ -1,38 +0,0 @@ -import asyncio -from typing import Annotated - -from fastmcp import FastMCP -from pydantic import Field - -mcp = FastMCP(name="AskUI Utility MCP") - - -@mcp.tool( - description="Wait for a specified number of seconds", - tags={"utility"}, -) -async def utility_wait( - seconds: Annotated[ - float, - Field(ge=0.0, le=3600.0, description="Number of seconds to wait (0-3600)"), - ], -) -> str: - """ - Wait for the specified number of seconds. - - Args: - seconds (float): Number of seconds to wait, between 0 and 3600 (1 hour). - - Returns: - str: Confirmation message indicating the wait is complete. - - Example: - ```python - wait(5.0) # Wait for 5 seconds - ``` - """ - if seconds == 0: - return "Wait completed immediately (0 seconds)" - - await asyncio.sleep(seconds) - return f"Wait completed after {seconds} seconds" diff --git a/src/askui/chat/api/messages/__init__.py b/src/askui/chat/api/messages/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/messages/chat_history_manager.py b/src/askui/chat/api/messages/chat_history_manager.py deleted file mode 100644 index c766de23..00000000 --- a/src/askui/chat/api/messages/chat_history_manager.py +++ /dev/null @@ -1,99 +0,0 @@ -from anthropic.types.beta import BetaTextBlockParam, BetaToolUnionParam - -from askui.chat.api.messages.models import Message, MessageCreate -from askui.chat.api.messages.service import MessageService -from askui.chat.api.messages.translator import MessageTranslator -from askui.chat.api.models import MessageId, ThreadId, WorkspaceId -from askui.models.shared.agent_message_param import MessageParam -from askui.models.shared.settings import ActSystemPrompt -from askui.models.shared.truncation_strategies import TruncationStrategyFactory -from askui.utils.api_utils import NotFoundError - - -class ChatHistoryManager: - """ - Manages chat history by providing methods to retrieve and add messages. - - This service encapsulates the interaction between MessageService and MessageTranslator - to provide a clean interface for managing chat history in the context of AI agents. - """ - - def __init__( - self, - message_service: MessageService, - message_translator: MessageTranslator, - truncation_strategy_factory: TruncationStrategyFactory, - ) -> None: - """ - Initialize the chat history manager. - - Args: - message_service (MessageService): Service for managing message persistence. - message_translator (MessageTranslator): Translator for converting between - message formats. - truncation_strategy_factory (TruncationStrategyFactory): Factory for creating truncation strategies. - """ - self._message_service = message_service - self._message_translator = message_translator - self._message_content_translator = message_translator.content_translator - self._truncation_strategy_factory = truncation_strategy_factory - - async def retrieve_message_params( - self, - workspace_id: WorkspaceId, - thread_id: ThreadId, - model: str, - system: ActSystemPrompt | None, - tools: list[BetaToolUnionParam], - ) -> list[MessageParam]: - truncation_strategy = ( - self._truncation_strategy_factory.create_truncation_strategy( - system=system, - tools=tools, - messages=[], - model=model, - ) - ) - for msg in self._message_service.iter( - workspace_id=workspace_id, - thread_id=thread_id, - ): - anthropic_message = await self._message_translator.to_anthropic(msg) - truncation_strategy.append_message(anthropic_message) - return truncation_strategy.messages - - def retrieve_last_message( - self, - workspace_id: WorkspaceId, - thread_id: ThreadId, - ) -> MessageId: - last_message_id = self._message_service.retrieve_last_message_id( - workspace_id, thread_id - ) - if last_message_id is None: - error_msg = f"No messages found in thread {thread_id}" - raise NotFoundError(error_msg) - return last_message_id - - async def append_message( - self, - workspace_id: WorkspaceId, - thread_id: ThreadId, - assistant_id: str | None, - run_id: str, - message: MessageParam, - parent_id: str, - ) -> Message: - return self._message_service.create( - workspace_id=workspace_id, - thread_id=thread_id, - params=MessageCreate( - parent_id=parent_id, - assistant_id=assistant_id if message.role == "assistant" else None, - role=message.role, - content=await self._message_content_translator.from_anthropic( - message.content - ), - run_id=run_id, - ), - ) diff --git a/src/askui/chat/api/messages/dependencies.py b/src/askui/chat/api/messages/dependencies.py deleted file mode 100644 index cfea1d91..00000000 --- a/src/askui/chat/api/messages/dependencies.py +++ /dev/null @@ -1,56 +0,0 @@ -from fastapi import Depends - -from askui.chat.api.db.session import SessionDep -from askui.chat.api.dependencies import WorkspaceIdDep -from askui.chat.api.files.dependencies import FileServiceDep -from askui.chat.api.files.service import FileService -from askui.chat.api.messages.chat_history_manager import ChatHistoryManager -from askui.chat.api.messages.service import MessageService -from askui.chat.api.messages.translator import MessageTranslator -from askui.chat.api.models import WorkspaceId -from askui.models.shared.truncation_strategies import ( - SimpleTruncationStrategyFactory, - TruncationStrategyFactory, -) - - -def get_message_service( - session: SessionDep, -) -> MessageService: - """Get MessageService instance.""" - return MessageService(session) - - -MessageServiceDep = Depends(get_message_service) - - -def get_message_translator( - file_service: FileService = FileServiceDep, - workspace_id: WorkspaceId | None = WorkspaceIdDep, -) -> MessageTranslator: - return MessageTranslator(file_service, workspace_id) - - -MessageTranslatorDep = Depends(get_message_translator) - - -def get_truncation_strategy_factory() -> TruncationStrategyFactory: - return SimpleTruncationStrategyFactory() - - -TruncationStrategyFactoryDep = Depends(get_truncation_strategy_factory) - - -def get_chat_history_manager( - message_service: MessageService = MessageServiceDep, - message_translator: MessageTranslator = MessageTranslatorDep, - truncation_strategy_factory: TruncationStrategyFactory = TruncationStrategyFactoryDep, -) -> ChatHistoryManager: - return ChatHistoryManager( - message_service=message_service, - message_translator=message_translator, - truncation_strategy_factory=truncation_strategy_factory, - ) - - -ChatHistoryManagerDep = Depends(get_chat_history_manager) diff --git a/src/askui/chat/api/messages/models.py b/src/askui/chat/api/messages/models.py deleted file mode 100644 index 2f9d34aa..00000000 --- a/src/askui/chat/api/messages/models.py +++ /dev/null @@ -1,108 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel - -from askui.chat.api.models import ( - AssistantId, - FileId, - MessageId, - RunId, - ThreadId, - WorkspaceId, - WorkspaceResource, -) -from askui.models.shared.agent_message_param import ( - Base64ImageSourceParam, - BetaRedactedThinkingBlock, - BetaThinkingBlock, - CacheControlEphemeralParam, - StopReason, - TextBlockParam, - ToolUseBlockParam, - UrlImageSourceParam, -) -from askui.utils.datetime_utils import UnixDatetime, now -from askui.utils.id_utils import generate_time_ordered_id - -ROOT_MESSAGE_PARENT_ID = "msg_000000000000000000000000" - - -class BetaFileDocumentSourceParam(BaseModel): - file_id: str - type: Literal["file"] = "file" - - -Source = BetaFileDocumentSourceParam - - -class RequestDocumentBlockParam(BaseModel): - source: Source - type: Literal["document"] = "document" - cache_control: CacheControlEphemeralParam | None = None - - -class FileImageSourceParam(BaseModel): - """Image source that references a saved file.""" - - id: FileId - type: Literal["file"] = "file" - - -class ImageBlockParam(BaseModel): - source: Base64ImageSourceParam | UrlImageSourceParam | FileImageSourceParam - type: Literal["image"] = "image" - cache_control: CacheControlEphemeralParam | None = None - - -class ToolResultBlockParam(BaseModel): - tool_use_id: str - type: Literal["tool_result"] = "tool_result" - cache_control: CacheControlEphemeralParam | None = None - content: str | list[TextBlockParam | ImageBlockParam] - is_error: bool = False - - -ContentBlockParam = ( - ImageBlockParam - | TextBlockParam - | ToolResultBlockParam - | ToolUseBlockParam - | BetaThinkingBlock - | BetaRedactedThinkingBlock - | RequestDocumentBlockParam -) - - -class MessageParam(BaseModel): - role: Literal["user", "assistant"] - content: str | list[ContentBlockParam] - stop_reason: StopReason | None = None - - -class MessageBase(MessageParam): - assistant_id: AssistantId | None = None - run_id: RunId | None = None - parent_id: MessageId | None = None - - -class MessageCreate(MessageBase): - pass - - -class Message(MessageBase, WorkspaceResource): - id: MessageId - object: Literal["thread.message"] = "thread.message" - created_at: UnixDatetime - thread_id: ThreadId - - @classmethod - def create( - cls, workspace_id: WorkspaceId, thread_id: ThreadId, params: MessageCreate - ) -> "Message": - return cls( - id=generate_time_ordered_id("msg"), - created_at=now(), - workspace_id=workspace_id, - thread_id=thread_id, - **params.model_dump(), - ) diff --git a/src/askui/chat/api/messages/orms.py b/src/askui/chat/api/messages/orms.py deleted file mode 100644 index f2931bd4..00000000 --- a/src/askui/chat/api/messages/orms.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Message database model.""" - -from datetime import datetime -from typing import Any -from uuid import UUID - -from sqlalchemy import JSON, ForeignKey, String, Uuid -from sqlalchemy.orm import Mapped, mapped_column - -from askui.chat.api.assistants.orms import AssistantId -from askui.chat.api.db.orm.base import Base -from askui.chat.api.db.orm.types import ( - RunId, - ThreadId, - UnixDatetime, - create_prefixed_id_type, - create_sentinel_id_type, -) -from askui.chat.api.messages.models import ROOT_MESSAGE_PARENT_ID, Message - -MessageId = create_prefixed_id_type("msg") -_ParentMessageId = create_sentinel_id_type("msg", ROOT_MESSAGE_PARENT_ID) - - -class MessageOrm(Base): - """Message database model.""" - - __tablename__ = "messages" - - id: Mapped[str] = mapped_column(MessageId, primary_key=True) - thread_id: Mapped[str] = mapped_column( - ThreadId, - ForeignKey("threads.id", ondelete="CASCADE"), - nullable=False, - index=True, - ) - workspace_id: Mapped[UUID] = mapped_column(Uuid, nullable=False, index=True) - created_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) - role: Mapped[str] = mapped_column(String, nullable=False) - content: Mapped[str | list[dict[str, Any]]] = mapped_column(JSON, nullable=False) - stop_reason: Mapped[str | None] = mapped_column(String, nullable=True) - assistant_id: Mapped[str | None] = mapped_column( - AssistantId, ForeignKey("assistants.id", ondelete="SET NULL"), nullable=True - ) - run_id: Mapped[str | None] = mapped_column( - RunId, ForeignKey("runs.id", ondelete="SET NULL"), nullable=True - ) - parent_id: Mapped[str] = mapped_column( - _ParentMessageId, - ForeignKey("messages.id", ondelete="CASCADE"), - nullable=True, - index=True, - ) - - @classmethod - def from_model(cls, model: Message) -> "MessageOrm": - return cls( - **model.model_dump(exclude={"object", "created_at"}), - created_at=model.created_at, - ) - - def to_model(self) -> Message: - return Message.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py deleted file mode 100644 index 79f2eb9e..00000000 --- a/src/askui/chat/api/messages/router.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Annotated - -from fastapi import APIRouter, Header, status - -from askui.chat.api.dependencies import ListQueryDep -from askui.chat.api.messages.dependencies import MessageServiceDep -from askui.chat.api.messages.models import Message, MessageCreate -from askui.chat.api.messages.service import MessageService -from askui.chat.api.models import MessageId, ThreadId, WorkspaceId -from askui.utils.api_utils import ListQuery, ListResponse - -router = APIRouter(prefix="/threads/{thread_id}/messages", tags=["messages"]) - - -@router.get("") -def list_messages( - askui_workspace: Annotated[WorkspaceId, Header()], - thread_id: ThreadId, - query: ListQuery = ListQueryDep, - message_service: MessageService = MessageServiceDep, -) -> ListResponse[Message]: - return message_service.list_( - workspace_id=askui_workspace, thread_id=thread_id, query=query - ) - - -@router.get("/{message_id}/siblings") -def list_siblings( - askui_workspace: Annotated[WorkspaceId, Header()], - thread_id: ThreadId, - message_id: MessageId, - message_service: MessageService = MessageServiceDep, -) -> list[Message]: - """List all sibling messages for a given message. - - Sibling messages are messages that share the same `parent_id` as the specified message. - The specified message itself is included in the results. - Results are sorted by ID (chronological order, as IDs are BSON-based). - - Args: - askui_workspace (WorkspaceId): The workspace ID from header. - thread_id (ThreadId): The thread ID. - message_id (MessageId): The message ID to find siblings for. - message_service (MessageService): The message service dependency. - - Returns: - list[Message]: List of sibling messages sorted by ID. - - Raises: - NotFoundError: If the specified message does not exist. - """ - return message_service.list_siblings( - workspace_id=askui_workspace, - thread_id=thread_id, - message_id=message_id, - ) - - -@router.post("", status_code=status.HTTP_201_CREATED) -async def create_message( - askui_workspace: Annotated[WorkspaceId, Header()], - thread_id: ThreadId, - params: MessageCreate, - message_service: MessageService = MessageServiceDep, -) -> Message: - return message_service.create( - workspace_id=askui_workspace, - thread_id=thread_id, - params=params, - inject_cancelled_tool_results=True, - ) - - -@router.get("/{message_id}") -def retrieve_message( - askui_workspace: Annotated[WorkspaceId, Header()], - thread_id: ThreadId, - message_id: MessageId, - message_service: MessageService = MessageServiceDep, -) -> Message: - return message_service.retrieve( - workspace_id=askui_workspace, thread_id=thread_id, message_id=message_id - ) - - -@router.delete("/{message_id}", status_code=status.HTTP_204_NO_CONTENT) -def delete_message( - askui_workspace: Annotated[WorkspaceId, Header()], - thread_id: ThreadId, - message_id: MessageId, - message_service: MessageService = MessageServiceDep, -) -> None: - message_service.delete( - workspace_id=askui_workspace, thread_id=thread_id, message_id=message_id - ) diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py deleted file mode 100644 index 917e22d2..00000000 --- a/src/askui/chat/api/messages/service.py +++ /dev/null @@ -1,568 +0,0 @@ -from typing import Any, Iterator - -from sqlalchemy import CTE, desc, select -from sqlalchemy.orm import Query, Session - -from askui.chat.api.messages.models import ( - ROOT_MESSAGE_PARENT_ID, - ContentBlockParam, - Message, - MessageCreate, - ToolResultBlockParam, - ToolUseBlockParam, -) -from askui.chat.api.messages.orms import MessageOrm -from askui.chat.api.models import MessageId, ThreadId, WorkspaceId -from askui.chat.api.threads.orms import ThreadOrm -from askui.utils.api_utils import ( - LIST_LIMIT_DEFAULT, - ListOrder, - ListQuery, - ListResponse, - NotFoundError, -) - -_CANCELLED_TOOL_RESULT_CONTENT = ( - "Tool execution was cancelled because the previous run was interrupted. " - "Please retry the operation if needed." -) - - -class MessageService: - """Service for managing Message resources with database persistence.""" - - def __init__(self, session: Session) -> None: - self._session = session - - def _create_cancelled_tool_results( - self, - workspace_id: WorkspaceId, - thread_id: ThreadId, - parent_message: Message, - run_id: str | None, - ) -> MessageId: - """Create cancelled tool results if parent has pending tool_use blocks. - - Args: - workspace_id (WorkspaceId): The workspace ID. - thread_id (ThreadId): The thread ID. - parent_message (Message): The parent message to check for tool_use blocks. - run_id (str | None): The run ID to associate with the tool result message. - - Returns: - MessageId: The ID of the created tool result message, or the parent - message ID if no tool_use blocks were found. - """ - if not isinstance(parent_message.content, list): - return parent_message.id - - tool_use_blocks = [ - block - for block in parent_message.content - if isinstance(block, ToolUseBlockParam) - ] - if not tool_use_blocks: - return parent_message.id - - tool_result_content: list[ContentBlockParam] = [ - ToolResultBlockParam( - tool_use_id=block.id, - content=_CANCELLED_TOOL_RESULT_CONTENT, - is_error=True, - ) - for block in tool_use_blocks - ] - tool_result_params = MessageCreate( - role="user", - content=tool_result_content, - parent_id=parent_message.id, - run_id=run_id, - ) - tool_result_message = Message.create( - workspace_id, thread_id, tool_result_params - ) - self._session.add(MessageOrm.from_model(tool_result_message)) - return tool_result_message.id - - def _find_by_id( - self, workspace_id: WorkspaceId, thread_id: ThreadId, message_id: MessageId - ) -> MessageOrm: - """Find message by ID.""" - message_orm: MessageOrm | None = ( - self._session.query(MessageOrm) - .filter( - MessageOrm.id == message_id, - MessageOrm.thread_id == thread_id, - MessageOrm.workspace_id == workspace_id, - ) - .first() - ) - if message_orm is None: - error_msg = f"Message {message_id} not found in thread {thread_id}" - raise NotFoundError(error_msg) - return message_orm - - def _retrieve_latest_root( - self, workspace_id: WorkspaceId, thread_id: ThreadId - ) -> str | None: - """Retrieve the latest root message ID in a thread. - - Args: - workspace_id (WorkspaceId): The workspace ID. - thread_id (ThreadId): The thread ID. - - Returns: - str | None: The ID of the latest root message, or `None` if no root messages exist. - """ - return self._session.execute( - select(MessageOrm.id) - .filter( - MessageOrm.parent_id.is_(None), - MessageOrm.thread_id == thread_id, - MessageOrm.workspace_id == workspace_id, - ) - .order_by(desc(MessageOrm.id)) - .limit(1) - ).scalar_one_or_none() - - def _build_ancestors_cte( - self, message_id: MessageId, workspace_id: WorkspaceId, thread_id: ThreadId - ) -> CTE: - """Build a recursive CTE to traverse up the message tree from a given message. - - Args: - message_id (MessageId): The ID of the message to start traversing from. - workspace_id (WorkspaceId): The workspace ID. - thread_id (ThreadId): The thread ID. - - Returns: - CTE: A recursive common table expression that contains all ancestors of the message. - """ - # Build CTE to traverse up the tree from message_id - _ancestors_cte = ( - select(MessageOrm.id, MessageOrm.parent_id) - .filter( - MessageOrm.id == message_id, - MessageOrm.thread_id == thread_id, - MessageOrm.workspace_id == workspace_id, - ) - .cte(name="ancestors", recursive=True) - ) - - # Recursively traverse up until we hit NULL (root message) - _ancestors_recursive = select(MessageOrm.id, MessageOrm.parent_id).filter( - MessageOrm.id == _ancestors_cte.c.parent_id, - _ancestors_cte.c.parent_id.is_not(None), - ) - return _ancestors_cte.union_all(_ancestors_recursive) - - def _build_descendants_cte(self, message_id: MessageId) -> CTE: - """Build a recursive CTE to traverse down the message tree from a given message. - - Args: - message_id (MessageId): The ID of the message to start traversing from. - - Returns: - CTE: A recursive common table expression that contains all descendants of the message. - """ - # Build CTE to traverse down the tree from message_id - _descendants_cte = ( - select(MessageOrm.id, MessageOrm.parent_id) - .filter( - MessageOrm.id == message_id, - ) - .cte(name="descendants", recursive=True) - ) - - # Recursively traverse down - _descendants_recursive = select(MessageOrm.id, MessageOrm.parent_id).filter( - MessageOrm.parent_id == _descendants_cte.c.id, - ) - return _descendants_cte.union_all(_descendants_recursive) - - def _retrieve_latest_leaf(self, message_id: MessageId) -> str | None: - """Retrieve the latest leaf node in the subtree rooted at the given message. - - Args: - message_id (MessageId): The ID of the root message to start from. - - Returns: - str | None: The ID of the latest leaf node (highest ID), or `None` if no descendants exist. - """ - # Build CTE to traverse down the tree from message_id - _descendants_cte = self._build_descendants_cte(message_id) - - # Get the latest leaf (highest ID) - return self._session.execute( - select(_descendants_cte.c.id).order_by(desc(_descendants_cte.c.id)).limit(1) - ).scalar_one_or_none() - - def _retrieve_branch_root( - self, leaf_id: MessageId, workspace_id: WorkspaceId, thread_id: ThreadId - ) -> str | None: - """Retrieve the branch root node by traversing up from a leaf node. - - Args: - leaf_id (MessageId): The ID of the leaf message to start from. - workspace_id (WorkspaceId): The workspace ID. - thread_id (ThreadId): The thread ID. - - Returns: - str | None: The ID of the root node (with parent_id == NULL), or `None` if not found. - """ - # Build CTE to traverse up the tree from leaf_id - _ancestors_cte = self._build_ancestors_cte(leaf_id, workspace_id, thread_id) - - # Get the root node (the one with parent_id == NULL) - return self._session.execute( - select(MessageOrm.id).filter( - MessageOrm.id.in_(select(_ancestors_cte.c.id)), - MessageOrm.parent_id.is_(None), - ) - ).scalar_one_or_none() - - def _build_path_query(self, path_start: str, path_end: str) -> Query[MessageOrm]: - """Build a query for messages in the path from end to start. - - Args: - path_start (str): The ID of the path start message (upper node). - path_end (str): The ID of the path end message (lower node). - - Returns: - Query[MessageOrm]: A query object for fetching messages in the path. - """ - # Build path from path_end up to path_start using recursive CTE - # Start from path_end and traverse upward following parent_id until we reach path_start - _path_cte = ( - select(MessageOrm.id, MessageOrm.parent_id) - .filter( - MessageOrm.id == path_end, - ) - .cte(name="path", recursive=True) - ) - - # Recursively fetch parent nodes, stopping before we go past path_start - # No need to filter by thread_id/workspace_id - parent_id relationship ensures correct path - _path_recursive = select(MessageOrm.id, MessageOrm.parent_id).filter( - MessageOrm.id == _path_cte.c.parent_id, - # Stop recursion: don't fetch parent of path_start - _path_cte.c.id != path_start, - ) - - _path_cte = _path_cte.union_all(_path_recursive) - - return self._session.query(MessageOrm).join( - _path_cte, MessageOrm.id == _path_cte.c.id - ) - - def retrieve_last_message_id( - self, workspace_id: WorkspaceId, thread_id: ThreadId - ) -> MessageId | None: - """Get the last message ID in a thread. If no messages exist, return the root message ID.""" - return self._session.execute( - select(MessageOrm.id) - .filter( - MessageOrm.thread_id == thread_id, - MessageOrm.workspace_id == workspace_id, - ) - .order_by(desc(MessageOrm.id)) - .limit(1) - ).scalar_one_or_none() - - def create( - self, - workspace_id: WorkspaceId, - thread_id: ThreadId, - params: MessageCreate, - inject_cancelled_tool_results: bool = False, - ) -> Message: - """Create a new message. - - Args: - workspace_id (WorkspaceId): The workspace ID. - thread_id (ThreadId): The thread ID. - params (MessageCreate): The message creation parameters. - inject_cancelled_tool_results (bool, optional): If `True`, inject cancelled - tool results when the parent message has pending tool_use blocks. - Defaults to `False`. - """ - # Validate thread exists - thread_orm: ThreadOrm | None = ( - self._session.query(ThreadOrm) - .filter( - ThreadOrm.id == thread_id, - ThreadOrm.workspace_id == workspace_id, - ) - .first() - ) - if thread_orm is None: - error_msg = f"Thread {thread_id} not found" - raise NotFoundError(error_msg) - - if ( - params.parent_id is None - ): # If no parent ID is provided, use the last message in the thread - parent_id = self.retrieve_last_message_id(workspace_id, thread_id) - - # if the thread is empty, use the root message parent ID - if parent_id is None: - parent_id = ROOT_MESSAGE_PARENT_ID - params.parent_id = parent_id - - # Validate parent message exists (if not root) - if params.parent_id and params.parent_id != ROOT_MESSAGE_PARENT_ID: - parent_message_orm: MessageOrm | None = ( - self._session.query(MessageOrm) - .filter( - MessageOrm.id == params.parent_id, - MessageOrm.thread_id == thread_id, - MessageOrm.workspace_id == workspace_id, - ) - .first() - ) - if parent_message_orm is None: - error_msg = ( - f"Parent message {params.parent_id} not found in thread {thread_id}" - ) - raise NotFoundError(error_msg) - - # If parent has tool_use, create cancelled tool_result first - if inject_cancelled_tool_results: - params.parent_id = self._create_cancelled_tool_results( - workspace_id, - thread_id, - parent_message_orm.to_model(), - params.run_id, - ) - - message = Message.create(workspace_id, thread_id, params) - message_orm = MessageOrm.from_model(message) - self._session.add(message_orm) - self._session.commit() - return message - - def _get_path_endpoints( - self, workspace_id: WorkspaceId, thread_id: ThreadId, query: ListQuery - ) -> tuple[str, str] | None: - """Determine the path start and end node IDs for path traversal. - - Executes queries to get concrete ID values for the path start and end nodes. - - Args: - workspace_id (WorkspaceId): The workspace ID. - thread_id (ThreadId): The thread ID. - query (ListQuery): Pagination query (after/before, limit, order). - - Returns: - tuple[str, str] | None: A tuple of (path_start, path_end) where path_start is the - upper node and path_end is the lower node. Returns `None` if no messages exist - in the thread. - - Raises: - ValueError: If both `after` and `before` parameters are specified. - NotFoundError: If the specified message in `before` or `after` does not exist. - """ - if query.after and query.before: - error_msg = "Cannot specify both 'after' and 'before' parameters" - raise ValueError(error_msg) - - # Determine cursor and direction based on after/before and order - # Key insight: (after+desc) and (before+asc) both traverse UP (towards root) - # (after+asc) and (before+desc) both traverse DOWN (towards leaves) - _cursor = query.after or query.before - _should_traverse_up = (query.after and query.order == "desc") or ( - query.before and query.order == "asc" - ) - - path_start: str | None - path_end: str | None - - if _cursor: - if _should_traverse_up: - # Traverse UP: set path_end to cursor and find path_start by going to root - path_end = _cursor - path_start = self._retrieve_branch_root( - path_end, workspace_id, thread_id - ) - if path_start is None: - error_msg = f"Message with id '{path_end}' not found" - raise NotFoundError(error_msg) - else: - # Traverse DOWN: set path_start to cursor and find path_end by going to leaf - path_start = _cursor - path_end = self._retrieve_latest_leaf(path_start) - if path_end is None: - error_msg = f"Message with id '{path_start}' not found" - raise NotFoundError(error_msg) - else: - # No pagination - get the full branch from latest root to latest leaf - path_end = self.retrieve_last_message_id(workspace_id, thread_id) - if path_end is None: - return None - path_start = self._retrieve_branch_root(path_end, workspace_id, thread_id) - if path_start is None: - error_msg = f"Message with id '{path_end}' not found" - raise NotFoundError(error_msg) - - return path_start, path_end - - def list_( - self, workspace_id: WorkspaceId, thread_id: ThreadId, query: ListQuery - ) -> ListResponse[Message]: - """List messages in a tree path with pagination and filtering. - - Behavior: - - If `after` is provided: - - With `order=desc`: Returns path from `after` node up to root (excludes `after` itself) - - With `order=asc`: Returns path from `after` node down to latest leaf (excludes `after` itself) - - If `before` is provided: - - With `order=asc`: Returns path from `before` node up to root (excludes `before` itself) - - With `order=desc`: Returns path from `before` node down to latest leaf (excludes `before` itself) - - If neither: Returns main branch (root to latest leaf in entire thread) - - The method identifies a start_id (upper node) and end_id (leaf node), - traverses from end_id up to start_id, then applies the specified order. - - Args: - workspace_id (WorkspaceId): The workspace ID. - thread_id (ThreadId): The thread ID. - query (ListQuery): Pagination query (after/before, limit, order). - - Returns: - ListResponse[Message]: Paginated list of messages in the tree path. - - Raises: - ValueError: If both `after` and `before` parameters are specified. - NotFoundError: If the specified message in `before` or `after` does not exist. - """ - # Step 1: Get concrete path_start and path_end - _endpoints = self._get_path_endpoints(workspace_id, thread_id, query) - - # If no messages exist yet, return empty response - if _endpoints is None: - return ListResponse(data=[], has_more=False) - - _path_start, _path_end = _endpoints - - # Step 2: Build path query from path_end up to path_start - _query = self._build_path_query(_path_start, _path_end) - - # Build all filters at once for better query planning - _filters: list[Any] = [] - if query.after: - _filters.append(MessageOrm.id != query.after) - if query.before: - _filters.append(MessageOrm.id != query.before) - - if _filters: - _query = _query.filter(*_filters) - - orms = ( - _query.order_by( - MessageOrm.id if query.order == "asc" else desc(MessageOrm.id) - ) - .limit(query.limit + 1) - .all() - ) - - if not orms: - return ListResponse(data=[], has_more=False) - - has_more = len(orms) > query.limit - data = [orm.to_model() for orm in orms[: query.limit]] - - return ListResponse( - data=data, - has_more=has_more, - first_id=data[0].id if data else None, - last_id=data[-1].id if data else None, - ) - - def iter( - self, - workspace_id: WorkspaceId, - thread_id: ThreadId, - order: ListOrder = "asc", - batch_size: int = LIST_LIMIT_DEFAULT, - ) -> Iterator[Message]: - """Iterate through messages in batches.""" - - has_more = True - last_id: str | None = None - while has_more: - list_messages_response = self.list_( - workspace_id=workspace_id, - thread_id=thread_id, - query=ListQuery(limit=batch_size, order=order, after=last_id), - ) - has_more = list_messages_response.has_more - last_id = list_messages_response.last_id - for msg in list_messages_response.data: - yield msg - - def retrieve( - self, workspace_id: WorkspaceId, thread_id: ThreadId, message_id: MessageId - ) -> Message: - """Retrieve message by ID.""" - message_orm = self._find_by_id(workspace_id, thread_id, message_id) - return message_orm.to_model() - - def delete( - self, workspace_id: WorkspaceId, thread_id: ThreadId, message_id: MessageId - ) -> None: - """Delete a message.""" - message_orm = self._find_by_id(workspace_id, thread_id, message_id) - self._session.delete(message_orm) - self._session.commit() - - def list_siblings( - self, - workspace_id: WorkspaceId, - thread_id: ThreadId, - message_id: MessageId, - ) -> list[Message]: - """List all sibling messages for a given message. - - Sibling messages are messages that share the same `parent_id` as the specified message. - The specified message itself is included in the results. - Results are sorted by ID (chronological order, as IDs are BSON-based). - - Args: - workspace_id (WorkspaceId): The workspace ID. - thread_id (ThreadId): The thread ID. - message_id (MessageId): The message ID to find siblings for. - - Returns: - list[Message]: List of sibling messages sorted by ID. - - Raises: - NotFoundError: If the specified message does not exist. - """ - # Query for all sibling messages using a subquery to get parent_id - _parent_id_subquery = ( - select(MessageOrm.parent_id) - .filter( - MessageOrm.id == message_id, - MessageOrm.thread_id == thread_id, - MessageOrm.workspace_id == workspace_id, - ) - .scalar_subquery() - ) - - orms = ( - self._session.query(MessageOrm) - .filter( - MessageOrm.parent_id.is_not_distinct_from(_parent_id_subquery), - MessageOrm.thread_id == thread_id, - MessageOrm.workspace_id == workspace_id, - ) - .order_by(desc(MessageOrm.id)) - .all() - ) - - # Validate that the message exists (if no results, message doesn't exist) - if not orms: - error_msg = f"Message {message_id} not found in thread {thread_id}" - raise NotFoundError(error_msg) - - return [orm.to_model() for orm in orms] diff --git a/src/askui/chat/api/messages/translator.py b/src/askui/chat/api/messages/translator.py deleted file mode 100644 index 25de3788..00000000 --- a/src/askui/chat/api/messages/translator.py +++ /dev/null @@ -1,342 +0,0 @@ -from PIL import Image - -from askui.chat.api.files.service import FileService -from askui.chat.api.messages.models import ( - ContentBlockParam, - FileImageSourceParam, - ImageBlockParam, - MessageParam, - RequestDocumentBlockParam, - ToolResultBlockParam, -) -from askui.chat.api.models import WorkspaceId -from askui.data_extractor import DataExtractor -from askui.models.models import ModelName -from askui.models.shared.agent_message_param import ( - Base64ImageSourceParam, - TextBlockParam, - UrlImageSourceParam, -) -from askui.models.shared.agent_message_param import ( - ContentBlockParam as AnthropicContentBlockParam, -) -from askui.models.shared.agent_message_param import ( - ImageBlockParam as AnthropicImageBlockParam, -) -from askui.models.shared.agent_message_param import ( - MessageParam as AnthropicMessageParam, -) -from askui.models.shared.agent_message_param import ( - ToolResultBlockParam as AnthropicToolResultBlockParam, -) -from askui.utils.excel_utils import OfficeDocumentSource -from askui.utils.image_utils import ImageSource, image_to_base64 -from askui.utils.source_utils import Source, load_source - - -class RequestDocumentBlockParamTranslator: - """Translator for RequestDocumentBlockParam to/from Anthropic format.""" - - def __init__( - self, file_service: FileService, workspace_id: WorkspaceId | None - ) -> None: - self._file_service = file_service - self._workspace_id = workspace_id - self._data_extractor = DataExtractor() - - def extract_content( - self, source: Source, block: RequestDocumentBlockParam - ) -> list[AnthropicContentBlockParam]: - if isinstance(source, ImageSource): - return [ - AnthropicImageBlockParam( - source=Base64ImageSourceParam( - data=source.to_base64(), - media_type="image/png", - ), - type="image", - cache_control=block.cache_control, - ) - ] - if isinstance(source, OfficeDocumentSource): - with source.reader as r: - data = r.read() - return [ - TextBlockParam( - text=data.decode(), - type="text", - cache_control=block.cache_control, - ) - ] - text = self._data_extractor.get( - query="""Extract all the content of the PDF to Markdown format. - Preserve layout and formatting as much as possible, e.g., representing - tables as HTML tables. For all images, videos, figures, extract text - from it and describe what you are seeing, e.g., what is shown in the - image or figure, and include that description.""", - source=source, - model=ModelName.ASKUI, - ) - return [ - TextBlockParam( - text=text, - type="text", - cache_control=block.cache_control, - ) - ] - - async def to_anthropic( - self, block: RequestDocumentBlockParam - ) -> list[AnthropicContentBlockParam]: - file, path = self._file_service.retrieve_file_content( - self._workspace_id, block.source.file_id - ) - source = load_source(path) - content = self.extract_content(source, block) - return [ - TextBlockParam( - text=file.model_dump_json(), - type="text", - cache_control=block.cache_control, - ), - ] + content - - -class ImageBlockParamSourceTranslator: - def __init__( - self, file_service: FileService, workspace_id: WorkspaceId | None - ) -> None: - self._file_service = file_service - self._workspace_id = workspace_id - - async def from_anthropic( # noqa: RET503 - self, source: UrlImageSourceParam | Base64ImageSourceParam - ) -> UrlImageSourceParam | Base64ImageSourceParam | FileImageSourceParam: - if source.type == "url": - return source - if source.type == "base64": # noqa: RET503 - # Readd translation to FileImageSourceParam as soon as we support it in frontend - return source - # try: - # image = base64_to_image(source.data) - # bytes_io = BytesIO() - # image.save(bytes_io, format="PNG") - # bytes_io.seek(0) - # file = await self._file_service.upload_file( - # file=UploadFile( - # file=bytes_io, - # headers=Headers( - # { - # "Content-Type": "image/png", - # } - # ), - # ) - # ) - # except Exception as e: # noqa: BLE001 - # logger.warning(f"Failed to save image: {e}", exc_info=True) - # return source - # else: - # return FileImageSourceParam(id=file.id, type="file") - - async def to_anthropic( # noqa: RET503 - self, - source: UrlImageSourceParam | Base64ImageSourceParam | FileImageSourceParam, - ) -> UrlImageSourceParam | Base64ImageSourceParam: - if source.type == "url": - return source - if source.type == "base64": - return source - if source.type == "file": # noqa: RET503 - file, path = self._file_service.retrieve_file_content( - self._workspace_id, source.id - ) - image = Image.open(path) - return Base64ImageSourceParam( - data=image_to_base64(image), - media_type=file.media_type, - ) - - -class ImageBlockParamTranslator: - def __init__( - self, file_service: FileService, workspace_id: WorkspaceId | None - ) -> None: - self.source_translator = ImageBlockParamSourceTranslator( - file_service, workspace_id - ) - - async def from_anthropic(self, block: AnthropicImageBlockParam) -> ImageBlockParam: - return ImageBlockParam( - source=await self.source_translator.from_anthropic(block.source), - type="image", - cache_control=block.cache_control, - ) - - async def to_anthropic(self, block: ImageBlockParam) -> AnthropicImageBlockParam: - return AnthropicImageBlockParam( - source=await self.source_translator.to_anthropic(block.source), - type="image", - cache_control=block.cache_control, - ) - - -class ToolResultContentBlockParamTranslator: - def __init__( - self, file_service: FileService, workspace_id: WorkspaceId | None - ) -> None: - self.image_translator = ImageBlockParamTranslator(file_service, workspace_id) - - async def from_anthropic( - self, block: AnthropicImageBlockParam | TextBlockParam - ) -> ImageBlockParam | TextBlockParam: - if block.type == "image": - return await self.image_translator.from_anthropic(block) - return block - - async def to_anthropic( - self, block: ImageBlockParam | TextBlockParam - ) -> AnthropicImageBlockParam | TextBlockParam: - if block.type == "image": - return await self.image_translator.to_anthropic(block) - return block - - -class ToolResultContentTranslator: - def __init__( - self, file_service: FileService, workspace_id: WorkspaceId | None - ) -> None: - self.block_param_translator = ToolResultContentBlockParamTranslator( - file_service, workspace_id - ) - - async def from_anthropic( - self, content: str | list[AnthropicImageBlockParam | TextBlockParam] - ) -> str | list[ImageBlockParam | TextBlockParam]: - if isinstance(content, str): - return content - return [ - await self.block_param_translator.from_anthropic(block) for block in content - ] - - async def to_anthropic( - self, content: str | list[ImageBlockParam | TextBlockParam] - ) -> str | list[AnthropicImageBlockParam | TextBlockParam]: - if isinstance(content, str): - return content - return [ - await self.block_param_translator.to_anthropic(block) for block in content - ] - - -class ToolResultBlockParamTranslator: - def __init__( - self, file_service: FileService, workspace_id: WorkspaceId | None - ) -> None: - self.content_translator = ToolResultContentTranslator( - file_service, workspace_id - ) - - async def from_anthropic( - self, block: AnthropicToolResultBlockParam - ) -> ToolResultBlockParam: - return ToolResultBlockParam( - tool_use_id=block.tool_use_id, - type="tool_result", - cache_control=block.cache_control, - content=await self.content_translator.from_anthropic(block.content), - is_error=block.is_error, - ) - - async def to_anthropic( - self, block: ToolResultBlockParam - ) -> AnthropicToolResultBlockParam: - return AnthropicToolResultBlockParam( - tool_use_id=block.tool_use_id, - type="tool_result", - cache_control=block.cache_control, - content=await self.content_translator.to_anthropic(block.content), - is_error=block.is_error, - ) - - -class MessageContentBlockParamTranslator: - def __init__( - self, file_service: FileService, workspace_id: WorkspaceId | None - ) -> None: - self.image_translator = ImageBlockParamTranslator(file_service, workspace_id) - self.tool_result_translator = ToolResultBlockParamTranslator( - file_service, workspace_id - ) - self.request_document_translator = RequestDocumentBlockParamTranslator( - file_service, workspace_id - ) - - async def from_anthropic( - self, block: AnthropicContentBlockParam - ) -> list[ContentBlockParam]: - if block.type == "image": - return [await self.image_translator.from_anthropic(block)] - if block.type == "tool_result": - return [await self.tool_result_translator.from_anthropic(block)] - return [block] - - async def to_anthropic( - self, block: ContentBlockParam - ) -> list[AnthropicContentBlockParam]: - if block.type == "image": - return [await self.image_translator.to_anthropic(block)] - if block.type == "tool_result": - return [await self.tool_result_translator.to_anthropic(block)] - if block.type == "document": - return await self.request_document_translator.to_anthropic(block) - return [block] - - -class MessageContentTranslator: - def __init__( - self, file_service: FileService, workspace_id: WorkspaceId | None - ) -> None: - self.block_param_translator = MessageContentBlockParamTranslator( - file_service, workspace_id - ) - - async def from_anthropic( - self, content: list[AnthropicContentBlockParam] | str - ) -> list[ContentBlockParam] | str: - if isinstance(content, str): - return content - lists_of_blocks = [ - await self.block_param_translator.from_anthropic(block) for block in content - ] - return [block for sublist in lists_of_blocks for block in sublist] - - async def to_anthropic( - self, content: list[ContentBlockParam] | str - ) -> list[AnthropicContentBlockParam] | str: - if isinstance(content, str): - return content - lists_of_blocks = [ - await self.block_param_translator.to_anthropic(block) for block in content - ] - return [block for sublist in lists_of_blocks for block in sublist] - - -class MessageTranslator: - def __init__( - self, file_service: FileService, workspace_id: WorkspaceId | None - ) -> None: - self.content_translator = MessageContentTranslator(file_service, workspace_id) - - async def from_anthropic(self, message: AnthropicMessageParam) -> MessageParam: - return MessageParam( - role=message.role, - content=await self.content_translator.from_anthropic(message.content), - stop_reason=message.stop_reason, - ) - - async def to_anthropic(self, message: MessageParam) -> AnthropicMessageParam: - return AnthropicMessageParam( - role=message.role, - content=await self.content_translator.to_anthropic(message.content), - stop_reason=message.stop_reason, - ) diff --git a/src/askui/chat/api/models.py b/src/askui/chat/api/models.py deleted file mode 100644 index 9028abcc..00000000 --- a/src/askui/chat/api/models.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Annotated, TypeVar - -from pydantic import UUID4 - -from askui.utils.api_utils import Resource -from askui.utils.id_utils import IdField - -AssistantId = Annotated[str, IdField("asst")] -McpConfigId = Annotated[str, IdField("mcpcnf")] -FileId = Annotated[str, IdField("file")] -MessageId = Annotated[str, IdField("msg")] -RunId = Annotated[str, IdField("run")] -ScheduledJobId = Annotated[str, IdField("schedjob")] -ThreadId = Annotated[str, IdField("thread")] -WorkspaceId = UUID4 - - -class WorkspaceResource(Resource): - workspace_id: WorkspaceId | None = None - - -WorkspaceResourceT = TypeVar("WorkspaceResourceT", bound=WorkspaceResource) diff --git a/src/askui/chat/api/runs/__init__.py b/src/askui/chat/api/runs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/runs/dependencies.py b/src/askui/chat/api/runs/dependencies.py deleted file mode 100644 index 7dd5085f..00000000 --- a/src/askui/chat/api/runs/dependencies.py +++ /dev/null @@ -1,96 +0,0 @@ -from fastapi import Depends -from pydantic import UUID4 -from sqlalchemy.orm import Session - -from askui.chat.api.assistants.dependencies import ( - AssistantServiceDep, - get_assistant_service, -) -from askui.chat.api.assistants.service import AssistantService -from askui.chat.api.db.session import SessionDep -from askui.chat.api.dependencies import SettingsDep, get_settings -from askui.chat.api.files.dependencies import get_file_service -from askui.chat.api.mcp_clients.dependencies import ( - McpClientManagerManagerDep, - get_mcp_client_manager_manager, -) -from askui.chat.api.mcp_clients.manager import McpClientManagerManager -from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service -from askui.chat.api.messages.chat_history_manager import ChatHistoryManager -from askui.chat.api.messages.dependencies import ( - ChatHistoryManagerDep, - get_chat_history_manager, - get_message_service, - get_message_translator, - get_truncation_strategy_factory, -) -from askui.chat.api.runs.models import RunListQuery -from askui.chat.api.settings import Settings - -from .service import RunService - -RunListQueryDep = Depends(RunListQuery) - - -def get_runs_service( - session: SessionDep, - assistant_service: AssistantService = AssistantServiceDep, - chat_history_manager: ChatHistoryManager = ChatHistoryManagerDep, - mcp_client_manager_manager: McpClientManagerManager = McpClientManagerManagerDep, - settings: Settings = SettingsDep, -) -> RunService: - """ - Get RunService instance for FastAPI dependency injection. - - This function is designed for use with FastAPI's DI system. - For manual construction outside of a request context, use `create_run_service()`. - """ - return RunService( - session=session, - assistant_service=assistant_service, - mcp_client_manager_manager=mcp_client_manager_manager, - chat_history_manager=chat_history_manager, - settings=settings, - ) - - -RunServiceDep = Depends(get_runs_service) - - -def create_run_service(session: Session, workspace_id: UUID4) -> RunService: - """ - Create a RunService with all required dependencies manually. - - Use this function when you need a `RunService` outside of FastAPI's - dependency injection context (e.g. APScheduler callbacks). - - Args: - session (Session): Database session. - workspace_id (UUID4): The workspace ID for the run execution. - - Returns: - RunService: Configured run service. - """ - settings = get_settings() - - assistant_service = get_assistant_service(session) - file_service = get_file_service(session, settings) - mcp_config_service = get_mcp_config_service(session, settings) - mcp_client_manager_manager = get_mcp_client_manager_manager(mcp_config_service) - - message_service = get_message_service(session) - message_translator = get_message_translator(file_service, workspace_id) - truncation_strategy_factory = get_truncation_strategy_factory() - chat_history_manager = get_chat_history_manager( - message_service, - message_translator, - truncation_strategy_factory, - ) - - return RunService( - session=session, - assistant_service=assistant_service, - mcp_client_manager_manager=mcp_client_manager_manager, - chat_history_manager=chat_history_manager, - settings=settings, - ) diff --git a/src/askui/chat/api/runs/events/__init__.py b/src/askui/chat/api/runs/events/__init__.py deleted file mode 100644 index 56b48f30..00000000 --- a/src/askui/chat/api/runs/events/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from askui.chat.api.runs.events.done_events import DoneEvent -from askui.chat.api.runs.events.error_events import ErrorEvent -from askui.chat.api.runs.events.event_base import EventBase -from askui.chat.api.runs.events.events import Event -from askui.chat.api.runs.events.message_events import MessageEvent -from askui.chat.api.runs.events.run_events import RunEvent - -__all__ = [ - "DoneEvent", - "ErrorEvent", - "EventBase", - "Event", - "MessageEvent", - "RunEvent", -] diff --git a/src/askui/chat/api/runs/events/done_events.py b/src/askui/chat/api/runs/events/done_events.py deleted file mode 100644 index 458daa8a..00000000 --- a/src/askui/chat/api/runs/events/done_events.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Literal - -from askui.chat.api.runs.events.event_base import EventBase - - -class DoneEvent(EventBase): - event: Literal["done"] = "done" - data: Literal["[DONE]"] = "[DONE]" diff --git a/src/askui/chat/api/runs/events/error_events.py b/src/askui/chat/api/runs/events/error_events.py deleted file mode 100644 index 98107479..00000000 --- a/src/askui/chat/api/runs/events/error_events.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel - -from askui.chat.api.runs.events.event_base import EventBase - - -class ErrorEventDataError(BaseModel): - message: str - - -class ErrorEventData(BaseModel): - error: ErrorEventDataError - - -class ErrorEvent(EventBase): - event: Literal["error"] = "error" - data: ErrorEventData diff --git a/src/askui/chat/api/runs/events/event_base.py b/src/askui/chat/api/runs/events/event_base.py deleted file mode 100644 index 60250480..00000000 --- a/src/askui/chat/api/runs/events/event_base.py +++ /dev/null @@ -1,7 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel - - -class EventBase(BaseModel): - object: Literal["event"] = "event" diff --git a/src/askui/chat/api/runs/events/events.py b/src/askui/chat/api/runs/events/events.py deleted file mode 100644 index 94ecb9d2..00000000 --- a/src/askui/chat/api/runs/events/events.py +++ /dev/null @@ -1,10 +0,0 @@ -from pydantic import TypeAdapter - -from askui.chat.api.runs.events.done_events import DoneEvent -from askui.chat.api.runs.events.error_events import ErrorEvent -from askui.chat.api.runs.events.message_events import MessageEvent -from askui.chat.api.runs.events.run_events import RunEvent - -Event = DoneEvent | ErrorEvent | MessageEvent | RunEvent - -EventAdapter: TypeAdapter[Event] = TypeAdapter(Event) diff --git a/src/askui/chat/api/runs/events/io_publisher.py b/src/askui/chat/api/runs/events/io_publisher.py deleted file mode 100644 index f0893010..00000000 --- a/src/askui/chat/api/runs/events/io_publisher.py +++ /dev/null @@ -1,43 +0,0 @@ -"""IO publisher for publishing events to stdout.""" - -import json -import sys -from typing import Any - -from askui.chat.api.runs.events.events import Event -from askui.chat.api.settings import Settings - - -class IOPublisher: - """Publisher that serializes events to JSON and writes to stdout.""" - - def __init__(self, enabled: bool) -> None: - """ - Initialize the IO publisher. - - Args: - settings: The settings instance containing configuration for the IO publisher. - """ - self._enabled = enabled - - def publish(self, event: Event) -> None: - """ - Publish an event by serializing it to JSON and writing to stdout. - - If the publisher is disabled, this method does nothing. - - Args: - event: The event to publish - """ - if not self._enabled: - return - - try: - event_dict: dict[str, Any] = event.model_dump(mode="json") - event_json = json.dumps(event_dict) - - sys.stdout.write(event_json + "\n") - sys.stdout.flush() - except (TypeError, ValueError, AttributeError, OSError) as e: - sys.stderr.write(f"Error publishing event: {e}\n") - sys.stderr.flush() diff --git a/src/askui/chat/api/runs/events/message_events.py b/src/askui/chat/api/runs/events/message_events.py deleted file mode 100644 index f8eb374e..00000000 --- a/src/askui/chat/api/runs/events/message_events.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Literal - -from askui.chat.api.messages.models import Message -from askui.chat.api.runs.events.event_base import EventBase - - -class MessageEvent(EventBase): - data: Message - event: Literal["thread.message.created"] diff --git a/src/askui/chat/api/runs/events/run_events.py b/src/askui/chat/api/runs/events/run_events.py deleted file mode 100644 index 66a83517..00000000 --- a/src/askui/chat/api/runs/events/run_events.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Literal - -from askui.chat.api.runs.events.event_base import EventBase -from askui.chat.api.runs.models import Run - - -class RunEvent(EventBase): - data: Run - event: Literal[ - "thread.run.created", - "thread.run.queued", - "thread.run.in_progress", - "thread.run.completed", - "thread.run.failed", - "thread.run.cancelling", - "thread.run.cancelled", - "thread.run.expired", - ] diff --git a/src/askui/chat/api/runs/events/service.py b/src/askui/chat/api/runs/events/service.py deleted file mode 100644 index 4b9210f7..00000000 --- a/src/askui/chat/api/runs/events/service.py +++ /dev/null @@ -1,343 +0,0 @@ -import asyncio -import logging -import types -from abc import ABC, abstractmethod -from contextlib import asynccontextmanager -from pathlib import Path -from typing import TYPE_CHECKING, AsyncIterator, Type - -import aiofiles - -if TYPE_CHECKING: - from aiofiles.threadpool.text import AsyncTextIOWrapper - -from askui.chat.api.models import RunId, ThreadId, WorkspaceId -from askui.chat.api.runs.events.done_events import DoneEvent -from askui.chat.api.runs.events.error_events import ( - ErrorEvent, - ErrorEventData, - ErrorEventDataError, -) -from askui.chat.api.runs.events.events import Event, EventAdapter -from askui.chat.api.runs.events.run_events import RunEvent -from askui.chat.api.runs.models import Run - -logger = logging.getLogger(__name__) - - -class EventFileManager: - """Manages the lifecycle of a single event file with reference counting.""" - - def __init__(self, file_path: Path) -> None: - self.file_path = file_path - self.readers_count = 0 - self.writer_active = False - self._lock = asyncio.Lock() - self._file_created_event = asyncio.Event() - self._new_event_event = asyncio.Event() - - async def add_reader(self) -> None: - """Add a reader reference.""" - async with self._lock: - self.readers_count += 1 - - async def remove_reader(self) -> None: - """Remove a reader reference and cleanup if no refs remain.""" - async with self._lock: - self.readers_count -= 1 - await self._cleanup_if_needed() - - async def set_writer_active(self, active: bool) -> None: - """Set writer active status.""" - async with self._lock: - self.writer_active = active - if not active: - await self._cleanup_if_needed() - - async def _cleanup_if_needed(self) -> None: - """Delete file if no active connections remain.""" - if not self.writer_active and self.readers_count == 0: - try: - if self.file_path.exists(): - self.file_path.unlink() - # we keep the parent directory - except FileNotFoundError: - pass # Already deleted - - async def notify_file_created(self) -> None: - """Signal that the file has been created.""" - self._file_created_event.set() - - async def wait_for_file(self, timeout: float = 30.0) -> None: - """Wait for the file to be created. - - Args: - timeout: Timeout in seconds. - - Raises: - TimeoutError: If the file is not created within the timeout. - """ - await asyncio.wait_for(self._file_created_event.wait(), timeout) - - async def notify_new_event(self) -> None: - """Signal that a new event has been written to the file.""" - self._new_event_event.set() - - async def wait_for_new_event( - self, timeout: float = 30.0, clear: bool = False - ) -> None: - """Wait for a new event to be written to the file.""" - await asyncio.wait_for(self._new_event_event.wait(), timeout) - if clear: - self._new_event_event.clear() - - -class RetrieveRunService(ABC): - @abstractmethod - def retrieve( - self, workspace_id: WorkspaceId, thread_id: ThreadId, run_id: RunId - ) -> Run: - raise NotImplementedError - - -class EventWriter: - """Writer for appending events to a JSONL file.""" - - def __init__(self, manager: EventFileManager): - self._manager = manager - self._file: "AsyncTextIOWrapper | None" = None - - async def write_event(self, event: Event) -> None: - """Write an event to the file.""" - if self._file is None: - self._file = await aiofiles.open( - self._manager.file_path, "a", encoding="utf-8" - ).__aenter__() - await self._manager.notify_file_created() - - event_json = event.model_dump_json() - await self._file.write(f"{event_json}\n") - await self._file.flush() - await self._manager.notify_new_event() - - async def __aenter__(self) -> "EventWriter": - return self - - async def __aexit__( - self, - exc_type: Type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: types.TracebackType | None, - ) -> None: - if self._file: - await self._file.close() - - -class EventReader: - """Reader for streaming events from a JSONL file.""" - - def __init__( - self, - manager: EventFileManager, - run_service: RetrieveRunService, - start_index: int, - workspace_id: WorkspaceId, - thread_id: ThreadId, - run_id: RunId, - ): - self._manager = manager - self._run_service = run_service - self._start_index = start_index - self._workspace_id = workspace_id - self._thread_id = thread_id - self._run_id = run_id - - async def _iter_final_events(self, run: Run) -> AsyncIterator[Event]: - match run.status: - case "completed": - yield RunEvent(data=run, event="thread.run.completed") - yield DoneEvent() - case "failed": - yield ErrorEvent( - data=ErrorEventData( - error=ErrorEventDataError( - message=run.last_error.message - if run.last_error - else "Unknown error" - ) - ) - ) - case "cancelled": - yield RunEvent(data=run, event="thread.run.cancelled") - yield DoneEvent() - case "expired": - yield RunEvent(data=run, event="thread.run.expired") - yield DoneEvent() - case _: - pass - - async def read_events(self) -> AsyncIterator[Event]: # noqa: C901 - """ - Stream events from the file starting at the specified index. - Continues reading until a terminal event (DoneEvent or ErrorEvent) is found. - - Yields: - Event objects parsed from the JSONL file. - """ - while True: - try: - await self._manager.wait_for_file() - break - except asyncio.exceptions.TimeoutError: - logger.warning( - "Timeout waiting for file %s to be created", - self._manager.file_path, - ) - if run := self._run_service.retrieve( - self._workspace_id, self._thread_id, self._run_id - ): - if run.status not in ("queued", "in_progress"): - async for event in self._iter_final_events(run): - yield event - return - - line_index = -1 - current_position = 0 - async with aiofiles.open( - self._manager.file_path, "r", encoding="utf-8" - ) as file: - while True: - if await file.tell() != current_position: - await file.seek(current_position) - async for line in file: - line_index += 1 - if line_index < self._start_index: - continue - - if stripped_line := line.strip(): - event = EventAdapter.validate_json(stripped_line) - yield event - if isinstance(event, (DoneEvent, ErrorEvent)): - return - await asyncio.sleep(0.25) - current_position = await file.tell() - while True: - try: - await self._manager.wait_for_new_event(clear=True) - break - except asyncio.exceptions.TimeoutError: - logger.warning( - "Timeout waiting for file %s to have a new event", - self._manager.file_path, - ) - if run := self._run_service.retrieve( - self._workspace_id, self._thread_id, self._run_id - ): - if run.status not in ( - "queued", - "in_progress", - "cancelling", - ): - async for event in self._iter_final_events(run): - yield event - return - - -class EventService: - """ - Service for managing event files with concurrent read/write access. - - Features: - - Single writer, multiple readers per file - - Automatic file cleanup when all connections close - - Thread-safe operations - - Performant streaming reads - """ - - _file_managers: dict[RunId, EventFileManager] = {} - _lock = asyncio.Lock() - - def __init__(self, base_dir: Path, run_service: RetrieveRunService) -> None: - self._base_dir = base_dir - self._run_service = run_service - - def _get_event_path(self, thread_id: ThreadId, run_id: RunId) -> Path: - """Get the file path for an event.""" - return self._base_dir / "events" / thread_id / f"{run_id}.jsonl" - - async def _get_or_create_manager( - self, thread_id: ThreadId, run_id: RunId - ) -> EventFileManager: - """Get or create a file manager for the session.""" - async with self._lock: - if run_id not in self._file_managers: - events_file = self._get_event_path(thread_id, run_id) - events_file.parent.mkdir(parents=True, exist_ok=True) - self._file_managers[run_id] = EventFileManager(events_file) - return self._file_managers[run_id] - - @asynccontextmanager - async def create_writer( - self, thread_id: ThreadId, run_id: RunId - ) -> AsyncIterator["EventWriter"]: - """ - Create a writer context manager for appending events to a file. - - Args: - thread_id: Thread ID of the file to write. - run_id: Run ID of the file to write. - - Yields: - EventWriter instance for writing events. - """ - manager = await self._get_or_create_manager(thread_id, run_id) - await manager.set_writer_active(True) - - try: - writer = EventWriter(manager) - yield writer - finally: - await manager.set_writer_active(False) - # Cleanup manager reference if file was deleted - async with self._lock: - if not manager.file_path.exists(): - self._file_managers.pop(run_id, None) - - @asynccontextmanager - async def create_reader( - self, - workspace_id: WorkspaceId, - thread_id: ThreadId, - run_id: RunId, - start_index: int = 0, - ) -> AsyncIterator["EventReader"]: - """ - Create a reader context manager for reading events from a file. - - Args: - thread_id: Thread ID of the file to read. - run_id: Run ID of the file to read. - start_index: Index to start reading from (0-based). - - Yields: - EventReader instance for reading events. - """ - manager = await self._get_or_create_manager(thread_id, run_id) - await manager.add_reader() - - try: - reader = EventReader( - manager=manager, - run_service=self._run_service, - start_index=start_index, - workspace_id=workspace_id, - thread_id=thread_id, - run_id=run_id, - ) - yield reader - finally: - await manager.remove_reader() - # Cleanup manager reference if file was deleted - async with self._lock: - if not manager.file_path.exists(): - self._file_managers.pop(run_id, None) diff --git a/src/askui/chat/api/runs/models.py b/src/askui/chat/api/runs/models.py deleted file mode 100644 index 5d311b1d..00000000 --- a/src/askui/chat/api/runs/models.py +++ /dev/null @@ -1,208 +0,0 @@ -from dataclasses import dataclass -from datetime import timedelta -from typing import Annotated, Literal - -from fastapi import Query -from pydantic import BaseModel, Field, computed_field - -from askui.chat.api.models import ( - AssistantId, - RunId, - ThreadId, - WorkspaceId, - WorkspaceResource, -) -from askui.chat.api.threads.models import ThreadCreate -from askui.utils.api_utils import ListQuery -from askui.utils.datetime_utils import UnixDatetime, now -from askui.utils.id_utils import generate_time_ordered_id - -RunStatus = Literal[ - "queued", - "in_progress", - "completed", - "cancelling", - "cancelled", - "failed", - "expired", -] - - -class RunError(BaseModel): - """Error information for a failed run.""" - - message: str - code: Literal["server_error", "rate_limit_exceeded", "invalid_prompt"] - - -class RunCreate(BaseModel): - """Parameters for creating a run.""" - - stream: bool = False - assistant_id: AssistantId - model: str | None = None - - -class RunStart(BaseModel): - """Parameters for starting a run.""" - - type: Literal["start"] = "start" - status: Literal["in_progress"] = "in_progress" - started_at: UnixDatetime = Field(default_factory=now) - expires_at: UnixDatetime = Field( - default_factory=lambda: now() + timedelta(minutes=10) - ) - - -class RunPing(BaseModel): - """Parameters for pinging a run.""" - - type: Literal["ping"] = "ping" - expires_at: UnixDatetime = Field( - default_factory=lambda: now() + timedelta(minutes=10) - ) - - -class RunComplete(BaseModel): - """Parameters for completing a run.""" - - type: Literal["complete"] = "complete" - status: Literal["completed"] = "completed" - completed_at: UnixDatetime = Field(default_factory=now) - - -class RunTryCancelling(BaseModel): - """Parameters for trying to cancel a run.""" - - type: Literal["try_cancelling"] = "try_cancelling" - status: Literal["cancelling"] = "cancelling" - tried_cancelling_at: UnixDatetime = Field(default_factory=now) - - -class RunCancel(BaseModel): - """Parameters for canceling a run.""" - - type: Literal["cancel"] = "cancel" - status: Literal["cancelled"] = "cancelled" - cancelled_at: UnixDatetime = Field(default_factory=now) - - -class RunFail(BaseModel): - """Parameters for failing a run.""" - - type: Literal["fail"] = "fail" - status: Literal["failed"] = "failed" - failed_at: UnixDatetime = Field(default_factory=now) - last_error: RunError - - -RunModify = RunStart | RunPing | RunComplete | RunTryCancelling | RunCancel | RunFail - - -class ThreadAndRunCreate(RunCreate): - thread: ThreadCreate - - -def map_status_to_readable_description(status: RunStatus) -> str: - match status: - case "queued": - return "Run has been queued." - case "in_progress": - return "Run is in progress." - case "completed": - return "Run has been completed." - case "cancelled": - return "Run has been cancelled." - case "failed": - return "Run has failed." - case "expired": - return "Run has expired." - case "cancelling": - return "Run is being cancelled." - - -class Run(WorkspaceResource): - """A run execution within a thread.""" - - id: RunId - object: Literal["thread.run"] = "thread.run" - thread_id: ThreadId - created_at: UnixDatetime - expires_at: UnixDatetime - started_at: UnixDatetime | None = None - completed_at: UnixDatetime | None = None - failed_at: UnixDatetime | None = None - cancelled_at: UnixDatetime | None = None - tried_cancelling_at: UnixDatetime | None = None - last_error: RunError | None = None - assistant_id: AssistantId | None = None - - @classmethod - def create( - cls, workspace_id: WorkspaceId, thread_id: ThreadId, params: RunCreate - ) -> "Run": - return cls( - id=generate_time_ordered_id("run"), - workspace_id=workspace_id, - thread_id=thread_id, - created_at=now(), - expires_at=now() + timedelta(minutes=10), - **params.model_dump(exclude={"model", "stream"}), - ) - - @computed_field # type: ignore[prop-decorator] - @property - def status(self) -> RunStatus: - if self.cancelled_at: - return "cancelled" - if self.failed_at: - return "failed" - if self.completed_at: - return "completed" - if self.expires_at and self.expires_at < now(): - return "expired" - if self.tried_cancelling_at: - return "cancelling" - if self.started_at: - return "in_progress" - return "queued" - - def validate_modify(self, params: RunModify) -> None: # noqa: C901 - status_description = map_status_to_readable_description(self.status) - error_msg = status_description - match params.type: - case "start": - if self.status != "queued": - error_msg += " Cannot start it (again). Please create a new run." - raise ValueError(error_msg) - case "ping": - if self.status != "in_progress": - error_msg += " Cannot ping. Run is not in progress." - raise ValueError(error_msg) - case "complete": - if self.status != "in_progress": - error_msg += " Cannot complete. Run is not in progress." - raise ValueError(error_msg) - case "try_cancelling": - if self.status not in ["queued", "in_progress"]: - error_msg += " Cannot cancel (again)." - if self.status != "cancelling": - # I think this just sounds better if this is only added if it - # is not being cancelled as it is still in progress while being - # cancelled. - error_msg += " Run is neither queued nor in progress." - raise ValueError(error_msg) - case "cancel": - if self.status not in ["queued", "in_progress", "cancelling"]: - error_msg += " Cannot cancel. Run is neither queued, in progress, nor has it been tried to be cancelled." - raise ValueError(error_msg) - case "fail": - if self.status not in ["queued", "in_progress", "cancelling"]: - error_msg += " Cannot fail. Run is neither queued, in progress, nor has it been tried to be cancelled." - raise ValueError(error_msg) - - -@dataclass(kw_only=True) -class RunListQuery(ListQuery): - thread: Annotated[ThreadId | None, Query()] = None - status: Annotated[list[RunStatus] | None, Query()] = None diff --git a/src/askui/chat/api/runs/orms.py b/src/askui/chat/api/runs/orms.py deleted file mode 100644 index 7dfd84a8..00000000 --- a/src/askui/chat/api/runs/orms.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Run database model.""" - -from datetime import datetime -from typing import Any -from uuid import UUID - -from sqlalchemy import JSON, ForeignKey, Uuid -from sqlalchemy.orm import Mapped, mapped_column -from sqlalchemy.sql.sqltypes import String - -from askui.chat.api.assistants.orms import AssistantId -from askui.chat.api.db.orm.base import Base -from askui.chat.api.db.orm.types import ThreadId, UnixDatetime, create_prefixed_id_type -from askui.chat.api.runs.models import Run - -RunId = create_prefixed_id_type("run") - - -class RunOrm(Base): - """Run database model.""" - - __tablename__ = "runs" - - id: Mapped[str] = mapped_column(RunId, primary_key=True) - thread_id: Mapped[str] = mapped_column( - ThreadId, - ForeignKey("threads.id", ondelete="CASCADE"), - nullable=False, - index=True, - ) - workspace_id: Mapped[UUID] = mapped_column(Uuid, nullable=False, index=True) - status: Mapped[str] = mapped_column(String, nullable=False, index=True) - created_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) - expires_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) - started_at: Mapped[datetime | None] = mapped_column(UnixDatetime, nullable=True) - completed_at: Mapped[datetime | None] = mapped_column(UnixDatetime, nullable=True) - failed_at: Mapped[datetime | None] = mapped_column(UnixDatetime, nullable=True) - cancelled_at: Mapped[datetime | None] = mapped_column(UnixDatetime, nullable=True) - tried_cancelling_at: Mapped[datetime | None] = mapped_column( - UnixDatetime, nullable=True - ) - last_error: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) - assistant_id: Mapped[str | None] = mapped_column( - AssistantId, ForeignKey("assistants.id", ondelete="SET NULL"), nullable=True - ) - - @classmethod - def from_model(cls, model: Run) -> "RunOrm": - return cls(**model.model_dump(exclude={"object"})) - - def to_model(self) -> Run: - return Run.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/runs/router.py b/src/askui/chat/api/runs/router.py deleted file mode 100644 index b01bb284..00000000 --- a/src/askui/chat/api/runs/router.py +++ /dev/null @@ -1,151 +0,0 @@ -from collections.abc import AsyncGenerator -from typing import Annotated - -from fastapi import APIRouter, BackgroundTasks, Header, Path, Query, Response, status -from fastapi.responses import JSONResponse, StreamingResponse -from pydantic import BaseModel - -from askui.chat.api.models import RunId, ThreadId, WorkspaceId -from askui.chat.api.runs.models import RunCreate -from askui.chat.api.threads.dependencies import ThreadFacadeDep -from askui.chat.api.threads.facade import ThreadFacade -from askui.utils.api_utils import ListResponse - -from .dependencies import RunListQueryDep, RunServiceDep -from .models import Run, RunCancel, RunListQuery, ThreadAndRunCreate -from .service import RunService - -router = APIRouter(tags=["runs"]) - - -@router.post("/threads/{thread_id}/runs") -async def create_run( - askui_workspace: Annotated[WorkspaceId, Header()], - thread_id: Annotated[ThreadId, Path(...)], - params: RunCreate, - background_tasks: BackgroundTasks, - run_service: RunService = RunServiceDep, -) -> Response: - stream = params.stream - run, async_generator = await run_service.create( - workspace_id=askui_workspace, thread_id=thread_id, params=params - ) - if stream: - - async def sse_event_stream() -> AsyncGenerator[str, None]: - async for event in async_generator: - data = ( - event.data.model_dump_json() - if isinstance(event.data, BaseModel) - else event.data - ) - yield f"event: {event.event}\ndata: {data}\n\n" - - return StreamingResponse( - status_code=status.HTTP_201_CREATED, - content=sse_event_stream(), - media_type="text/event-stream", - ) - - async def _run_async_generator() -> None: - async for _ in async_generator: - pass - - background_tasks.add_task(_run_async_generator) - return JSONResponse( - status_code=status.HTTP_201_CREATED, content=run.model_dump(mode="json") - ) - - -@router.post("/runs") -async def create_thread_and_run( - askui_workspace: Annotated[WorkspaceId, Header()], - params: ThreadAndRunCreate, - background_tasks: BackgroundTasks, - thread_facade: ThreadFacade = ThreadFacadeDep, -) -> Response: - stream = params.stream - run, async_generator = await thread_facade.create_thread_and_run( - workspace_id=askui_workspace, params=params - ) - if stream: - - async def sse_event_stream() -> AsyncGenerator[str, None]: - async for event in async_generator: - data = ( - event.data.model_dump_json() - if isinstance(event.data, BaseModel) - else event.data - ) - yield f"event: {event.event}\ndata: {data}\n\n" - - return StreamingResponse( - status_code=status.HTTP_201_CREATED, - content=sse_event_stream(), - media_type="text/event-stream", - ) - - async def _run_async_generator() -> None: - async for _ in async_generator: - pass - - background_tasks.add_task(_run_async_generator) - return JSONResponse( - status_code=status.HTTP_201_CREATED, content=run.model_dump(mode="json") - ) - - -@router.get("/threads/{thread_id}/runs/{run_id}") -async def retrieve_run( - askui_workspace: Annotated[WorkspaceId, Header()], - thread_id: Annotated[ThreadId, Path(...)], - run_id: Annotated[RunId, Path(...)], - stream: Annotated[bool, Query()] = False, - run_service: RunService = RunServiceDep, -) -> Response: - if not stream: - return JSONResponse( - content=run_service.retrieve( - workspace_id=askui_workspace, thread_id=thread_id, run_id=run_id - ).model_dump(mode="json"), - ) - - async def sse_event_stream() -> AsyncGenerator[str, None]: - async for event in run_service.retrieve_stream( - workspace_id=askui_workspace, thread_id=thread_id, run_id=run_id - ): - data = ( - event.data.model_dump_json() - if isinstance(event.data, BaseModel) - else event.data - ) - yield f"event: {event.event}\ndata: {data}\n\n" - - return StreamingResponse( - content=sse_event_stream(), - media_type="text/event-stream", - ) - - -@router.get("/runs") -async def list_runs( - askui_workspace: Annotated[WorkspaceId, Header()], - query: RunListQuery = RunListQueryDep, - run_service: RunService = RunServiceDep, -) -> ListResponse[Run]: - return run_service.list_(workspace_id=askui_workspace, query=query) - - -@router.post("/threads/{thread_id}/runs/{run_id}/cancel") -def cancel_run( - askui_workspace: Annotated[WorkspaceId, Header()], - thread_id: Annotated[ThreadId, Path(...)], - run_id: Annotated[RunId, Path(...)], - run_service: RunService = RunServiceDep, -) -> Run: - return run_service.modify( - workspace_id=askui_workspace, - thread_id=thread_id, - run_id=run_id, - params=RunCancel(), - ) diff --git a/src/askui/chat/api/runs/runner/__init__.py b/src/askui/chat/api/runs/runner/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py deleted file mode 100644 index 84889fac..00000000 --- a/src/askui/chat/api/runs/runner/runner.py +++ /dev/null @@ -1,261 +0,0 @@ -import json -import logging -from abc import ABC, abstractmethod -from datetime import datetime, timezone - -from anthropic.types.beta import BetaCacheControlEphemeralParam, BetaTextBlockParam -from anyio.abc import ObjectStream -from asyncer import asyncify, syncify - -from askui.chat.api.assistants.models import Assistant -from askui.chat.api.mcp_clients.manager import McpClientManagerManager -from askui.chat.api.messages.chat_history_manager import ChatHistoryManager -from askui.chat.api.models import MessageId, RunId, ThreadId, WorkspaceId -from askui.chat.api.runs.events.done_events import DoneEvent -from askui.chat.api.runs.events.error_events import ( - ErrorEvent, - ErrorEventData, - ErrorEventDataError, -) -from askui.chat.api.runs.events.events import Event -from askui.chat.api.runs.events.message_events import MessageEvent -from askui.chat.api.runs.events.run_events import RunEvent -from askui.chat.api.runs.events.service import RetrieveRunService -from askui.chat.api.runs.models import ( - Run, - RunCancel, - RunComplete, - RunError, - RunFail, - RunModify, - RunPing, - RunStart, -) -from askui.chat.api.settings import Settings -from askui.custom_agent import CustomAgent -from askui.models.shared.agent_message_param import MessageParam -from askui.models.shared.agent_on_message_cb import OnMessageCbParam -from askui.models.shared.prompts import ActSystemPrompt -from askui.models.shared.settings import ActSettings, MessageSettings -from askui.models.shared.tools import ToolCollection -from askui.prompts.act_prompts import caesr_system_prompt - -logger = logging.getLogger(__name__) - - -class RunnerRunService(RetrieveRunService, ABC): - @abstractmethod - def modify( - self, - workspace_id: WorkspaceId, - thread_id: ThreadId, - run_id: RunId, - params: RunModify, - ) -> Run: - raise NotImplementedError - - -class Runner: - def __init__( - self, - run_id: RunId, - thread_id: ThreadId, - workspace_id: WorkspaceId, - assistant: Assistant, - chat_history_manager: ChatHistoryManager, - mcp_client_manager_manager: McpClientManagerManager, - run_service: RunnerRunService, - settings: Settings, - last_message_id: MessageId, - model: str | None = None, - ) -> None: - self._run_id = run_id - self._workspace_id = workspace_id - self._thread_id = thread_id - self._assistant = assistant - self._chat_history_manager = chat_history_manager - self._mcp_client_manager_manager = mcp_client_manager_manager - self._run_service = run_service - self._settings = settings - self._last_message_id = last_message_id - self._model: str | None = model - - def _retrieve_run(self) -> Run: - return self._run_service.retrieve( - workspace_id=self._workspace_id, - thread_id=self._thread_id, - run_id=self._run_id, - ) - - def _modify_run(self, params: RunModify) -> Run: - return self._run_service.modify( - workspace_id=self._workspace_id, - thread_id=self._thread_id, - run_id=self._run_id, - params=params, - ) - - def _build_system(self) -> ActSystemPrompt: - metadata = json.dumps( - { - **self._get_run_extra_info(), - "continued_by_user_at": datetime.now(timezone.utc).strftime( - "%A, %B %d, %Y %H:%M:%S %z" - ), - } - ) - assistant_prompt = self._assistant.system if self._assistant.system else "" - - return caesr_system_prompt(assistant_prompt, metadata) - - async def _run_agent( - self, - send_stream: ObjectStream[Event], - ) -> None: - async def async_on_message( - on_message_cb_param: OnMessageCbParam, - ) -> MessageParam | None: - created_message = await self._chat_history_manager.append_message( - workspace_id=self._workspace_id, - thread_id=self._thread_id, - assistant_id=self._assistant.id, - run_id=self._run_id, - message=on_message_cb_param.message, - parent_id=self._last_message_id, - ) - # Update the parent_id for the next message - self._last_message_id = created_message.id - await send_stream.send( - MessageEvent( - data=created_message, - event="thread.message.created", - ) - ) - updated_run = self._retrieve_run() - if self._should_abort(updated_run): - return None - self._modify_run(RunPing()) - return on_message_cb_param.message - - on_message = syncify(async_on_message) - mcp_client = await self._mcp_client_manager_manager.get_mcp_client_manager( - self._workspace_id - ) - - def _run_agent_inner() -> None: - tools = ToolCollection( - mcp_client=mcp_client, - include=set(self._assistant.tools), - ) - betas = tools.retrieve_tool_beta_flags() - system = self._build_system() - model = self._get_model() - messages = syncify(self._chat_history_manager.retrieve_message_params)( - workspace_id=self._workspace_id, - thread_id=self._thread_id, - tools=tools.to_params(), - system=system, - model=model, - ) - custom_agent = CustomAgent() - custom_agent.act( - messages, - model=model, - on_message=on_message, - tools=tools, - settings=ActSettings( - messages=MessageSettings( - betas=betas, - system=system, - thinking={"type": "enabled", "budget_tokens": 4096}, - max_tokens=8192, - ), - ), - ) - - await asyncify(_run_agent_inner)() - - def _get_run_extra_info(self) -> dict[str, str]: - return { - "run_id": self._run_id, - "thread_id": self._thread_id, - "workspace_id": str(self._workspace_id), - "assistant_id": self._assistant.id, - } - - async def run( - self, - send_stream: ObjectStream[Event], - ) -> None: - try: - updated_run = self._modify_run(RunStart()) - logger.info( - "Run started", - extra=self._get_run_extra_info(), - ) - await send_stream.send( - RunEvent( - data=updated_run, - event="thread.run.in_progress", - ) - ) - await self._run_agent(send_stream=send_stream) - updated_run = self._retrieve_run() - if updated_run.status == "in_progress": - self._modify_run(RunComplete()) - await send_stream.send( - RunEvent( - data=updated_run, - event="thread.run.completed", - ) - ) - if updated_run.status == "cancelling": - await send_stream.send( - RunEvent( - data=updated_run, - event="thread.run.cancelling", - ) - ) - self._modify_run(RunCancel()) - await send_stream.send( - RunEvent( - data=updated_run, - event="thread.run.cancelled", - ) - ) - if updated_run.status == "expired": - await send_stream.send( - RunEvent( - data=updated_run, - event="thread.run.expired", - ) - ) - await send_stream.send(DoneEvent()) - except Exception as e: # noqa: BLE001 - logger.exception( - "Run failed", - extra=self._get_run_extra_info(), - ) - updated_run = self._retrieve_run() - self._modify_run( - RunFail(last_error=RunError(message=str(e), code="server_error")), - ) - await send_stream.send( - RunEvent( - data=updated_run, - event="thread.run.failed", - ) - ) - await send_stream.send( - ErrorEvent( - data=ErrorEventData(error=ErrorEventDataError(message=str(e))) - ) - ) - - def _should_abort(self, run: Run) -> bool: - return run.status in ("cancelled", "cancelling", "expired") - - def _get_model(self) -> str: - if self._model is not None: - return self._model - return self._settings.model diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py deleted file mode 100644 index 726c29f9..00000000 --- a/src/askui/chat/api/runs/service.py +++ /dev/null @@ -1,216 +0,0 @@ -from collections.abc import AsyncGenerator -from datetime import datetime, timezone - -import anyio -from sqlalchemy import ColumnElement, or_ -from sqlalchemy.orm import Session -from typing_extensions import override - -from askui.chat.api.assistants.service import AssistantService -from askui.chat.api.db.queries import list_all -from askui.chat.api.mcp_clients.manager import McpClientManagerManager -from askui.chat.api.messages.chat_history_manager import ChatHistoryManager -from askui.chat.api.models import RunId, ThreadId, WorkspaceId -from askui.chat.api.runs.events.events import DoneEvent, ErrorEvent, Event, RunEvent -from askui.chat.api.runs.events.io_publisher import IOPublisher -from askui.chat.api.runs.events.service import EventService -from askui.chat.api.runs.models import ( - Run, - RunCreate, - RunListQuery, - RunModify, - RunStatus, -) -from askui.chat.api.runs.orms import RunOrm -from askui.chat.api.runs.runner.runner import Runner, RunnerRunService -from askui.chat.api.settings import Settings -from askui.utils.api_utils import ListResponse, NotFoundError - - -class RunService(RunnerRunService): - """Service for managing Run resources with database persistence.""" - - def __init__( - self, - session: Session, - assistant_service: AssistantService, - mcp_client_manager_manager: McpClientManagerManager, - chat_history_manager: ChatHistoryManager, - settings: Settings, - ) -> None: - self._session = session - self._assistant_service = assistant_service - self._mcp_client_manager_manager = mcp_client_manager_manager - self._chat_history_manager = chat_history_manager - self._settings = settings - self._event_service = EventService(settings.data_dir, self) - self._io_publisher = IOPublisher(settings.enable_io_events) - - def _find_by_id( - self, workspace_id: WorkspaceId | None, thread_id: ThreadId, run_id: RunId - ) -> RunOrm: - """Find run by ID.""" - run_orm: RunOrm | None = ( - self._session.query(RunOrm) - .filter( - RunOrm.id == run_id, - RunOrm.thread_id == thread_id, - RunOrm.workspace_id == workspace_id, - ) - .first() - ) - if run_orm is None: - error_msg = f"Run {run_id} not found in thread {thread_id}" - raise NotFoundError(error_msg) - return run_orm - - def _create( - self, workspace_id: WorkspaceId, thread_id: ThreadId, params: RunCreate - ) -> Run: - """Create a new run.""" - run = Run.create(workspace_id, thread_id, params) - run_orm = RunOrm.from_model(run) - self._session.add(run_orm) - self._session.commit() - return run - - async def create( - self, - workspace_id: WorkspaceId, - thread_id: ThreadId, - params: RunCreate, - ) -> tuple[Run, AsyncGenerator[Event, None]]: - assistant = self._assistant_service.retrieve( - workspace_id=workspace_id, assistant_id=params.assistant_id - ) - run = self._create(workspace_id, thread_id, params) - send_stream, receive_stream = anyio.create_memory_object_stream[Event]() - - last_message_id = self._chat_history_manager.retrieve_last_message( - workspace_id=workspace_id, - thread_id=thread_id, - ) - runner = Runner( - run_id=run.id, - thread_id=thread_id, - workspace_id=workspace_id, - assistant=assistant, - chat_history_manager=self._chat_history_manager, - mcp_client_manager_manager=self._mcp_client_manager_manager, - run_service=self, - settings=self._settings, - last_message_id=last_message_id, - model=params.model, - ) - - async def event_generator() -> AsyncGenerator[Event, None]: - try: - async with self._event_service.create_writer( - thread_id, run.id - ) as event_writer: - run_created_event = RunEvent( - data=run, - event="thread.run.created", - ) - await event_writer.write_event(run_created_event) - yield run_created_event - run_queued_event = RunEvent( - data=run, - event="thread.run.queued", - ) - await event_writer.write_event(run_queued_event) - yield run_queued_event - - async def run_runner() -> None: - try: - await runner.run(send_stream) # type: ignore[arg-type] - finally: - await send_stream.aclose() - - async with anyio.create_task_group() as tg: - tg.start_soon(run_runner) - - while True: - try: - event = await receive_stream.receive() - await event_writer.write_event(event) - yield event - if isinstance(event, DoneEvent) or isinstance( - event, ErrorEvent - ): - self._io_publisher.publish(event) - break - except anyio.EndOfStream: - break - finally: - await send_stream.aclose() - - return run, event_generator() - - @override - def modify( - self, - workspace_id: WorkspaceId, - thread_id: ThreadId, - run_id: RunId, - params: RunModify, - ) -> Run: - run_orm = self._find_by_id(workspace_id, thread_id, run_id) - run = run_orm.to_model() - run.validate_modify(params) - run_orm.update(params.model_dump(exclude={"type"})) - self._session.commit() - self._session.refresh(run_orm) - return run_orm.to_model() - - @override - def retrieve( - self, workspace_id: WorkspaceId, thread_id: ThreadId, run_id: RunId - ) -> Run: - """Retrieve run by ID.""" - run_orm = self._find_by_id(workspace_id, thread_id, run_id) - return run_orm.to_model() - - async def retrieve_stream( - self, workspace_id: WorkspaceId, thread_id: ThreadId, run_id: RunId - ) -> AsyncGenerator[Event, None]: - async with self._event_service.create_reader( - workspace_id=workspace_id, thread_id=thread_id, run_id=run_id - ) as event_reader: - async for event in event_reader.read_events(): - yield event - - def _build_status_condition(self, status: RunStatus) -> ColumnElement[bool]: - match status: - case "expired": - return (RunOrm.status == "expired") | ( - (RunOrm.status.in_(["queued", "in_progress", "cancelling"])) - & (RunOrm.expires_at < datetime.now(tz=timezone.utc)) - ) - case _: - return RunOrm.status == status - - def list_( - self, workspace_id: WorkspaceId, query: RunListQuery - ) -> ListResponse[Run]: - """List runs with pagination and filtering.""" - q = self._session.query(RunOrm).filter(RunOrm.workspace_id == workspace_id) - - if query.thread: - q = q.filter(RunOrm.thread_id == query.thread) - - if query.status: - status_conditions = [ - self._build_status_condition(status) for status in query.status - ] - q = q.filter(or_(*status_conditions)) - - orms: list[RunOrm] - orms, has_more = list_all(q, query, RunOrm.id) - data = [orm.to_model() for orm in orms] - return ListResponse( - data=data, - has_more=has_more, - first_id=data[0].id if data else None, - last_id=data[-1].id if data else None, - ) diff --git a/src/askui/chat/api/scheduled_jobs/__init__.py b/src/askui/chat/api/scheduled_jobs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/scheduled_jobs/dependencies.py b/src/askui/chat/api/scheduled_jobs/dependencies.py deleted file mode 100644 index bc835878..00000000 --- a/src/askui/chat/api/scheduled_jobs/dependencies.py +++ /dev/null @@ -1,14 +0,0 @@ -"""FastAPI dependencies for scheduled jobs.""" - -from fastapi import Depends - -from askui.chat.api.scheduled_jobs.scheduler import scheduler -from askui.chat.api.scheduled_jobs.service import ScheduledJobService - - -def get_scheduled_job_service() -> ScheduledJobService: - """Get ScheduledJobService instance with the singleton scheduler.""" - return ScheduledJobService(scheduler=scheduler) - - -ScheduledJobServiceDep = Depends(get_scheduled_job_service) diff --git a/src/askui/chat/api/scheduled_jobs/executor.py b/src/askui/chat/api/scheduled_jobs/executor.py deleted file mode 100644 index baff6100..00000000 --- a/src/askui/chat/api/scheduled_jobs/executor.py +++ /dev/null @@ -1,122 +0,0 @@ -"""Executor for scheduled job callbacks.""" - -import base64 -import logging -import os -from typing import Any - -from sqlalchemy.orm import Session - -from askui.chat.api.db.engine import engine -from askui.chat.api.messages.dependencies import get_message_service -from askui.chat.api.runs.dependencies import create_run_service -from askui.chat.api.runs.models import RunCreate -from askui.chat.api.scheduled_jobs.models import ( - MessageRerunnerData, - ScheduledJobExecutionResult, - scheduled_job_data_adapter, -) - -_logger = logging.getLogger(__name__) - - -async def execute_job( - **_kwargs: Any, -) -> ScheduledJobExecutionResult: - """ - APScheduler callback that creates fresh services and executes the job. - - This function is called by APScheduler when a job fires. It creates fresh - database sessions and service instances to avoid stale connections. - - Args: - **_kwargs (Any): Keyword arguments containing job data. - - Returns: - ScheduledJobExecutionResult: The result containing job data and optional error. - """ - # Validates and returns the correct concrete type based on the `type` discriminator - job_data = scheduled_job_data_adapter.validate_python(_kwargs) - - _logger.info( - "Executing scheduled job: workspace=%s, thread=%s", - job_data.workspace_id, - job_data.thread_id, - ) - - error: str | None = None - - try: - # future proofing of new job types - if isinstance(job_data, MessageRerunnerData): # pyright: ignore[reportUnnecessaryIsInstance] - # Save previous ASKUI_TOKEN and AUTHORIZATION_HEADER env vars - _previous_authorization = os.environ.get("ASKUI__AUTHORIZATION") - - # remove authorization header since it takes precedence over the token and is set when forwarding bearer token - os.environ["ASKUI__AUTHORIZATION"] = ( - f"Basic {base64.b64encode(job_data.askui_token.get_secret_value().encode()).decode()}" - ) - - try: - await _execute_message_rerunner_job(job_data) - finally: - # Restore previous AUTHORIZATION_HEADER env var - if _previous_authorization is not None: - os.environ["ASKUI__AUTHORIZATION"] = _previous_authorization - except Exception as e: - error = f"{type(e).__name__}: {e}" - _logger.exception("Scheduled job failed: %s", error) - - # Always return job data with optional error - return ScheduledJobExecutionResult(data=job_data, error=error) - - -async def _execute_message_rerunner_job( - job_data: MessageRerunnerData, -) -> None: - """ - Execute a message rerunner job. - - Args: - job_data: The job data. - """ - with Session(engine) as session: - message_service = get_message_service(session) - run_service = create_run_service(session, job_data.workspace_id) - - # Create message - message_service.create( - workspace_id=job_data.workspace_id, - thread_id=job_data.thread_id, - params=job_data.message, - ) - - # Create and execute run - _logger.debug("Creating run with assistant %s", job_data.assistant_id) - run, generator = await run_service.create( - workspace_id=job_data.workspace_id, - thread_id=job_data.thread_id, - params=RunCreate(assistant_id=job_data.assistant_id, model=job_data.model), - ) - - # Consume generator to completion of run - _logger.debug("Waiting for run %s to complete", run.id) - async for _event in generator: - pass - - # Check if run completed with error - completed_run = run_service.retrieve( - workspace_id=job_data.workspace_id, - thread_id=job_data.thread_id, - run_id=run.id, - ) - - if completed_run.status == "failed": - error_message = ( - completed_run.last_error.message - if completed_run.last_error - else "Run failed with unknown error" - ) - raise RuntimeError(error_message) - - _logger.info("Scheduled job completed: run_id=%s", run.id) diff --git a/src/askui/chat/api/scheduled_jobs/models.py b/src/askui/chat/api/scheduled_jobs/models.py deleted file mode 100644 index c220e9d0..00000000 --- a/src/askui/chat/api/scheduled_jobs/models.py +++ /dev/null @@ -1,205 +0,0 @@ -from typing import Literal, Union - -from apscheduler import Schedule -from apscheduler.triggers.date import DateTrigger -from pydantic import BaseModel, Field, SecretStr, TypeAdapter - -from askui.chat.api.messages.models import ROOT_MESSAGE_PARENT_ID, MessageCreate -from askui.chat.api.models import ( - AssistantId, - MessageId, - ScheduledJobId, - ThreadId, - WorkspaceId, -) -from askui.utils.datetime_utils import UnixDatetime -from askui.utils.id_utils import generate_time_ordered_id - - -class ScheduledMessageCreate(MessageCreate): - """ - Message creation parameters for scheduled jobs. - - Extends `MessageCreate` with required `parent_id` to ensure the message - is added to the correct branch in the conversation tree, since the thread - state may change between scheduling and execution. - - Args: - parent_id (MessageId): The parent message ID for branching. - Required for scheduled messages. Use `ROOT_MESSAGE_PARENT_ID` to - create a new root branch. - """ - - parent_id: MessageId = Field(default=ROOT_MESSAGE_PARENT_ID) # pyright: ignore - - -# ============================================================================= -# API Input Models (without workspace_id - injected from header) -# ============================================================================= - - -class _BaseMessageRerunnerDataCreate(BaseModel): - """ - API input data for the message_rerunner job type. - - This job type creates a new message in a thread and executes a run - at the scheduled time. - - Args: - type (ScheduledJobType): The type of the scheduled job. - thread_id (ThreadId): The thread to add the message to. - assistant_id (AssistantId): The assistant to run. - model (str): The model to use for the run. - message (ScheduledMessageCreate): The message to create. - askui_token (str): The AskUI token to use for authenticated API calls - when the job executes. This is a long-lived credential that doesn't - expire like Bearer tokens. - """ - - type: Literal["message_rerunner"] = "message_rerunner" - name: str - thread_id: ThreadId - assistant_id: AssistantId - model: str - message: ScheduledMessageCreate - askui_token: SecretStr - - -class ScheduledJobCreate(BaseModel): - """ - API input data for scheduled job creation. - - Args: - next_fire_time (UnixDatetime): The time when the job should execute. - data (ScheduledJobData): The data for the job. - """ - - next_fire_time: UnixDatetime - data: _BaseMessageRerunnerDataCreate - - -# ============================================================================= -# Internal Models (with workspace_id - populated after injection) -# ============================================================================= - - -class MessageRerunnerData(_BaseMessageRerunnerDataCreate): - """ - Internal data for the message_rerunner job type. - - Extends `MessageRerunnerDataCreate` with required `workspace_id` that is - injected from the request header. - - Args: - workspace_id (WorkspaceId): The workspace this job belongs to. - """ - - workspace_id: WorkspaceId - - -# Discriminated union of all job data types (extensible for future types) -ScheduledJobData = Union[MessageRerunnerData] - -scheduled_job_data_adapter: TypeAdapter[ScheduledJobData] = TypeAdapter( - ScheduledJobData -) - - -class ScheduledJob(BaseModel): - """ - A scheduled job that will execute at a specified time. - - Maps to APScheduler's `Schedule` structure for easy conversion. - - Args: - id (ScheduledJobId): Unique identifier for the scheduled job. - Maps to `Schedule.id`. - next_fire_time (UnixDatetime): When the job is scheduled to execute. - Maps to `Schedule.next_fire_time` or `Schedule.trigger.run_time`. - data (ScheduledJobData): Type-specific job data. Always contains `type` and - `workspace_id`. Maps to `Schedule.kwargs`. - object (Literal["scheduled_job"]): Object type identifier. - """ - - id: ScheduledJobId - object: Literal["scheduled_job"] = "scheduled_job" - next_fire_time: UnixDatetime - data: ScheduledJobData - - @classmethod - def create( - cls, - workspace_id: WorkspaceId, - params: ScheduledJobCreate, - ) -> "ScheduledJob": - """ - Create a new ScheduledJob with a generated ID. - - Args: - workspace_id (WorkspaceId): The workspace this job belongs to. - params (ScheduledJobCreate): The job creation parameters. - - Returns: - ScheduledJob: The created scheduled job. - """ - return cls( - id=generate_time_ordered_id("schedjob"), - next_fire_time=params.next_fire_time, - data=MessageRerunnerData( - workspace_id=workspace_id, - name=params.data.name, - thread_id=params.data.thread_id, - assistant_id=params.data.assistant_id, - model=params.data.model, - message=params.data.message, - askui_token=params.data.askui_token, - ), - ) - - @classmethod - def from_schedule(cls, schedule: Schedule) -> "ScheduledJob": - """ - Create a ScheduledJob from an APScheduler Schedule. - - Args: - schedule (Schedule): The APScheduler schedule to convert. - - Returns: - ScheduledJob: The converted scheduled job. - - Raises: - ValueError: If the schedule has no determinable `next_fire_time`. - """ - # Extract next_fire_time from schedule or trigger - next_fire_time: UnixDatetime - if schedule.next_fire_time is not None: - next_fire_time = schedule.next_fire_time - elif isinstance(schedule.trigger, DateTrigger): - next_fire_time = schedule.trigger.run_time - else: - error_msg = f"Schedule {schedule.id} has no next_fire_time" - raise ValueError(error_msg) - # Reconstruct data from kwargs - data = MessageRerunnerData.model_validate(schedule.kwargs or {}) - - return cls( - id=schedule.id, - next_fire_time=next_fire_time, - data=data, - ) - - -class ScheduledJobExecutionResult(BaseModel): - """ - Return value stored by the job executor in APScheduler's job result. - - This ensures we always have job data available even if the job fails, - since APScheduler clears return_value on exception. - - Args: - data (ScheduledJobData): The job data that was executed. - error (str | None): Error message if the job failed. - """ - - data: ScheduledJobData - error: str | None = None diff --git a/src/askui/chat/api/scheduled_jobs/router.py b/src/askui/chat/api/scheduled_jobs/router.py deleted file mode 100644 index 794e9474..00000000 --- a/src/askui/chat/api/scheduled_jobs/router.py +++ /dev/null @@ -1,59 +0,0 @@ -"""API router for scheduled jobs.""" - -from typing import Annotated - -from fastapi import APIRouter, Header, status - -from askui.chat.api.models import ScheduledJobId, WorkspaceId -from askui.chat.api.scheduled_jobs.dependencies import ScheduledJobServiceDep -from askui.chat.api.scheduled_jobs.models import ScheduledJob, ScheduledJobCreate -from askui.chat.api.scheduled_jobs.service import ScheduledJobService -from askui.utils.api_utils import ListResponse - -router = APIRouter(prefix="/scheduled-jobs", tags=["scheduled-jobs"]) - - -@router.post("", status_code=status.HTTP_201_CREATED) -async def create_scheduled_job( - askui_workspace: Annotated[WorkspaceId, Header()], - params: ScheduledJobCreate, - scheduled_job_service: ScheduledJobService = ScheduledJobServiceDep, -) -> ScheduledJob: - """Create a new scheduled job.""" - - return await scheduled_job_service.create( - workspace_id=askui_workspace, - params=params, - ) - - -@router.get("") -async def list_scheduled_jobs( - askui_workspace: Annotated[WorkspaceId, Header()], - scheduled_job_service: ScheduledJobService = ScheduledJobServiceDep, -) -> ListResponse[ScheduledJob]: - """List scheduled jobs with optional status filter.""" - return await scheduled_job_service.list_( - workspace_id=askui_workspace, - ) - - -@router.delete("/{job_id}", status_code=status.HTTP_204_NO_CONTENT) -async def cancel_scheduled_job( - askui_workspace: Annotated[WorkspaceId, Header()], - job_id: ScheduledJobId, - scheduled_job_service: ScheduledJobService = ScheduledJobServiceDep, -) -> None: - """ - Cancel a scheduled job. - - Only works for jobs with status 'pending'. Removes the job from the scheduler. - Cancelled jobs have no history (they are simply removed). - - Raises: - NotFoundError: If the job is not found or already executed. - """ - await scheduled_job_service.cancel( - workspace_id=askui_workspace, - job_id=job_id, - ) diff --git a/src/askui/chat/api/scheduled_jobs/scheduler.py b/src/askui/chat/api/scheduled_jobs/scheduler.py deleted file mode 100644 index aa3d2afb..00000000 --- a/src/askui/chat/api/scheduled_jobs/scheduler.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -Module-level APScheduler singleton management. - -Similar to how `engine.py` manages the database engine, this module manages -the APScheduler instance as a singleton to ensure jobs persist across requests. - -Uses the shared database engine from `engine.py` which is configured with -optimized SQLite pragmas for concurrent access (WAL mode, etc.). -""" - -import logging -from datetime import timedelta -from typing import Any - -from apscheduler import AsyncScheduler -from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore - -from askui.chat.api.db.engine import engine - -_logger = logging.getLogger(__name__) - -# Use shared engine from db/engine.py (already configured with SQLite pragmas) -# APScheduler will create its own tables (apscheduler_*) in the same database -_data_store: Any = SQLAlchemyDataStore(engine_or_url=engine) - -# Module-level singleton scheduler instance -# - max_concurrent_jobs=1: only one job runs at a time (sequential execution) -# At module level: just create the scheduler (don't start it) -scheduler: AsyncScheduler = AsyncScheduler( - data_store=_data_store, - max_concurrent_jobs=1, - cleanup_interval=timedelta(minutes=1), # Cleanup every minute -) - - -async def start_scheduler() -> None: - """ - Start the scheduler to begin processing jobs. - - This initializes the scheduler and starts it in the background so it can - poll for and execute scheduled jobs while the FastAPI application handles requests. - """ - # First initialize the scheduler via context manager entry - await scheduler.__aenter__() - # Then start background processing of jobs - await scheduler.start_in_background() - _logger.info("Scheduler started in background") - - -async def shutdown_scheduler() -> None: - """Shut down the scheduler gracefully.""" - await scheduler.__aexit__(None, None, None) - _logger.info("Scheduler shut down") diff --git a/src/askui/chat/api/scheduled_jobs/service.py b/src/askui/chat/api/scheduled_jobs/service.py deleted file mode 100644 index 37d2e0f2..00000000 --- a/src/askui/chat/api/scheduled_jobs/service.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Service for managing scheduled jobs.""" - -import logging -from datetime import timedelta - -from apscheduler import AsyncScheduler, Schedule -from apscheduler.triggers.date import DateTrigger - -from askui.chat.api.models import ScheduledJobId, WorkspaceId -from askui.chat.api.scheduled_jobs.executor import execute_job -from askui.chat.api.scheduled_jobs.models import ScheduledJob, ScheduledJobCreate -from askui.utils.api_utils import ListResponse, NotFoundError - -logger = logging.getLogger(__name__) - - -class ScheduledJobService: - """ - Service for managing scheduled jobs using APScheduler. - - This service provides methods to create, list, and cancel scheduled jobs. - Job data is stored in APScheduler's SQLAlchemy data store. - - Args: - scheduler (Any): The APScheduler `AsyncScheduler` instance to use. - """ - - def __init__(self, scheduler: AsyncScheduler) -> None: - self._scheduler: AsyncScheduler = scheduler - - async def create( - self, - workspace_id: WorkspaceId, - params: ScheduledJobCreate, - ) -> ScheduledJob: - """ - Create a new scheduled job. - - Args: - workspace_id (WorkspaceId): The workspace this job belongs to. - params (ScheduledJobCreate): The job creation parameters. - - Returns: - ScheduledJob: The created scheduled job. - """ - job = ScheduledJob.create( - workspace_id=workspace_id, - params=params, - ) - - # Prepare kwargs for the job callback - - logger.info( - "Creating scheduled job: id=%s, type=%s, next_fire_time=%s", - job.id, - job.data.type, - job.next_fire_time, - ) - - await self._scheduler.add_schedule( - func_or_task_id=execute_job, - trigger=DateTrigger(run_time=job.next_fire_time), - id=job.id, - kwargs={ - **job.data.model_dump(mode="json"), - "askui_token": job.data.askui_token.get_secret_value(), - }, - misfire_grace_time=timedelta(minutes=10), - job_result_expiration_time=timedelta(weeks=30000), # Never expire - ) - - logger.info("Scheduled job created: %s", job.id) - return job - - async def list_( - self, - workspace_id: WorkspaceId, - ) -> ListResponse[ScheduledJob]: - """ - List pending scheduled jobs. - - Args: - workspace_id (WorkspaceId): Filter by workspace. - query (ListQuery): Query parameters. - - Returns: - ListResponse[ScheduledJob]: Paginated list of pending scheduled jobs. - """ - jobs = await self._get_pending_jobs(workspace_id) - - return ListResponse( - data=jobs, - has_more=False, - first_id=jobs[0].id if jobs else None, - last_id=jobs[-1].id if jobs else None, - ) - - async def cancel( - self, - workspace_id: WorkspaceId, - job_id: ScheduledJobId, - ) -> None: - """ - Cancel a scheduled job. - - This removes the schedule from APScheduler. Only works for pending jobs. - - Args: - workspace_id (WorkspaceId): The workspace the job belongs to. - job_id (ScheduledJobId): The job ID to cancel. - - Raises: - NotFoundError: If the job is not found or already executed. - """ - logger.info("Canceling scheduled job: %s", job_id) - - schedules: list[Schedule] = await self._scheduler.data_store.get_schedules( - {job_id} - ) - - if not schedules: - msg = f"Scheduled job {job_id} not found" - raise NotFoundError(msg) - - scheduled_job = ScheduledJob.from_schedule(schedules[0]) - if scheduled_job.data.workspace_id != workspace_id: - msg = f"Scheduled job {job_id} not found in workspace {workspace_id}" - raise NotFoundError(msg) - - await self._scheduler.data_store.remove_schedules([job_id]) - logger.info("Scheduled job canceled: %s", job_id) - - async def _get_pending_jobs(self, workspace_id: WorkspaceId) -> list[ScheduledJob]: - """Get pending jobs from APScheduler schedules.""" - scheduled_jobs: list[ScheduledJob] = [] - - schedules: list[Schedule] = await self._scheduler.data_store.get_schedules() - - for schedule in schedules: - scheduled_job = ScheduledJob.from_schedule(schedule) - if scheduled_job.data.workspace_id != workspace_id: - continue - scheduled_jobs.append(scheduled_job) - - return sorted(scheduled_jobs, key=lambda x: x.next_fire_time) diff --git a/src/askui/chat/api/settings.py b/src/askui/chat/api/settings.py deleted file mode 100644 index c2823984..00000000 --- a/src/askui/chat/api/settings.py +++ /dev/null @@ -1,124 +0,0 @@ -from pathlib import Path - -from fastmcp.mcp_config import StdioMCPServer -from pydantic import BaseModel, Field, field_validator -from pydantic_settings import BaseSettings, SettingsConfigDict - -from askui.chat.api.mcp_configs.models import McpConfig, RemoteMCPServer -from askui.chat.api.telemetry.integrations.fastapi.settings import TelemetrySettings -from askui.chat.api.telemetry.logs.settings import LogFilter, LogSettings -from askui.telemetry.otel import OtelSettings -from askui.utils.datetime_utils import now - - -class DbSettings(BaseModel): - """Database configuration settings.""" - - url: str = Field( - default_factory=lambda: f"sqlite:///{(Path.cwd().absolute() / 'askui_chat.db').as_posix()}", - description="Database URL for SQLAlchemy connection (used for all data including scheduler)", - ) - auto_migrate: bool = Field( - default=True, - description="Whether to run migrations automatically on startup", - ) - - @field_validator("url") - @classmethod - def validate_sqlite_url(cls, v: str) -> str: - """Ensure only synchronous SQLite URLs are allowed.""" - if not v.startswith("sqlite://"): - error_msg = ( - "Only synchronous SQLite URLs are allowed (must start with 'sqlite://')" - ) - raise ValueError(error_msg) - return v - - -def _get_default_mcp_configs(chat_api_host: str, chat_api_port: int) -> list[McpConfig]: - return [ - McpConfig( - id="mcpcnf_68ac2c4edc4b2f27faa5a252", - created_at=now(), - name="askui_chat", - mcp_server=RemoteMCPServer( - url=f"http://{chat_api_host}:{chat_api_port}/mcp/sse", - transport="sse", - ), - ), - McpConfig( - id="mcpcnf_68ac2c4edc4b2f27faa5a251", - created_at=now(), - name="playwright", - mcp_server=StdioMCPServer( - command="npx", - args=[ - "@playwright/mcp@latest", - "--isolated", - ], - ), - ), - ] - - -class Settings(BaseSettings): - """Settings for the chat API.""" - - model_config = SettingsConfigDict( - env_prefix="ASKUI__CHAT_API__", env_nested_delimiter="__" - ) - - data_dir: Path = Field( - default_factory=lambda: Path.cwd() / "chat", - description="Base directory for chat data (used during migration)", - ) - db: DbSettings = Field(default_factory=DbSettings) - host: str = Field( - default="127.0.0.1", - description="Host for the chat API", - ) - port: int = Field( - default=9261, - description="Port for the chat API", - ge=1024, - le=65535, - ) - mcp_configs: list[McpConfig] = Field( - default_factory=lambda data: _get_default_mcp_configs( - data["host"], data["port"] - ), - description=( - "Global MCP configurations used to " - "connect to MCP servers shared across all workspaces." - ), - ) - model: str = Field( - default="askui/claude-haiku-4-5-20251001", - description="Default model to use for chat interactions", - ) - allow_origins: list[str] = Field( - default_factory=lambda: [ - "https://app.caesr.ai", - "https://hub.askui.com", - ], - description="CORS allowed origins for the chat API", - ) - telemetry: TelemetrySettings = Field( - default_factory=lambda: TelemetrySettings( - log=LogSettings( - filters=[ - LogFilter(type="equals", key="path", value="/v1/health"), - LogFilter(type="equals", key="path", value="/v1/metrics"), - LogFilter(type="equals", key="method", value="OPTIONS"), - ], - ), - ), - ) - otel: OtelSettings = Field( - default_factory=OtelSettings, - description="OpenTelemetry configuration settings", - ) - enable_io_events: bool = Field( - default=False, - description="Whether to enable the publishing events to stdout", - ) diff --git a/src/askui/chat/api/telemetry/__init__.py b/src/askui/chat/api/telemetry/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/telemetry/integrations/__init__.py b/src/askui/chat/api/telemetry/integrations/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/telemetry/integrations/fastapi/__init__.py b/src/askui/chat/api/telemetry/integrations/fastapi/__init__.py deleted file mode 100644 index f6c98e4a..00000000 --- a/src/askui/chat/api/telemetry/integrations/fastapi/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -from asgi_correlation_id import CorrelationIdMiddleware -from fastapi import FastAPI -from prometheus_fastapi_instrumentator import Instrumentator -from starlette_context.middleware import RawContextMiddleware - -from askui.chat.api.telemetry.integrations.fastapi.settings import TelemetrySettings -from askui.chat.api.telemetry.logs import propagate_logs_up, setup_logging, silence_logs - -from .fastapi_middleware import ( - AccessLoggingMiddleware, - ExceptionHandlingMiddleware, - ProcessTimingMiddleware, - TracingMiddleware, -) -from .structlog_processors import merge_starlette_contextvars - - -def instrument( - app: FastAPI, - settings: TelemetrySettings | None = None, -) -> None: - _settings = settings or TelemetrySettings() - setup_logging( - _settings.log, - pre_processors=[merge_starlette_contextvars], - ) - silence_logs(["uvicorn.access"]) - propagate_logs_up(["uvicorn", "uvicorn.error"]) - app.add_middleware(ExceptionHandlingMiddleware) - app.add_middleware(TracingMiddleware) - app.add_middleware(ProcessTimingMiddleware) - app.add_middleware(AccessLoggingMiddleware) - app.add_middleware(CorrelationIdMiddleware) - app.add_middleware(RawContextMiddleware) - Instrumentator().instrument(app).expose(app, endpoint="/v1/metrics") diff --git a/src/askui/chat/api/telemetry/integrations/fastapi/fastapi_middleware.py b/src/askui/chat/api/telemetry/integrations/fastapi/fastapi_middleware.py deleted file mode 100644 index 825cf082..00000000 --- a/src/askui/chat/api/telemetry/integrations/fastapi/fastapi_middleware.py +++ /dev/null @@ -1,82 +0,0 @@ -import logging -from typing import Awaitable, Callable - -import structlog -from asgi_correlation_id.context import correlation_id -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request -from starlette.responses import Response -from starlette.types import ASGIApp - -from askui.chat.api.telemetry.integrations.fastapi.models import AccessLogLine, TimeSpan - -from . import structlog_context -from .utils import compact - -access_logger = structlog.stdlib.get_logger("api.access") -error_logger = structlog.stdlib.get_logger("api.error") - - -EVENT = "API Accessed" - - -class ExceptionHandlingMiddleware(BaseHTTPMiddleware): - async def dispatch( - self, request: Request, call_next: Callable[[Request], Awaitable[Response]] - ) -> Response: - try: - return await call_next(request) - except Exception: # noqa: BLE001 - error_message = "Uncaught exception raised handling request" - error_logger.exception(error_message) - return Response("Internal Server Error", status_code=500) - - -class TracingMiddleware(BaseHTTPMiddleware): - async def dispatch( - self, request: Request, call_next: Callable[[Request], Awaitable[Response]] - ) -> Response: - request_id = correlation_id.get() - structlog_context.bind(request_id=request_id) - return await call_next(request) - - -class ProcessTimingMiddleware(BaseHTTPMiddleware): - async def dispatch( - self, request: Request, call_next: Callable[[Request], Awaitable[Response]] - ) -> Response: - time_span = TimeSpan() - - response = await call_next(request) - time_span.end() - response.headers.append("x-process-time", str(time_span.in_s)) - structlog_context.bind(time_ms=time_span.in_ms) - return response - - -class AccessLoggingMiddleware(BaseHTTPMiddleware): - def __init__(self, app: ASGIApp): - super().__init__(app) - - def determine_log_level(self, request: Request) -> int: # noqa: ARG002 - return logging.INFO - - async def dispatch( - self, request: Request, call_next: Callable[[Request], Awaitable[Response]] - ) -> Response: - response = await call_next(request) - access_log_line = AccessLogLine( - level=self.determine_log_level(request), - event=EVENT, - method=request.method, - path=request.url.path, - query=request.url.query, - status=response.status_code, - http_version=request.scope["http_version"], - ip=request.client.host if request.client else None, - port=request.client.port if request.client else None, - ) - await access_logger.alog( - **compact({**access_log_line, **structlog_context.get()}) - ) - return response diff --git a/src/askui/chat/api/telemetry/integrations/fastapi/models.py b/src/askui/chat/api/telemetry/integrations/fastapi/models.py deleted file mode 100644 index 003615cd..00000000 --- a/src/askui/chat/api/telemetry/integrations/fastapi/models.py +++ /dev/null @@ -1,43 +0,0 @@ -import time -from typing import Optional, TypedDict - - -class AccessLogLine(TypedDict): - level: int - event: str - method: str - path: str - query: Optional[str] - status: int - http_version: str - ip: Optional[str] - port: Optional[int] - - -class TimeSpanData(TypedDict, total=False): - started_at: int - ended_at: Optional[int] - - -class TimeSpan: - def __init__(self) -> None: - self.started_at: int = time.perf_counter_ns() - self.ended_at: Optional[int] = None - - def end(self) -> None: - self.ended_at = time.perf_counter_ns() - - @property - def in_ns(self) -> Optional[int]: - if self.ended_at is None: - return None - - return self.ended_at - self.started_at - - @property - def in_ms(self) -> Optional[float]: - return self.in_ns / 10**6 if self.in_ns is not None else None - - @property - def in_s(self) -> Optional[float]: - return self.in_ns / 10**9 if self.in_ns is not None else None diff --git a/src/askui/chat/api/telemetry/integrations/fastapi/settings.py b/src/askui/chat/api/telemetry/integrations/fastapi/settings.py deleted file mode 100644 index 7ac3e27d..00000000 --- a/src/askui/chat/api/telemetry/integrations/fastapi/settings.py +++ /dev/null @@ -1,7 +0,0 @@ -from pydantic import BaseModel, Field - -from askui.chat.api.telemetry.logs.settings import LogSettings - - -class TelemetrySettings(BaseModel): - log: LogSettings = Field(default_factory=LogSettings) diff --git a/src/askui/chat/api/telemetry/integrations/fastapi/structlog_context.py b/src/askui/chat/api/telemetry/integrations/fastapi/structlog_context.py deleted file mode 100644 index c33a9a08..00000000 --- a/src/askui/chat/api/telemetry/integrations/fastapi/structlog_context.py +++ /dev/null @@ -1,24 +0,0 @@ -from copy import deepcopy -from typing import Any - -from starlette_context import context - -STRUCTLOG_REQUEST_CONTEXT_KEY = "structlog_context" - - -def is_available() -> bool: - return context.exists() - - -def get() -> dict[str, Any]: - return deepcopy(context.get(STRUCTLOG_REQUEST_CONTEXT_KEY, {})) - - -def bind(**kw: Any) -> None: - new_context = get() - new_context.update(kw) - context[STRUCTLOG_REQUEST_CONTEXT_KEY] = new_context - - -def reset() -> None: - context[STRUCTLOG_REQUEST_CONTEXT_KEY] = {} diff --git a/src/askui/chat/api/telemetry/integrations/fastapi/structlog_processors.py b/src/askui/chat/api/telemetry/integrations/fastapi/structlog_processors.py deleted file mode 100644 index d4212aaf..00000000 --- a/src/askui/chat/api/telemetry/integrations/fastapi/structlog_processors.py +++ /dev/null @@ -1,22 +0,0 @@ -import logging - -from structlog.types import EventDict - -from . import structlog_context - - -def merge_starlette_contextvars( - logger: logging.Logger, # noqa: ARG001 - method_name: str, # noqa: ARG001 - event_dict: EventDict, -) -> EventDict: - """ - Merges the starlette contextvars into the structlog contextvars. - """ - - if structlog_context.is_available(): - return { - **event_dict, - **structlog_context.get(), - } - return event_dict diff --git a/src/askui/chat/api/telemetry/integrations/fastapi/utils.py b/src/askui/chat/api/telemetry/integrations/fastapi/utils.py deleted file mode 100644 index 2dba0c91..00000000 --- a/src/askui/chat/api/telemetry/integrations/fastapi/utils.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Any - - -def compact(d: dict[Any, Any]) -> dict[Any, Any]: - result = {} - for k, v in d.items(): - if isinstance(v, dict): - v = compact(v) - if not (not v and type(v) not in (bool, int, float, complex)): - result[k] = v - return result diff --git a/src/askui/chat/api/telemetry/logs/__init__.py b/src/askui/chat/api/telemetry/logs/__init__.py deleted file mode 100644 index 739ec2d8..00000000 --- a/src/askui/chat/api/telemetry/logs/__init__.py +++ /dev/null @@ -1,64 +0,0 @@ -import logging as logging_stdlib -import sys -from types import TracebackType - -from structlog import types as structlog_types - -from .settings import LogSettings -from .structlog import setup_structlog - -logger = logging_stdlib.getLogger(__name__) - - -def setup_uncaught_exception_logging(logger: logging_stdlib.Logger) -> None: - def handle_uncaught_exception( - exc_type: type[BaseException], - exc_value: BaseException, - exc_traceback: TracebackType | None, - ) -> None: - """ - Log any uncaught exception instead of letting it be printed by Python - (but leave KeyboardInterrupt untouched to allow users to Ctrl+C to stop) - See https://stackoverflow.com/a/16993115/3641865 - """ - if issubclass(exc_type, KeyboardInterrupt): - sys.__excepthook__(exc_type, exc_value, exc_traceback) - return - - logger.error( - "Uncaught exception raised", exc_info=(exc_type, exc_value, exc_traceback) - ) - - sys.excepthook = handle_uncaught_exception - - -def propagate_logs_up(loggers: list[str]) -> None: - for logger_name in loggers: - logger = logging_stdlib.getLogger(logger_name) - logger.handlers.clear() - logger.propagate = True - - -def silence_logs(loggers: list[str]) -> None: - for logger_name in loggers: - logger = logging_stdlib.getLogger(logger_name) - logger.handlers.clear() - logger.propagate = False - - -_logging_setup = False - - -def setup_logging( - settings: LogSettings, - pre_processors: list[structlog_types.Processor] | None = None, -) -> None: - global _logging_setup - if _logging_setup: - logger.debug("Logging already setup. Skipping setup...") - return - logging_stdlib.captureWarnings(True) - root_logger = logging_stdlib.getLogger() - setup_structlog(root_logger, settings, pre_processors) - setup_uncaught_exception_logging(root_logger) - _logging_setup = True diff --git a/src/askui/chat/api/telemetry/logs/settings.py b/src/askui/chat/api/telemetry/logs/settings.py deleted file mode 100644 index 81b2386f..00000000 --- a/src/askui/chat/api/telemetry/logs/settings.py +++ /dev/null @@ -1,24 +0,0 @@ -import logging -from typing import Literal - -from pydantic import BaseModel, Field - -logger = logging.getLogger(__name__) - -LogFormat = Literal["JSON", "LOGFMT"] -LogLevel = Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] - - -class EqualsLogFilter(BaseModel): - type: Literal["equals"] - key: str - value: str - - -LogFilter = EqualsLogFilter - - -class LogSettings(BaseModel): - format: LogFormat = Field("LOGFMT") - level: LogLevel = Field("INFO") - filters: list[LogFilter] | None = None diff --git a/src/askui/chat/api/telemetry/logs/structlog.py b/src/askui/chat/api/telemetry/logs/structlog.py deleted file mode 100644 index 20ddae48..00000000 --- a/src/askui/chat/api/telemetry/logs/structlog.py +++ /dev/null @@ -1,85 +0,0 @@ -import logging - -import structlog -from structlog.dev import plain_traceback - -from .settings import LogFormat, LogLevel, LogSettings -from .structlog_processors import ( - create_filter_processor, - drop_color_message_key_processor, - flatten_dict_processor, -) - - -def setup_structlog( - root_logger: logging.Logger, - settings: LogSettings, - pre_processors: list[structlog.types.Processor] | None = None, -) -> None: - shared_processors = (pre_processors or []) + get_shared_processors(settings) - structlog.configure( - processors=shared_processors - + [structlog.stdlib.ProcessorFormatter.wrap_for_formatter], - logger_factory=structlog.stdlib.LoggerFactory(), - cache_logger_on_first_use=True, - ) - formatter = structlog.stdlib.ProcessorFormatter( - foreign_pre_chain=shared_processors, - processors=[ - structlog.stdlib.ProcessorFormatter.remove_processors_meta, - get_renderer(settings.format), - ], - ) - configure_stdlib_logger(root_logger, settings.level, formatter) - - -def configure_stdlib_logger( - logger: logging.Logger, log_level: LogLevel, formatter: logging.Formatter -) -> None: - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logger.addHandler(handler) - logger.setLevel(log_level) - - -EVENT_KEY = "message" - - -def get_shared_processors(settings: LogSettings) -> list[structlog.types.Processor]: - """Returns a list of processors, i.e., a processor chain, that can be shared between - structlog and stdlib loggers so that their content is consistent.""" - format_dependent_processors = get_format_dependent_processors(settings.format) - filter_processor = create_filter_processor(settings.filters) - return [ - structlog.contextvars.merge_contextvars, - structlog.stdlib.add_logger_name, - structlog.stdlib.add_log_level, - structlog.stdlib.PositionalArgumentsFormatter(), - structlog.stdlib.ExtraAdder(), - drop_color_message_key_processor, - structlog.processors.TimeStamper(fmt="iso"), - structlog.processors.StackInfoRenderer(), - *format_dependent_processors, - structlog.processors.EventRenamer(EVENT_KEY), - filter_processor, - ] - - -def get_format_dependent_processors( - log_format: LogFormat, -) -> list[structlog.types.Processor]: - if log_format == "JSON": - return [structlog.processors.format_exc_info] - return [ - structlog.dev.set_exc_info, - flatten_dict_processor, - ] - - -def get_renderer(log_format: LogFormat) -> structlog.types.Processor: - if log_format == "JSON": - return structlog.processors.JSONRenderer() - return structlog.dev.ConsoleRenderer( - event_key=EVENT_KEY, - exception_formatter=plain_traceback, - ) diff --git a/src/askui/chat/api/telemetry/logs/structlog_processors.py b/src/askui/chat/api/telemetry/logs/structlog_processors.py deleted file mode 100644 index eea83ced..00000000 --- a/src/askui/chat/api/telemetry/logs/structlog_processors.py +++ /dev/null @@ -1,78 +0,0 @@ -import logging - -from structlog import DropEvent -from structlog.types import EventDict, Processor - -from askui.chat.api.telemetry.logs.settings import LogFilter - -from .utils import flatten_dict - - -def flatten_dict_processor( - logger: logging.Logger, # noqa: ARG001 - method_name: str, # noqa: ARG001 - event_dict: EventDict, -) -> EventDict: - """ - Flattens a nested event dictionary deeply. Nested keys are concatenated with dot notation. - """ - return flatten_dict(event_dict) - - -def drop_color_message_key_processor( - logger: logging.Logger, # noqa: ARG001 - method_name: str, # noqa: ARG001 - event_dict: EventDict, -) -> EventDict: - """ - Uvicorn logs the message a second time in the extra `color_message`, but we don't - need it. This processor drops the key from the event dict if it exists. - """ - event_dict.pop("color_message", None) - return event_dict - - -def null_processor( - logger: logging.Logger, # noqa: ARG001 - method_name: str, # noqa: ARG001 - event_dict: EventDict, -) -> EventDict: - """ - A processor that does nothing. - """ - return event_dict - - -def create_filter_processor(filters: list[LogFilter] | None) -> Processor: - """ - Creates a structlog processor that filters out log lines based on field matches. - - Args: - filters (dict[str, Any] | None): Dictionary of field names to values to filter out. - If a log line has a field with a matching value, it will be filtered out. - - Returns: - A structlog processor function that filters log lines. - """ - if not filters: - return null_processor - - def filter_processor( - logger: logging.Logger, # noqa: ARG001 - method_name: str, # noqa: ARG001 - event_dict: EventDict, - ) -> EventDict: - """ - Filters out log lines where any field matches the filter values. - Returns None to drop the log line, or the event_dict to keep it. - """ - for filter_ in filters: - if filter_.type == "equals": - if ( - filter_.key in event_dict - and event_dict[filter_.key] == filter_.value - ): - raise DropEvent - return event_dict - - return filter_processor diff --git a/src/askui/chat/api/telemetry/logs/utils.py b/src/askui/chat/api/telemetry/logs/utils.py deleted file mode 100644 index c978902a..00000000 --- a/src/askui/chat/api/telemetry/logs/utils.py +++ /dev/null @@ -1,16 +0,0 @@ -from collections.abc import MutableMapping -from typing import Any - - -def flatten_dict( - d: MutableMapping[Any, Any], parent_key: str = "", sep: str = "." -) -> MutableMapping[str, Any]: - result: list[tuple[str, Any]] = [] - for k, v in d.items(): - k = str(k) - new_key = parent_key + sep + k if parent_key else k - if isinstance(v, MutableMapping): - result.extend(flatten_dict(v, new_key, sep=sep).items()) - else: - result.append((new_key, v)) - return dict(result) diff --git a/src/askui/chat/api/threads/__init__.py b/src/askui/chat/api/threads/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/threads/dependencies.py b/src/askui/chat/api/threads/dependencies.py deleted file mode 100644 index 9eef2483..00000000 --- a/src/askui/chat/api/threads/dependencies.py +++ /dev/null @@ -1,30 +0,0 @@ -from fastapi import Depends - -from askui.chat.api.db.session import SessionDep -from askui.chat.api.runs.dependencies import RunServiceDep -from askui.chat.api.runs.service import RunService -from askui.chat.api.threads.facade import ThreadFacade -from askui.chat.api.threads.service import ThreadService - - -def get_thread_service( - session: SessionDep, -) -> ThreadService: - """Get ThreadService instance.""" - return ThreadService(session=session) - - -ThreadServiceDep = Depends(get_thread_service) - - -def get_thread_facade( - thread_service: ThreadService = ThreadServiceDep, - run_service: RunService = RunServiceDep, -) -> ThreadFacade: - return ThreadFacade( - thread_service=thread_service, - run_service=run_service, - ) - - -ThreadFacadeDep = Depends(get_thread_facade) diff --git a/src/askui/chat/api/threads/facade.py b/src/askui/chat/api/threads/facade.py deleted file mode 100644 index 4cb6f705..00000000 --- a/src/askui/chat/api/threads/facade.py +++ /dev/null @@ -1,32 +0,0 @@ -from collections.abc import AsyncGenerator - -from askui.chat.api.models import WorkspaceId -from askui.chat.api.runs.events.events import Event -from askui.chat.api.runs.models import Run, ThreadAndRunCreate -from askui.chat.api.runs.service import RunService -from askui.chat.api.threads.service import ThreadService - - -class ThreadFacade: - """ - Facade service that coordinates operations across threads, messages, and runs. - """ - - def __init__( - self, - thread_service: ThreadService, - run_service: RunService, - ) -> None: - self._thread_service = thread_service - self._run_service = run_service - - async def create_thread_and_run( - self, workspace_id: WorkspaceId, params: ThreadAndRunCreate - ) -> tuple[Run, AsyncGenerator[Event, None]]: - """Create a thread and a run, ensuring the thread exists first.""" - thread = self._thread_service.create(workspace_id, params.thread) - return await self._run_service.create( - workspace_id=workspace_id, - thread_id=thread.id, - params=params, - ) diff --git a/src/askui/chat/api/threads/models.py b/src/askui/chat/api/threads/models.py deleted file mode 100644 index fea45552..00000000 --- a/src/askui/chat/api/threads/models.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel - -from askui.chat.api.messages.models import MessageCreate -from askui.chat.api.models import ThreadId, WorkspaceId, WorkspaceResource -from askui.utils.datetime_utils import UnixDatetime, now -from askui.utils.id_utils import generate_time_ordered_id -from askui.utils.not_given import NOT_GIVEN, BaseModelWithNotGiven, NotGiven - - -class ThreadBase(BaseModel): - """Base thread model.""" - - name: str | None = None - - -class ThreadCreate(ThreadBase): - """Parameters for creating a thread.""" - - messages: list[MessageCreate] | None = None - - -class ThreadModify(BaseModelWithNotGiven): - """Parameters for modifying a thread.""" - - name: str | None | NotGiven = NOT_GIVEN - - -class Thread(ThreadBase, WorkspaceResource): - """A chat thread/session.""" - - id: ThreadId - object: Literal["thread"] = "thread" - created_at: UnixDatetime - - @classmethod - def create(cls, workspace_id: WorkspaceId, params: ThreadCreate) -> "Thread": - return cls( - id=generate_time_ordered_id("thread"), - created_at=now(), - workspace_id=workspace_id, - **params.model_dump(exclude={"messages"}), - ) diff --git a/src/askui/chat/api/threads/orms.py b/src/askui/chat/api/threads/orms.py deleted file mode 100644 index 9e2b3cfc..00000000 --- a/src/askui/chat/api/threads/orms.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Thread database model.""" - -from datetime import datetime -from uuid import UUID - -from sqlalchemy import String, Uuid -from sqlalchemy.orm import Mapped, mapped_column - -from askui.chat.api.db.orm.base import Base -from askui.chat.api.db.orm.types import UnixDatetime, create_prefixed_id_type -from askui.chat.api.threads.models import Thread - -ThreadId = create_prefixed_id_type("thread") - - -class ThreadOrm(Base): - """Thread database model.""" - - __tablename__ = "threads" - - id: Mapped[str] = mapped_column(ThreadId, primary_key=True) - workspace_id: Mapped[UUID] = mapped_column(Uuid, nullable=False, index=True) - created_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) - name: Mapped[str | None] = mapped_column(String, nullable=True) - - @classmethod - def from_model(cls, model: Thread) -> "ThreadOrm": - return cls(**model.model_dump(exclude={"object"})) - - def to_model(self) -> Thread: - return Thread.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/threads/router.py b/src/askui/chat/api/threads/router.py deleted file mode 100644 index 5ea65319..00000000 --- a/src/askui/chat/api/threads/router.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import Annotated - -from fastapi import APIRouter, Header, status - -from askui.chat.api.dependencies import ListQueryDep -from askui.chat.api.models import ThreadId, WorkspaceId -from askui.chat.api.threads.dependencies import ThreadServiceDep -from askui.chat.api.threads.models import Thread, ThreadCreate, ThreadModify -from askui.chat.api.threads.service import ThreadService -from askui.utils.api_utils import ListQuery, ListResponse - -router = APIRouter(prefix="/threads", tags=["threads"]) - - -@router.get("") -def list_threads( - askui_workspace: Annotated[WorkspaceId, Header()], - query: ListQuery = ListQueryDep, - thread_service: ThreadService = ThreadServiceDep, -) -> ListResponse[Thread]: - return thread_service.list_(workspace_id=askui_workspace, query=query) - - -@router.post("", status_code=status.HTTP_201_CREATED) -def create_thread( - askui_workspace: Annotated[WorkspaceId, Header()], - params: ThreadCreate, - thread_service: ThreadService = ThreadServiceDep, -) -> Thread: - return thread_service.create(workspace_id=askui_workspace, params=params) - - -@router.get("/{thread_id}") -def retrieve_thread( - askui_workspace: Annotated[WorkspaceId, Header()], - thread_id: ThreadId, - thread_service: ThreadService = ThreadServiceDep, -) -> Thread: - return thread_service.retrieve(workspace_id=askui_workspace, thread_id=thread_id) - - -@router.post("/{thread_id}") -def modify_thread( - askui_workspace: Annotated[WorkspaceId, Header()], - thread_id: ThreadId, - params: ThreadModify, - thread_service: ThreadService = ThreadServiceDep, -) -> Thread: - return thread_service.modify( - workspace_id=askui_workspace, thread_id=thread_id, params=params - ) - - -@router.delete("/{thread_id}", status_code=status.HTTP_204_NO_CONTENT) -def delete_thread( - askui_workspace: Annotated[WorkspaceId, Header()], - thread_id: ThreadId, - thread_service: ThreadService = ThreadServiceDep, -) -> None: - thread_service.delete(workspace_id=askui_workspace, thread_id=thread_id) diff --git a/src/askui/chat/api/threads/service.py b/src/askui/chat/api/threads/service.py deleted file mode 100644 index 413bea1f..00000000 --- a/src/askui/chat/api/threads/service.py +++ /dev/null @@ -1,75 +0,0 @@ -from sqlalchemy.orm import Session - -from askui.chat.api.db.queries import list_all -from askui.chat.api.models import ThreadId, WorkspaceId -from askui.chat.api.threads.models import Thread, ThreadCreate, ThreadModify -from askui.chat.api.threads.orms import ThreadOrm -from askui.utils.api_utils import ListQuery, ListResponse, NotFoundError - - -class ThreadService: - """Service for managing Thread resources with database persistence.""" - - def __init__(self, session: Session) -> None: - self._session = session - - def _find_by_id(self, workspace_id: WorkspaceId, thread_id: ThreadId) -> ThreadOrm: - """Find thread by ID.""" - thread_orm: ThreadOrm | None = ( - self._session.query(ThreadOrm) - .filter( - ThreadOrm.id == thread_id, - ThreadOrm.workspace_id == workspace_id, - ) - .first() - ) - if thread_orm is None: - error_msg = f"Thread {thread_id} not found" - raise NotFoundError(error_msg) - return thread_orm - - def list_( - self, workspace_id: WorkspaceId, query: ListQuery - ) -> ListResponse[Thread]: - """List threads with pagination and filtering.""" - q = self._session.query(ThreadOrm).filter( - ThreadOrm.workspace_id == workspace_id - ) - orms: list[ThreadOrm] - orms, has_more = list_all(q, query, ThreadOrm.id) - data = [orm.to_model() for orm in orms] - return ListResponse( - data=data, - has_more=has_more, - first_id=data[0].id if data else None, - last_id=data[-1].id if data else None, - ) - - def retrieve(self, workspace_id: WorkspaceId, thread_id: ThreadId) -> Thread: - """Retrieve thread by ID.""" - thread_orm = self._find_by_id(workspace_id, thread_id) - return thread_orm.to_model() - - def create(self, workspace_id: WorkspaceId, params: ThreadCreate) -> Thread: - """Create a new thread.""" - thread = Thread.create(workspace_id, params) - thread_orm = ThreadOrm.from_model(thread) - self._session.add(thread_orm) - self._session.commit() - return thread - - def modify( - self, workspace_id: WorkspaceId, thread_id: ThreadId, params: ThreadModify - ) -> Thread: - """Modify an existing thread.""" - thread_orm = self._find_by_id(workspace_id, thread_id) - thread_orm.update(params.model_dump()) - self._session.commit() - self._session.refresh(thread_orm) - return thread_orm.to_model() - - def delete(self, workspace_id: WorkspaceId, thread_id: ThreadId) -> None: - """Delete a thread and cascade to messages and runs.""" - thread_orm = self._find_by_id(workspace_id, thread_id) - self._session.delete(thread_orm) - self._session.commit() diff --git a/src/askui/chat/api/utils.py b/src/askui/chat/api/utils.py deleted file mode 100644 index caf1586a..00000000 --- a/src/askui/chat/api/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Callable, Type - -from askui.chat.api.models import WorkspaceId, WorkspaceResourceT - - -def build_workspace_filter_fn( - workspace: WorkspaceId | None, - resource_type: Type[WorkspaceResourceT], # noqa: ARG001 -) -> Callable[[WorkspaceResourceT], bool]: - def filter_fn(resource: WorkspaceResourceT) -> bool: - return resource.workspace_id is None or resource.workspace_id == workspace - - return filter_fn diff --git a/src/askui/chat/api/workflows/__init__.py b/src/askui/chat/api/workflows/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/workflows/dependencies.py b/src/askui/chat/api/workflows/dependencies.py deleted file mode 100644 index 79dbdb66..00000000 --- a/src/askui/chat/api/workflows/dependencies.py +++ /dev/null @@ -1,13 +0,0 @@ -from fastapi import Depends - -from askui.chat.api.dependencies import SettingsDep -from askui.chat.api.settings import Settings -from askui.chat.api.workflows.service import WorkflowService - - -def get_workflow_service(settings: Settings = SettingsDep) -> WorkflowService: - """Get WorkflowService instance.""" - return WorkflowService(settings.data_dir) - - -WorkflowServiceDep = Depends(get_workflow_service) diff --git a/src/askui/chat/api/workflows/models.py b/src/askui/chat/api/workflows/models.py deleted file mode 100644 index daf984bb..00000000 --- a/src/askui/chat/api/workflows/models.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Annotated, Literal - -from pydantic import BaseModel, Field - -from askui.chat.api.models import WorkspaceId, WorkspaceResource -from askui.utils.datetime_utils import UnixDatetime, now -from askui.utils.id_utils import IdField, generate_time_ordered_id - -WorkflowId = Annotated[str, IdField("wf")] - - -class WorkflowCreateParams(BaseModel): - """ - Parameters for creating a workflow via API. - """ - - name: str - description: str - tags: list[str] = Field(default_factory=list) - - -class WorkflowModifyParams(BaseModel): - """ - Parameters for modifying a workflow via API. - """ - - name: str | None = None - description: str | None = None - tags: list[str] | None = None - - -class Workflow(WorkspaceResource): - """ - A workflow resource in the chat API. - - Args: - id (WorkflowId): The id of the workflow. Must start with the 'wf_' prefix and be - followed by one or more alphanumerical characters. - object (Literal['workflow']): The object type, always 'workflow'. - created_at (UnixDatetime): The creation time as a Unix timestamp. - name (str): The name or title of the workflow. - description (str): A detailed description of the workflow's purpose and steps. - tags (list[str], optional): Tags associated with the workflow for filtering or - categorization. Default is an empty list. - workspace_id (WorkspaceId | None, optional): The workspace this workflow belongs to. - """ - - id: WorkflowId - object: Literal["workflow"] = "workflow" - created_at: UnixDatetime - name: str - description: str - tags: list[str] = Field(default_factory=list) - - @classmethod - def create( - cls, workspace_id: WorkspaceId | None, params: WorkflowCreateParams - ) -> "Workflow": - return cls( - id=generate_time_ordered_id("wf"), - created_at=now(), - workspace_id=workspace_id, - **params.model_dump(), - ) - - def modify(self, params: WorkflowModifyParams) -> "Workflow": - update_data = {k: v for k, v in params.model_dump().items() if v is not None} - return Workflow.model_validate( - { - **self.model_dump(), - **update_data, - } - ) diff --git a/src/askui/chat/api/workflows/router.py b/src/askui/chat/api/workflows/router.py deleted file mode 100644 index c0546db7..00000000 --- a/src/askui/chat/api/workflows/router.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import Annotated - -from fastapi import APIRouter, Header, Path, Query, status - -from askui.chat.api.dependencies import ListQueryDep -from askui.chat.api.models import WorkspaceId -from askui.chat.api.workflows.dependencies import WorkflowServiceDep -from askui.chat.api.workflows.models import ( - Workflow, - WorkflowCreateParams, - WorkflowId, - WorkflowModifyParams, -) -from askui.chat.api.workflows.service import WorkflowService -from askui.utils.api_utils import ListQuery, ListResponse - -router = APIRouter(prefix="/workflows", tags=["workflows"]) - - -@router.get("") -def list_workflows( - askui_workspace: Annotated[WorkspaceId | None, Header()], - tags: Annotated[list[str] | None, Query()] = None, - query: ListQuery = ListQueryDep, - workflow_service: WorkflowService = WorkflowServiceDep, -) -> ListResponse[Workflow]: - """ - List workflows with optional tag filtering. - - Args: - askui_workspace: The workspace ID from header - tags: Optional list of tags to filter by - query: Standard list query parameters (limit, after, before, order) - workflow_service: Injected workflow service - - Returns: - ListResponse containing workflows matching the criteria - """ - return workflow_service.list_(workspace_id=askui_workspace, query=query, tags=tags) - - -@router.post("", status_code=status.HTTP_201_CREATED) -def create_workflow( - askui_workspace: Annotated[WorkspaceId | None, Header()], - params: WorkflowCreateParams, - workflow_service: WorkflowService = WorkflowServiceDep, -) -> Workflow: - """ - Create a new workflow. - - Args: - askui_workspace: The workspace ID from header - params: Workflow creation parameters (name, description, tags) - workflow_service: Injected workflow service - - Returns: - The created workflow - """ - return workflow_service.create(workspace_id=askui_workspace, params=params) - - -@router.get("/{workflow_id}") -def retrieve_workflow( - askui_workspace: Annotated[WorkspaceId | None, Header()], - workflow_id: Annotated[WorkflowId, Path(...)], - workflow_service: WorkflowService = WorkflowServiceDep, -) -> Workflow: - """ - Retrieve a specific workflow by ID. - - Args: - askui_workspace: The workspace ID from header - workflow_id: The workflow ID to retrieve - workflow_service: Injected workflow service - - Returns: - The requested workflow - - Raises: - NotFoundError: If workflow doesn't exist or user doesn't have access - """ - return workflow_service.retrieve( - workspace_id=askui_workspace, workflow_id=workflow_id - ) - - -@router.patch("/{workflow_id}") -def modify_workflow( - askui_workspace: Annotated[WorkspaceId | None, Header()], - workflow_id: Annotated[WorkflowId, Path(...)], - params: WorkflowModifyParams, - workflow_service: WorkflowService = WorkflowServiceDep, -) -> Workflow: - """ - Modify an existing workflow. - - Args: - askui_workspace: The workspace ID from header - workflow_id: The workflow ID to modify - params: Workflow modification parameters (name, description, tags) - workflow_service: Injected workflow service - - Returns: - The modified workflow - - Raises: - NotFoundError: If workflow doesn't exist or user doesn't have access - """ - return workflow_service.modify( - workspace_id=askui_workspace, workflow_id=workflow_id, params=params - ) diff --git a/src/askui/chat/api/workflows/service.py b/src/askui/chat/api/workflows/service.py deleted file mode 100644 index 31e6862d..00000000 --- a/src/askui/chat/api/workflows/service.py +++ /dev/null @@ -1,109 +0,0 @@ -from pathlib import Path -from typing import Callable - -from askui.chat.api.models import WorkspaceId -from askui.chat.api.utils import build_workspace_filter_fn -from askui.chat.api.workflows.models import ( - Workflow, - WorkflowCreateParams, - WorkflowId, - WorkflowModifyParams, -) -from askui.utils.api_utils import ( - ConflictError, - ListQuery, - ListResponse, - NotFoundError, - list_resources, -) - - -def _build_workflow_filter_fn( - workspace_id: WorkspaceId | None, - tags: list[str] | None = None, -) -> Callable[[Workflow], bool]: - workspace_filter: Callable[[Workflow], bool] = build_workspace_filter_fn( - workspace_id, Workflow - ) - - def filter_fn(workflow: Workflow) -> bool: - if not workspace_filter(workflow): - return False - if tags is not None: - return any(tag in workflow.tags for tag in tags) - return True - - return filter_fn - - -class WorkflowService: - def __init__(self, base_dir: Path) -> None: - self._base_dir = base_dir - self._workflows_dir = base_dir / "workflows" - - def _get_workflow_path(self, workflow_id: WorkflowId, new: bool = False) -> Path: - workflow_path = self._workflows_dir / f"{workflow_id}.json" - exists = workflow_path.exists() - if new and exists: - error_msg = f"Workflow {workflow_id} already exists" - raise ConflictError(error_msg) - if not new and not exists: - error_msg = f"Workflow {workflow_id} not found" - raise NotFoundError(error_msg) - return workflow_path - - def list_( - self, - workspace_id: WorkspaceId | None, - query: ListQuery, - tags: list[str] | None = None, - ) -> ListResponse[Workflow]: - return list_resources( - base_dir=self._workflows_dir, - query=query, - resource_type=Workflow, - filter_fn=_build_workflow_filter_fn(workspace_id, tags=tags), - ) - - def retrieve( - self, workspace_id: WorkspaceId | None, workflow_id: WorkflowId - ) -> Workflow: - try: - workflow_path = self._get_workflow_path(workflow_id) - workflow = Workflow.model_validate_json( - workflow_path.read_text(encoding="utf-8") - ) - - # Check workspace access - if workspace_id is not None and workflow.workspace_id != workspace_id: - error_msg = f"Workflow {workflow_id} not found" - raise NotFoundError(error_msg) - - except FileNotFoundError as e: - error_msg = f"Workflow {workflow_id} not found" - raise NotFoundError(error_msg) from e - else: - return workflow - - def create( - self, workspace_id: WorkspaceId | None, params: WorkflowCreateParams - ) -> Workflow: - workflow = Workflow.create(workspace_id, params) - self._save(workflow, new=True) - return workflow - - def modify( - self, - workspace_id: WorkspaceId | None, - workflow_id: WorkflowId, - params: WorkflowModifyParams, - ) -> Workflow: - workflow = self.retrieve(workspace_id, workflow_id) - modified = workflow.modify(params) - self._save(modified) - return modified - - def _save(self, workflow: Workflow, new: bool = False) -> None: - self._workflows_dir.mkdir(parents=True, exist_ok=True) - workflow_file = self._get_workflow_path(workflow.id, new=new) - workflow_file.write_text(workflow.model_dump_json(), encoding="utf-8") diff --git a/src/askui/chat/migrations/__init__.py b/src/askui/chat/migrations/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/migrations/alembic.ini b/src/askui/chat/migrations/alembic.ini deleted file mode 100644 index a035c7ac..00000000 --- a/src/askui/chat/migrations/alembic.ini +++ /dev/null @@ -1,114 +0,0 @@ -# A generic, single database configuration. - -[alembic] -# path to migration scripts -script_location = src/askui/chat/migrations - -# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s -# Uncomment the line below if you want the files to be prepended with date and time -# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s - -# sys.path path, will be prepended to sys.path if present. -# defaults to the current working directory. -prepend_sys_path = . - -# timezone to use when rendering the date within the migration file -# as well as the filename. -# If specified, requires the python-dateutil library that can be -# installed by adding `alembic[tz]` to the pip requirements -# string value is passed to dateutil.tz.gettz() -# leave blank for localtime -# timezone = - -# max length of characters to apply to the -# "slug" field -# truncate_slug_length = 40 - -# set to 'true' to run the environment during -# the 'revision' command, regardless of autogenerate -# revision_environment = false - -# set to 'true' to allow .pyc and .pyo files without -# a source .py file to be detected as revisions in the -# versions/ directory -# sourceless = false - -# version number format -version_num_format = %%04d - -# version path separator; As mentioned above, this is the character used to split -# version_locations. The default within new alembic.ini files is "os", which uses -# os.pathsep. If this key is omitted entirely, it falls back to the legacy -# behavior of splitting on spaces and/or commas. -# Valid values for version_path_separator are: -# -# version_path_separator = : -# version_path_separator = ; -# version_path_separator = space -version_path_separator = os - -# set to 'true' to search source files recursively -# in each "version_locations" directory -# new in Alembic version 1.10 -# recursive_version_locations = false - -# the output encoding used when revision files -# are written from script.py.mako -# output_encoding = utf-8 - -# overridden inside env.py -sqlalchemy.url = driver://user:pass@localhost/dbname - - -[post_write_hooks] -# post_write_hooks defines scripts or Python functions that are run -# on newly generated revision scripts. See the documentation for further -# detail and examples - -# format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks = black -# black.type = console_scripts -# black.entrypoint = black -# black.options = -l 79 REVISION_SCRIPT_FILENAME - -# lint with attempts to fix using "ruff" - use the exec runner, execute a binary -# hooks = ruff -# ruff.type = exec -# ruff.executable = %(here)s/.venv/bin/ruff -# ruff.options = --fix REVISION_SCRIPT_FILENAME - -# Logging configuration -# Overridden by env.py -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = INFO -handlers = -qualname = - -[logger_sqlalchemy] -level = INFO -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/src/askui/chat/migrations/env.py b/src/askui/chat/migrations/env.py deleted file mode 100644 index 197b0513..00000000 --- a/src/askui/chat/migrations/env.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Alembic environment configuration.""" - -import logging - -from alembic import context - -# We need to import the orms to ensure they are registered -import askui.chat.api.assistants.orms -import askui.chat.api.files.orms -import askui.chat.api.mcp_configs.orms -import askui.chat.api.messages.orms -import askui.chat.api.runs.orms -import askui.chat.api.threads.orms -from askui.chat.api.db.orm.base import Base -from askui.chat.api.dependencies import get_settings -from askui.chat.api.telemetry.logs import setup_logging - -config = context.config -settings = get_settings() -setup_logging(settings.telemetry.log) -sqlalchemy_logger = logging.getLogger("sqlalchemy.engine") -alembic_logger = logging.getLogger("alembic") -sqlalchemy_logger.setLevel(settings.telemetry.log.level) -alembic_logger.setLevel(settings.telemetry.log.level) -target_metadata = Base.metadata - - -def get_url() -> str: - """Get database URL from settings.""" - return settings.db.url - - -def run_migrations_offline() -> None: - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - - """ - url = get_url() - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, - ) - - with context.begin_transaction(): - context.run_migrations() - - -def run_migrations_online() -> None: - """Run migrations in 'online' mode. - - In this scenario we need to create an Engine - and associate a connection with the context. - - """ - from askui.chat.api.db.engine import engine - - with engine.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) - - with context.begin_transaction(): - context.run_migrations() - - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() diff --git a/src/askui/chat/migrations/runner.py b/src/askui/chat/migrations/runner.py deleted file mode 100644 index 014a9f74..00000000 --- a/src/askui/chat/migrations/runner.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Migration runner for Alembic.""" - -import logging -from pathlib import Path - -from alembic import command -from alembic.config import Config - -logger = logging.getLogger(__name__) - - -def run_migrations() -> None: - """Run Alembic migrations to upgrade database to head.""" - migrations_dir = Path(__file__).parent - alembic_cfg = Config(str(migrations_dir / "alembic.ini")) - alembic_cfg.set_main_option("script_location", str(migrations_dir)) - logger.info("Running database migrations...") - try: - command.upgrade(alembic_cfg, "head") - logger.info("Database migrations completed successfully") - except Exception: - logger.exception("Failed to run database migrations") - raise diff --git a/src/askui/chat/migrations/script.py.mako b/src/askui/chat/migrations/script.py.mako deleted file mode 100644 index fbc4b07d..00000000 --- a/src/askui/chat/migrations/script.py.mako +++ /dev/null @@ -1,26 +0,0 @@ -"""${message} - -Revision ID: ${up_revision} -Revises: ${down_revision | comma,n} -Create Date: ${create_date} - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -${imports if imports else ""} - -# revision identifiers, used by Alembic. -revision: str = ${repr(up_revision)} -down_revision: Union[str, None] = ${repr(down_revision)} -branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} -depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} - - -def upgrade() -> None: - ${upgrades if upgrades else "pass"} - - -def downgrade() -> None: - ${downgrades if downgrades else "pass"} diff --git a/src/askui/chat/migrations/shared/__init__.py b/src/askui/chat/migrations/shared/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/migrations/shared/assistants/__init__.py b/src/askui/chat/migrations/shared/assistants/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/migrations/shared/assistants/models.py b/src/askui/chat/migrations/shared/assistants/models.py deleted file mode 100644 index 693e1d1d..00000000 --- a/src/askui/chat/migrations/shared/assistants/models.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Annotated, Any, Literal - -from pydantic import BaseModel, BeforeValidator, Field - -from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 -from askui.chat.migrations.shared.utils import build_prefixer - -AssistantIdV1 = Annotated[ - str, Field(pattern=r"^asst_[a-z0-9]+$"), BeforeValidator(build_prefixer("asst")) -] - - -class AssistantV1(BaseModel): - id: AssistantIdV1 - object: Literal["assistant"] = "assistant" - created_at: UnixDatetimeV1 - workspace_id: WorkspaceIdV1 | None = None - name: str | None = None - description: str | None = None - avatar: str | None = None - tools: list[str] = Field(default_factory=list) - system: str | None = None - - def to_db_dict(self) -> dict[str, Any]: - return { - **self.model_dump(exclude={"id", "object"}), - "id": self.id.removeprefix("asst_"), - "workspace_id": self.workspace_id.hex if self.workspace_id else None, - } diff --git a/src/askui/chat/migrations/shared/assistants/seeds.py b/src/askui/chat/migrations/shared/assistants/seeds.py deleted file mode 100644 index 4ccf194b..00000000 --- a/src/askui/chat/migrations/shared/assistants/seeds.py +++ /dev/null @@ -1,400 +0,0 @@ -import platform -import sys - -from askui.chat.migrations.shared.assistants.models import AssistantV1 -from askui.chat.migrations.shared.utils import now_v1 - -COMPUTER_AGENT_V1 = AssistantV1( - id="asst_68ac2c4edc4b2f27faa5a253", - created_at=now_v1(), - name="Computer Agent", - avatar="", - system=( - f""" -* You are utilising a {sys.platform} machine using {platform.machine()} architecture with internet access. -* When you cannot find something (application window, ui element etc.) on the currently selected/active displa/screen, check the other available displays by listing them and checking which one is currently active and then going through the other displays one by one until you find it or you have checked all of them. -* When asked to perform web tasks try to open the browser (firefox, chrome, safari, ...) if not already open. Often you can find the browser icons in the toolbars of the operating systems. -* When viewing a page it can be helpful to zoom out/in so that you can see everything on the page. Either that, or make sure you scroll down/up to see everything before deciding something isn't available. -* When using your function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. - - - -* When using Firefox, if a startup wizard appears, IGNORE IT. Do not even click "skip this step". Instead, click on the address bar where it says "Search or enter address", and enter the appropriate search term or URL there. -* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool. -""" - ), - tools=[ - "computer_disconnect", - "computer_connect", - "computer_mouse_click", - "computer_get_mouse_position", - "computer_keyboard_pressed", - "computer_keyboard_release", - "computer_keyboard_tap", - "computer_list_displays", - "computer_mouse_hold_down", - "computer_mouse_release", - "computer_mouse_scroll", - "computer_move_mouse", - "computer_retrieve_active_display", - "computer_screenshot", - "computer_set_active_display", - "computer_type", - ], -) - -ANDROID_AGENT_V1 = AssistantV1( - id="asst_68ac2c4edc4b2f27faa5a255", - created_at=now_v1(), - name="Android Agent", - avatar="", - system=( - """ - -You are an autonomous Android device control agent operating via ADB on a test device with full system access. -Your primary goal is to execute tasks efficiently and reliably while maintaining system stability. - - - -* Autonomy: Operate independently and make informed decisions without requiring user input. -* Never ask for other tasks to be done, only do the task you are given. -* Reliability: Ensure actions are repeatable and maintain system stability. -* Efficiency: Optimize operations to minimize latency and resource usage. -* Safety: Always verify actions before execution, even with full system access. - - - -1. Tool Usage: - * Verify tool availability before starting any operation - * Use the most direct and efficient tool for each task - * Combine tools strategically for complex operations - * Prefer built-in tools over shell commands when possible - -2. Error Handling: - * Assess failures systematically: check tool availability, permissions, and device state - * Implement retry logic with exponential backoff for transient failures - * Use fallback strategies when primary approaches fail - * Provide clear, actionable error messages with diagnostic information - -3. Performance Optimization: - * Use one-liner shell commands with inline filtering (grep, cut, awk, jq) for efficiency - * Minimize screen captures and coordinate calculations - * Cache device state information when appropriate - * Batch related operations when possible - -4. Screen Interaction: - * Ensure all coordinates are integers and within screen bounds - * Implement smart scrolling for off-screen elements - * Use appropriate gestures (tap, swipe, drag) based on context - * Verify element visibility before interaction - -5. System Access: - * Leverage full system access responsibly - * Use shell commands for system-level operations - * Monitor system state and resource usage - * Maintain system stability during operations - -6. Recovery Strategies: - * If an element is not visible, try: - - Scrolling in different directions - - Adjusting view parameters - - Using alternative interaction methods - * If a tool fails: - - Check device connection and state - - Verify tool availability and permissions - - Try alternative tools or approaches - * If stuck: - - Provide clear diagnostic information - - Suggest potential solutions - - Request user intervention only if necessary - -7. Best Practices: - * Document all significant operations - * Maintain operation logs for debugging - * Implement proper cleanup after operations - * Follow Android best practices for UI interaction - - -* This is a test device with full system access - use this capability responsibly -* Always verify the success of critical operations -* Maintain system stability as the highest priority -* Provide clear, actionable feedback for all operations -* Use the most efficient method for each task - -""" - ), - tools=[ - "android_screenshot_tool", - "android_tap_tool", - "android_type_tool", - "android_drag_and_drop_tool", - "android_key_event_tool", - "android_swipe_tool", - "android_key_combination_tool", - "android_shell_tool", - "android_connect_tool", - "android_get_connected_devices_serial_numbers_tool", - "android_get_connected_displays_infos_tool", - "android_get_current_connected_device_infos_tool", - "android_get_connected_device_display_infos_tool", - "android_select_device_by_serial_number_tool", - "android_select_display_by_unique_id_tool", - "android_setup_helper", - ], -) - -WEB_AGENT_V1 = AssistantV1( - id="asst_68ac2c4edc4b2f27faa5a256", - created_at=now_v1(), - name="Web Agent", - avatar="", - system=( - """ - -* You are utilizing a webbrowser in full-screen mode. So you are only seeing the content of the currently opened webpage (tab). -* It can be helpful to zoom in/out or scroll down/up so that you can see everything on the page. Make sure to that before deciding something isn't available. -* When using your tools, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. -* If a tool call returns with an error that a browser distribution is not found, stop, so that the user can install it and, then, continue the conversation. - -""" - ), - tools=[ - "browser_click", - "browser_close", - "browser_console_messages", - "browser_drag", - "browser_evaluate", - "browser_file_upload", - "browser_fill_form", - "browser_handle_dialog", - "browser_hover", - "browser_navigate", - "browser_navigate_back", - "browser_network_requests", - "browser_press_key", - "browser_resize", - "browser_select_option", - "browser_snapshot", - "browser_take_screenshot", - "browser_type", - "browser_wait_for", - "browser_tabs", - "browser_mouse_click_xy", - "browser_mouse_drag_xy", - "browser_mouse_move_xy", - "browser_pdf_save", - "browser_verify_element_visible", - "browser_verify_list_visible", - "browser_verify_text_visible", - "browser_verify_value", - ], -) - -TESTING_AGENT_V1 = AssistantV1( - id="asst_68ac2c4edc4b2f27faa5a257", - created_at=now_v1(), - name="Testing Agent", - avatar="", - system=( - """ -You are an advanced AI testing agent responsible for managing and executing software tests. Your primary goal is to create, refine, and execute test scenarios based on given specifications or targets. You have access to various tools and subagents to accomplish this task. - -Available tools: -1. Feature management: retrieve, list, modify, create, delete -2. Scenario management: retrieve, list, modify, create, delete -3. Execution management: retrieve, list, modify, create, delete -4. Tools for executing tests using subagents: - - create_thread_and_run_v1_runs_post: Delegate tasks to subagents - - retrieve_run_v1_threads: Check the status of a run - - list_messages_v1_threads: Retrieve messages from a thread - - utility_wait: Wait for a specified number of seconds - -Subagents: -1. Computer control agent (ID: asst_68ac2c4edc4b2f27faa5a253) -2. Web browser control agent (ID: asst_68ac2c4edc4b2f27faa5a256) - -Main process: -1. Analyze test specification -2. Create and refine features if necessary by exploring the features (exploratory testing) -3. Create and refine scenarios if necessary by exploring the scenarios (exploratory testing) -4. Execute scenarios -5. Report results -6. Handle user feedback - -Detailed instructions: - -1. Analyze the test specification: - -{TEST_SPECIFICATION} - - -Review the provided test specification carefully. Identify the key features, functionalities, or areas that need to be tested. -Instead of a test specification, the user may also provide just the testing target (feature, url, application name etc.). Make -sure that you ask the user if it is a webapp or desktop app or where to find the app in general if not clear from the specification. - -2. Create and refine features: -a. Use the feature management tools to list existing features. -b. Create new features based on user input and if necessary exploring the features in the actual application using a subagent, ensuring no duplicates. -c. Present the features to the user and wait for feedback. -d. Refine the features based on user feedback until confirmation is received. - -3. Create and refine scenarios: -a. For each confirmed feature, use the scenario management tools to list existing scenarios. -b. Create new scenarios using Gherkin syntax, ensuring no duplicates. -c. Present the scenarios to the user and wait for feedback. -d. Refine the scenarios based on user feedback until confirmation is received. - -4. Execute scenarios: -a. Determine whether to use the computer control agent or web browser control agent (prefer web browser if possible). -b. Create and run a thread with the chosen subagent with a user message that contains the commands (scenario) to be executed. Set `stream` to `false` to wait for the agent to complete. -c. Use the retrieve_run_v1_threads tool to check the status of the task and the utility_wait tool for it to complete with an exponential backoff starting with 5 seconds increasing. -d. Collect and analyze the responses from the agent using the list_messages_v1_threads tool. Usually, you only need the last message within the thread (`limit=1`) which contains a summary of the execution results. If you need more details, you can use a higher limit and potentially multiple calls to the tool. - -5. Report results: -a. Use the execution management tools to create new execution records. -b. Update the execution records with the results (passed, failed, etc.). -c. Present a summary of the execution results to the user. - -6. Handle user feedback: -a. Review user feedback on the executions. -b. Based on feedback, determine whether to restart the process, modify existing tests, or perform other actions. - -Handling user commands: -Respond appropriately to user commands, such as: - -{USER_COMMAND} - - -- Execute existing scenarios -- List all available features -- Modify specific features or scenarios -- Delete features or scenarios - -Output format (for none tool calls): -``` -[Your detailed response, including any necessary explanations, lists, or summaries] - -**Next Actions**: -[Clearly state the next actions you will take or the next inputs you require from the user] - -``` - -Important reminders: -1. Always check for existing features and scenarios before creating new ones to avoid duplicates. -2. Use Gherkin syntax when creating or modifying scenarios. -3. Prefer the web browser control agent for test execution when possible. -4. Always wait for user confirmation before proceeding to the next major step in the process. -5. Be prepared to restart the process or modify existing tests based on user feedback. -6. Use tags for organizing the features and scenarios describing what is being tested and how it is being tested. -7. Prioritize sunny cases and critical features/scenarios first if not specified otherwise by the user. - -Your final output should only include the content within the and tags. Do not include any other tags or internal thought processes in your final output. -""" - ), - tools=[ - "create_feature", - "retrieve_feature", - "list_features", - "modify_feature", - "delete_feature", - "create_scenario", - "retrieve_scenario", - "list_scenarios", - "modify_scenario", - "delete_scenario", - "create_execution", - "retrieve_execution", - "list_executions", - "modify_execution", - "delete_execution", - "create_thread_and_run_v1_runs_post", - "retrieve_run_v1_threads", - "utility_wait", - "list_messages_v1_threads", - ], -) - -ORCHESTRATOR_AGENT_V1 = AssistantV1( - id="asst_68ac2c4edc4b2f27faa5a258", - created_at=now_v1(), - name="Orchestrator", - avatar="", - system=( - """ -You are an AI agent called "Orchestrator" with the ID "asst_68ac2c4edc4b2f27faa5a258". Your primary role is to perform high-level planning and management of all tasks involved in responding to a given prompt. For simple prompts, you will respond directly. For more complex, you will delegate and route the execution of these tasks to other specialized agents. - -You have the following tools at your disposal: - -1. list_assistants_v1_assistants_get - This tool enables you to discover all available assistants (agents) for task delegation. - -2. create_thread_and_run_v1_runs_post - This tool enables you to delegate tasks to other agents by starting a conversation (thread) with initial messages containing necessary instructions, and then running (calling/executing) the agent to get a response. The "stream" parameter should always be set to "false". - -3. retrieve_run_v1_threads - This tool enables you to retrieve the details of a run by its ID and, by that, checking wether an assistant is still answering or completed its answer (`status` field). - -4. list_messages_v1_threads - This tool enables you to retrieve the messages of the assistant. Depending on the prompt, you may only need the last message within the thread (`limit=1`) or the whole thread using a higher limit and potentially multiple calls to the tool. - -5. utility_wait - This tool enables you to wait for a specified number of seconds, e.g. to wait for an agent to finish its task / complete its answer. - -Your main task is to analyze the user prompt and classify it as simple vs. complex. For simple prompts, respond directly. For complex prompts, create a comprehensive plan to address it by utilizing the available agents. - -Follow these steps to complete your task: - -1. Analyze the user prompt and identify the main components or subtasks required to provide a complete response. - -2. Use the list_assistants_v1_assistants_get tool to discover all available agents. - -3. Create a plan that outlines how you will delegate these subtasks to the most appropriate agents based on their specialties. - -4. For each subtask: - a. Prepare clear and concise instructions for the chosen agent. - b. Use the create_thread_and_run_v1_runs_post tool to delegate the task to the agent. - c. Include all necessary context and information in the initial messages. - d. Set the "stream" parameter to "true". - -5. Use the retrieve_run_v1_threads tool to check the status of the task and the utility_wait tool for it to complete with an exponential backoff starting with 5 seconds increasing. - -5. Collect and analyze the responses from each agent using the list_messages_v1_threads tool. - -6. Synthesize the information from all agents into a coherent and comprehensive response to the original user prompt. - -Present your final output should be eitehr in the format of - -[Simple answer] - -or - -[ -# Plan -[Provide a detailed plan outlining the subtasks and the agents assigned to each] - -# Report -[For each agent interaction, include: -1. The agent's ID and specialty -2. The subtask assigned -3. A summary of the instructions given -4. A brief summary of the agent's response] - -# Answer -[Synthesize all the information into a cohesive response to the original user prompt] -] -""" - ), - tools=[ - "list_assistants_v1_assistants_get", - "create_thread_and_run_v1_runs_post", - "retrieve_run_v1_threads", - "utility_wait", - "list_messages_v1_threads", - ], -) - -SEEDS_V1 = [ - COMPUTER_AGENT_V1, - ANDROID_AGENT_V1, - WEB_AGENT_V1, -] diff --git a/src/askui/chat/migrations/shared/files/__init__.py b/src/askui/chat/migrations/shared/files/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/migrations/shared/files/models.py b/src/askui/chat/migrations/shared/files/models.py deleted file mode 100644 index c57595d8..00000000 --- a/src/askui/chat/migrations/shared/files/models.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Annotated, Any, Literal - -from pydantic import BaseModel, BeforeValidator, Field - -from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 -from askui.chat.migrations.shared.utils import build_prefixer - -FileIdV1 = Annotated[ - str, Field(pattern=r"^file_[a-z0-9]+$"), BeforeValidator(build_prefixer("file")) -] - - -class FileV1(BaseModel): - id: FileIdV1 - object: Literal["file"] = "file" - created_at: UnixDatetimeV1 - filename: str = Field(min_length=1) - size: int = Field(ge=0) - media_type: str - workspace_id: WorkspaceIdV1 | None = Field(default=None, exclude=True) - - def to_db_dict(self) -> dict[str, Any]: - return { - **self.model_dump(exclude={"id", "object"}), - "id": self.id.removeprefix("file_"), - "workspace_id": self.workspace_id.hex if self.workspace_id else None, - } diff --git a/src/askui/chat/migrations/shared/mcp_configs/__init__.py b/src/askui/chat/migrations/shared/mcp_configs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/migrations/shared/mcp_configs/models.py b/src/askui/chat/migrations/shared/mcp_configs/models.py deleted file mode 100644 index 0f5be421..00000000 --- a/src/askui/chat/migrations/shared/mcp_configs/models.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Annotated, Any, Literal - -from fastmcp.mcp_config import RemoteMCPServer as _RemoteMCPServer -from fastmcp.mcp_config import StdioMCPServer -from httpx import Auth -from pydantic import BaseModel, BeforeValidator, Field - -from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 -from askui.chat.migrations.shared.utils import build_prefixer - -McpConfigIdV1 = Annotated[ - str, Field(pattern=r"^mcpcnf_[a-z0-9]+$"), BeforeValidator(build_prefixer("mcpcnf")) -] - - -class RemoteMCPServerV1(_RemoteMCPServer): - auth: Annotated[ - str | Literal["oauth"] | Auth | None, # noqa: PYI051 - Field( - description='Either a string representing a Bearer token or the literal "oauth" to use OAuth authentication.', - ), - ] = None - - -McpServerV1 = StdioMCPServer | RemoteMCPServerV1 - - -class McpConfigV1(BaseModel): - id: McpConfigIdV1 - object: Literal["mcp_config"] = "mcp_config" - created_at: UnixDatetimeV1 - workspace_id: WorkspaceIdV1 | None = None - name: str - mcp_server: McpServerV1 - - def to_db_dict(self) -> dict[str, Any]: - return { - **self.model_dump(exclude={"id", "object"}), - "id": self.id.removeprefix("mcpcnf_"), - "workspace_id": self.workspace_id.hex if self.workspace_id else None, - } diff --git a/src/askui/chat/migrations/shared/messages/__init__.py b/src/askui/chat/migrations/shared/messages/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/migrations/shared/messages/models.py b/src/askui/chat/migrations/shared/messages/models.py deleted file mode 100644 index 8a900bd6..00000000 --- a/src/askui/chat/migrations/shared/messages/models.py +++ /dev/null @@ -1,166 +0,0 @@ -from typing import Annotated, Any, Literal - -from pydantic import BaseModel, BeforeValidator, Field - -from askui.chat.migrations.shared.assistants.models import AssistantIdV1 -from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 -from askui.chat.migrations.shared.runs.models import RunIdV1 -from askui.chat.migrations.shared.threads.models import ThreadIdV1 -from askui.chat.migrations.shared.utils import build_prefixer - -MessageIdV1 = Annotated[ - str, Field(pattern=r"^msg_[a-z0-9]+$"), BeforeValidator(build_prefixer("msg")) -] - - -class CacheControlEphemeralParamV1(BaseModel): - type: Literal["ephemeral"] = "ephemeral" - - -class CitationCharLocationParamV1(BaseModel): - cited_text: str - document_index: int - document_title: str | None = None - end_char_index: int - start_char_index: int - type: Literal["char_location"] = "char_location" - - -class CitationPageLocationParamV1(BaseModel): - cited_text: str - document_index: int - document_title: str | None = None - end_page_number: int - start_page_number: int - type: Literal["page_location"] = "page_location" - - -class CitationContentBlockLocationParamV1(BaseModel): - cited_text: str - document_index: int - document_title: str | None = None - end_block_index: int - start_block_index: int - type: Literal["content_block_location"] = "content_block_location" - - -TextCitationParamV1 = ( - CitationCharLocationParamV1 - | CitationPageLocationParamV1 - | CitationContentBlockLocationParamV1 -) - - -class UrlImageSourceParamV1(BaseModel): - type: Literal["url"] = "url" - url: str - - -class Base64ImageSourceParamV1(BaseModel): - data: str - media_type: Literal["image/jpeg", "image/png", "image/gif", "image/webp"] - type: Literal["base64"] = "base64" - - -class FileImageSourceParamV1(BaseModel): - """Image source that references a saved file.""" - - id: str # FileId equivalent - type: Literal["file"] = "file" - - -class ImageBlockParamV1(BaseModel): - source: Base64ImageSourceParamV1 | UrlImageSourceParamV1 | FileImageSourceParamV1 - type: Literal["image"] = "image" - cache_control: CacheControlEphemeralParamV1 | None = None - - -class TextBlockParamV1(BaseModel): - text: str - type: Literal["text"] = "text" - cache_control: CacheControlEphemeralParamV1 | None = None - citations: list[TextCitationParamV1] | None = None - - -class ToolResultBlockParamV1(BaseModel): - tool_use_id: str - type: Literal["tool_result"] = "tool_result" - cache_control: CacheControlEphemeralParamV1 | None = None - content: str | list[TextBlockParamV1 | ImageBlockParamV1] - is_error: bool = False - - -class ToolUseBlockParamV1(BaseModel): - id: str - input: object - name: str - type: Literal["tool_use"] = "tool_use" - cache_control: CacheControlEphemeralParamV1 | None = None - - -class BetaThinkingBlockV1(BaseModel): - signature: str - thinking: str - type: Literal["thinking"] - - -class BetaRedactedThinkingBlockV1(BaseModel): - data: str - type: Literal["redacted_thinking"] - - -class BetaFileDocumentSourceParamV1(BaseModel): - file_id: str - type: Literal["file"] = "file" - - -SourceV1 = BetaFileDocumentSourceParamV1 - - -class RequestDocumentBlockParamV1(BaseModel): - source: SourceV1 - type: Literal["document"] = "document" - cache_control: CacheControlEphemeralParamV1 | None = None - - -ContentBlockParamV1 = ( - ImageBlockParamV1 - | TextBlockParamV1 - | ToolResultBlockParamV1 - | ToolUseBlockParamV1 - | BetaThinkingBlockV1 - | BetaRedactedThinkingBlockV1 - | RequestDocumentBlockParamV1 -) - - -StopReasonV1 = Literal[ - "end_turn", "max_tokens", "stop_sequence", "tool_use", "pause_turn", "refusal" -] - - -class MessageV1(BaseModel): - id: MessageIdV1 - object: Literal["thread.message"] = "thread.message" - created_at: UnixDatetimeV1 - thread_id: ThreadIdV1 - role: Literal["user", "assistant"] - content: str | list[ContentBlockParamV1] - stop_reason: StopReasonV1 | None = None - assistant_id: AssistantIdV1 | None = None - run_id: RunIdV1 | None = None - workspace_id: WorkspaceIdV1 = Field(exclude=True) - - def to_db_dict(self) -> dict[str, Any]: - return { - **self.model_dump( - exclude={"id", "thread_id", "assistant_id", "run_id", "object"} - ), - "id": self.id.removeprefix("msg_"), - "thread_id": self.thread_id.removeprefix("thread_"), - "assistant_id": self.assistant_id.removeprefix("asst_") - if self.assistant_id - else None, - "run_id": self.run_id.removeprefix("run_") if self.run_id else None, - "workspace_id": self.workspace_id.hex, - } diff --git a/src/askui/chat/migrations/shared/models.py b/src/askui/chat/migrations/shared/models.py deleted file mode 100644 index 750d9d87..00000000 --- a/src/askui/chat/migrations/shared/models.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Annotated -from uuid import UUID - -from pydantic import AwareDatetime, PlainSerializer - -UnixDatetimeV1 = Annotated[ - AwareDatetime, - PlainSerializer( - lambda v: int(v.timestamp()), - return_type=int, - ), -] -WorkspaceIdV1 = UUID diff --git a/src/askui/chat/migrations/shared/runs/__init__.py b/src/askui/chat/migrations/shared/runs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/migrations/shared/runs/models.py b/src/askui/chat/migrations/shared/runs/models.py deleted file mode 100644 index 201f832f..00000000 --- a/src/askui/chat/migrations/shared/runs/models.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import Annotated, Any, Literal - -from pydantic import BaseModel, BeforeValidator, Field, computed_field - -from askui.chat.migrations.shared.assistants.models import AssistantIdV1 -from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 -from askui.chat.migrations.shared.threads.models import ThreadIdV1 -from askui.chat.migrations.shared.utils import build_prefixer, now_v1 - -RunStatusV1 = Literal[ - "queued", - "in_progress", - "completed", - "cancelling", - "cancelled", - "failed", - "expired", -] - - -class RunErrorV1(BaseModel): - """Error information for a failed run.""" - - message: str - code: Literal["server_error", "rate_limit_exceeded", "invalid_prompt"] - - -RunIdV1 = Annotated[ - str, Field(pattern=r"^run_[a-z0-9]+$"), BeforeValidator(build_prefixer("run")) -] - - -class RunV1(BaseModel): - id: RunIdV1 - object: Literal["thread.run"] = "thread.run" - thread_id: ThreadIdV1 - created_at: UnixDatetimeV1 - expires_at: UnixDatetimeV1 - started_at: UnixDatetimeV1 | None = None - completed_at: UnixDatetimeV1 | None = None - failed_at: UnixDatetimeV1 | None = None - cancelled_at: UnixDatetimeV1 | None = None - tried_cancelling_at: UnixDatetimeV1 | None = None - last_error: RunErrorV1 | None = None - assistant_id: AssistantIdV1 | None = None - workspace_id: WorkspaceIdV1 = Field(exclude=True) - - def to_db_dict(self) -> dict[str, Any]: - return { - **self.model_dump(exclude={"id", "thread_id", "assistant_id", "object"}), - "id": self.id.removeprefix("run_"), - "thread_id": self.thread_id.removeprefix("thread_"), - "assistant_id": self.assistant_id.removeprefix("asst_") - if self.assistant_id - else None, - "workspace_id": self.workspace_id.hex, - } - - @computed_field # type: ignore[prop-decorator] - @property - def status(self) -> RunStatusV1: - if self.cancelled_at: - return "cancelled" - if self.failed_at: - return "failed" - if self.completed_at: - return "completed" - if self.expires_at and self.expires_at < now_v1(): - return "expired" - if self.tried_cancelling_at: - return "cancelling" - if self.started_at: - return "in_progress" - return "queued" diff --git a/src/askui/chat/migrations/shared/settings.py b/src/askui/chat/migrations/shared/settings.py deleted file mode 100644 index c7671a1a..00000000 --- a/src/askui/chat/migrations/shared/settings.py +++ /dev/null @@ -1,29 +0,0 @@ -from pathlib import Path -from typing import Annotated -from uuid import UUID - -from pydantic import AwareDatetime, Field, PlainSerializer -from pydantic_settings import BaseSettings, SettingsConfigDict - -# Local models to avoid dependencies on askui.chat.api -UnixDatetime = Annotated[ - AwareDatetime, - PlainSerializer( - lambda v: int(v.timestamp()), - return_type=int, - ), -] - -AssistantId = Annotated[str, Field(pattern=r"^asst_[a-z0-9]+$")] -WorkspaceId = UUID - - -class SettingsV1(BaseSettings): - model_config = SettingsConfigDict( - env_prefix="ASKUI__CHAT_API__", env_nested_delimiter="__" - ) - - data_dir: Path = Field( - default_factory=lambda: Path.cwd() / "chat", - description="Base directory for chat data (used during migration)", - ) diff --git a/src/askui/chat/migrations/shared/threads/__init__.py b/src/askui/chat/migrations/shared/threads/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/migrations/shared/threads/models.py b/src/askui/chat/migrations/shared/threads/models.py deleted file mode 100644 index 5c90c549..00000000 --- a/src/askui/chat/migrations/shared/threads/models.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Annotated, Any, Literal - -from pydantic import BaseModel, BeforeValidator, Field - -from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 -from askui.chat.migrations.shared.utils import build_prefixer - -ThreadIdV1 = Annotated[ - str, Field(pattern=r"^thread_[a-z0-9]+$"), BeforeValidator(build_prefixer("thread")) -] - - -class ThreadV1(BaseModel): - id: ThreadIdV1 - object: Literal["thread"] = "thread" - created_at: UnixDatetimeV1 - name: str | None = None - workspace_id: WorkspaceIdV1 = Field(exclude=True) - - def to_db_dict(self) -> dict[str, Any]: - return { - **self.model_dump(exclude={"id", "object"}), - "id": self.id.removeprefix("thread_"), - "workspace_id": self.workspace_id.hex, - } diff --git a/src/askui/chat/migrations/shared/utils.py b/src/askui/chat/migrations/shared/utils.py deleted file mode 100644 index 5345ed1c..00000000 --- a/src/askui/chat/migrations/shared/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -from datetime import datetime, timezone -from typing import Callable - -from pydantic import AwareDatetime - - -def now_v1() -> AwareDatetime: - return datetime.now(tz=timezone.utc) - - -def build_prefixer(prefix: str) -> Callable[[str], str]: - def prefixer(id_: str) -> str: - if id_.startswith(prefix): - return id_ - return f"{prefix}_{id_}" - - return prefixer diff --git a/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py b/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py deleted file mode 100644 index 2b931f1c..00000000 --- a/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py +++ /dev/null @@ -1,124 +0,0 @@ -"""import_json_assistants - -Revision ID: 057f82313448 -Revises: 4d1e043b4254 -Create Date: 2025-10-10 11:21:55.527341 - -""" - -import json -import logging -from typing import Sequence, Union - -from alembic import op -from sqlalchemy import Connection, MetaData, Table - -from askui.chat.migrations.shared.assistants.models import AssistantV1 -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "057f82313448" -down_revision: Union[str, None] = "4d1e043b4254" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - - -BATCH_SIZE = 1000 - - -def _insert_assistants_batch( - connection: Connection, assistants_table: Table, assistants_batch: list[AssistantV1] -) -> None: - """Insert a batch of assistants into the database, ignoring conflicts.""" - if not assistants_batch: - logger.info("No assistants to insert, skipping batch") - return - - connection.execute( - assistants_table.insert().prefix_with("OR REPLACE"), - [assistant.to_db_dict() for assistant in assistants_batch], - ) - - -settings = SettingsV1() -assistants_dir = settings.data_dir / "assistants" - - -def upgrade() -> None: - """Import existing assistants from JSON files.""" - - # Skip if directory doesn't exist (e.g., first-time setup) - if not assistants_dir.exists(): - logger.info( - "Assistants directory does not exist, skipping import of assistants", - extra={"assistants_dir": str(assistants_dir)}, - ) - return - - # Get the table from the current database schema - connection = op.get_bind() - assistants_table = Table("assistants", MetaData(), autoload_with=connection) - - # Get all JSON files in the assistants directory - json_files = list(assistants_dir.glob("*.json")) - - # Process assistants in batches - assistants_batch: list[AssistantV1] = [] - - for json_file in json_files: - try: - content = json_file.read_text(encoding="utf-8").strip() - data = json.loads(content) - assistant = AssistantV1.model_validate(data) - assistants_batch.append(assistant) - - if len(assistants_batch) >= BATCH_SIZE: - _insert_assistants_batch(connection, assistants_table, assistants_batch) - assistants_batch.clear() - except Exception: # noqa: PERF203 - error_msg = "Failed to import" - logger.exception(error_msg, extra={"json_file": str(json_file)}) - continue - - # Insert remaining assistants in the final batch - if assistants_batch: - _insert_assistants_batch(connection, assistants_table, assistants_batch) - - -def downgrade() -> None: - """Recreate JSON files for assistants during downgrade.""" - - assistants_dir.mkdir(parents=True, exist_ok=True) - - connection = op.get_bind() - assistants_table = Table("assistants", MetaData(), autoload_with=connection) - - # Fetch all assistants from the database - result = connection.execute(assistants_table.select()) - rows = result.fetchall() - if not rows: - logger.info( - "No assistants found in the database, skipping export of rows to json", - ) - return - - for row in rows: - try: - assistant: AssistantV1 = AssistantV1.model_validate( - row, from_attributes=True - ) - json_path = assistants_dir / f"{assistant.id}.json" - if json_path.exists(): - logger.info( - "Json file for assistant already exists, skipping export of row to json", - extra={"assistant_id": assistant.id, "json_path": str(json_path)}, - ) - continue - with json_path.open("w", encoding="utf-8") as f: - f.write(assistant.model_dump_json()) - except Exception as e: # noqa: PERF203 - error_msg = f"Failed to export row to json: {e}" - logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) - continue diff --git a/src/askui/chat/migrations/versions/1a2b3c4d5e6f_create_threads_table.py b/src/askui/chat/migrations/versions/1a2b3c4d5e6f_create_threads_table.py deleted file mode 100644 index 011676bf..00000000 --- a/src/askui/chat/migrations/versions/1a2b3c4d5e6f_create_threads_table.py +++ /dev/null @@ -1,36 +0,0 @@ -"""create_threads_table - -Revision ID: 1a2b3c4d5e6f -Revises: 9e0f1a2b3c4d -Create Date: 2025-01-27 12:00:00.000000 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "1a2b3c4d5e6f" -down_revision: Union[str, None] = "9e0f1a2b3c4d" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "threads", - sa.Column("id", sa.String(24), nullable=False, primary_key=True), - sa.Column("workspace_id", sa.Uuid(), nullable=False, index=True), - sa.Column("created_at", sa.Integer(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("threads") - # ### end Alembic commands ### diff --git a/src/askui/chat/migrations/versions/2b3c4d5e6f7a_create_messages_table.py b/src/askui/chat/migrations/versions/2b3c4d5e6f7a_create_messages_table.py deleted file mode 100644 index 16a618c4..00000000 --- a/src/askui/chat/migrations/versions/2b3c4d5e6f7a_create_messages_table.py +++ /dev/null @@ -1,60 +0,0 @@ -"""create_messages_table - -Revision ID: 2b3c4d5e6f7a -Revises: 6f7a8b9c0d1e -Create Date: 2025-01-27 12:01:00.000000 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "2b3c4d5e6f7a" -down_revision: Union[str, None] = "6f7a8b9c0d1e" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "messages", - sa.Column("id", sa.String(24), nullable=False, primary_key=True), - sa.Column( - "thread_id", - sa.String(24), - sa.ForeignKey( - "threads.id", ondelete="CASCADE", name="fk_messages_thread_id" - ), - nullable=False, - ), - sa.Column("workspace_id", sa.Uuid(), nullable=False, index=True), - sa.Column("created_at", sa.Integer(), nullable=False), - sa.Column("role", sa.String(), nullable=False), - sa.Column("content", sa.JSON(), nullable=False), - sa.Column("stop_reason", sa.String(), nullable=True), - sa.Column( - "assistant_id", - sa.String(24), - sa.ForeignKey( - "assistants.id", ondelete="SET NULL", name="fk_messages_assistant_id" - ), - nullable=True, - ), - sa.Column( - "run_id", - sa.String(24), - sa.ForeignKey("runs.id", ondelete="SET NULL", name="fk_messages_run_id"), - nullable=True, - ), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("messages") - # ### end Alembic commands ### diff --git a/src/askui/chat/migrations/versions/3c4d5e6f7a8b_create_runs_table.py b/src/askui/chat/migrations/versions/3c4d5e6f7a8b_create_runs_table.py deleted file mode 100644 index 668c8e22..00000000 --- a/src/askui/chat/migrations/versions/3c4d5e6f7a8b_create_runs_table.py +++ /dev/null @@ -1,72 +0,0 @@ -"""create_runs_table - -Revision ID: 3c4d5e6f7a8b -Revises: 4d5e6f7a8b9c -Create Date: 2025-01-27 12:02:00.000000 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "3c4d5e6f7a8b" -down_revision: Union[str, None] = "4d5e6f7a8b9c" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "runs", - sa.Column("id", sa.String(24), nullable=False, primary_key=True), - sa.Column( - "thread_id", - sa.String(24), - sa.ForeignKey("threads.id", ondelete="CASCADE", name="fk_runs_thread_id"), - nullable=False, - index=True, - ), - sa.Column("workspace_id", sa.Uuid(), nullable=False, index=True), - sa.Column("created_at", sa.Integer(), nullable=False), - sa.Column("expires_at", sa.Integer(), nullable=False), - sa.Column("started_at", sa.Integer(), nullable=True), - sa.Column("completed_at", sa.Integer(), nullable=True), - sa.Column("failed_at", sa.Integer(), nullable=True), - sa.Column("cancelled_at", sa.Integer(), nullable=True), - sa.Column("tried_cancelling_at", sa.Integer(), nullable=True), - sa.Column("last_error", sa.JSON(), nullable=True), - sa.Column( - "assistant_id", - sa.String(24), - sa.ForeignKey( - "assistants.id", ondelete="SET NULL", name="fk_runs_assistant_id" - ), - nullable=True, - ), - sa.Column( - "status", - sa.Enum( - "queued", - "in_progress", - "completed", - "cancelled", - "failed", - "expired", - "cancelling", - ), - nullable=False, - index=True, - ), - ) - - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("runs") - # ### end Alembic commands ### diff --git a/src/askui/chat/migrations/versions/4d1e043b4254_create_assistants_table.py b/src/askui/chat/migrations/versions/4d1e043b4254_create_assistants_table.py deleted file mode 100644 index cbbdcfd0..00000000 --- a/src/askui/chat/migrations/versions/4d1e043b4254_create_assistants_table.py +++ /dev/null @@ -1,40 +0,0 @@ -"""create_assistants_table - -Revision ID: 4d1e043b4254 -Revises: -Create Date: 2025-10-10 11:21:24.218911 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "4d1e043b4254" -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "assistants", - sa.Column("id", sa.String(24), nullable=False, primary_key=True), - sa.Column("workspace_id", sa.Uuid(), nullable=True, index=True), - sa.Column("created_at", sa.Integer(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("description", sa.String(), nullable=True), - sa.Column("avatar", sa.Text(), nullable=True), - sa.Column("tools", sa.JSON(), nullable=False), - sa.Column("system", sa.Text(), nullable=True), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("assistants") - # ### end Alembic commands ### diff --git a/src/askui/chat/migrations/versions/4d5e6f7a8b9c_import_json_threads.py b/src/askui/chat/migrations/versions/4d5e6f7a8b9c_import_json_threads.py deleted file mode 100644 index deaf3d1a..00000000 --- a/src/askui/chat/migrations/versions/4d5e6f7a8b9c_import_json_threads.py +++ /dev/null @@ -1,140 +0,0 @@ -"""import_json_threads - -Revision ID: 4d5e6f7a8b9c -Revises: 1a2b3c4d5e6f -Create Date: 2025-01-27 12:03:00.000000 - -""" - -import json -import logging -from typing import Sequence, Union - -from alembic import op -from sqlalchemy import Connection, MetaData, Table - -from askui.chat.migrations.shared.settings import SettingsV1 -from askui.chat.migrations.shared.threads.models import ThreadV1 - -# revision identifiers, used by Alembic. -revision: str = "4d5e6f7a8b9c" -down_revision: Union[str, None] = "1a2b3c4d5e6f" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - - -BATCH_SIZE = 1000 - - -def _insert_threads_batch( - connection: Connection, threads_table: Table, threads_batch: list[ThreadV1] -) -> None: - """Insert a batch of threads into the database, ignoring conflicts.""" - if not threads_batch: - logger.info("No threads to insert, skipping batch") - return - - connection.execute( - threads_table.insert().prefix_with("OR REPLACE"), - [thread.to_db_dict() for thread in threads_batch], - ) - - -settings = SettingsV1() -workspaces_dir = settings.data_dir / "workspaces" - - -def upgrade() -> None: # noqa: C901 - """Import existing threads from JSON files in workspace directories.""" - - # Skip if workspaces directory doesn't exist (e.g., first-time setup) - if not workspaces_dir.exists(): - logger.info( - "Workspaces directory does not exist, skipping import of threads", - extra={"workspaces_dir": str(workspaces_dir)}, - ) - return - - # Get the table from the current database schema - connection = op.get_bind() - threads_table = Table("threads", MetaData(), autoload_with=connection) - - # Process threads in batches - threads_batch: list[ThreadV1] = [] - - # Iterate through all workspace directories - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - logger.info( - "Skipping non-directory in workspaces", - extra={"path": str(workspace_dir)}, - ) - continue - - workspace_id = workspace_dir.name - threads_dir = workspace_dir / "threads" - - if not threads_dir.exists(): - logger.info( - "Threads directory does not exist, skipping workspace", - extra={"workspace_id": workspace_id, "threads_dir": str(threads_dir)}, - ) - continue - - # Get all JSON files in the threads directory - json_files = list(threads_dir.glob("*.json")) - - for json_file in json_files: - try: - content = json_file.read_text(encoding="utf-8").strip() - data = json.loads(content) - thread = ThreadV1.model_validate({**data, "workspace_id": workspace_id}) - threads_batch.append(thread) - if len(threads_batch) >= BATCH_SIZE: - _insert_threads_batch(connection, threads_table, threads_batch) - threads_batch.clear() - except Exception: # noqa: PERF203 - error_msg = "Failed to import thread" - logger.exception(error_msg, extra={"json_file": str(json_file)}) - continue - - # Insert remaining threads in the final batch - if threads_batch: - _insert_threads_batch(connection, threads_table, threads_batch) - - -def downgrade() -> None: - """Recreate JSON files for threads during downgrade.""" - - connection = op.get_bind() - threads_table = Table("threads", MetaData(), autoload_with=connection) - - # Fetch all threads from the database - result = connection.execute(threads_table.select()) - rows = result.fetchall() - if not rows: - logger.info( - "No threads found in the database, skipping export of rows to json", - ) - return - - for row in rows: - try: - thread_model: ThreadV1 = ThreadV1.model_validate(row, from_attributes=True) - threads_dir = workspaces_dir / str(thread_model.workspace_id) / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - json_path = threads_dir / f"{thread_model.id}.json" - if json_path.exists(): - logger.info( - "Json file for thread already exists, skipping export of row to json", - extra={"thread_id": thread_model.id, "json_path": str(json_path)}, - ) - continue - with json_path.open("w", encoding="utf-8") as f: - f.write(thread_model.model_dump_json()) - except Exception as e: # noqa: PERF203 - error_msg = f"Failed to export row to json: {e}" - logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) - continue diff --git a/src/askui/chat/migrations/versions/5a1b2c3d4e5f_create_mcp_configs_table.py b/src/askui/chat/migrations/versions/5a1b2c3d4e5f_create_mcp_configs_table.py deleted file mode 100644 index 7c3c2001..00000000 --- a/src/askui/chat/migrations/versions/5a1b2c3d4e5f_create_mcp_configs_table.py +++ /dev/null @@ -1,50 +0,0 @@ -"""create_mcp_configs_table - -Revision ID: 5a1b2c3d4e5f -Revises: c35e88ea9595 -Create Date: 2025-01-27 10:00:00.000000 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "5a1b2c3d4e5f" -down_revision: Union[str, None] = "c35e88ea9595" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "mcp_configs", - sa.Column("id", sa.String(24), nullable=False, primary_key=True), - sa.Column("workspace_id", sa.Uuid(), nullable=True, index=True), - sa.Column("created_at", sa.Integer(), nullable=False), - sa.Column("name", sa.String(), nullable=False), - sa.Column("mcp_server", sa.JSON(), nullable=False), - ) - - # Add constraint to enforce MCP configuration limit - op.execute(""" - CREATE TRIGGER check_mcp_config_limit - BEFORE INSERT ON mcp_configs - WHEN ( - SELECT COUNT(*) FROM mcp_configs - WHERE workspace_id = NEW.workspace_id OR workspace_id IS NULL - ) >= 100 - BEGIN - SELECT RAISE(ABORT, 'MCP configuration limit reached. You may only have 100 MCP configurations. You can delete some MCP configurations to create new ones.'); - END; - """) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("mcp_configs") - # ### end Alembic commands ### diff --git a/src/askui/chat/migrations/versions/5e6f7a8b9c0d_import_json_messages.py b/src/askui/chat/migrations/versions/5e6f7a8b9c0d_import_json_messages.py deleted file mode 100644 index 6312c65f..00000000 --- a/src/askui/chat/migrations/versions/5e6f7a8b9c0d_import_json_messages.py +++ /dev/null @@ -1,291 +0,0 @@ -"""import_json_messages - -Revision ID: 5e6f7a8b9c0d -Revises: 2b3c4d5e6f7a -Create Date: 2025-01-27 12:04:00.000000 - -""" - -import json -import logging -from typing import Sequence, Union - -from alembic import op -from sqlalchemy import Connection, MetaData, Table, text - -from askui.chat.migrations.shared.messages.models import MessageV1 -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "5e6f7a8b9c0d" -down_revision: Union[str, None] = "2b3c4d5e6f7a" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - - -BATCH_SIZE = 1000 - - -def _insert_messages_batch( - connection: Connection, messages_table: Table, messages_batch: list[MessageV1] -) -> None: - """Insert a batch of messages into the database, handling foreign key violations.""" - if not messages_batch: - logger.info("No messages to insert, skipping batch") - return - - # Validate and fix foreign key references - valid_messages = _validate_and_fix_foreign_keys(connection, messages_batch) - - if valid_messages: - connection.execute( - messages_table.insert().prefix_with("OR REPLACE"), - [message.to_db_dict() for message in valid_messages], - ) - - -def _validate_and_fix_foreign_keys( # noqa: C901 - connection: Connection, messages_batch: list[MessageV1] -) -> list[MessageV1]: - """ - Validate foreign key references and fix invalid ones. - - - If thread_id is invalid: ignore the message completely - - If assistant_id is invalid: set to None - - If run_id is invalid: set to None - """ - if not messages_batch: - logger.info("Empty message batch, nothing to validate") - return [] - - # Extract all foreign key values - thread_ids = {msg.thread_id.removeprefix("thread_") for msg in messages_batch} - assistant_ids = { - msg.assistant_id.removeprefix("asst_") - for msg in messages_batch - if msg.assistant_id - } - run_ids = {msg.run_id.removeprefix("run_") for msg in messages_batch if msg.run_id} - - # Check which foreign keys exist in the database - valid_thread_ids: set[str] = set() - if thread_ids: - # Create placeholders for SQLite IN clause - placeholders = ",".join([":id" + str(i) for i in range(len(thread_ids))]) - params = {f"id{i}": thread_id for i, thread_id in enumerate(thread_ids)} - result = connection.execute( - text(f"SELECT id FROM threads WHERE id IN ({placeholders})"), params - ) - valid_thread_ids = {row[0] for row in result} - - valid_assistant_ids: set[str] = set() - if assistant_ids: - # Create placeholders for SQLite IN clause - placeholders = ",".join([":id" + str(i) for i in range(len(assistant_ids))]) - params = { - f"id{i}": assistant_id for i, assistant_id in enumerate(assistant_ids) - } - result = connection.execute( - text(f"SELECT id FROM assistants WHERE id IN ({placeholders})"), params - ) - valid_assistant_ids = {row[0] for row in result} - - valid_run_ids: set[str] = set() - if run_ids: - # Create placeholders for SQLite IN clause - placeholders = ",".join([":id" + str(i) for i in range(len(run_ids))]) - params = {f"id{i}": run_id for i, run_id in enumerate(run_ids)} - result = connection.execute( - text(f"SELECT id FROM runs WHERE id IN ({placeholders})"), params - ) - valid_run_ids = {row[0] for row in result} - - # Process each message - valid_messages: list[MessageV1] = [] - for message in messages_batch: - thread_id = message.thread_id.removeprefix("thread_") - assistant_id = ( - message.assistant_id.removeprefix("asst_") if message.assistant_id else None - ) - run_id = message.run_id.removeprefix("run_") if message.run_id else None - - # If thread_id is invalid, ignore the message completely - if thread_id not in valid_thread_ids: - logger.warning( - "Ignoring message with invalid thread_id (thread does not exist)", - extra={ - "message_id": message.id, - "thread_id": thread_id, - "workspace_id": str(message.workspace_id), - }, - ) - continue - - # Check and fix assistant_id and run_id - fixed_assistant_id = None - fixed_run_id = None - changes_made: list[str] = [] - - if assistant_id is not None and assistant_id not in valid_assistant_ids: - fixed_assistant_id = None - changes_made.append(f"assistant_id set to None (was: {assistant_id})") - elif assistant_id is not None: - fixed_assistant_id = assistant_id - - if run_id is not None and run_id not in valid_run_ids: - fixed_run_id = None - changes_made.append(f"run_id set to None (was: {run_id})") - elif run_id is not None: - fixed_run_id = run_id - - # Create a copy of the message with fixed foreign keys - if changes_made: - logger.info( - "Fixed foreign key references for message", - extra={ - "message_id": message.id, - "thread_id": thread_id, - "changes": changes_made, - }, - ) - - # Create new message with fixed foreign keys - fixed_message = MessageV1( - id=message.id, - object=message.object, - created_at=message.created_at, - thread_id=message.thread_id, - role=message.role, - content=message.content, - stop_reason=message.stop_reason, - assistant_id=f"asst_{fixed_assistant_id}" - if fixed_assistant_id - else None, - run_id=f"run_{fixed_run_id}" if fixed_run_id else None, - workspace_id=message.workspace_id, - ) - valid_messages.append(fixed_message) - else: - # No changes needed, use original message - valid_messages.append(message) - - return valid_messages - - -settings = SettingsV1() -workspaces_dir = settings.data_dir / "workspaces" - - -def upgrade() -> None: # noqa: C901 - """Import existing messages from JSON files in workspace directories.""" - # Skip if workspaces directory doesn't exist (e.g., first-time setup) - if not workspaces_dir.exists(): - logger.info( - "Workspaces directory does not exist, skipping import of messages", - extra={"workspaces_dir": str(workspaces_dir)}, - ) - return - - # Get the table from the current database schema - connection = op.get_bind() - messages_table = Table("messages", MetaData(), autoload_with=connection) - - # Process messages in batches - messages_batch: list[MessageV1] = [] - - # Iterate through all workspace directories - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - logger.info( - "Skipping non-directory in workspaces", - extra={"path": str(workspace_dir)}, - ) - continue - - workspace_id = workspace_dir.name - messages_dir = workspace_dir / "messages" - - if not messages_dir.exists(): - logger.info( - "Messages directory does not exist, skipping workspace", - extra={"workspace_id": workspace_id, "messages_dir": str(messages_dir)}, - ) - continue - - # Iterate through thread directories - for thread_dir in messages_dir.iterdir(): - if not thread_dir.is_dir(): - logger.info( - "Skipping non-directory in messages", - extra={"path": str(thread_dir)}, - ) - continue - - # Get all JSON files in the thread directory - json_files = list(thread_dir.glob("*.json")) - - for json_file in json_files: - try: - content = json_file.read_text(encoding="utf-8").strip() - data = json.loads(content) - message = MessageV1.model_validate( - {**data, "workspace_id": workspace_id} - ) - messages_batch.append(message) - if len(messages_batch) >= BATCH_SIZE: - _insert_messages_batch( - connection, messages_table, messages_batch - ) - messages_batch.clear() - except Exception: # noqa: PERF203 - error_msg = "Failed to import message" - logger.exception(error_msg, extra={"json_file": str(json_file)}) - continue - - # Insert remaining messages in the final batch - if messages_batch: - _insert_messages_batch(connection, messages_table, messages_batch) - - -def downgrade() -> None: - """Recreate JSON files for messages during downgrade.""" - - connection = op.get_bind() - messages_table = Table("messages", MetaData(), autoload_with=connection) - - # Fetch all messages from the database - result = connection.execute(messages_table.select()) - rows = result.fetchall() - if not rows: - logger.info( - "No messages found in the database, skipping export of rows to json", - ) - return - - for row in rows: - try: - message_model: MessageV1 = MessageV1.model_validate( - row, from_attributes=True - ) - messages_dir = ( - workspaces_dir - / str(message_model.workspace_id) - / "messages" - / message_model.thread_id - ) - messages_dir.mkdir(parents=True, exist_ok=True) - json_path = messages_dir / f"{message_model.id}.json" - if json_path.exists(): - logger.info( - "Json file for message already exists, skipping export of row to json", - extra={"message_id": message_model.id, "json_path": str(json_path)}, - ) - continue - with json_path.open("w", encoding="utf-8") as f: - f.write(message_model.model_dump_json()) - except Exception as e: # noqa: PERF203 - error_msg = f"Failed to export row to json: {e}" - logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) - continue diff --git a/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py b/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py deleted file mode 100644 index ca11ef15..00000000 --- a/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py +++ /dev/null @@ -1,127 +0,0 @@ -"""import_json_mcp_configs - -Revision ID: 6b2c3d4e5f6a -Revises: 5a1b2c3d4e5f -Create Date: 2025-01-27 10:01:00.000000 - -""" - -import json -import logging -from typing import Sequence, Union - -from alembic import op -from sqlalchemy import Connection, MetaData, Table - -from askui.chat.migrations.shared.mcp_configs.models import McpConfigV1 -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "6b2c3d4e5f6a" -down_revision: Union[str, None] = "5a1b2c3d4e5f" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - - -BATCH_SIZE = 1000 - - -def _insert_mcp_configs_batch( - connection: Connection, - mcp_configs_table: Table, - mcp_configs_batch: list[McpConfigV1], -) -> None: - """Insert a batch of MCP configs into the database, ignoring conflicts.""" - if not mcp_configs_batch: - logger.info("No MCP configs to insert, skipping batch") - return - - connection.execute( - mcp_configs_table.insert().prefix_with("OR REPLACE"), - [mcp_config.to_db_dict() for mcp_config in mcp_configs_batch], - ) - - -settings = SettingsV1() -mcp_configs_dir = settings.data_dir / "mcp_configs" - - -def upgrade() -> None: - """Import existing MCP configs from JSON files.""" - - # Skip if directory doesn't exist (e.g., first-time setup) - if not mcp_configs_dir.exists(): - logger.info( - "MCP configs directory does not exist, skipping import of MCP configs", - extra={"mcp_configs_dir": str(mcp_configs_dir)}, - ) - return - - # Get the table from the current database schema - connection = op.get_bind() - mcp_configs_table = Table("mcp_configs", MetaData(), autoload_with=connection) - - # Get all JSON files in the mcp_configs directory - json_files = list(mcp_configs_dir.glob("*.json")) - - # Process MCP configs in batches - mcp_configs_batch: list[McpConfigV1] = [] - - for json_file in json_files: - try: - content = json_file.read_text(encoding="utf-8").strip() - data = json.loads(content) - mcp_config = McpConfigV1.model_validate(data) - mcp_configs_batch.append(mcp_config) - if len(mcp_configs_batch) >= BATCH_SIZE: - _insert_mcp_configs_batch( - connection, mcp_configs_table, mcp_configs_batch - ) - mcp_configs_batch.clear() - except Exception: # noqa: PERF203 - error_msg = "Failed to import" - logger.exception(error_msg, extra={"json_file": str(json_file)}) - continue - - # Insert remaining MCP configs in the final batch - if mcp_configs_batch: - _insert_mcp_configs_batch(connection, mcp_configs_table, mcp_configs_batch) - - -def downgrade() -> None: - """Recreate JSON files for MCP configs during downgrade.""" - - mcp_configs_dir.mkdir(parents=True, exist_ok=True) - - connection = op.get_bind() - mcp_configs_table = Table("mcp_configs", MetaData(), autoload_with=connection) - - # Fetch all MCP configs from the database - result = connection.execute(mcp_configs_table.select()) - rows = result.fetchall() - if not rows: - logger.info( - "No MCP configs found in the database, skipping export of rows to json", - ) - return - - for row in rows: - try: - mcp_config: McpConfigV1 = McpConfigV1.model_validate( - row, from_attributes=True - ) - json_path = mcp_configs_dir / f"{mcp_config.id}.json" - if json_path.exists(): - logger.info( - "Json file for mcp config already exists, skipping export of row to json", - extra={"mcp_config_id": mcp_config.id, "json_path": str(json_path)}, - ) - continue - with json_path.open("w", encoding="utf-8") as f: - f.write(mcp_config.model_dump_json()) - except Exception as e: # noqa: PERF203 - error_msg = f"Failed to export row to json: {e}" - logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) - continue diff --git a/src/askui/chat/migrations/versions/6f7a8b9c0d1e_import_json_runs.py b/src/askui/chat/migrations/versions/6f7a8b9c0d1e_import_json_runs.py deleted file mode 100644 index 6c63c2fa..00000000 --- a/src/askui/chat/migrations/versions/6f7a8b9c0d1e_import_json_runs.py +++ /dev/null @@ -1,267 +0,0 @@ -"""import_json_runs - -Revision ID: 6f7a8b9c0d1e -Revises: 3c4d5e6f7a8b -Create Date: 2025-01-27 12:05:00.000000 - -""" - -import json -import logging -from typing import Sequence, Union - -from alembic import op -from sqlalchemy import Connection, MetaData, Table, text - -from askui.chat.migrations.shared.runs.models import RunV1 -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "6f7a8b9c0d1e" -down_revision: Union[str, None] = "3c4d5e6f7a8b" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - - -BATCH_SIZE = 1000 - - -def _insert_runs_batch( - connection: Connection, runs_table: Table, runs_batch: list[RunV1] -) -> None: - """Insert a batch of runs into the database, handling foreign key violations.""" - if not runs_batch: - logger.info("No runs to insert, skipping batch") - return - - # Validate and fix foreign key references - valid_runs = _validate_and_fix_foreign_keys(connection, runs_batch) - - if valid_runs: - connection.execute( - runs_table.insert().prefix_with("OR REPLACE"), - [run.to_db_dict() for run in valid_runs], - ) - - -def _validate_and_fix_foreign_keys( # noqa: C901 - connection: Connection, runs_batch: list[RunV1] -) -> list[RunV1]: - """ - Validate foreign key references and fix invalid ones. - - - If thread_id is invalid: ignore the run completely - - If assistant_id is invalid: set to None - """ - if not runs_batch: - logger.info("Empty run batch, nothing to validate") - return [] - - # Extract all foreign key values - thread_ids = {run.thread_id.removeprefix("thread_") for run in runs_batch} - assistant_ids = { - run.assistant_id.removeprefix("asst_") for run in runs_batch if run.assistant_id - } - - # Check which foreign keys exist in the database - valid_thread_ids: set[str] = set() - if thread_ids: - # Create placeholders for SQLite IN clause - placeholders = ",".join([":id" + str(i) for i in range(len(thread_ids))]) - params = {f"id{i}": thread_id for i, thread_id in enumerate(thread_ids)} - result = connection.execute( - text(f"SELECT id FROM threads WHERE id IN ({placeholders})"), params - ) - valid_thread_ids = {row[0] for row in result} - - valid_assistant_ids: set[str] = set() - if assistant_ids: - # Create placeholders for SQLite IN clause - placeholders = ",".join([":id" + str(i) for i in range(len(assistant_ids))]) - params = { - f"id{i}": assistant_id for i, assistant_id in enumerate(assistant_ids) - } - result = connection.execute( - text(f"SELECT id FROM assistants WHERE id IN ({placeholders})"), params - ) - valid_assistant_ids = {row[0] for row in result} - - # Process each run - valid_runs: list[RunV1] = [] - for run in runs_batch: - thread_id = run.thread_id.removeprefix("thread_") - assistant_id = ( - run.assistant_id.removeprefix("asst_") if run.assistant_id else None - ) - - # If thread_id is invalid, ignore the run completely - if thread_id not in valid_thread_ids: - logger.warning( - "Ignoring run with invalid thread_id (thread does not exist)", - extra={ - "run_id": run.id, - "thread_id": thread_id, - "workspace_id": str(run.workspace_id), - }, - ) - continue - - # Check and fix assistant_id - fixed_assistant_id = None - changes_made: list[str] = [] - - if assistant_id is not None and assistant_id not in valid_assistant_ids: - fixed_assistant_id = None - changes_made.append(f"assistant_id set to None (was: {assistant_id})") - elif assistant_id is not None: - fixed_assistant_id = assistant_id - - # Create a copy of the run with fixed foreign keys - if changes_made: - logger.info( - "Fixed foreign key references for run", - extra={ - "run_id": run.id, - "thread_id": thread_id, - "changes": changes_made, - }, - ) - - # Create new run with fixed foreign keys - fixed_run = RunV1( - id=run.id, - object=run.object, - thread_id=run.thread_id, - created_at=run.created_at, - expires_at=run.expires_at, - started_at=run.started_at, - completed_at=run.completed_at, - failed_at=run.failed_at, - cancelled_at=run.cancelled_at, - tried_cancelling_at=run.tried_cancelling_at, - last_error=run.last_error, - assistant_id=f"asst_{fixed_assistant_id}" - if fixed_assistant_id - else None, - workspace_id=run.workspace_id, - ) - valid_runs.append(fixed_run) - else: - # No changes needed, use original run - valid_runs.append(run) - - return valid_runs - - -settings = SettingsV1() -workspaces_dir = settings.data_dir / "workspaces" - - -def upgrade() -> None: # noqa: C901 - """Import existing runs from JSON files in workspace directories.""" - - # Skip if workspaces directory doesn't exist (e.g., first-time setup) - if not workspaces_dir.exists(): - logger.info( - "Workspaces directory does not exist, skipping import of runs", - extra={"workspaces_dir": str(workspaces_dir)}, - ) - return - - # Get the table from the current database schema - connection = op.get_bind() - runs_table = Table("runs", MetaData(), autoload_with=connection) - - # Process runs in batches - runs_batch: list[RunV1] = [] - - # Iterate through all workspace directories - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - logger.info( - "Skipping non-directory in workspaces", - extra={"path": str(workspace_dir)}, - ) - continue - - workspace_id = workspace_dir.name - runs_dir = workspace_dir / "runs" - - if not runs_dir.exists(): - logger.info( - "Runs directory does not exist, skipping workspace", - extra={"workspace_id": workspace_id, "runs_dir": str(runs_dir)}, - ) - continue - - # Iterate through thread directories - for thread_dir in runs_dir.iterdir(): - if not thread_dir.is_dir(): - logger.info( - "Skipping non-directory in runs", - extra={"path": str(thread_dir)}, - ) - continue - - # Get all JSON files in the thread directory - json_files = list(thread_dir.glob("*.json")) - - for json_file in json_files: - try: - content = json_file.read_text(encoding="utf-8").strip() - data = json.loads(content) - run = RunV1.model_validate({**data, "workspace_id": workspace_id}) - runs_batch.append(run) - if len(runs_batch) >= BATCH_SIZE: - _insert_runs_batch(connection, runs_table, runs_batch) - runs_batch.clear() - except Exception: # noqa: PERF203 - error_msg = "Failed to import run" - logger.exception(error_msg, extra={"json_file": str(json_file)}) - continue - - # Insert remaining runs in the final batch - if runs_batch: - _insert_runs_batch(connection, runs_table, runs_batch) - - -def downgrade() -> None: - """Recreate JSON files for runs during downgrade.""" - - connection = op.get_bind() - runs_table = Table("runs", MetaData(), autoload_with=connection) - - # Fetch all runs from the database - result = connection.execute(runs_table.select()) - rows = result.fetchall() - if not rows: - logger.info( - "No runs found in the database, skipping export of rows to json", - ) - return - - for row in rows: - try: - run_model: RunV1 = RunV1.model_validate(row, from_attributes=True) - runs_dir = ( - workspaces_dir - / str(run_model.workspace_id) - / "runs" - / run_model.thread_id - ) - runs_dir.mkdir(parents=True, exist_ok=True) - json_path = runs_dir / f"{run_model.id}.json" - if json_path.exists(): - logger.info( - "Json file for run already exists, skipping export of row to json", - extra={"run_id": run_model.id, "json_path": str(json_path)}, - ) - continue - with json_path.open("w", encoding="utf-8") as f: - f.write(run_model.model_dump_json()) - except Exception as e: # noqa: PERF203 - error_msg = f"Failed to export row to json: {e}" - logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) - continue diff --git a/src/askui/chat/migrations/versions/7b8c9d0e1f2a_add_parent_id_to_messages.py b/src/askui/chat/migrations/versions/7b8c9d0e1f2a_add_parent_id_to_messages.py deleted file mode 100644 index 2c8a1daf..00000000 --- a/src/askui/chat/migrations/versions/7b8c9d0e1f2a_add_parent_id_to_messages.py +++ /dev/null @@ -1,92 +0,0 @@ -"""add_parent_id_to_messages - -Revision ID: 7b8c9d0e1f2a -Revises: 5e6f7a8b9c0d -Create Date: 2025-11-05 12:00:00.000000 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "7b8c9d0e1f2a" -down_revision: Union[str, None] = "5e6f7a8b9c0d" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # Get database connection - connection = op.get_bind() - - # Check if parent_id column already exists - inspector = sa.inspect(connection) - columns = [col["name"] for col in inspector.get_columns("messages")] - column_exists = "parent_id" in columns - - # Only run batch operation if column doesn't exist - if not column_exists: - # Add column, foreign key, and index all in one batch operation - # This ensures the table is only recreated once in SQLite - with op.batch_alter_table("messages") as batch_op: - # Add parent_id column - batch_op.add_column(sa.Column("parent_id", sa.String(24), nullable=True)) - - # Add foreign key constraint (self-referential) - # parent_id remains nullable - NULL indicates a root message - batch_op.create_foreign_key( - "fk_messages_parent_id", - "messages", - ["parent_id"], - ["id"], - ondelete="CASCADE", - ) - - # Add index for performance - batch_op.create_index("ix_messages_parent_id", ["parent_id"]) - - # NOW populate parent_id values AFTER the table structure is finalized - # Fetch all threads - threads_result = connection.execute(sa.text("SELECT id FROM threads")) - thread_ids = [row[0] for row in threads_result] - - # For each thread, set up parent-child relationships - for thread_id in thread_ids: - # Get all messages in this thread, sorted by ID (which is time-ordered) - messages_result = connection.execute( - sa.text( - "SELECT id FROM messages WHERE thread_id = :thread_id ORDER BY id ASC" - ), - {"thread_id": thread_id}, - ) - message_ids = [row[0] for row in messages_result] - - # Set parent_id for each message - for i, message_id in enumerate(message_ids): - if i == 0: - # First message in thread has NULL as parent (root message) - parent_id = None - else: - # Each subsequent message's parent is the previous message - parent_id = message_ids[i - 1] - - connection.execute( - sa.text( - "UPDATE messages SET parent_id = :parent_id WHERE id = :message_id" - ), - {"parent_id": parent_id, "message_id": message_id}, - ) - - -def downgrade() -> None: - # Use batch_alter_table for SQLite compatibility - with op.batch_alter_table("messages") as batch_op: - # Drop index - batch_op.drop_index("ix_messages_parent_id") - # Drop foreign key constraint - batch_op.drop_constraint("fk_messages_parent_id", type_="foreignkey") - # Drop column - batch_op.drop_column("parent_id") diff --git a/src/askui/chat/migrations/versions/8d9e0f1a2b3c_create_files_table.py b/src/askui/chat/migrations/versions/8d9e0f1a2b3c_create_files_table.py deleted file mode 100644 index c227e0c8..00000000 --- a/src/askui/chat/migrations/versions/8d9e0f1a2b3c_create_files_table.py +++ /dev/null @@ -1,38 +0,0 @@ -"""create_files_table - -Revision ID: 8d9e0f1a2b3c -Revises: 6b2c3d4e5f6a -Create Date: 2025-01-27 11:00:00.000000 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "8d9e0f1a2b3c" -down_revision: Union[str, None] = "6b2c3d4e5f6a" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "files", - sa.Column("id", sa.String(24), nullable=False, primary_key=True), - sa.Column("workspace_id", sa.Uuid(), nullable=True, index=True), - sa.Column("created_at", sa.Integer(), nullable=False), - sa.Column("filename", sa.String(), nullable=False), - sa.Column("size", sa.Integer(), nullable=False), - sa.Column("media_type", sa.String(), nullable=False), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("files") - # ### end Alembic commands ### diff --git a/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py b/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py deleted file mode 100644 index 750cce56..00000000 --- a/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py +++ /dev/null @@ -1,143 +0,0 @@ -"""import_json_files - -Revision ID: 9e0f1a2b3c4d -Revises: 8d9e0f1a2b3c -Create Date: 2025-01-27 11:01:00.000000 - -""" - -import json -import logging -from typing import Sequence, Union - -from alembic import op -from sqlalchemy import Connection, MetaData, Table - -from askui.chat.migrations.shared.files.models import FileV1 -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "9e0f1a2b3c4d" -down_revision: Union[str, None] = "8d9e0f1a2b3c" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - - -BATCH_SIZE = 1000 - - -def _insert_files_batch( - connection: Connection, files_table: Table, files_batch: list[FileV1] -) -> None: - """Insert a batch of files into the database, ignoring conflicts.""" - if not files_batch: - logger.info("No files to insert, skipping batch") - return - - connection.execute( - files_table.insert().prefix_with("OR REPLACE"), - [file.to_db_dict() for file in files_batch], - ) - - -settings = SettingsV1() -workspaces_dir = settings.data_dir / "workspaces" - - -def upgrade() -> None: # noqa: C901 - """Import existing files from JSON files in workspace static directories.""" - - # Skip if workspaces directory doesn't exist (e.g., first-time setup) - if not workspaces_dir.exists(): - logger.info( - "Workspaces directory does not exist, skipping import of files", - extra={"workspaces_dir": str(workspaces_dir)}, - ) - return - - # Get the table from the current database schema - connection = op.get_bind() - files_table = Table("files", MetaData(), autoload_with=connection) - - # Process files in batches - files_batch: list[FileV1] = [] - - # Iterate through all workspace directories - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - logger.info( - "Skipping non-directory in workspaces", - extra={"path": str(workspace_dir)}, - ) - continue - - workspace_id = workspace_dir.name - files_dir = workspace_dir / "files" - - if not files_dir.exists(): - logger.info( - "Files directory does not exist, skipping workspace", - extra={"workspace_id": workspace_id, "files_dir": str(files_dir)}, - ) - continue - - # Get all JSON files in the static directory - json_files = list(files_dir.glob("*.json")) - - for json_file in json_files: - try: - content = json_file.read_text(encoding="utf-8").strip() - data = json.loads(content) - file = FileV1.model_validate({**data, "workspace_id": workspace_id}) - files_batch.append(file) - if len(files_batch) >= BATCH_SIZE: - _insert_files_batch(connection, files_table, files_batch) - files_batch.clear() - except Exception: # noqa: PERF203 - error_msg = "Failed to import file" - logger.exception(error_msg, extra={"json_file": str(json_file)}) - continue - - # Insert remaining files in the final batch - if files_batch: - _insert_files_batch(connection, files_table, files_batch) - - -def downgrade() -> None: - """Recreate JSON files for files during downgrade.""" - - connection = op.get_bind() - files_table = Table("files", MetaData(), autoload_with=connection) - - # Fetch all files from the database - result = connection.execute(files_table.select()) - rows = result.fetchall() - if not rows: - logger.info( - "No files found in the database, skipping export of rows to json", - ) - return - - for row in rows: - try: - file_model: FileV1 = FileV1.model_validate(row, from_attributes=True) - if file_model.workspace_id: - files_dir = workspaces_dir / str(file_model.workspace_id) / "files" - else: - files_dir = settings.data_dir / "files" - files_dir.mkdir(parents=True, exist_ok=True) - json_path = files_dir / f"{file_model.id}.json" - if json_path.exists(): - logger.info( - "Json file for file already exists, skipping export of row to json", - extra={"file_id": file_model.id, "json_path": str(json_path)}, - ) - continue - with json_path.open("w", encoding="utf-8") as f: - f.write(file_model.model_dump_json()) - except Exception as e: # noqa: PERF203 - error_msg = f"Failed to export row to json: {e}" - logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) - continue diff --git a/src/askui/chat/migrations/versions/__init__.py b/src/askui/chat/migrations/versions/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/migrations/versions/c35e88ea9595_seed_default_assistants.py b/src/askui/chat/migrations/versions/c35e88ea9595_seed_default_assistants.py deleted file mode 100644 index ad47e506..00000000 --- a/src/askui/chat/migrations/versions/c35e88ea9595_seed_default_assistants.py +++ /dev/null @@ -1,61 +0,0 @@ -"""seed_default_assistants - -Revision ID: c35e88ea9595 -Revises: 057f82313448 -Create Date: 2025-10-10 11:22:12.576195 - -""" - -import logging -from typing import Sequence, Union - -from alembic import op -from sqlalchemy import MetaData, Table -from sqlalchemy.exc import IntegrityError - -from askui.chat.migrations.shared.assistants.seeds import SEEDS_V1 - -# revision identifiers, used by Alembic. -revision: str = "c35e88ea9595" -down_revision: Union[str, None] = "057f82313448" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - - -def upgrade() -> None: - """Seed default assistants one by one, skipping duplicates. - - For each assistant in `SEEDS_V1`, insert a row into `assistants`. If a - row with the same `id` already exists, skip it and log on debug level. - """ - connection = op.get_bind() - assistants_table: Table = Table("assistants", MetaData(), autoload_with=connection) - - for seed in SEEDS_V1: - payload: dict[str, object] = seed.to_db_dict() - try: - connection.execute(assistants_table.insert().values(**payload)) - except IntegrityError: - logger.info( - "Assistant already exists, skipping", extra={"assistant_id": seed.id} - ) - continue - except Exception as e: # noqa: PERF203 - logger.exception( - "Failed to insert assistant", - extra={"assistant": seed.model_dump_json()}, - exc_info=e, - ) - continue - - -def downgrade() -> None: - """Remove exactly those assistants that were seeded in upgrade().""" - connection = op.get_bind() - assistant_table: Table = Table("assistants", MetaData(), autoload_with=connection) - - seed_db_ids: list[str] = [seed.id for seed in SEEDS_V1] - for id_ in seed_db_ids: - connection.execute(assistant_table.delete().where(assistant_table.c.id == id_)) diff --git a/tests/integration/chat/__init__.py b/tests/integration/chat/__init__.py deleted file mode 100644 index baa44a5b..00000000 --- a/tests/integration/chat/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Chat integration tests.""" diff --git a/tests/integration/chat/api/__init__.py b/tests/integration/chat/api/__init__.py deleted file mode 100644 index 13477e6b..00000000 --- a/tests/integration/chat/api/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Chat API integration tests diff --git a/tests/integration/chat/api/conftest.py b/tests/integration/chat/api/conftest.py deleted file mode 100644 index 77be8da2..00000000 --- a/tests/integration/chat/api/conftest.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Chat API integration test configuration and fixtures.""" - -import tempfile -import uuid -from collections.abc import Generator -from pathlib import Path - -import pytest -from fastapi import FastAPI -from fastapi.testclient import TestClient -from sqlalchemy import create_engine -from sqlalchemy.orm import Session, sessionmaker - -from askui.chat.api.app import app -from askui.chat.api.assistants.dependencies import get_assistant_service -from askui.chat.api.assistants.service import AssistantService -from askui.chat.api.db.orm.base import Base -from askui.chat.api.files.service import FileService - - -@pytest.fixture -def test_db_session() -> Generator[Session, None, None]: - """Create a test database session with temporary SQLite file.""" - with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as temp_db: - # Create an engine with the temporary file - engine = create_engine(f"sqlite:///{temp_db.name}", echo=True) - # Create all tables - Base.metadata.create_all(engine) - # Create a session - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - session = SessionLocal() - try: - yield session - finally: - session.close() - - -@pytest.fixture -def test_app() -> FastAPI: - """Get the FastAPI test application.""" - return app - - -@pytest.fixture -def test_client( - test_app: FastAPI, test_db_session: Session -) -> Generator[TestClient, None, None]: - """Yield a TestClient with common overrides - (assistants service uses the test DB). - """ - app.dependency_overrides[get_assistant_service] = lambda: AssistantService( - test_db_session - ) - try: - yield TestClient(test_app) - finally: - app.dependency_overrides.pop(get_assistant_service, None) - - -@pytest.fixture -def temp_workspace_dir() -> Path: - """Create a temporary workspace directory for testing.""" - temp_dir = tempfile.mkdtemp() - return Path(temp_dir) - - -@pytest.fixture -def test_workspace_id() -> str: - """Get a test workspace ID.""" - return str(uuid.uuid4()) - - -@pytest.fixture -def test_headers(test_workspace_id: str) -> dict[str, str]: - """Get test headers with workspace ID.""" - return {"askui-workspace": test_workspace_id} - - -@pytest.fixture -def mock_file_service( - test_db_session: Session, temp_workspace_dir: Path -) -> FileService: - """Create a mock file service with temporary workspace.""" - return FileService(test_db_session, temp_workspace_dir) - - -def create_test_app_with_overrides( - test_db_session: Session, workspace_path: Path -) -> FastAPI: - """Create a test app with all dependencies overridden.""" - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - - # Create a copy of the app to avoid modifying the global one - test_app = FastAPI() - test_app.router = app.router - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - test_app.dependency_overrides[get_workspace_dir] = override_workspace_dir - test_app.dependency_overrides[get_file_service] = override_file_service - test_app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - return test_app diff --git a/tests/integration/chat/api/test_assistants.py b/tests/integration/chat/api/test_assistants.py deleted file mode 100644 index 46197d4a..00000000 --- a/tests/integration/chat/api/test_assistants.py +++ /dev/null @@ -1,513 +0,0 @@ -"""Integration tests for the assistants API endpoints.""" - -from datetime import datetime, timezone -from uuid import UUID - -from fastapi import status -from fastapi.testclient import TestClient -from sqlalchemy.orm import Session - -from askui.chat.api.assistants.models import Assistant -from askui.chat.api.assistants.orms import AssistantOrm -from askui.chat.api.models import WorkspaceId - - -class TestAssistantsAPI: - """Test suite for the assistants API endpoints.""" - - def _create_test_assistant( - self, - assistant_id: str, - workspace_id: WorkspaceId | None = None, - name: str = "Test Assistant", - description: str = "A test assistant", - avatar: str | None = None, - created_at: datetime | None = None, - ) -> Assistant: - """Create a test assistant model.""" - if created_at is None: - created_at = datetime.fromtimestamp(1234567890, tz=timezone.utc) - return Assistant( - id=assistant_id, - object="assistant", - created_at=created_at, - name=name, - description=description, - avatar=avatar, - workspace_id=workspace_id, - ) - - def _add_assistant_to_db( - self, assistant: Assistant, test_db_session: Session - ) -> None: - """Add an assistant to the test database.""" - assistant_orm = AssistantOrm.from_model(assistant) - test_db_session.add(assistant_orm) - test_db_session.commit() - - def test_list_assistants_empty( - self, test_headers: dict[str, str], test_client: TestClient - ) -> None: - """Test listing assistants when no assistants exist.""" - response = test_client.get("/v1/assistants", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert data["data"] == [] - assert data["has_more"] is False - - def test_list_assistants_with_assistants( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test listing assistants when assistants exist.""" - workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 - mock_assistant = self._create_test_assistant( - "asst_test123", workspace_id=workspace_id, avatar="test_avatar.png" - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - response = test_client.get("/v1/assistants", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert len(data["data"]) == 1 - assert data["data"][0]["id"] == "asst_test123" - assert data["data"][0]["name"] == "Test Assistant" - assert data["data"][0]["description"] == "A test assistant" - assert data["data"][0]["avatar"] == "test_avatar.png" - assert data["has_more"] is False - - def test_list_assistants_with_pagination( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test listing assistants with pagination parameters.""" - workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 - - # Create multiple mock assistants in the database - for i in range(5): - mock_assistant = self._create_test_assistant( - f"asst_test{i}", - workspace_id=workspace_id, - name=f"Test Assistant {i}", - description=f"Test assistant {i}", - created_at=datetime.fromtimestamp(1234567890 + i, tz=timezone.utc), - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - - response = test_client.get("/v1/assistants?limit=3", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert len(data["data"]) == 3 - assert data["has_more"] is True - - def test_create_assistant( - self, test_headers: dict[str, str], test_client: TestClient - ) -> None: - """Test creating a new assistant.""" - assistant_data = { - "name": "New Test Assistant", - "description": "A newly created test assistant", - "avatar": "new_avatar.png", - } - response = test_client.post( - "/v1/assistants", json=assistant_data, headers=test_headers - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["name"] == "New Test Assistant" - assert data["description"] == "A newly created test assistant" - assert data["avatar"] == "new_avatar.png" - assert data["object"] == "assistant" - assert "id" in data - assert "created_at" in data - - def test_create_assistant_minimal( - self, test_headers: dict[str, str], test_client: TestClient - ) -> None: - """Test creating an assistant with minimal data.""" - response = test_client.post("/v1/assistants", json={}, headers=test_headers) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["object"] == "assistant" - assert data["name"] is None - assert data["description"] is None - assert data["avatar"] is None - - def test_create_assistant_with_tools_and_system( - self, test_headers: dict[str, str], test_client: TestClient - ) -> None: - """Test creating a new assistant with tools and system prompt.""" - response = test_client.post( - "/v1/assistants", - headers=test_headers, - json={ - "name": "Custom Assistant", - "description": "A custom assistant with tools", - "tools": ["tool1", "tool2", "tool3"], - "system": "You are a helpful custom assistant.", - }, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["name"] == "Custom Assistant" - assert data["description"] == "A custom assistant with tools" - assert data["tools"] == ["tool1", "tool2", "tool3"] - assert data["system"] == "You are a helpful custom assistant." - assert "id" in data - assert "created_at" in data - - def test_create_assistant_with_empty_tools( - self, test_headers: dict[str, str], test_client: TestClient - ) -> None: - """Test creating a new assistant with empty tools list.""" - response = test_client.post( - "/v1/assistants", - headers=test_headers, - json={ - "name": "Empty Tools Assistant", - "tools": [], - }, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["name"] == "Empty Tools Assistant" - assert data["tools"] == [] - assert "id" in data - assert "created_at" in data - - def test_retrieve_assistant( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test retrieving an existing assistant.""" - mock_assistant = self._create_test_assistant("asst_test123") - self._add_assistant_to_db(mock_assistant, test_db_session) - response = test_client.get("/v1/assistants/asst_test123", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["id"] == "asst_test123" - assert data["name"] == "Test Assistant" - assert data["description"] == "A test assistant" - - def test_retrieve_assistant_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test retrieving a non-existent assistant.""" - response = test_client.get( - "/v1/assistants/asst_nonexistent123", headers=test_headers - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert "detail" in data - - def test_modify_assistant( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test modifying an existing assistant.""" - workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 - mock_assistant = self._create_test_assistant( - "asst_test123", - workspace_id=workspace_id, - name="Original Name", - description="Original description", - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - modify_data = { - "name": "Modified Name", - "description": "Modified description", - } - response = test_client.post( - "/v1/assistants/asst_test123", - json=modify_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["name"] == "Modified Name" - assert data["description"] == "Modified description" - assert data["id"] == "asst_test123" - assert data["created_at"] == 1234567890 - - def test_modify_assistant_with_tools_and_system( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test modifying an assistant with tools and system prompt.""" - workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 - mock_assistant = self._create_test_assistant( - "asst_test123", - workspace_id=workspace_id, - name="Original Name", - description="Original description", - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - modify_data = { - "name": "Modified Name", - "tools": ["new_tool1", "new_tool2"], - "system": "You are a modified custom assistant.", - } - response = test_client.post( - "/v1/assistants/asst_test123", - json=modify_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["name"] == "Modified Name" - assert data["tools"] == ["new_tool1", "new_tool2"] - assert data["system"] == "You are a modified custom assistant." - assert data["id"] == "asst_test123" - assert data["created_at"] == 1234567890 - - def test_modify_assistant_partial( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test modifying an assistant with partial data.""" - workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 - mock_assistant = self._create_test_assistant( - "asst_test123", - workspace_id=workspace_id, - name="Original Name", - description="Original description", - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - modify_data = {"name": "Only Name Modified"} - response = test_client.post( - "/v1/assistants/asst_test123", - json=modify_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["name"] == "Only Name Modified" - assert data["description"] == "Original description" # Unchanged - - def test_modify_assistant_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test modifying a non-existent assistant.""" - modify_data = {"name": "Modified Name"} - response = test_client.post( - "/v1/assistants/asst_nonexistent123", json=modify_data, headers=test_headers - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - - def test_delete_assistant( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test deleting an existing assistant.""" - workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 - mock_assistant = self._create_test_assistant( - "asst_test123", workspace_id=workspace_id - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - response = test_client.delete( - "/v1/assistants/asst_test123", headers=test_headers - ) - - assert response.status_code == status.HTTP_204_NO_CONTENT - assert response.content == b"" - - def test_delete_assistant_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test deleting a non-existent assistant.""" - response = test_client.delete( - "/v1/assistants/asst_nonexistent123", headers=test_headers - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - - def test_modify_default_assistant_forbidden( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test that modifying a default assistant returns 403 Forbidden.""" - default_assistant = self._create_test_assistant( - "asst_default123", - workspace_id=None, # No workspace_id = default - name="Default Assistant", - description="This is a default assistant", - ) - self._add_assistant_to_db(default_assistant, test_db_session) - # Try to modify the default assistant - response = test_client.post( - "/v1/assistants/asst_default123", - headers=test_headers, - json={"name": "Modified Name"}, - ) - assert response.status_code == 403 - assert "cannot be modified" in response.json()["detail"] - - def test_delete_default_assistant_forbidden( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test that deleting a default assistant returns 403 Forbidden.""" - default_assistant = self._create_test_assistant( - "asst_default456", - workspace_id=None, # No workspace_id = default - name="Default Assistant", - description="This is a default assistant", - ) - self._add_assistant_to_db(default_assistant, test_db_session) - # Try to delete the default assistant - response = test_client.delete( - "/v1/assistants/asst_default456", - headers=test_headers, - ) - assert response.status_code == 403 - assert "cannot be deleted" in response.json()["detail"] - - def test_list_assistants_includes_default_and_workspace( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test that listing assistants includes both default and - workspace-scoped ones. - """ - # Create a default assistant (no workspace_id) - default_assistant = self._create_test_assistant( - "asst_default789", - workspace_id=None, # No workspace_id = default - name="Default Assistant", - description="This is a default assistant", - ) - self._add_assistant_to_db(default_assistant, test_db_session) - - # Create a workspace-scoped assistant - workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 - workspace_assistant = self._create_test_assistant( - "asst_workspace123", - workspace_id=workspace_id, - name="Workspace Assistant", - description="This is a workspace assistant", - ) - self._add_assistant_to_db(workspace_assistant, test_db_session) - - # List assistants - should include both - response = test_client.get("/v1/assistants", headers=test_headers) - assert response.status_code == 200 - - data = response.json() - assistant_ids = [assistant["id"] for assistant in data["data"]] - - # Should include both default and workspace assistants - assert "asst_default789" in assistant_ids - assert "asst_workspace123" in assistant_ids - - # Verify workspace_id fields - default_assistant_data = next( - a for a in data["data"] if a["id"] == "asst_default789" - ) - workspace_assistant_data = next( - a for a in data["data"] if a["id"] == "asst_workspace123" - ) - - assert default_assistant_data["workspace_id"] is None - assert workspace_assistant_data["workspace_id"] == str(workspace_id) - - def test_retrieve_default_assistant_success( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test that retrieving a default assistant works.""" - default_assistant = self._create_test_assistant( - "asst_defaultretrieve", - workspace_id=None, # No workspace_id = default - name="Default Assistant", - description="This is a default assistant", - ) - self._add_assistant_to_db(default_assistant, test_db_session) - # Retrieve the default assistant - response = test_client.get( - "/v1/assistants/asst_defaultretrieve", - headers=test_headers, - ) - assert response.status_code == 200 - - data = response.json() - assert data["id"] == "asst_defaultretrieve" - assert data["workspace_id"] is None - - def test_workspace_scoped_assistant_operations_success( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test that workspace-scoped assistants can be modified and deleted.""" - workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 - workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 - workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 - workspace_assistant = self._create_test_assistant( - "asst_workspaceops", - workspace_id=workspace_id, - name="Workspace Assistant", - description="This is a workspace assistant", - ) - self._add_assistant_to_db(workspace_assistant, test_db_session) - # Modify the workspace assistant - response = test_client.post( - "/v1/assistants/asst_workspaceops", - headers=test_headers, - json={"name": "Modified Workspace Assistant"}, - ) - assert response.status_code == 200 - - data = response.json() - assert data["name"] == "Modified Workspace Assistant" - assert data["workspace_id"] == str(workspace_id) - - # Delete the workspace assistant - response = test_client.delete( - "/v1/assistants/asst_workspaceops", - headers=test_headers, - ) - assert response.status_code == 204 - - # Verify it's deleted - response = test_client.get( - "/v1/assistants/asst_workspaceops", - headers=test_headers, - ) - assert response.status_code == 404 diff --git a/tests/integration/chat/api/test_files.py b/tests/integration/chat/api/test_files.py deleted file mode 100644 index 8eb76e2f..00000000 --- a/tests/integration/chat/api/test_files.py +++ /dev/null @@ -1,612 +0,0 @@ -"""Integration tests for the files API endpoints.""" - -import io -import tempfile -from datetime import datetime, timezone -from pathlib import Path -from uuid import UUID - -import pytest -from fastapi import status -from fastapi.testclient import TestClient -from sqlalchemy.orm import Session - -from askui.chat.api.files.models import File -from askui.chat.api.files.orms import FileOrm -from askui.chat.api.files.service import FileService -from askui.chat.api.models import FileId -from askui.utils.api_utils import NotFoundError - - -class TestFilesAPI: - """Test suite for the files API endpoints.""" - - def _add_file_to_db(self, file: File, test_db_session: Session) -> None: - """Add a file to the test database.""" - file_orm = FileOrm.from_model(file) - test_db_session.add(file_orm) - test_db_session.commit() - - def test_list_files_empty( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test listing files when no files exist.""" - response = test_client.get("/v1/files", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert data["data"] == [] - assert data["has_more"] is False - - def test_list_files_with_files( - self, - test_headers: dict[str, str], - test_db_session: Session, - ) -> None: - """Test listing files when files exist.""" - # Create a mock file in the temporary workspace - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - files_dir = workspace_path / "files" - files_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock file - workspace_id = UUID(test_headers["askui-workspace"]) - mock_file = File( - id="file_test123", - object="file", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - filename="test.txt", - size=32, - media_type="text/plain", - workspace_id=workspace_id, - ) - (files_dir / "file_test123.json").write_text(mock_file.model_dump_json()) - - # Add file to database - self._add_file_to_db(mock_file, test_db_session) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - - try: - with TestClient(app) as client: - response = client.get("/v1/files", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert len(data["data"]) == 1 - assert data["data"][0]["id"] == "file_test123" - assert data["data"][0]["filename"] == "test.txt" - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_list_files_with_pagination( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test listing files with pagination parameters.""" - # Create multiple mock files in the temporary workspace - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - files_dir = workspace_path / "files" - files_dir.mkdir(parents=True, exist_ok=True) - - # Create multiple mock files - workspace_id = UUID(test_headers["askui-workspace"]) - for i in range(5): - mock_file = File( - id=f"file_test{i}", - object="file", - created_at=datetime.fromtimestamp(1234567890 + i, tz=timezone.utc), - filename=f"test{i}.txt", - size=32, - media_type="text/plain", - workspace_id=workspace_id, - ) - (files_dir / f"file_test{i}.json").write_text(mock_file.model_dump_json()) - # Add file to database - self._add_file_to_db(mock_file, test_db_session) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - - try: - with TestClient(app) as client: - response = client.get("/v1/files?limit=2", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert len(data["data"]) == 2 - assert data["has_more"] is True - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_upload_file_success( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test successful file upload.""" - file_content = b"test file content" - files = {"file": ("test.txt", io.BytesIO(file_content), "text/plain")} - - response = test_client.post("/v1/files", files=files, headers=test_headers) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["object"] == "file" - assert data["filename"] == "test.txt" - assert data["size"] == len(file_content) - assert data["media_type"] == "text/plain" - assert "id" in data - assert "created_at" in data - - def test_upload_file_without_filename( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test file upload with simple filename.""" - file_content = b"test file content" - # Test with a simple filename - files = {"file": ("test", io.BytesIO(file_content), "text/plain")} - - # Create a test app with overridden dependencies - from .conftest import create_test_app_with_overrides - - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - test_app = create_test_app_with_overrides(test_db_session, workspace_path) - - with TestClient(test_app) as client: - response = client.post("/v1/files", files=files, headers=test_headers) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["object"] == "file" - # Should use the provided filename - assert data["filename"] == "test" - assert data["size"] == len(file_content) - assert data["media_type"] == "text/plain" - - def test_upload_file_large_size( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test file upload with file exceeding size limit.""" - # Create a file larger than 20MB - large_content = b"x" * (21 * 1024 * 1024) - files = {"file": ("large.txt", io.BytesIO(large_content), "text/plain")} - - response = test_client.post("/v1/files", files=files, headers=test_headers) - - assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE - data = response.json() - assert "detail" in data - - def test_retrieve_file_success( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test successful file retrieval.""" - # Create a mock file in the temporary workspace - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - files_dir = workspace_path / "files" - files_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock file - workspace_id = UUID(test_headers["askui-workspace"]) - mock_file = File( - id="file_test123", - object="file", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - filename="test.txt", - size=32, - media_type="text/plain", - workspace_id=workspace_id, - ) - (files_dir / "file_test123.json").write_text(mock_file.model_dump_json()) - - # Add file to database - self._add_file_to_db(mock_file, test_db_session) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - - try: - with TestClient(app) as client: - response = client.get("/v1/files/file_test123", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["id"] == "file_test123" - assert data["filename"] == "test.txt" - assert data["size"] == 32 - assert data["media_type"] == "text/plain" - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_retrieve_file_not_found( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test file retrieval when file doesn't exist.""" - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - response = client.get( - "/v1/files/file_nonexistent123", headers=test_headers - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert "detail" in data - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_download_file_success( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test successful file download.""" - # Create a mock file in the temporary workspace - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - files_dir = workspace_path / "files" - workspace_id = UUID(test_headers["askui-workspace"]) - static_dir = workspace_path / "workspaces" / str(workspace_id) / "static" - files_dir.mkdir(parents=True, exist_ok=True) - static_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock file - mock_file = File( - id="file_test123", - object="file", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - filename="test.txt", - size=32, - media_type="text/plain", - workspace_id=workspace_id, - ) - (files_dir / "file_test123.json").write_text(mock_file.model_dump_json()) - - # Create the actual file content - file_content = b"test file content" - (static_dir / "file_test123.txt").write_bytes(file_content) - - # Add file to database - self._add_file_to_db(mock_file, test_db_session) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - - try: - with TestClient(app) as client: - response = client.get( - "/v1/files/file_test123/content", headers=test_headers - ) - - assert response.status_code == status.HTTP_200_OK - assert response.content == file_content - assert response.headers["content-type"].startswith("text/plain") - assert ( - response.headers["content-disposition"] - == 'attachment; filename="test.txt"' - ) - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_download_file_not_found( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test file download when file doesn't exist.""" - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - response = client.get( - "/v1/files/file_nonexistent123/content", headers=test_headers - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert "detail" in data - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_delete_file_success( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test successful file deletion.""" - # Create a mock file in the temporary workspace - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - files_dir = workspace_path / "files" - workspace_id = UUID(test_headers["askui-workspace"]) - static_dir = workspace_path / "workspaces" / str(workspace_id) / "static" - files_dir.mkdir(parents=True, exist_ok=True) - static_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock file - mock_file = File( - id="file_test123", - object="file", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - filename="test.txt", - size=32, - media_type="text/plain", - workspace_id=workspace_id, - ) - (files_dir / "file_test123.json").write_text(mock_file.model_dump_json()) - - # Create the actual file content - file_content = b"test file content" - (static_dir / "file_test123.txt").write_bytes(file_content) - - # Add file to database - self._add_file_to_db(mock_file, test_db_session) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - - def override_workspace_dir() -> Path: - return workspace_path - - file_service_override = FileService(test_db_session, workspace_path) - - def override_file_service() -> FileService: - return file_service_override - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - - try: - with TestClient(app) as client: - response = client.delete("/v1/files/file_test123", headers=test_headers) - - assert response.status_code == status.HTTP_204_NO_CONTENT - - # Verify static file is deleted (JSON files are no longer used) - assert not (static_dir / "file_test123.txt").exists() - - # Verify file is deleted from database - with pytest.raises(NotFoundError): - file_service_override.retrieve(workspace_id, FileId("file_test123")) - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_delete_file_not_found( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test file deletion when file doesn't exist.""" - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - response = client.delete( - "/v1/files/file_nonexistent123", headers=test_headers - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert "detail" in data - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_upload_different_file_types( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test uploading different file types.""" - # Test JSON file - json_content = b'{"key": "value"}' - json_files = { - "file": ("data.json", io.BytesIO(json_content), "application/json") - } - - response = test_client.post("/v1/files", files=json_files, headers=test_headers) - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["media_type"] == "application/json" - assert data["filename"] == "data.json" - - # Test PDF file - pdf_content = b"%PDF-1.4\ntest pdf content" - pdf_files = { - "file": ("document.pdf", io.BytesIO(pdf_content), "application/pdf") - } - - response = test_client.post("/v1/files", files=pdf_files, headers=test_headers) - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["media_type"] == "application/pdf" - assert data["filename"] == "document.pdf" - - def test_upload_file_without_content_type( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test file upload without content type.""" - file_content = b"test file content" - files = {"file": ("test.txt", io.BytesIO(file_content), None)} - - response = test_client.post("/v1/files", files=files, headers=test_headers) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - # FastAPI might infer content type from filename, so we just check it's not None - assert data["media_type"] is not None - assert data["media_type"] != "" - - def test_list_files_with_filtering( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test listing files with filtering parameters.""" - # Create multiple mock files in the temporary workspace - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - files_dir = workspace_path / "files" - files_dir.mkdir(parents=True, exist_ok=True) - - # Create multiple mock files with different timestamps - workspace_id = UUID(test_headers["askui-workspace"]) - for i in range(3): - mock_file = File( - id=f"file_test{i}", - object="file", - created_at=datetime.fromtimestamp(1234567890 + i, tz=timezone.utc), - filename=f"test{i}.txt", - size=32, - media_type="text/plain", - workspace_id=workspace_id, - ) - (files_dir / f"file_test{i}.json").write_text(mock_file.model_dump_json()) - # Add file to database - self._add_file_to_db(mock_file, test_db_session) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - - try: - with TestClient(app) as client: - # Test with after parameter - response = client.get( - "/v1/files?after=file_test0", headers=test_headers - ) - assert response.status_code == status.HTTP_200_OK - data = response.json() - # In descending lexicographic order, file_test0 is the last file, - # so there are no files "after" it - assert len(data["data"]) == 0 - - # Test with before parameter - response = client.get( - "/v1/files?before=file_test2", headers=test_headers - ) - assert response.status_code == status.HTTP_200_OK - data = response.json() - # In descending lexicographic order, file_test2 is the first file, - # so there are no files "before" it - assert len(data["data"]) == 0 - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() diff --git a/tests/integration/chat/api/test_files_edge_cases.py b/tests/integration/chat/api/test_files_edge_cases.py deleted file mode 100644 index dad236af..00000000 --- a/tests/integration/chat/api/test_files_edge_cases.py +++ /dev/null @@ -1,460 +0,0 @@ -"""Edge case and error scenario tests for the files API endpoints.""" - -import io -import tempfile -from pathlib import Path - -from fastapi import status -from fastapi.testclient import TestClient -from sqlalchemy.orm import Session - - -class TestFilesAPIEdgeCases: - """Test suite for edge cases and error scenarios in the files API.""" - - def test_upload_empty_file( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test uploading an empty file.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - from askui.chat.api.files.service import FileService - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - empty_content = b"" - files = {"file": ("empty.txt", io.BytesIO(empty_content), "text/plain")} - - response = client.post("/v1/files", files=files, headers=test_headers) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["size"] == 0 - assert data["filename"] == "empty.txt" - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_upload_file_with_special_characters_in_filename( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test uploading a file with special characters in the filename.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - from askui.chat.api.files.service import FileService - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - file_content = b"test content" - special_filename = "file with spaces & special chars!@#$%^&*().txt" - files = { - "file": (special_filename, io.BytesIO(file_content), "text/plain") - } - - response = client.post("/v1/files", files=files, headers=test_headers) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["filename"] == special_filename - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_upload_file_with_very_long_filename( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test uploading a file with a very long filename.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - from askui.chat.api.files.service import FileService - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - file_content = b"test content" - long_filename = "a" * 255 + ".txt" # Very long filename - files = { - "file": (long_filename, io.BytesIO(file_content), "text/plain") - } - - response = client.post("/v1/files", files=files, headers=test_headers) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["filename"] == long_filename - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_upload_file_with_unknown_mime_type( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test uploading a file with an unknown MIME type.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - from askui.chat.api.files.service import FileService - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - file_content = b"test content" - unknown_mime = "application/unknown-type" - files = {"file": ("test.xyz", io.BytesIO(file_content), unknown_mime)} - - response = client.post("/v1/files", files=files, headers=test_headers) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["media_type"] == unknown_mime - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_upload_file_with_binary_content( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test uploading a file with binary content.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - from askui.chat.api.files.service import FileService - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - # Create binary content (PNG header) - binary_content = b"\x89PNG\r\n\x1a\n" + b"x" * 100 - files = {"file": ("test.png", io.BytesIO(binary_content), "image/png")} - - response = client.post("/v1/files", files=files, headers=test_headers) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["media_type"] == "image/png" - assert data["size"] == len(binary_content) - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_upload_file_without_workspace_header( - self, test_client: TestClient - ) -> None: - """Test uploading a file without workspace header.""" - file_content = b"test content" - files = {"file": ("test.txt", io.BytesIO(file_content), "text/plain")} - - response = test_client.post("/v1/files", files=files) - - # Should fail due to missing workspace header - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - - def test_upload_file_with_invalid_workspace_header( - self, test_client: TestClient - ) -> None: - """Test uploading a file with an invalid workspace header.""" - file_content = b"test content" - files = {"file": ("test.txt", io.BytesIO(file_content), "text/plain")} - invalid_headers = {"askui-workspace": "invalid-uuid"} - - response = test_client.post("/v1/files", files=files, headers=invalid_headers) - - # Should fail due to invalid workspace format - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - - def test_upload_file_with_malformed_file_data( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test uploading with malformed file data.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - from askui.chat.api.files.service import FileService - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - # Send request without file data - response = client.post("/v1/files", headers=test_headers) - - # Should fail due to missing file - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_upload_file_with_corrupted_content( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test uploading a file with corrupted content.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - from askui.chat.api.files.service import FileService - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - # Create a file-like object that raises an error when read - class CorruptedFile: - def read(self, size: int) -> bytes: # noqa: ARG002 - error_msg = "Simulated corruption" - raise IOError(error_msg) - - files = {"file": ("corrupted.txt", CorruptedFile(), "text/plain")} - - response = client.post("/v1/files", files=files, headers=test_headers) # type: ignore[arg-type] - - # Should fail due to corruption - FastAPI returns 400 for this case - assert response.status_code == status.HTTP_400_BAD_REQUEST - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_list_files_with_invalid_pagination( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test listing files with invalid pagination parameters.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - from askui.chat.api.files.service import FileService - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - # Test with negative limit - response = client.get("/v1/files?limit=-1", headers=test_headers) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - - # Test with zero limit - response = client.get("/v1/files?limit=0", headers=test_headers) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - - # Test with very large limit - response = client.get("/v1/files?limit=10000", headers=test_headers) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_retrieve_file_with_invalid_id_format( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test retrieving a file with an invalid ID format.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - from askui.chat.api.files.service import FileService - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - # Test with empty ID - FastAPI returns 200 for this (lists files) - response = client.get("/v1/files/", headers=test_headers) - assert response.status_code == status.HTTP_200_OK - - # Test with ID containing invalid characters - should fail validation - response = client.get("/v1/files/file@#$%", headers=test_headers) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() - - def test_delete_file_with_invalid_id_format( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test deleting a file with an invalid ID format.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir - from askui.chat.api.files.dependencies import get_file_service - from askui.chat.api.files.service import FileService - - def override_workspace_dir() -> Path: - return workspace_path - - def override_file_service() -> FileService: - return FileService(test_db_session, workspace_path) - - def override_set_env_from_headers() -> None: - # No-op for testing - pass - - app.dependency_overrides[get_workspace_dir] = override_workspace_dir - app.dependency_overrides[get_file_service] = override_file_service - app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers - - try: - with TestClient(app) as client: - # Test with empty ID - FastAPI returns 405 Method Not Allowed for this - response = client.delete("/v1/files/", headers=test_headers) - assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED - - # Test with ID containing invalid characters - should fail validation - response = client.delete("/v1/files/file@#$%", headers=test_headers) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - finally: - # Clean up dependency overrides - app.dependency_overrides.clear() diff --git a/tests/integration/chat/api/test_files_service.py b/tests/integration/chat/api/test_files_service.py deleted file mode 100644 index 00d29629..00000000 --- a/tests/integration/chat/api/test_files_service.py +++ /dev/null @@ -1,373 +0,0 @@ -"""Integration tests for the FileService class.""" - -import tempfile -from pathlib import Path -from unittest.mock import AsyncMock - -import pytest -from fastapi import UploadFile -from sqlalchemy.orm import Session - -from askui.chat.api.files.models import File, FileCreate -from askui.chat.api.files.service import FileService -from askui.chat.api.models import FileId -from askui.utils.api_utils import FileTooLargeError, NotFoundError - - -class TestFileService: - """Test suite for the FileService class.""" - - @pytest.fixture - def temp_workspace_dir(self) -> Path: - """Create a temporary workspace directory for testing.""" - temp_dir = tempfile.mkdtemp() - return Path(temp_dir) - - @pytest.fixture - def file_service( - self, test_db_session: Session, temp_workspace_dir: Path - ) -> FileService: - """Create a FileService instance with temporary workspace.""" - return FileService(test_db_session, temp_workspace_dir) - - @pytest.fixture - def sample_file_params(self) -> FileCreate: - """Create sample file creation parameters.""" - return FileCreate(filename="test.txt", size=32, media_type="text/plain") - - def test_get_static_file_path(self, file_service: FileService) -> None: - """Test getting static file path based on file extension.""" - from datetime import datetime, timezone - - file = File( - id="file_test123", - object="file", - created_at=datetime.now(timezone.utc), - filename="test.txt", - size=32, - media_type="text/plain", - workspace_id=None, - ) - - static_path = file_service._get_static_file_path(file) - expected_path = file_service._data_dir / "static" / "file_test123.txt" - assert static_path == expected_path - - def test_get_static_file_path_no_extension(self, file_service: FileService) -> None: - """Test getting static file path when MIME type has no extension.""" - from datetime import datetime, timezone - - file = File( - id="file_test123", - object="file", - created_at=datetime.now(timezone.utc), - filename="test", - size=32, - media_type="application/octet-stream", - workspace_id=None, - ) - - static_path = file_service._get_static_file_path(file) - expected_path = file_service._data_dir / "static" / "file_test123" - assert static_path == expected_path - - def test_list_files_empty(self, file_service: FileService) -> None: - """Test listing files when no files exist.""" - from askui.utils.api_utils import ListQuery - - query = ListQuery() - result = file_service.list_(None, query) - - assert result.object == "list" - assert result.data == [] - assert result.has_more is False - - def test_list_files_with_files( - self, file_service: FileService, sample_file_params: FileCreate - ) -> None: - """Test listing files when files exist.""" - from askui.utils.api_utils import ListQuery - - # Create a file first - temp_file = Path(tempfile.mktemp()) - file_content = b"test content" - temp_file.write_bytes(file_content) - - # Update the size to match the actual file content - params = FileCreate( - filename=sample_file_params.filename, - size=len(file_content), - media_type=sample_file_params.media_type, - ) - - try: - file = file_service.create(None, params, temp_file) - - query = ListQuery() - result = file_service.list_(None, query) - - assert result.object == "list" - assert len(result.data) == 1 - assert result.data[0].id == file.id - assert result.data[0].filename == file.filename - finally: - temp_file.unlink(missing_ok=True) - - def test_retrieve_file_success( - self, file_service: FileService, sample_file_params: FileCreate - ) -> None: - """Test successful file retrieval.""" - # Create a file first - temp_file = Path(tempfile.mktemp()) - file_content = b"test content" - temp_file.write_bytes(file_content) - - # Update the size to match the actual file content - params = FileCreate( - filename=sample_file_params.filename, - size=len(file_content), - media_type=sample_file_params.media_type, - ) - - try: - file = file_service.create(None, params, temp_file) - - retrieved_file = file_service.retrieve(None, file.id) - - assert retrieved_file.id == file.id - assert retrieved_file.filename == file.filename - assert retrieved_file.size == file.size - assert retrieved_file.media_type == file.media_type - finally: - temp_file.unlink(missing_ok=True) - - def test_retrieve_file_not_found(self, file_service: FileService) -> None: - """Test file retrieval when file doesn't exist.""" - file_id = FileId("file_nonexistent123") - - with pytest.raises(NotFoundError): - file_service.retrieve(None, file_id) - - def test_delete_file_success( - self, file_service: FileService, sample_file_params: FileCreate - ) -> None: - """Test successful file deletion.""" - from uuid import UUID - - # Create a workspace_id for the test file (non-default files can be deleted) - workspace_id = UUID("75592acb-9f48-4a10-8331-ea8faeed54a5") - - # Create a file first - temp_file = Path(tempfile.mktemp()) - file_content = b"test content" - temp_file.write_bytes(file_content) - - # Update the size to match the actual file content - params = FileCreate( - filename=sample_file_params.filename, - size=len(file_content), - media_type=sample_file_params.media_type, - ) - - try: - file = file_service.create(workspace_id, params, temp_file) - - # Verify file exists by retrieving it - retrieved_file = file_service.retrieve(workspace_id, file.id) - assert retrieved_file.id == file.id - - # Delete the file - file_service.delete(workspace_id, file.id) - - # Verify file is deleted by trying to retrieve it - # (should raise NotFoundError) - with pytest.raises(NotFoundError): - file_service.retrieve(workspace_id, file.id) - finally: - temp_file.unlink(missing_ok=True) - - def test_delete_file_not_found(self, file_service: FileService) -> None: - """Test file deletion when file doesn't exist.""" - file_id = FileId("file_nonexistent123") - - with pytest.raises(NotFoundError): - file_service.delete(None, file_id) - - def test_retrieve_file_content_success( - self, file_service: FileService, sample_file_params: FileCreate - ) -> None: - """Test successful file content retrieval.""" - # Create a file first - temp_file = Path(tempfile.mktemp()) - file_content = b"test content" - temp_file.write_bytes(file_content) - - # Update the size to match the actual file content - params = FileCreate( - filename=sample_file_params.filename, - size=len(file_content), - media_type=sample_file_params.media_type, - ) - - try: - file = file_service.create(None, params, temp_file) - - retrieved_file, file_path = file_service.retrieve_file_content( - None, file.id - ) - - assert retrieved_file.id == file.id - assert file_path.exists() - finally: - temp_file.unlink(missing_ok=True) - - def test_retrieve_file_content_not_found(self, file_service: FileService) -> None: - """Test file content retrieval when file doesn't exist.""" - file_id = FileId("file_nonexistent123") - - with pytest.raises(NotFoundError): - file_service.retrieve_file_content(None, file_id) - - def test_create_file_success( - self, file_service: FileService, sample_file_params: FileCreate - ) -> None: - """Test successful file creation.""" - temp_file = Path(tempfile.mktemp()) - file_content = b"test content" - temp_file.write_bytes(file_content) - - try: - # Update the size to match the actual file content - params = FileCreate( - filename=sample_file_params.filename, - size=len(file_content), - media_type=sample_file_params.media_type, - ) - - file = file_service.create(None, params, temp_file) - - assert file.id.startswith("file_") - assert file.filename == params.filename - assert file.size == params.size - assert file.media_type == params.media_type - # created_at is a datetime, compare with timezone-aware datetime - from datetime import datetime, timezone - - assert isinstance(file.created_at, datetime) - assert file.created_at > datetime(2020, 1, 1, tzinfo=timezone.utc) - - # Verify static file was moved - static_path = file_service._get_static_file_path(file) - assert static_path.exists() - - finally: - temp_file.unlink(missing_ok=True) - - def test_create_file_without_filename(self, file_service: FileService) -> None: - """Test file creation without filename.""" - temp_file = Path(tempfile.mktemp()) - file_content = b"test content" - temp_file.write_bytes(file_content) - - params = FileCreate( - filename=None, size=len(file_content), media_type="text/plain" - ) - - try: - file = file_service.create(None, params, temp_file) - - # Should auto-generate filename with extension - assert file.filename.endswith(".txt") - assert file.filename.startswith("file_") - - finally: - temp_file.unlink(missing_ok=True) - - @pytest.mark.asyncio - async def test_write_to_temp_file_success(self, file_service: FileService) -> None: - """Test successful writing to temporary file.""" - file_content = b"test file content" - mock_upload_file = AsyncMock(spec=UploadFile) - mock_upload_file.content_type = "text/plain" - mock_upload_file.filename = None - mock_upload_file.read.side_effect = [ - file_content, - b"", - ] # Read content, then empty - - params, temp_path = await file_service._write_to_temp_file(mock_upload_file) - - assert params.filename is None # No filename provided - assert params.size == len(file_content) - assert params.media_type == "text/plain" - assert temp_path.exists() - assert temp_path.read_bytes() == file_content - - # Cleanup - temp_path.unlink() - - @pytest.mark.asyncio - async def test_write_to_temp_file_large_size( - self, file_service: FileService - ) -> None: - """Test writing to temporary file with size exceeding limit.""" - # Create content larger than 20MB - large_content = b"x" * (21 * 1024 * 1024) - mock_upload_file = AsyncMock(spec=UploadFile) - mock_upload_file.content_type = "text/plain" - mock_upload_file.filename = "test.txt" - mock_upload_file.read.side_effect = [ - large_content, # Read all content at once - ] - - with pytest.raises(FileTooLargeError): - await file_service._write_to_temp_file(mock_upload_file) - - @pytest.mark.asyncio - async def test_write_to_temp_file_no_content_type( - self, file_service: FileService - ) -> None: - """Test writing to temporary file without content type.""" - file_content = b"test content" - mock_upload_file = AsyncMock(spec=UploadFile) - mock_upload_file.content_type = None - mock_upload_file.filename = "test.txt" - mock_upload_file.read.side_effect = [file_content, b""] - - params, temp_path = await file_service._write_to_temp_file(mock_upload_file) - - assert params.media_type == "application/octet-stream" # Default fallback - - # Cleanup - temp_path.unlink() - - @pytest.mark.asyncio - async def test_upload_file_success(self, file_service: FileService) -> None: - """Test successful file upload.""" - file_content = b"test file content" - mock_upload_file = AsyncMock(spec=UploadFile) - mock_upload_file.filename = "test.txt" - mock_upload_file.content_type = "text/plain" - mock_upload_file.read.side_effect = [file_content, b""] - - file = await file_service.upload_file(None, mock_upload_file) - - assert file.filename == "test.txt" - assert file.size == len(file_content) - assert file.media_type == "text/plain" - assert file.id.startswith("file_") - - # Verify static file was created - static_path = file_service._get_static_file_path(file) - assert static_path.exists() - - @pytest.mark.asyncio - async def test_upload_file_upload_failure(self, file_service: FileService) -> None: - """Test file upload when writing fails.""" - mock_upload_file = AsyncMock(spec=UploadFile) - mock_upload_file.filename = "test.txt" - mock_upload_file.content_type = "text/plain" - mock_upload_file.read.side_effect = Exception("Simulated upload failure") - - with pytest.raises(Exception, match="Simulated upload failure"): - await file_service.upload_file(None, mock_upload_file) diff --git a/tests/integration/chat/api/test_health.py b/tests/integration/chat/api/test_health.py deleted file mode 100644 index 74ad60d2..00000000 --- a/tests/integration/chat/api/test_health.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Integration tests for the health API endpoint.""" - -from fastapi import status -from fastapi.testclient import TestClient - - -class TestHealthAPI: - """Test suite for the health API endpoint.""" - - def test_health_check( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test the health check endpoint.""" - response = test_client.get("/v1/health", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["status"] == "OK" - - def test_health_check_without_headers(self, test_client: TestClient) -> None: - """Test the health check endpoint without workspace headers.""" - response = test_client.get("/v1/health") - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["status"] == "OK" - - def test_health_check_response_structure( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test that the health check response has the correct structure.""" - response = test_client.get("/v1/health", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - - # Check that only the expected fields are present - assert set(data.keys()) == {"status"} - assert isinstance(data["status"], str) - assert data["status"] == "OK" - - def test_health_check_content_type( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test that the health check response has the correct content type.""" - response = test_client.get("/v1/health", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - assert response.headers["content-type"] == "application/json" diff --git a/tests/integration/chat/api/test_mcp_configs.py b/tests/integration/chat/api/test_mcp_configs.py deleted file mode 100644 index cb32aa5d..00000000 --- a/tests/integration/chat/api/test_mcp_configs.py +++ /dev/null @@ -1,611 +0,0 @@ -"""Integration tests for the MCP configs API endpoints.""" - -from uuid import UUID - -from fastapi import status -from fastapi.testclient import TestClient -from fastmcp.mcp_config import StdioMCPServer -from sqlalchemy.orm import Session - -from askui.chat.api.mcp_configs.models import McpConfig -from askui.chat.api.mcp_configs.orms import McpConfigOrm -from askui.chat.api.mcp_configs.service import McpConfigService - - -class TestMcpConfigsAPI: - """Test suite for the MCP configs API endpoints.""" - - def test_list_mcp_configs_with_configs( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test listing MCP configs when configs exist.""" - from datetime import datetime, timezone - - # Create a mock MCP config in the database - workspace_id = UUID(test_headers["askui-workspace"]) - mock_config = McpConfig( - id="mcpcnf_test123", - object="mcp_config", - created_at=datetime.fromtimestamp(1234567890, timezone.utc), - name="Test MCP Config", - mcp_server=StdioMCPServer(type="stdio", command="test_command"), - workspace_id=workspace_id, - ) - mcp_config_orm = McpConfigOrm.from_model(mock_config) - test_db_session.add(mcp_config_orm) - test_db_session.commit() - - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - response = client.get("/v1/mcp-configs", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert len(data["data"]) == 1 - assert data["data"][0]["id"] == "mcpcnf_test123" - assert data["data"][0]["name"] == "Test MCP Config" - assert data["data"][0]["mcp_server"]["type"] == "stdio" - finally: - app.dependency_overrides.clear() - - def test_list_mcp_configs_with_pagination( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test listing MCP configs with pagination parameters.""" - from datetime import datetime, timezone - - # Create multiple mock MCP configs in the database - workspace_id = UUID(test_headers["askui-workspace"]) - for i in range(5): - mock_config = McpConfig( - id=f"mcpcnf_test{i}", - object="mcp_config", - created_at=datetime.fromtimestamp(1234567890 + i, timezone.utc), - name=f"Test MCP Config {i}", - mcp_server=StdioMCPServer(type="stdio", command=f"test_command_{i}"), - workspace_id=workspace_id, - ) - mcp_config_orm = McpConfigOrm.from_model(mock_config) - test_db_session.add(mcp_config_orm) - test_db_session.commit() - - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - response = client.get("/v1/mcp-configs?limit=3", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert len(data["data"]) == 3 - assert data["has_more"] is True - finally: - app.dependency_overrides.clear() - - def test_create_mcp_config( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test creating a new MCP config.""" - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - config_data = { - "name": "New MCP Config", - "mcp_server": {"type": "stdio", "command": "new_command"}, - } - response = client.post( - "/v1/mcp-configs", json=config_data, headers=test_headers - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["name"] == "New MCP Config" - assert data["mcp_server"]["type"] == "stdio" - assert data["mcp_server"]["command"] == "new_command" - finally: - app.dependency_overrides.clear() - - def test_create_mcp_config_minimal( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test creating an MCP config with minimal data.""" - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - response = client.post( - "/v1/mcp-configs", - json={ - "name": "Minimal Config", - "mcp_server": {"type": "stdio", "command": "minimal"}, - }, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["object"] == "mcp_config" - assert data["name"] == "Minimal Config" - assert data["mcp_server"]["type"] == "stdio" - assert data["mcp_server"]["command"] == "minimal" - finally: - app.dependency_overrides.clear() - - def test_retrieve_mcp_config( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test retrieving an existing MCP config.""" - from datetime import datetime, timezone - - mock_config = McpConfig( - id="mcpcnf_test123", - object="mcp_config", - created_at=datetime.fromtimestamp(1234567890, timezone.utc), - name="Test MCP Config", - mcp_server=StdioMCPServer(type="stdio", command="test_command"), - workspace_id=None, - ) - mcp_config_orm = McpConfigOrm.from_model(mock_config) - test_db_session.add(mcp_config_orm) - test_db_session.commit() - - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - response = client.get( - "/v1/mcp-configs/mcpcnf_test123", headers=test_headers - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["id"] == "mcpcnf_test123" - assert data["name"] == "Test MCP Config" - assert data["mcp_server"]["type"] == "stdio" - assert data["mcp_server"]["command"] == "test_command" - finally: - app.dependency_overrides.clear() - - def test_retrieve_mcp_config_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test retrieving a non-existent MCP config.""" - response = test_client.get( - "/v1/mcp-configs/mcpcnf_nonexistent123", headers=test_headers - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert "detail" in data - - def test_modify_mcp_config( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test modifying an existing MCP config.""" - from datetime import datetime, timezone - - workspace_id = UUID(test_headers["askui-workspace"]) - mock_config = McpConfig( - id="mcpcnf_test123", - object="mcp_config", - created_at=datetime.fromtimestamp(1234567890, timezone.utc), - name="Original Name", - mcp_server=StdioMCPServer(type="stdio", command="original_command"), - workspace_id=workspace_id, - ) - mcp_config_orm = McpConfigOrm.from_model(mock_config) - test_db_session.add(mcp_config_orm) - test_db_session.commit() - - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - modify_data = { - "name": "Modified Name", - "mcp_server": {"type": "stdio", "command": "modified_command"}, - } - response = client.post( - "/v1/mcp-configs/mcpcnf_test123", - json=modify_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["name"] == "Modified Name" - assert data["mcp_server"]["type"] == "stdio" - assert data["mcp_server"]["command"] == "modified_command" - finally: - app.dependency_overrides.clear() - - def test_modify_mcp_config_partial( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test modifying an MCP config with partial data.""" - from datetime import datetime, timezone - - workspace_id = UUID(test_headers["askui-workspace"]) - mock_config = McpConfig( - id="mcpcnf_test123", - object="mcp_config", - created_at=datetime.fromtimestamp(1234567890, timezone.utc), - name="Original Name", - mcp_server=StdioMCPServer(type="stdio", command="original_command"), - workspace_id=workspace_id, - ) - mcp_config_orm = McpConfigOrm.from_model(mock_config) - test_db_session.add(mcp_config_orm) - test_db_session.commit() - - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - modify_data = {"name": "Only Name Modified"} - response = client.post( - "/v1/mcp-configs/mcpcnf_test123", - json=modify_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["name"] == "Only Name Modified" - - finally: - app.dependency_overrides.clear() - - def test_modify_mcp_config_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test modifying a non-existent MCP config.""" - modify_data = {"name": "Modified Name"} - response = test_client.post( - "/v1/mcp-configs/mcpcnf_nonexistent123", - json=modify_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - - def test_delete_mcp_config( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test deleting an existing MCP config.""" - from datetime import datetime, timezone - - workspace_id = UUID(test_headers["askui-workspace"]) - mock_config = McpConfig( - id="mcpcnf_test123", - object="mcp_config", - created_at=datetime.fromtimestamp(1234567890, timezone.utc), - name="Test MCP Config", - mcp_server=StdioMCPServer(type="stdio", command="test_command"), - workspace_id=workspace_id, - ) - mcp_config_orm = McpConfigOrm.from_model(mock_config) - test_db_session.add(mcp_config_orm) - test_db_session.commit() - - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - response = client.delete( - "/v1/mcp-configs/mcpcnf_test123", headers=test_headers - ) - - assert response.status_code == status.HTTP_204_NO_CONTENT - assert response.content == b"" - finally: - app.dependency_overrides.clear() - - def test_delete_mcp_config_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test deleting a non-existent MCP config.""" - response = test_client.delete( - "/v1/mcp-configs/mcpcnf_nonexistent123", headers=test_headers - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - - def test_modify_default_mcp_config_forbidden( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test that modifying a default MCP configuration returns 403 Forbidden.""" - from datetime import datetime, timezone - - # Create a default MCP config (no workspace_id) in the database - default_config = McpConfig( - id="mcpcnf_default123", - object="mcp_config", - created_at=datetime.fromtimestamp(1234567890, timezone.utc), - name="Default MCP Config", - mcp_server=StdioMCPServer(type="stdio", command="default_command"), - workspace_id=None, # No workspace_id = default - ) - mcp_config_orm = McpConfigOrm.from_model(default_config) - test_db_session.add(mcp_config_orm) - test_db_session.commit() - - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - # Try to modify the default MCP config - response = client.post( - "/v1/mcp-configs/mcpcnf_default123", - headers=test_headers, - json={"name": "Modified Name"}, - ) - assert response.status_code == 403 - assert "cannot be modified" in response.json()["detail"] - finally: - app.dependency_overrides.clear() - - def test_delete_default_mcp_config_forbidden( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test that deleting a default MCP configuration returns 403 Forbidden.""" - from datetime import datetime, timezone - - # Create a default MCP config (no workspace_id) in the database - default_config = McpConfig( - id="mcpcnf_default456", - object="mcp_config", - created_at=datetime.fromtimestamp(1234567890, timezone.utc), - name="Default MCP Config", - mcp_server=StdioMCPServer(type="stdio", command="default_command"), - workspace_id=None, # No workspace_id = default - ) - mcp_config_orm = McpConfigOrm.from_model(default_config) - test_db_session.add(mcp_config_orm) - test_db_session.commit() - - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - # Try to delete the default MCP config - response = client.delete( - "/v1/mcp-configs/mcpcnf_default456", - headers=test_headers, - ) - assert response.status_code == 403 - assert "cannot be deleted" in response.json()["detail"] - finally: - app.dependency_overrides.clear() - - def test_list_mcp_configs_includes_default_and_workspace( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test that listing MCP configs includes both default and workspace-scoped - ones.""" - from datetime import datetime, timezone - - # Create a default MCP config (no workspace_id) in the database - default_config = McpConfig( - id="mcpcnf_default789", - object="mcp_config", - created_at=datetime.fromtimestamp(1234567890, timezone.utc), - name="Default MCP Config", - mcp_server=StdioMCPServer(type="stdio", command="default_command"), - workspace_id=None, # No workspace_id = default - ) - mcp_config_orm = McpConfigOrm.from_model(default_config) - test_db_session.add(mcp_config_orm) - - # Create a workspace-scoped MCP config - workspace_id = UUID(test_headers["askui-workspace"]) - workspace_config = McpConfig( - id="mcpcnf_workspace123", - object="mcp_config", - created_at=datetime.fromtimestamp(1234567890, timezone.utc), - name="Workspace MCP Config", - mcp_server=StdioMCPServer(type="stdio", command="workspace_command"), - workspace_id=workspace_id, - ) - workspace_config_orm = McpConfigOrm.from_model(workspace_config) - test_db_session.add(workspace_config_orm) - test_db_session.commit() - - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - # List MCP configs - should include both - response = client.get("/v1/mcp-configs", headers=test_headers) - assert response.status_code == 200 - - data = response.json() - config_ids = [config["id"] for config in data["data"]] - - # Should include both default and workspace configs - assert "mcpcnf_default789" in config_ids - assert "mcpcnf_workspace123" in config_ids - - # Verify workspace_id fields - default_config_data = next( - c for c in data["data"] if c["id"] == "mcpcnf_default789" - ) - workspace_config_data = next( - c for c in data["data"] if c["id"] == "mcpcnf_workspace123" - ) - - # Default config should not have workspace_id field (excluded when None) - assert "workspace_id" not in default_config_data - assert workspace_config_data["workspace_id"] == str(workspace_id) - finally: - app.dependency_overrides.clear() - - def test_retrieve_default_mcp_config_success( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test that retrieving a default MCP configuration works.""" - from datetime import datetime, timezone - - # Create a default MCP config (no workspace_id) in the database - default_config = McpConfig( - id="mcpcnf_defaultretrieve", - object="mcp_config", - created_at=datetime.fromtimestamp(1234567890, timezone.utc), - name="Default MCP Config", - mcp_server=StdioMCPServer(type="stdio", command="default_command"), - workspace_id=None, # No workspace_id = default - ) - mcp_config_orm = McpConfigOrm.from_model(default_config) - test_db_session.add(mcp_config_orm) - test_db_session.commit() - - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - # Retrieve the default MCP config - response = client.get( - "/v1/mcp-configs/mcpcnf_defaultretrieve", - headers=test_headers, - ) - assert response.status_code == 200 - - data = response.json() - assert data["id"] == "mcpcnf_defaultretrieve" - # Default config should not have workspace_id field (excluded when None) - assert "workspace_id" not in data - finally: - app.dependency_overrides.clear() - - def test_workspace_scoped_mcp_config_operations_success( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test that workspace-scoped MCP configs can be modified and deleted.""" - from datetime import datetime, timezone - - workspace_id = UUID(test_headers["askui-workspace"]) - workspace_config = McpConfig( - id="mcpcnf_workspaceops", - object="mcp_config", - created_at=datetime.fromtimestamp(1234567890, timezone.utc), - name="Workspace MCP Config", - mcp_server=StdioMCPServer(type="stdio", command="workspace_command"), - workspace_id=workspace_id, - ) - mcp_config_orm = McpConfigOrm.from_model(workspace_config) - test_db_session.add(mcp_config_orm) - test_db_session.commit() - - from askui.chat.api.app import app - from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service - - def override_mcp_config_service() -> McpConfigService: - return McpConfigService(test_db_session, seeds=[]) - - app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service - - try: - with TestClient(app) as client: - # Modify the workspace MCP config - response = client.post( - "/v1/mcp-configs/mcpcnf_workspaceops", - headers=test_headers, - json={"name": "Modified Workspace MCP Config"}, - ) - assert response.status_code == 200 - - data = response.json() - assert data["name"] == "Modified Workspace MCP Config" - assert data["workspace_id"] == str(workspace_id) - - # Delete the workspace MCP config - response = client.delete( - "/v1/mcp-configs/mcpcnf_workspaceops", - headers=test_headers, - ) - assert response.status_code == 204 - - # Verify it's deleted - response = client.get( - "/v1/mcp-configs/mcpcnf_workspaceops", - headers=test_headers, - ) - assert response.status_code == 404 - finally: - app.dependency_overrides.clear() diff --git a/tests/integration/chat/api/test_message_service.py b/tests/integration/chat/api/test_message_service.py deleted file mode 100644 index 0075460c..00000000 --- a/tests/integration/chat/api/test_message_service.py +++ /dev/null @@ -1,464 +0,0 @@ -"""Unit tests for the MessageService.""" - -from datetime import datetime, timezone -from uuid import UUID, uuid4 - -import pytest -from sqlalchemy.orm import Session - -from askui.chat.api.messages.models import ( - ROOT_MESSAGE_PARENT_ID, - Message, - MessageCreate, -) -from askui.chat.api.messages.service import MessageService -from askui.chat.api.threads.models import Thread -from askui.chat.api.threads.orms import ThreadOrm -from askui.utils.api_utils import ListQuery - - -class TestMessageServicePagination: - """Test pagination behavior with different order and after/before parameters.""" - - @pytest.fixture - def _workspace_id(self) -> UUID: - """Create a test workspace ID.""" - return uuid4() - - @pytest.fixture - def _thread_id(self, test_db_session: Session, _workspace_id: UUID) -> str: - """Create a test thread.""" - _thread = Thread( - id="thread_testpagination", - object="thread", - created_at=datetime.now(timezone.utc), - name="Test Thread for Pagination", - workspace_id=_workspace_id, - ) - _thread_orm = ThreadOrm.from_model(_thread) - test_db_session.add(_thread_orm) - test_db_session.commit() - return _thread.id - - @pytest.fixture - def _message_service(self, test_db_session: Session) -> MessageService: - """Create a MessageService instance.""" - return MessageService(test_db_session) - - @pytest.fixture - def _messages( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - ) -> list[Message]: - """Create two branches of messages for testing. - - Branch 1: Messages 0-9 (linear chain from ROOT) - Branch 2: Messages 10-19 (separate linear chain from ROOT) - """ - _created_messages: list[Message] = [] - - # Create first branch: messages 0-9 (linear chain) - for i in range(10): - _msg = _message_service.create( - workspace_id=_workspace_id, - thread_id=_thread_id, - params=MessageCreate( - role="user" if i % 2 == 0 else "assistant", - content=f"Test message {i}", - parent_id=( - ROOT_MESSAGE_PARENT_ID - if i == 0 - else _created_messages[i - 1].id - ), - ), - ) - _created_messages.append(_msg) - - # Create second branch: messages 10-19 (separate linear chain from ROOT) - for i in range(10, 20): - _msg = _message_service.create( - workspace_id=_workspace_id, - thread_id=_thread_id, - params=MessageCreate( - role="user" if i % 2 == 0 else "assistant", - content=f"Test message {i}", - parent_id=( - ROOT_MESSAGE_PARENT_ID - if i == 10 - else _created_messages[i - 1].id - ), - ), - ) - _created_messages.append(_msg) - - return _created_messages - - def test_list_asc_without_after( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test listing messages in ascending order without 'after' parameter.""" - # Without before/after, gets latest branch (branch 2) - _response = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=5, order="asc"), - ) - - assert len(_response.data) == 5 - # Should get the first 5 messages from branch 2 (10, 11, 12, 13, 14) - assert [_msg.content for _msg in _response.data] == [ - "Test message 10", - "Test message 11", - "Test message 12", - "Test message 13", - "Test message 14", - ] - assert _response.has_more is True - - def test_list_asc_with_after( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test listing messages in ascending order with 'after' parameter.""" - # First, get the first page from branch 2 (default) - _first_page = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=3, order="asc"), - ) - - assert len(_first_page.data) == 3 - assert [_msg.content for _msg in _first_page.data] == [ - "Test message 10", - "Test message 11", - "Test message 12", - ] - - # Now get the second page using 'after' - _second_page = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=3, order="asc", after=_first_page.last_id), - ) - - assert len(_second_page.data) == 3 - # Should get the next 3 messages (13, 14, 15) - assert [_msg.content for _msg in _second_page.data] == [ - "Test message 13", - "Test message 14", - "Test message 15", - ] - assert _second_page.has_more is True - - # Get the third page - _third_page = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=3, order="asc", after=_second_page.last_id), - ) - - assert len(_third_page.data) == 3 - # Should get the next 3 messages (16, 17, 18) - assert [_msg.content for _msg in _third_page.data] == [ - "Test message 16", - "Test message 17", - "Test message 18", - ] - - def test_list_desc_without_after( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test listing messages in descending order without 'after' parameter.""" - # Without before/after, gets latest branch (branch 2) - _response = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=5, order="desc"), - ) - - assert len(_response.data) == 5 - # Should get the last 5 messages from branch 2 (19, 18, 17, 16, 15) - assert [_msg.content for _msg in _response.data] == [ - "Test message 19", - "Test message 18", - "Test message 17", - "Test message 16", - "Test message 15", - ] - assert _response.has_more is True - - def test_list_desc_with_after( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test listing messages in descending order with 'after' parameter.""" - # First, get the first page from branch 2 (default) - _first_page = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=3, order="desc"), - ) - - assert len(_first_page.data) == 3 - assert [_msg.content for _msg in _first_page.data] == [ - "Test message 19", - "Test message 18", - "Test message 17", - ] - - # Now get the second page using 'after' - _second_page = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=3, order="desc", after=_first_page.last_id), - ) - - assert len(_second_page.data) == 3 - # Should get the previous 3 messages (16, 15, 14) - assert [_msg.content for _msg in _second_page.data] == [ - "Test message 16", - "Test message 15", - "Test message 14", - ] - - def test_iter_asc( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test iterating through messages in ascending order.""" - # Without before/after, iter returns the latest branch (branch 2) - _collected_messages: list[Message] = list( - _message_service.iter( - workspace_id=_workspace_id, - thread_id=_thread_id, - order="asc", - batch_size=3, - ) - ) - - # Should get all 10 messages from branch 2 in ascending order - assert len(_collected_messages) == 10 - assert [_msg.content for _msg in _collected_messages] == [ - f"Test message {i}" for i in range(10, 20) - ] - - def test_iter_desc( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test iterating through messages in descending order.""" - # Without before/after, iter returns the latest branch (branch 2) - _collected_messages: list[Message] = list( - _message_service.iter( - workspace_id=_workspace_id, - thread_id=_thread_id, - order="desc", - batch_size=3, - ) - ) - - # Should get all 10 messages from branch 2 in descending order - assert len(_collected_messages) == 10 - assert [_msg.content for _msg in _collected_messages] == [ - f"Test message {i}" for i in range(19, 9, -1) - ] - - def test_list_asc_with_before( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test listing messages in ascending order with 'before' parameter.""" - # Get messages before message 7 in ascending order - # Should get messages from root up to (but excluding) message 7 - _response = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=10, order="asc", before=_messages[7].id), - ) - - # Should get messages 0-6 in ascending order - assert len(_response.data) == 7 - assert [_msg.content for _msg in _response.data] == [ - f"Test message {i}" for i in range(7) - ] - assert _response.has_more is False - - def test_list_desc_with_before( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test listing messages in descending order with 'before' parameter.""" - # Get messages before (i.e., after in the tree) message 3 in descending - # order. Should get messages from message 3 down to the latest leaf - # (excluding message 3) - _response = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=10, order="desc", before=_messages[3].id), - ) - - # Should get messages 9-4 in descending order (excluding message 3) - assert len(_response.data) == 6 - assert [_msg.content for _msg in _response.data] == [ - f"Test message {i}" for i in range(9, 3, -1) - ] - assert _response.has_more is False - - def test_list_asc_with_before_paginated( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test listing messages in ascending order with 'before' and pagination.""" - # Get 3 messages before message 7 in ascending order - _response = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=3, order="asc", before=_messages[7].id), - ) - - # Should get messages 0-2 in ascending order - assert len(_response.data) == 3 - assert [_msg.content for _msg in _response.data] == [ - "Test message 0", - "Test message 1", - "Test message 2", - ] - assert _response.has_more is True - - def test_list_desc_with_before_paginated( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test listing messages in descending order with 'before' and pagination.""" - # Get 3 messages before (after in tree) message 3 in descending order - _response = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=3, order="desc", before=_messages[3].id), - ) - - # Should get messages 9-7 in descending order - assert len(_response.data) == 3 - assert [_msg.content for _msg in _response.data] == [ - "Test message 9", - "Test message 8", - "Test message 7", - ] - assert _response.has_more is True - - def test_list_branch1_with_after( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test querying branch 1 by starting from its first message.""" - # Query from the first message of branch 1 downward - _response = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=20, order="asc", after=_messages[0].id), - ) - - # Should get messages 1-9 from branch 1 (excluding message 0) - assert len(_response.data) == 9 - assert [_msg.content for _msg in _response.data] == [ - f"Test message {i}" for i in range(1, 10) - ] - assert _response.has_more is False - - def test_list_branch2_with_after( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test querying branch 2 by starting from its first message.""" - # Query from the first message of branch 2 downward - _response = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=20, order="asc", after=_messages[10].id), - ) - - # Should get messages 11-19 from branch 2 (excluding message 10) - assert len(_response.data) == 9 - assert [_msg.content for _msg in _response.data] == [ - f"Test message {i}" for i in range(11, 20) - ] - assert _response.has_more is False - - def test_list_branches_separately( - self, - _message_service: MessageService, - _workspace_id: UUID, - _thread_id: str, - _messages: list[Message], - ) -> None: - """Test that the two branches are separate by querying from each.""" - # Get branch 1: query from branch 1's last message going up - _branch1_response = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=20, order="desc", after=_messages[9].id), - ) - - # Should get messages 9-0 from branch 1 in descending order - assert len(_branch1_response.data) == 9 - assert [_msg.content for _msg in _branch1_response.data] == [ - f"Test message {i}" for i in range(8, -1, -1) - ] - - # Get branch 2: query from branch 2's last message going up - _branch2_response = _message_service.list_( - workspace_id=_workspace_id, - thread_id=_thread_id, - query=ListQuery(limit=20, order="desc", after=_messages[19].id), - ) - - # Should get messages 19-10 from branch 2 in descending order - assert len(_branch2_response.data) == 9 - assert [_msg.content for _msg in _branch2_response.data] == [ - f"Test message {i}" for i in range(18, 9, -1) - ] - - # Verify no overlap between branches - _branch1_ids = {_msg.id for _msg in _branch1_response.data} - _branch2_ids = {_msg.id for _msg in _branch2_response.data} - assert _branch1_ids.isdisjoint(_branch2_ids) diff --git a/tests/integration/chat/api/test_messages.py b/tests/integration/chat/api/test_messages.py deleted file mode 100644 index 1ae347f8..00000000 --- a/tests/integration/chat/api/test_messages.py +++ /dev/null @@ -1,516 +0,0 @@ -"""Integration tests for the messages API endpoints.""" - -import tempfile -from datetime import datetime, timezone -from pathlib import Path -from uuid import UUID - -from fastapi import status -from fastapi.testclient import TestClient -from sqlalchemy.orm import Session - -from askui.chat.api.messages.models import ROOT_MESSAGE_PARENT_ID, Message -from askui.chat.api.messages.orms import MessageOrm -from askui.chat.api.messages.service import MessageService -from askui.chat.api.threads.models import Thread -from askui.chat.api.threads.orms import ThreadOrm -from askui.chat.api.threads.service import ThreadService - - -class TestMessagesAPI: - """Test suite for the messages API endpoints.""" - - def _add_thread_to_db(self, thread: Thread, test_db_session: Session) -> None: - """Add a thread to the test database.""" - thread_orm = ThreadOrm.from_model(thread) - test_db_session.add(thread_orm) - test_db_session.commit() - - def _add_message_to_db(self, message: Message, test_db_session: Session) -> None: - """Add a message to the test database.""" - message_orm = MessageOrm.from_model(message) - test_db_session.add(message_orm) - test_db_session.commit() - - def test_list_messages_empty( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test listing messages when no messages exist.""" - # First create a thread - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - - workspace_id = UUID(test_headers["askui-workspace"]) - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - name="Test Thread", - workspace_id=workspace_id, - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Add thread to database - self._add_thread_to_db(mock_thread, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.messages.dependencies import get_message_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - def override_message_service() -> MessageService: - return MessageService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_message_service] = override_message_service - - try: - with TestClient(app) as client: - response = client.get( - "/v1/threads/thread_test123/messages", headers=test_headers - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert data["data"] == [] - assert data["has_more"] is False - finally: - app.dependency_overrides.clear() - - def test_list_messages_with_messages( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test listing messages when messages exist.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - messages_dir = workspace_path / "messages" / "thread_test123" - messages_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - workspace_id = UUID(test_headers["askui-workspace"]) - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - name="Test Thread", - workspace_id=workspace_id, - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Add thread to database - self._add_thread_to_db(mock_thread, test_db_session) - - # Create a mock message - workspace_id = UUID(test_headers["askui-workspace"]) - mock_message = Message( - id="msg_test123", - parent_id=ROOT_MESSAGE_PARENT_ID, - object="thread.message", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - thread_id="thread_test123", - role="user", - content="Hello, this is a test message", - metadata={"key": "value"}, - workspace_id=workspace_id, - ) - (messages_dir / "msg_test123.json").write_text(mock_message.model_dump_json()) - - # Add message to database - self._add_message_to_db(mock_message, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.messages.dependencies import get_message_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - def override_message_service() -> MessageService: - return MessageService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_message_service] = override_message_service - - try: - with TestClient(app) as client: - response = client.get( - "/v1/threads/thread_test123/messages", headers=test_headers - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert len(data["data"]) == 1 - assert data["data"][0]["id"] == "msg_test123" - assert data["data"][0]["content"] == "Hello, this is a test message" - assert data["data"][0]["role"] == "user" - finally: - app.dependency_overrides.clear() - - def test_list_messages_with_pagination( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test listing messages with pagination parameters.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - messages_dir = workspace_path / "messages" / "thread_test123" - messages_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - workspace_id = UUID(test_headers["askui-workspace"]) - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - name="Test Thread", - workspace_id=workspace_id, - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Add thread to database - self._add_thread_to_db(mock_thread, test_db_session) - - # Create multiple mock messages - workspace_id = UUID(test_headers["askui-workspace"]) - for i in range(5): - mock_message = Message( - id=f"msg_test{i}", - object="thread.message", - parent_id=ROOT_MESSAGE_PARENT_ID if i == 0 else f"msg_test{i - 1}", - created_at=datetime.fromtimestamp(1234567890 + i, tz=timezone.utc), - thread_id="thread_test123", - role="user" if i % 2 == 0 else "assistant", - content=f"Test message {i}", - workspace_id=workspace_id, - ) - (messages_dir / f"msg_test{i}.json").write_text( - mock_message.model_dump_json() - ) - # Add message to database - self._add_message_to_db(mock_message, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.messages.dependencies import get_message_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - def override_message_service() -> MessageService: - return MessageService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_message_service] = override_message_service - - try: - with TestClient(app) as client: - response = client.get( - "/v1/threads/thread_test123/messages?limit=3", headers=test_headers - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert len(data["data"]) == 3 - assert data["has_more"] is True - finally: - app.dependency_overrides.clear() - - def test_create_message( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test creating a new message.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - workspace_id = UUID(test_headers["askui-workspace"]) - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - name="Test Thread", - workspace_id=workspace_id, - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Add thread to database - self._add_thread_to_db(mock_thread, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.messages.dependencies import get_message_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - def override_message_service() -> MessageService: - return MessageService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_message_service] = override_message_service - - try: - with TestClient(app) as client: - message_data = { - "role": "user", - "content": "Hello, this is a new message", - "metadata": {"key": "value", "number": 42}, - } - response = client.post( - "/v1/threads/thread_test123/messages", - json=message_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["role"] == "user" - assert data["content"] == "Hello, this is a new message" - - assert data["object"] == "thread.message" - assert data["thread_id"] == "thread_test123" - assert "id" in data - assert "created_at" in data - finally: - app.dependency_overrides.clear() - - def test_create_message_minimal( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test creating a message with minimal data.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - workspace_id = UUID(test_headers["askui-workspace"]) - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - name="Test Thread", - workspace_id=workspace_id, - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Add thread to database - self._add_thread_to_db(mock_thread, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.messages.dependencies import get_message_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - def override_message_service() -> MessageService: - return MessageService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_message_service] = override_message_service - - try: - with TestClient(app) as client: - message_data = {"role": "user", "content": "Minimal message"} - response = client.post( - "/v1/threads/thread_test123/messages", - json=message_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["object"] == "thread.message" - assert data["role"] == "user" - assert data["content"] == "Minimal message" - - finally: - app.dependency_overrides.clear() - - def test_create_message_invalid_thread( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test creating a message in a non-existent thread.""" - message_data = {"role": "user", "content": "Test message"} - response = test_client.post( - "/v1/threads/thread_nonexistent123/messages", - json=message_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - - def test_retrieve_message( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test retrieving an existing message.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - messages_dir = workspace_path / "messages" / "thread_test123" - messages_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - workspace_id = UUID(test_headers["askui-workspace"]) - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - name="Test Thread", - workspace_id=workspace_id, - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Add thread to database - self._add_thread_to_db(mock_thread, test_db_session) - - # Create a mock message - workspace_id = UUID(test_headers["askui-workspace"]) - mock_message = Message( - id="msg_test123", - object="thread.message", - parent_id=ROOT_MESSAGE_PARENT_ID, - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - thread_id="thread_test123", - role="user", - content="Test message content", - metadata={"key": "value"}, - workspace_id=workspace_id, - ) - (messages_dir / "msg_test123.json").write_text(mock_message.model_dump_json()) - - # Add message to database - self._add_message_to_db(mock_message, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.messages.dependencies import get_message_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - def override_message_service() -> MessageService: - return MessageService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_message_service] = override_message_service - - try: - with TestClient(app) as client: - response = client.get( - "/v1/threads/thread_test123/messages/msg_test123", - headers=test_headers, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["id"] == "msg_test123" - assert data["content"] == "Test message content" - assert data["role"] == "user" - assert data["thread_id"] == "thread_test123" - finally: - app.dependency_overrides.clear() - - def test_retrieve_message_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test retrieving a non-existent message.""" - response = test_client.get( - "/v1/threads/thread_test123/messages/msg_nonexistent123", - headers=test_headers, - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert "detail" in data - - def test_delete_message( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test deleting an existing message.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - messages_dir = workspace_path / "messages" / "thread_test123" - messages_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - workspace_id = UUID(test_headers["askui-workspace"]) - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - name="Test Thread", - workspace_id=workspace_id, - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Add thread to database - self._add_thread_to_db(mock_thread, test_db_session) - - # Create a mock message - workspace_id = UUID(test_headers["askui-workspace"]) - mock_message = Message( - id="msg_test123", - object="thread.message", - parent_id=ROOT_MESSAGE_PARENT_ID, - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - thread_id="thread_test123", - role="user", - content="Test message to delete", - workspace_id=workspace_id, - ) - (messages_dir / "msg_test123.json").write_text(mock_message.model_dump_json()) - - # Add message to database - self._add_message_to_db(mock_message, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.messages.dependencies import get_message_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - def override_message_service() -> MessageService: - return MessageService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_message_service] = override_message_service - - try: - with TestClient(app) as client: - response = client.delete( - "/v1/threads/thread_test123/messages/msg_test123", - headers=test_headers, - ) - - assert response.status_code == status.HTTP_204_NO_CONTENT - assert response.content == b"" - finally: - app.dependency_overrides.clear() - - def test_delete_message_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test deleting a non-existent message.""" - response = test_client.delete( - "/v1/threads/thread_test123/messages/msg_nonexistent123", - headers=test_headers, - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/integration/chat/api/test_request_document_translator.py b/tests/integration/chat/api/test_request_document_translator.py deleted file mode 100644 index 14e4e42c..00000000 --- a/tests/integration/chat/api/test_request_document_translator.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Integration tests for RequestDocumentBlockParamTranslator.""" - -import pathlib -import shutil -import tempfile -from typing import Generator - -import pytest -from PIL import Image -from sqlalchemy.orm import Session - -from askui.chat.api.files.service import FileService -from askui.chat.api.messages.models import RequestDocumentBlockParam -from askui.chat.api.messages.translator import RequestDocumentBlockParamTranslator -from askui.models.shared.agent_message_param import CacheControlEphemeralParam -from askui.utils.excel_utils import OfficeDocumentSource -from askui.utils.image_utils import ImageSource - - -class TestRequestDocumentBlockParamTranslator: - """Integration tests for RequestDocumentBlockParamTranslator with real files.""" - - @pytest.fixture - def temp_dir(self) -> Generator[pathlib.Path, None, None]: - """Create a temporary directory for test files.""" - temp_dir = pathlib.Path(tempfile.mkdtemp()) - yield temp_dir - # Cleanup: remove the temporary directory and all its contents - shutil.rmtree(temp_dir, ignore_errors=True) - - @pytest.fixture - def file_service( - self, test_db_session: Session, temp_dir: pathlib.Path - ) -> FileService: - """Create a FileService instance using the temporary directory.""" - return FileService(test_db_session, temp_dir) - - @pytest.fixture - def translator( - self, file_service: FileService - ) -> RequestDocumentBlockParamTranslator: - """Create a RequestDocumentBlockParamTranslator instance.""" - return RequestDocumentBlockParamTranslator(file_service, None) - - @pytest.fixture - def cache_control(self) -> CacheControlEphemeralParam: - """Sample cache control parameter.""" - return CacheControlEphemeralParam(type="ephemeral") - - def test_extract_content_from_image( - self, - translator: RequestDocumentBlockParamTranslator, - path_fixtures_github_com__icon: pathlib.Path, - temp_dir: pathlib.Path, - cache_control: CacheControlEphemeralParam, - ) -> None: - """Test extracting content from an image file.""" - # Copy the fixture image to the temporary directory - temp_image_path = temp_dir / "test_icon.png" - shutil.copy2(path_fixtures_github_com__icon, temp_image_path) - - # Create a document block with cache control - document_block = RequestDocumentBlockParam( - source={"file_id": "image123", "type": "file"}, - type="document", - cache_control=cache_control, - ) - - # Load the image source using PIL Image from the temporary file - pil_image = Image.open(temp_image_path) - image_source = ImageSource(pil_image) - - # Extract content - result = translator.extract_content(image_source, document_block) - - # Should return a list with one image block - assert isinstance(result, list) - assert len(result) == 1 - - # First element should be an image block - image_block = result[0] - assert image_block.type == "image" - assert image_block.cache_control == cache_control - - # Check the source is base64 encoded - assert image_block.source.type == "base64" - assert image_block.source.media_type == "image/png" - assert isinstance(image_block.source.data, str) - assert len(image_block.source.data) > 0 - - def test_extract_content_from_excel( - self, - translator: RequestDocumentBlockParamTranslator, - path_fixtures_dummy_excel: pathlib.Path, - temp_dir: pathlib.Path, - cache_control: CacheControlEphemeralParam, - ) -> None: - """Test extracting content from an Excel file.""" - # Copy the fixture Excel file to the temporary directory - temp_excel_path = temp_dir / "test_data.xlsx" - shutil.copy2(path_fixtures_dummy_excel, temp_excel_path) - - # Create a document block with cache control - document_block = RequestDocumentBlockParam( - source={"file_id": "excel123", "type": "file"}, - type="document", - cache_control=cache_control, - ) - - # Load the Excel source from the temporary file - excel_source = OfficeDocumentSource(root=temp_excel_path) - - # Extract content - result = translator.extract_content(excel_source, document_block) - - # Should return a list with one text block - assert isinstance(result, list) - assert len(result) == 1 - - # First element should be a text block - text_block = result[0] - assert text_block.type == "text" - assert text_block.cache_control == cache_control - - # Check the text content - assert isinstance(text_block.text, str) - assert len(text_block.text) > 0 - - def test_extract_content_from_word( - self, - translator: RequestDocumentBlockParamTranslator, - path_fixtures_dummy_doc: pathlib.Path, - temp_dir: pathlib.Path, - cache_control: CacheControlEphemeralParam, - ) -> None: - """Test extracting content from a Word document.""" - # Copy the fixture Word file to the temporary directory - temp_doc_path = temp_dir / "test_document.docx" - shutil.copy2(path_fixtures_dummy_doc, temp_doc_path) - - # Create a document block with cache control - document_block = RequestDocumentBlockParam( - source={"file_id": "word123", "type": "file"}, - type="document", - cache_control=cache_control, - ) - - # Load the Word source from the temporary file - word_source = OfficeDocumentSource(root=temp_doc_path) - - # Extract content - result = translator.extract_content(word_source, document_block) - - # Should return a list with one text block - assert isinstance(result, list) - assert len(result) == 1 - - # First element should be a text block - text_block = result[0] - assert text_block.type == "text" - assert text_block.cache_control == cache_control - - # Check the text content - assert isinstance(text_block.text, str) - assert len(text_block.text) > 0 - - def test_extract_content_from_image_no_cache_control( - self, - translator: RequestDocumentBlockParamTranslator, - path_fixtures_github_com__icon: pathlib.Path, - temp_dir: pathlib.Path, - ) -> None: - """Test extracting content from an image file without cache control.""" - # Copy the fixture image to the temporary directory - temp_image_path = temp_dir / "test_icon_no_cache.png" - shutil.copy2(path_fixtures_github_com__icon, temp_image_path) - - # Create a document block without cache control - document_block = RequestDocumentBlockParam( - source={"file_id": "image123", "type": "file"}, - type="document", - ) - - # Load the image source using PIL Image from the temporary file - pil_image = Image.open(temp_image_path) - image_source = ImageSource(pil_image) - - # Extract content - result = translator.extract_content(image_source, document_block) - - # Should return a list with one image block - assert isinstance(result, list) - assert len(result) == 1 - - # First element should be an image block - image_block = result[0] - assert image_block.type == "image" - assert image_block.cache_control is None - - # Check the source is base64 encoded - assert image_block.source.type == "base64" - assert image_block.source.media_type == "image/png" - assert isinstance(image_block.source.data, str) - assert len(image_block.source.data) > 0 - - def test_extract_content_from_excel_no_cache_control( - self, - translator: RequestDocumentBlockParamTranslator, - path_fixtures_dummy_excel: pathlib.Path, - temp_dir: pathlib.Path, - ) -> None: - """Test extracting content from an Excel file without cache control.""" - # Copy the fixture Excel file to the temporary directory - temp_excel_path = temp_dir / "test_data_no_cache.xlsx" - shutil.copy2(path_fixtures_dummy_excel, temp_excel_path) - - # Create a document block without cache control - document_block = RequestDocumentBlockParam( - source={"file_id": "excel123", "type": "file"}, - type="document", - ) - - # Load the Excel source from the temporary file - excel_source = OfficeDocumentSource(root=temp_excel_path) - - # Extract content - result = translator.extract_content(excel_source, document_block) - - # Should return a list with one text block - assert isinstance(result, list) - assert len(result) == 1 - - # First element should be a text block - text_block = result[0] - assert text_block.type == "text" - assert text_block.cache_control is None - - # Check the text content - assert isinstance(text_block.text, str) - assert len(text_block.text) > 0 - - def test_extract_content_from_word_no_cache_control( - self, - translator: RequestDocumentBlockParamTranslator, - path_fixtures_dummy_doc: pathlib.Path, - temp_dir: pathlib.Path, - ) -> None: - """Test extracting content from a Word document without cache control.""" - # Copy the fixture Word file to the temporary directory - temp_doc_path = temp_dir / "test_document_no_cache.docx" - shutil.copy2(path_fixtures_dummy_doc, temp_doc_path) - - # Create a document block without cache control - document_block = RequestDocumentBlockParam( - source={"file_id": "word123", "type": "file"}, - type="document", - ) - - # Load the Word source from the temporary file - word_source = OfficeDocumentSource(root=temp_doc_path) - - # Extract content - result = translator.extract_content(word_source, document_block) - - # Should return a list with one text block - assert isinstance(result, list) - assert len(result) == 1 - - # First element should be a text block - text_block = result[0] - assert text_block.type == "text" - assert text_block.cache_control is None - - # Check the text content - assert isinstance(text_block.text, str) - assert len(text_block.text) > 0 diff --git a/tests/integration/chat/api/test_runs.py b/tests/integration/chat/api/test_runs.py deleted file mode 100644 index aadaf3af..00000000 --- a/tests/integration/chat/api/test_runs.py +++ /dev/null @@ -1,1149 +0,0 @@ -"""Integration tests for the runs API endpoints.""" - -import tempfile -from datetime import datetime, timezone -from pathlib import Path -from unittest.mock import Mock -from uuid import UUID - -from fastapi import status -from fastapi.testclient import TestClient -from sqlalchemy.orm import Session - -from askui.chat.api.assistants.models import Assistant -from askui.chat.api.assistants.orms import AssistantOrm -from askui.chat.api.assistants.service import AssistantService -from askui.chat.api.models import WorkspaceId -from askui.chat.api.runs.models import Run -from askui.chat.api.runs.orms import RunOrm -from askui.chat.api.runs.service import RunService -from askui.chat.api.settings import Settings -from askui.chat.api.threads.models import Thread -from askui.chat.api.threads.orms import ThreadOrm -from askui.chat.api.threads.service import ThreadService - - -def create_mock_mcp_client_manager_manager() -> Mock: - """Create a properly configured mock MCP config service.""" - mock_service = Mock() - # Configure mock to return proper data structure - mock_service.get_mcp_client_manager.return_value = None - return mock_service - - -class TestRunsAPI: - """Test suite for the runs API endpoints.""" - - def _create_test_assistant( - self, - assistant_id: str, - workspace_id: WorkspaceId | None = None, - name: str = "Test Assistant", - description: str = "A test assistant", - avatar: str | None = None, - created_at: datetime | None = None, - tools: list[str] | None = None, - system: str | None = None, - ) -> Assistant: - """Create a test assistant model.""" - if created_at is None: - created_at = datetime.fromtimestamp(1234567890, tz=timezone.utc) - if tools is None: - tools = [] - return Assistant( - id=assistant_id, - object="assistant", - created_at=created_at, - name=name, - description=description, - avatar=avatar, - workspace_id=workspace_id, - tools=tools, - system=system, - ) - - def _add_assistant_to_db( - self, assistant: Assistant, test_db_session: Session - ) -> None: - """Add an assistant to the test database.""" - assistant_orm = AssistantOrm.from_model(assistant) - test_db_session.add(assistant_orm) - test_db_session.commit() - - def _add_thread_to_db(self, thread: Thread, test_db_session: Session) -> None: - """Add a thread to the test database.""" - thread_orm = ThreadOrm.from_model(thread) - test_db_session.add(thread_orm) - test_db_session.commit() - - def _add_run_to_db(self, run: Run, test_db_session: Session) -> None: - """Add a run to the test database.""" - # Need to include status (computed field) in the model dump - run_dict = run.model_dump(exclude={"object"}) - run_dict["status"] = run.status # Add computed status field - run_orm = RunOrm(**run_dict) - test_db_session.add(run_orm) - test_db_session.commit() - - def _create_test_workspace(self) -> Path: - """Create a temporary workspace directory for testing.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - return workspace_path - - def _create_test_thread( - self, - workspace_path: Path, - thread_id: str = "thread_test123", - test_db_session: Session | None = None, - workspace_id: UUID | None = None, - ) -> Thread: - """Create a test thread in the workspace.""" - threads_dir = workspace_path / "threads" - if workspace_id is None and test_db_session is not None: - # Need workspace_id if adding to DB - error_msg = "workspace_id required when test_db_session is provided" - raise ValueError(error_msg) - mock_thread = Thread( - id=thread_id, - object="thread", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - name="Test Thread", - workspace_id=workspace_id, - ) - (threads_dir / f"{thread_id}.json").write_text(mock_thread.model_dump_json()) - if test_db_session is not None and workspace_id is not None: - self._add_thread_to_db(mock_thread, test_db_session) - return mock_thread - - def _create_test_run( - self, - workspace_path: Path, - thread_id: str = "thread_test123", - run_id: str = "run_test123", - test_db_session: Session | None = None, - workspace_id: UUID | None = None, - ) -> Run: - """Create a test run in the workspace.""" - runs_dir = workspace_path / "runs" / thread_id - runs_dir.mkdir(parents=True, exist_ok=True) - - mock_run = Run( - id=run_id, - object="thread.run", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - thread_id=thread_id, - assistant_id="asst_test123", - expires_at=datetime.fromtimestamp(1755846718, tz=timezone.utc), - started_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - completed_at=datetime.fromtimestamp(1234567900, tz=timezone.utc), - workspace_id=workspace_id, - ) - (runs_dir / f"{run_id}.json").write_text(mock_run.model_dump_json()) - if test_db_session is not None and workspace_id is not None: - self._add_run_to_db(mock_run, test_db_session) - return mock_run - - def _setup_runs_dependencies( - self, workspace_path: Path, test_db_session: Session - ) -> None: - """Set up dependency overrides for runs and threads services.""" - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - return ThreadService(session=test_db_session) - - def override_runs_service() -> RunService: - assistant_service = AssistantService(test_db_session) - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - settings = Settings(data_dir=workspace_path) - return RunService( - session=test_db_session, - assistant_service=assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=settings, - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - - def _create_multiple_test_runs( - self, - workspace_path: Path, - thread_id: str = "thread_test123", - count: int = 5, - test_db_session: Session | None = None, - workspace_id: UUID | None = None, - ) -> None: - """Create multiple test runs in the workspace.""" - runs_dir = workspace_path / "runs" / thread_id - runs_dir.mkdir(parents=True, exist_ok=True) - - for i in range(count): - mock_run = Run( - id=f"run_test{i}", - object="thread.run", - created_at=datetime.fromtimestamp(1234567890 + i, tz=timezone.utc), - thread_id=thread_id, - assistant_id=f"asst_test{i}", - expires_at=datetime.fromtimestamp( - 1234567890 + i + 600, tz=timezone.utc - ), - workspace_id=workspace_id, - ) - (runs_dir / f"run_test{i}.json").write_text(mock_run.model_dump_json()) - if test_db_session is not None and workspace_id is not None: - self._add_run_to_db(mock_run, test_db_session) - - def _cleanup_dependencies(self) -> None: - """Clean up dependency overrides.""" - from askui.chat.api.app import app - - app.dependency_overrides.clear() - - def test_list_runs_empty( - self, - test_headers: dict[str, str], - test_client: TestClient, - test_db_session: Session, - ) -> None: - """Test listing runs when no runs exist.""" - workspace_path = self._create_test_workspace() - self._create_test_thread(workspace_path) - - self._setup_runs_dependencies(workspace_path, test_db_session) - - try: - response = test_client.get( - "/v1/runs?thread=thread_test123", headers=test_headers - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert data["data"] == [] - assert data["has_more"] is False - finally: - self._cleanup_dependencies() - - def test_list_runs_with_runs( - self, - test_headers: dict[str, str], - test_client: TestClient, - test_db_session: Session, - ) -> None: - """Test listing runs when runs exist.""" - workspace_path = self._create_test_workspace() - workspace_id = UUID(test_headers["askui-workspace"]) - self._create_test_thread( - workspace_path, test_db_session=test_db_session, workspace_id=workspace_id - ) - # Add assistant for foreign key - mock_assistant = self._create_test_assistant( - assistant_id="asst_test123", - workspace_id=workspace_id, - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - self._create_test_run( - workspace_path, - test_db_session=test_db_session, - workspace_id=workspace_id, - ) - - self._setup_runs_dependencies(workspace_path, test_db_session) - - try: - response = test_client.get( - "/v1/runs?thread=thread_test123", headers=test_headers - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert len(data["data"]) == 1 - assert data["data"][0]["id"] == "run_test123" - assert data["data"][0]["status"] == "completed" - assert data["data"][0]["assistant_id"] == "asst_test123" - finally: - self._cleanup_dependencies() - - def test_list_runs_with_pagination( - self, - test_headers: dict[str, str], - test_client: TestClient, - test_db_session: Session, - ) -> None: - """Test listing runs with pagination parameters.""" - workspace_path = self._create_test_workspace() - workspace_id = UUID(test_headers["askui-workspace"]) - self._create_test_thread( - workspace_path, test_db_session=test_db_session, workspace_id=workspace_id - ) - # Add assistants for foreign keys - for i in range(5): - mock_assistant = self._create_test_assistant( - assistant_id=f"asst_test{i}", - workspace_id=workspace_id, - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - self._create_multiple_test_runs( - workspace_path, - test_db_session=test_db_session, - workspace_id=workspace_id, - ) - - self._setup_runs_dependencies(workspace_path, test_db_session) - - try: - response = test_client.get( - "/v1/runs?thread=thread_test123&limit=3", headers=test_headers - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert len(data["data"]) == 3 - assert data["has_more"] is True - finally: - self._cleanup_dependencies() - - def test_create_run( - self, - test_headers: dict[str, str], - test_client: TestClient, - test_db_session: Session, - ) -> None: - """Test creating a new run.""" - workspace_path = self._create_test_workspace() - workspace_id = UUID(test_headers["askui-workspace"]) - self._create_test_thread( - workspace_path, test_db_session=test_db_session, workspace_id=workspace_id - ) - self._setup_runs_dependencies(workspace_path, test_db_session) - self._add_assistant_to_db( - self._create_test_assistant( - assistant_id="asst_test123", workspace_id=workspace_id - ), - test_db_session, - ) - - try: - run_data = { - "assistant_id": "asst_test123", - "stream": False, - "metadata": {"key": "value", "number": 42}, - } - response = test_client.post( - "/v1/threads/thread_test123/runs", - json=run_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["assistant_id"] == "asst_test123" - assert data["thread_id"] == "thread_test123" - assert data["object"] == "thread.run" - assert "id" in data - assert "created_at" in data - finally: - self._cleanup_dependencies() - - def test_create_run_minimal( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test creating a run with minimal data.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - workspace_id = UUID(test_headers["askui-workspace"]) - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - name="Test Thread", - workspace_id=workspace_id, - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Add thread to database - self._add_thread_to_db(mock_thread, test_db_session) - - # Add assistant to database (required for foreign key) - mock_assistant = self._create_test_assistant( - assistant_id="asst_test123", - workspace_id=workspace_id, - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - return ThreadService(session=test_db_session) - - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - settings = Settings(data_dir=workspace_path) - return RunService( - session=test_db_session, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=settings, - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - - try: - with TestClient(app) as client: - run_data = {"assistant_id": "asst_test123"} - response = client.post( - "/v1/threads/thread_test123/runs", - json=run_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["object"] == "thread.run" - assert data["assistant_id"] == "asst_test123" - # stream field is not returned in the response - finally: - app.dependency_overrides.clear() - - def test_create_run_streaming( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test creating a streaming run.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - workspace_id = UUID(test_headers["askui-workspace"]) - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - name="Test Thread", - workspace_id=workspace_id, - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Add thread to database - self._add_thread_to_db(mock_thread, test_db_session) - - # Add assistant to database (required for foreign key) - mock_assistant = self._create_test_assistant( - assistant_id="asst_test123", - workspace_id=workspace_id, - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - return ThreadService(session=test_db_session) - - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - settings = Settings(data_dir=workspace_path) - return RunService( - session=test_db_session, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=settings, - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - - try: - with TestClient(app) as client: - run_data = { - "assistant_id": "asst_test123", - "stream": True, - } - response = client.post( - "/v1/threads/thread_test123/runs", - json=run_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_201_CREATED - assert "text/event-stream" in response.headers["content-type"] - finally: - app.dependency_overrides.clear() - - def test_create_thread_and_run( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test creating a thread and run in one request.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Add assistant to database (required for foreign key) - workspace_id = UUID(test_headers["askui-workspace"]) - mock_assistant = self._create_test_assistant( - assistant_id="asst_test123", - workspace_id=workspace_id, - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - return ThreadService(session=test_db_session) - - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - settings = Settings(data_dir=workspace_path) - return RunService( - session=test_db_session, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=settings, - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - - try: - with TestClient(app) as client: - thread_and_run_data = { - "assistant_id": "asst_test123", - "stream": False, - "thread": { - "name": "Test Thread", - "messages": [ - {"role": "user", "content": "Hello, how are you?"} - ], - }, - "metadata": {"key": "value", "number": 42}, - } - response = client.post( - "/v1/runs", - json=thread_and_run_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["assistant_id"] == "asst_test123" - assert data["object"] == "thread.run" - assert "id" in data - assert "created_at" in data - assert "thread_id" in data - finally: - app.dependency_overrides.clear() - - def test_create_thread_and_run_minimal( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test creating a thread and run with minimal data.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Add assistant to database (required for foreign key) - workspace_id = UUID(test_headers["askui-workspace"]) - mock_assistant = self._create_test_assistant( - assistant_id="asst_test123", - workspace_id=workspace_id, - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - return ThreadService(session=test_db_session) - - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - settings = Settings(data_dir=workspace_path) - return RunService( - session=test_db_session, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=settings, - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - - try: - with TestClient(app) as client: - thread_and_run_data = {"assistant_id": "asst_test123", "thread": {}} - response = client.post( - "/v1/runs", - json=thread_and_run_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["object"] == "thread.run" - assert data["assistant_id"] == "asst_test123" - assert "id" in data - assert "thread_id" in data - finally: - app.dependency_overrides.clear() - - def test_create_thread_and_run_streaming( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test creating a streaming thread and run.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Add assistant to database (required for foreign key) - workspace_id = UUID(test_headers["askui-workspace"]) - mock_assistant = self._create_test_assistant( - assistant_id="asst_test123", - workspace_id=workspace_id, - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - return ThreadService(session=test_db_session) - - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - settings = Settings(data_dir=workspace_path) - return RunService( - session=test_db_session, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=settings, - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - - try: - with TestClient(app) as client: - thread_and_run_data = { - "assistant_id": "asst_test123", - "stream": True, - "thread": { - "name": "Streaming Thread", - "messages": [{"role": "user", "content": "Tell me a story"}], - }, - } - response = client.post( - "/v1/runs", - json=thread_and_run_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_201_CREATED - assert "text/event-stream" in response.headers["content-type"] - finally: - app.dependency_overrides.clear() - - def test_create_thread_and_run_with_messages( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test creating a thread and run with initial messages.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Add assistant to database (required for foreign key) - workspace_id = UUID(test_headers["askui-workspace"]) - mock_assistant = self._create_test_assistant( - assistant_id="asst_test123", - workspace_id=workspace_id, - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - return ThreadService(session=test_db_session) - - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - settings = Settings(data_dir=workspace_path) - return RunService( - session=test_db_session, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=settings, - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - - try: - with TestClient(app) as client: - thread_and_run_data = { - "assistant_id": "asst_test123", - "stream": False, - "thread": { - "name": "Conversation Thread", - "messages": [ - {"role": "user", "content": "What is the weather like?"}, - { - "role": "assistant", - "content": ( - "I don't have access to real-time weather data." - ), - }, - {"role": "user", "content": "Can you help me plan my day?"}, - ], - }, - } - response = client.post( - "/v1/runs", - json=thread_and_run_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["assistant_id"] == "asst_test123" - assert data["object"] == "thread.run" - assert "id" in data - assert "thread_id" in data - finally: - app.dependency_overrides.clear() - - def test_create_thread_and_run_validation_error( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test creating thread and run with invalid data.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - return ThreadService(session=test_db_session) - - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - settings = Settings(data_dir=workspace_path) - return RunService( - session=test_db_session, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=settings, - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - - try: - with TestClient(app) as client: - # Missing required assistant_id - invalid_data = {"thread": {}} # type: ignore[var-annotated] - response = client.post( - "/v1/runs", - json=invalid_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - data = response.json() - assert "detail" in data - finally: - app.dependency_overrides.clear() - - def test_create_thread_and_run_empty_thread( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test creating thread and run with completely empty thread object.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Add assistant to database (required for foreign key) - workspace_id = UUID(test_headers["askui-workspace"]) - mock_assistant = self._create_test_assistant( - assistant_id="asst_test123", - workspace_id=workspace_id, - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - return ThreadService(session=test_db_session) - - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - settings = Settings(data_dir=workspace_path) - return RunService( - session=test_db_session, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=settings, - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - - try: - with TestClient(app) as client: - thread_and_run_data = {"assistant_id": "asst_test123", "thread": {}} - response = client.post( - "/v1/runs", - json=thread_and_run_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["assistant_id"] == "asst_test123" - assert "thread_id" in data - finally: - app.dependency_overrides.clear() - - def test_create_run_invalid_thread( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test creating a run in a non-existent thread.""" - run_data = {"assistant_id": "asst_test123"} - response = test_client.post( - "/v1/threads/thread_nonexistent123/runs", - json=run_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - - def test_retrieve_run( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test retrieving an existing run.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - runs_dir = workspace_path / "runs" / "thread_test123" - runs_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - workspace_id = UUID(test_headers["askui-workspace"]) - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - name="Test Thread", - workspace_id=workspace_id, - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Add thread to database - self._add_thread_to_db(mock_thread, test_db_session) - - # Create and add assistant to database (required for foreign key) - mock_assistant = self._create_test_assistant( - assistant_id="asst_test123", - workspace_id=workspace_id, - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - - # Create a mock run - mock_run = Run( - id="run_test123", - object="thread.run", - created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - thread_id="thread_test123", - assistant_id="asst_test123", - expires_at=datetime.fromtimestamp(1755846718, tz=timezone.utc), - started_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), - completed_at=datetime.fromtimestamp(1234567900, tz=timezone.utc), - workspace_id=workspace_id, - ) - (runs_dir / "run_test123.json").write_text(mock_run.model_dump_json()) - - # Add run to database - self._add_run_to_db(mock_run, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - return ThreadService(session=test_db_session) - - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - settings = Settings(data_dir=workspace_path) - return RunService( - session=test_db_session, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=settings, - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - - try: - with TestClient(app) as client: - response = client.get( - "/v1/threads/thread_test123/runs/run_test123", headers=test_headers - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["id"] == "run_test123" - assert data["status"] == "completed" - assert data["assistant_id"] == "asst_test123" - assert data["thread_id"] == "thread_test123" - finally: - app.dependency_overrides.clear() - - def test_retrieve_run_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test retrieving a non-existent run.""" - response = test_client.get( - "/v1/threads/thread_test123/runs/run_nonexistent123", headers=test_headers - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert "detail" in data - - def test_cancel_run( - self, test_headers: dict[str, str], test_db_session: Session - ) -> None: - """Test canceling an existing run.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - runs_dir = workspace_path / "runs" / "thread_test123" - runs_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - workspace_id = UUID(test_headers["askui-workspace"]) - import time - - current_time = int(time.time()) - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=datetime.fromtimestamp(current_time, tz=timezone.utc), - name="Test Thread", - workspace_id=workspace_id, - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - self._add_thread_to_db(mock_thread, test_db_session) - - # Create and add assistant to database (required for foreign key) - mock_assistant = self._create_test_assistant( - assistant_id="asst_test123", - workspace_id=workspace_id, - ) - self._add_assistant_to_db(mock_assistant, test_db_session) - - # Create a mock run - mock_run = Run( - id="run_test123", - object="thread.run", - created_at=datetime.fromtimestamp(current_time, tz=timezone.utc), - thread_id="thread_test123", - assistant_id="asst_test123", - expires_at=datetime.fromtimestamp(current_time + 600, tz=timezone.utc), - workspace_id=workspace_id, - ) - (runs_dir / "run_test123.json").write_text(mock_run.model_dump_json()) - self._add_run_to_db(mock_run, test_db_session) - - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - return ThreadService(session=test_db_session) - - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - settings = Settings(data_dir=workspace_path) - return RunService( - session=test_db_session, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=settings, - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - - try: - with TestClient(app) as client: - response = client.post( - "/v1/threads/thread_test123/runs/run_test123/cancel", - headers=test_headers, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["id"] == "run_test123" - # The cancel operation sets the status to "cancelled" - assert data["status"] == "cancelled" - finally: - app.dependency_overrides.clear() - - def test_cancel_run_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test canceling a non-existent run.""" - response = test_client.post( - "/v1/threads/thread_test123/runs/run_nonexistent123/cancel", - headers=test_headers, - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - - def test_create_run_with_custom_assistant( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test creating a run with a custom assistant.""" - workspace_path = self._create_test_workspace() - workspace_id = UUID(test_headers["askui-workspace"]) - self._create_test_thread( - workspace_path, test_db_session=test_db_session, workspace_id=workspace_id - ) - - # Create a custom assistant in the database - custom_assistant = self._create_test_assistant( - "asst_custom123", - workspace_id=workspace_id, - name="Custom Assistant", - tools=["tool1", "tool2"], - system="You are a custom assistant.", - ) - self._add_assistant_to_db(custom_assistant, test_db_session) - - self._setup_runs_dependencies(workspace_path, test_db_session) - - try: - response = test_client.post( - "/v1/threads/thread_test123/runs", - headers=test_headers, - json={"assistant_id": "asst_custom123"}, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["assistant_id"] == "asst_custom123" - assert data["thread_id"] == "thread_test123" - assert data["status"] == "queued" - assert "id" in data - assert "created_at" in data - finally: - self._cleanup_dependencies() - - def test_create_run_with_custom_assistant_empty_tools( - self, - test_headers: dict[str, str], - test_db_session: Session, - test_client: TestClient, - ) -> None: - """Test creating a run with a custom assistant that has empty tools.""" - workspace_path = self._create_test_workspace() - workspace_id = UUID(test_headers["askui-workspace"]) - self._create_test_thread( - workspace_path, test_db_session=test_db_session, workspace_id=workspace_id - ) - - # Create a custom assistant with empty tools in the database - empty_tools_assistant = self._create_test_assistant( - "asst_customempty123", - workspace_id=workspace_id, - name="Empty Tools Assistant", - tools=[], - system="You are a assistant with no tools.", - ) - self._add_assistant_to_db(empty_tools_assistant, test_db_session) - - self._setup_runs_dependencies(workspace_path, test_db_session) - - try: - response = test_client.post( - "/v1/threads/thread_test123/runs", - headers=test_headers, - json={"assistant_id": "asst_customempty123"}, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["assistant_id"] == "asst_customempty123" - assert data["thread_id"] == "thread_test123" - assert data["status"] == "queued" - assert "id" in data - assert "created_at" in data - finally: - self._cleanup_dependencies() diff --git a/tests/integration/chat/api/test_threads.py b/tests/integration/chat/api/test_threads.py deleted file mode 100644 index e3cfbb72..00000000 --- a/tests/integration/chat/api/test_threads.py +++ /dev/null @@ -1,339 +0,0 @@ -"""Integration tests for the threads API endpoints.""" - -from typing import TYPE_CHECKING -from uuid import UUID - -from fastapi import status -from fastapi.testclient import TestClient -from sqlalchemy.orm import Session - -from askui.chat.api.threads.models import ThreadCreate -from askui.chat.api.threads.service import ThreadService - -if TYPE_CHECKING: - from askui.chat.api.models import WorkspaceId - - -class TestThreadsAPI: - """Test suite for the threads API endpoints.""" - - def test_list_threads_empty( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test listing threads when no threads exist.""" - response = test_client.get("/v1/threads", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert data["data"] == [] - assert data["has_more"] is False - - def test_list_threads_with_threads( - self, - test_db_session: Session, - test_headers: dict[str, str], - test_workspace_id: str, - ) -> None: - """Test listing threads when threads exist.""" - from askui.chat.api.app import app - from askui.chat.api.threads.dependencies import get_thread_service - - thread_service = ThreadService(test_db_session) - workspace_id: WorkspaceId = UUID(test_workspace_id) - # Create a thread via the service - created_thread = thread_service.create( - workspace_id=workspace_id, - params=ThreadCreate(name="Test Thread"), - ) - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - - try: - with TestClient(app) as client: - response = client.get("/v1/threads", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert len(data["data"]) == 1 - assert data["data"][0]["id"] == created_thread.id - assert data["data"][0]["name"] == "Test Thread" - finally: - app.dependency_overrides.clear() - - def test_list_threads_with_pagination( - self, - test_db_session: Session, - test_headers: dict[str, str], - test_workspace_id: str, - ) -> None: - """Test listing threads with pagination parameters.""" - from askui.chat.api.app import app - from askui.chat.api.threads.dependencies import get_thread_service - - thread_service = ThreadService(test_db_session) - workspace_id: WorkspaceId = UUID(test_workspace_id) - # Create multiple threads via the service - for i in range(5): - thread_service.create( - workspace_id=workspace_id, - params=ThreadCreate(name=f"Test Thread {i}"), - ) - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - - try: - with TestClient(app) as client: - response = client.get("/v1/threads?limit=3", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert len(data["data"]) == 3 - assert data["has_more"] is True - finally: - app.dependency_overrides.clear() - - def test_create_thread( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test creating a new thread.""" - from askui.chat.api.app import app - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - - try: - with TestClient(app) as client: - thread_data = { - "name": "New Test Thread", - } - response = client.post( - "/v1/threads", json=thread_data, headers=test_headers - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["name"] == "New Test Thread" - assert data["object"] == "thread" - assert "id" in data - assert "created_at" in data - finally: - app.dependency_overrides.clear() - - def test_create_thread_minimal( - self, test_db_session: Session, test_headers: dict[str, str] - ) -> None: - """Test creating a thread with minimal data.""" - from askui.chat.api.app import app - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - - try: - with TestClient(app) as client: - response = client.post("/v1/threads", json={}, headers=test_headers) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["object"] == "thread" - assert data["name"] is None - finally: - app.dependency_overrides.clear() - - def test_retrieve_thread( - self, - test_db_session: Session, - test_headers: dict[str, str], - test_workspace_id: str, - ) -> None: - """Test retrieving an existing thread.""" - from askui.chat.api.app import app - from askui.chat.api.threads.dependencies import get_thread_service - - thread_service = ThreadService(test_db_session) - workspace_id: WorkspaceId = UUID(test_workspace_id) - # Create a thread via the service - created_thread = thread_service.create( - workspace_id=workspace_id, - params=ThreadCreate(name="Test Thread"), - ) - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - - try: - with TestClient(app) as client: - response = client.get( - f"/v1/threads/{created_thread.id}", headers=test_headers - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["id"] == created_thread.id - assert data["name"] == "Test Thread" - finally: - app.dependency_overrides.clear() - - def test_retrieve_thread_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test retrieving a non-existent thread.""" - response = test_client.get( - "/v1/threads/thread_nonexistent123", headers=test_headers - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert "detail" in data - - def test_modify_thread( - self, - test_db_session: Session, - test_headers: dict[str, str], - test_workspace_id: str, - ) -> None: - """Test modifying an existing thread.""" - from askui.chat.api.app import app - from askui.chat.api.threads.dependencies import get_thread_service - - thread_service = ThreadService(test_db_session) - workspace_id: WorkspaceId = UUID(test_workspace_id) - # Create a thread via the service - created_thread = thread_service.create( - workspace_id=workspace_id, - params=ThreadCreate(name="Original Name"), - ) - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - - try: - with TestClient(app) as client: - modify_data = { - "name": "Modified Name", - } - response = client.post( - f"/v1/threads/{created_thread.id}", - json=modify_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["name"] == "Modified Name" - assert data["id"] == created_thread.id - # API returns Unix timestamp, convert datetime to timestamp for - # comparison - assert data["created_at"] == int(created_thread.created_at.timestamp()) - finally: - app.dependency_overrides.clear() - - def test_modify_thread_partial( - self, - test_db_session: Session, - test_headers: dict[str, str], - test_workspace_id: str, - ) -> None: - """Test modifying a thread with partial data.""" - from askui.chat.api.app import app - from askui.chat.api.threads.dependencies import get_thread_service - - thread_service = ThreadService(test_db_session) - workspace_id: WorkspaceId = UUID(test_workspace_id) - # Create a thread via the service - created_thread = thread_service.create( - workspace_id=workspace_id, - params=ThreadCreate(name="Original Name"), - ) - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - - try: - with TestClient(app) as client: - modify_data = {"name": "Only Name Modified"} - response = client.post( - f"/v1/threads/{created_thread.id}", - json=modify_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["name"] == "Only Name Modified" - finally: - app.dependency_overrides.clear() - - def test_modify_thread_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test modifying a non-existent thread.""" - modify_data = {"name": "Modified Name"} - response = test_client.post( - "/v1/threads/thread_nonexistent123", json=modify_data, headers=test_headers - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - - def test_delete_thread( - self, - test_db_session: Session, - test_headers: dict[str, str], - test_workspace_id: str, - ) -> None: - """Test deleting an existing thread.""" - from askui.chat.api.app import app - from askui.chat.api.threads.dependencies import get_thread_service - - thread_service = ThreadService(test_db_session) - workspace_id: WorkspaceId = UUID(test_workspace_id) - # Create a thread via the service - created_thread = thread_service.create( - workspace_id=workspace_id, - params=ThreadCreate(name="Test Thread"), - ) - - def override_thread_service() -> ThreadService: - return ThreadService(test_db_session) - - app.dependency_overrides[get_thread_service] = override_thread_service - - try: - with TestClient(app) as client: - response = client.delete( - f"/v1/threads/{created_thread.id}", headers=test_headers - ) - - assert response.status_code == status.HTTP_204_NO_CONTENT - assert response.content == b"" - finally: - app.dependency_overrides.clear() - - def test_delete_thread_not_found( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test deleting a non-existent thread.""" - response = test_client.delete( - "/v1/threads/thread_nonexistent123", headers=test_headers - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/unit/test_request_document_translator.py b/tests/unit/test_request_document_translator.py deleted file mode 100644 index b5fea6a2..00000000 --- a/tests/unit/test_request_document_translator.py +++ /dev/null @@ -1,134 +0,0 @@ -from pathlib import Path -from unittest.mock import MagicMock - -import pytest -import pytest_mock - -from askui.chat.api.messages.models import RequestDocumentBlockParam -from askui.chat.api.messages.translator import RequestDocumentBlockParamTranslator -from askui.models.shared.agent_message_param import ( - CacheControlEphemeralParam, - TextBlockParam, -) -from askui.utils.pdf_utils import PdfSource - - -class TestRequestDocumentBlockParamTranslator: - """Test cases for RequestDocumentBlockParamTranslator.""" - - @pytest.fixture - def file_service(self) -> MagicMock: - """Mock file service.""" - return MagicMock() - - @pytest.fixture - def translator( - self, file_service: MagicMock - ) -> RequestDocumentBlockParamTranslator: - """Create translator instance.""" - return RequestDocumentBlockParamTranslator(file_service, None) - - @pytest.fixture - def cache_control(self) -> CacheControlEphemeralParam: - """Sample cache control parameter.""" - return CacheControlEphemeralParam(type="ephemeral") - - def test_init(self, file_service: MagicMock) -> None: - """Test translator initialization.""" - translator = RequestDocumentBlockParamTranslator(file_service, None) - assert translator._file_service == file_service - - @pytest.mark.asyncio - async def test_to_anthropic_success( - self, - translator: RequestDocumentBlockParamTranslator, - cache_control: CacheControlEphemeralParam, - mocker: pytest_mock.MockerFixture, - ) -> None: - """Test successful conversion to Anthropic format.""" - document_block = RequestDocumentBlockParam( - source={"file_id": "xyz789", "type": "file"}, - type="document", - cache_control=cache_control, - ) - - # Mock the file service response - mock_file = MagicMock() - mock_file.model_dump_json.return_value = '{"id": "xyz789", "name": "test.pdf"}' - mock_path = Path("/tmp/test.pdf") - mocker.patch.object( - translator._file_service, - "retrieve_file_content", - return_value=(mock_file, mock_path), - ) - - # Mock the load_source function to avoid filesystem access - mock_pdf_source = PdfSource(root=mock_path) - mocker.patch( - "askui.chat.api.messages.translator.load_source", - return_value=mock_pdf_source, - ) - - # Mock the extract_content method to return a simple text block - mock_text_block = TextBlockParam( - text="Extracted text content", type="text", cache_control=cache_control - ) - mocker.patch.object( - translator, "extract_content", return_value=[mock_text_block] - ) - - result = await translator.to_anthropic(document_block) - - assert isinstance(result, list) - assert len(result) == 2 # file info + extracted content - # First element should be the file info as TextBlockParam - assert isinstance(result[0], TextBlockParam) - assert result[0].type == "text" - assert result[0].cache_control == cache_control - # Second element should be the extracted content - assert result[1] == mock_text_block - - @pytest.mark.asyncio - async def test_to_anthropic_no_cache_control( - self, - translator: RequestDocumentBlockParamTranslator, - mocker: pytest_mock.MockerFixture, - ) -> None: - """Test conversion without cache control.""" - document_block = RequestDocumentBlockParam( - source={"file_id": "def456", "type": "file"}, - type="document", - ) - - # Mock the file service response - mock_file = MagicMock() - mock_file.model_dump_json.return_value = '{"id": "def456", "name": "test.pdf"}' - mock_path = Path("/tmp/test.pdf") - mocker.patch.object( - translator._file_service, - "retrieve_file_content", - return_value=(mock_file, mock_path), - ) - - # Mock the load_source function to avoid filesystem access - mock_pdf_source = PdfSource(root=mock_path) - mocker.patch( - "askui.chat.api.messages.translator.load_source", - return_value=mock_pdf_source, - ) - - # Mock the extract_content method to return a simple text block - mock_text_block = TextBlockParam(text="Extracted text content", type="text") - mocker.patch.object( - translator, "extract_content", return_value=[mock_text_block] - ) - - result = await translator.to_anthropic(document_block) - - assert isinstance(result, list) - assert len(result) == 2 # file info + extracted content - # First element should be the file info as TextBlockParam - assert isinstance(result[0], TextBlockParam) - assert result[0].cache_control is None - # Second element should be the extracted content - assert result[1] == mock_text_block