diff --git a/.env.template b/.env.template index c2a4d8a2..388edbf5 100644 --- a/.env.template +++ b/.env.template @@ -90,16 +90,12 @@ CHAT_TEMPERATURE=0.7 # SPEECH-TO-TEXT CONFIGURATION # ======================================== -# Primary transcription provider: deepgram, mistral, or parakeet +# Primary transcription provider: deepgram or parakeet TRANSCRIPTION_PROVIDER=deepgram # Deepgram configuration DEEPGRAM_API_KEY=your-deepgram-key-here -# Mistral configuration (when TRANSCRIPTION_PROVIDER=mistral) -MISTRAL_API_KEY=your-mistral-key-here -MISTRAL_MODEL=voxtral-mini-2507 - # Parakeet ASR configuration (when TRANSCRIPTION_PROVIDER=parakeet) PARAKEET_ASR_URL=http://host.docker.internal:8767 diff --git a/.github/workflows/robot-tests.yml b/.github/workflows/robot-tests.yml index 3333266d..b48b5e75 100644 --- a/.github/workflows/robot-tests.yml +++ b/.github/workflows/robot-tests.yml @@ -85,6 +85,18 @@ jobs: echo "✓ Test config.yml created from tests/configs/deepgram-openai.yml" ls -lh config/config.yml + - name: Create plugins.yml from template + run: | + echo "Creating plugins.yml from template..." + if [ -f "config/plugins.yml.template" ]; then + cp config/plugins.yml.template config/plugins.yml + echo "✓ plugins.yml created from template" + ls -lh config/plugins.yml + else + echo "❌ ERROR: config/plugins.yml.template not found" + exit 1 + fi + - name: Run Robot Framework tests working-directory: tests env: diff --git a/.gitignore b/.gitignore index 23141c6b..6fa02d7f 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,10 @@ tests/setup/.env.test config/config.yml !config/config.yml.template +# Plugins config (contains secrets) +config/plugins.yml +!config/plugins.yml.template + # Config backups config/*.backup.* config/*.backup* diff --git a/CLAUDE.md b/CLAUDE.md index 7f5f5507..88c901be 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -18,7 +18,7 @@ This supports a comprehensive web dashboard for management. Chronicle includes an **interactive setup wizard** for easy configuration. The wizard guides you through: - Service selection (backend + optional services) - Authentication setup (admin account, JWT secrets) -- Transcription provider configuration (Deepgram, Mistral, or offline ASR) +- Transcription provider configuration (Deepgram or offline ASR) - LLM provider setup (OpenAI or Ollama) - Memory provider selection (Chronicle Native with Qdrant or OpenMemory MCP) - Network configuration and HTTPS setup @@ -115,16 +115,8 @@ cp .env.template .env # Configure API keys # Run full integration test suite ./run-test.sh -# Manual test execution (for debugging) -source .env && export DEEPGRAM_API_KEY && export OPENAI_API_KEY -uv run robot --outputdir test-results --loglevel INFO ../../tests/integration/integration_test.robot - # Leave test containers running for debugging (don't auto-cleanup) -CLEANUP_CONTAINERS=false source .env && export DEEPGRAM_API_KEY && export OPENAI_API_KEY -uv run robot --outputdir test-results --loglevel INFO ../../tests/integration/integration_test.robot - -# Manual cleanup when needed -docker compose -f docker-compose-test.yml down -v +CLEANUP_CONTAINERS=false ./run-test.sh ``` #### Test Configuration Flags @@ -185,12 +177,12 @@ docker compose up --build ## Architecture Overview ### Key Components -- **Audio Pipeline**: Real-time Opus/PCM → Application-level processing → Deepgram/Mistral transcription → memory extraction +- **Audio Pipeline**: Real-time Opus/PCM → Application-level processing → Deepgram transcription → memory extraction - **Wyoming Protocol**: WebSocket communication uses Wyoming protocol (JSONL + binary) for structured audio sessions - **Unified Pipeline**: Job-based tracking system for all audio processing (WebSocket and file uploads) - **Job Tracker**: Tracks pipeline jobs with stage events (audio → transcription → memory) and completion status - **Task Management**: BackgroundTaskManager tracks all async tasks to prevent orphaned processes -- **Unified Transcription**: Deepgram/Mistral transcription with fallback to offline ASR services +- **Unified Transcription**: Deepgram transcription with fallback to offline ASR services - **Memory System**: Pluggable providers (Chronicle native or OpenMemory MCP) - **Authentication**: Email-based login with MongoDB ObjectId user system - **Client Management**: Auto-generated client IDs as `{user_id_suffix}-{device_name}`, centralized ClientManager @@ -206,7 +198,7 @@ Required: Recommended: - Vector Storage: Qdrant (Chronicle provider) or OpenMemory MCP server - - Transcription: Deepgram, Mistral, or offline ASR services + - Transcription: Deepgram or offline ASR services Optional: - Parakeet ASR: Offline transcription service @@ -330,12 +322,7 @@ Chronicle supports multiple transcription services: TRANSCRIPTION_PROVIDER=deepgram DEEPGRAM_API_KEY=your-deepgram-key-here -# Option 2: Mistral (Voxtral models) -TRANSCRIPTION_PROVIDER=mistral -MISTRAL_API_KEY=your-mistral-key-here -MISTRAL_MODEL=voxtral-mini-2507 - -# Option 3: Local ASR (Parakeet) +# Option 2: Local ASR (Parakeet) PARAKEET_ASR_URL=http://host.docker.internal:8767 ``` @@ -353,7 +340,7 @@ SPEAKER_SERVICE_URL=http://speaker-recognition:8085 ### Common Endpoints - **GET /health**: Basic application health check - **GET /readiness**: Service dependency validation -- **WS /ws_pcm**: Primary audio streaming endpoint (Wyoming protocol + raw PCM fallback) +- **WS /ws**: Audio streaming endpoint with codec parameter (Wyoming protocol, supports pcm and opus codecs) - **GET /api/conversations**: User's conversations with transcripts - **GET /api/memories/search**: Semantic memory search with relevance scoring - **POST /auth/jwt/login**: Email-based login (returns JWT token) diff --git a/Docs/audio-pipeline-architecture.md b/Docs/audio-pipeline-architecture.md new file mode 100644 index 00000000..f36f6e40 --- /dev/null +++ b/Docs/audio-pipeline-architecture.md @@ -0,0 +1,1244 @@ +# Audio Pipeline Architecture + +This document explains how audio flows through the Chronicle system from initial capture to final storage, including all intermediate processing stages, Redis streams, and data storage locations. + +## Table of Contents + +- [Overview](#overview) +- [Architecture Diagram](#architecture-diagram) +- [Data Sources](#data-sources) +- [Redis Streams: The Central Pipeline](#redis-streams-the-central-pipeline) +- [Producer: AudioStreamProducer](#producer-audiostreamproducer) +- [Dual-Consumer Architecture](#dual-consumer-architecture) +- [Transcription Results Aggregator](#transcription-results-aggregator) +- [Job Queue Orchestration (RQ)](#job-queue-orchestration-rq) +- [Data Storage](#data-storage) +- [Complete End-to-End Flow](#complete-end-to-end-flow) +- [Key Design Patterns](#key-design-patterns) +- [Failure Handling](#failure-handling) + +## Overview + +Chronicle's audio pipeline is built on three core technologies: + +- **Redis Streams**: Distributed message queues for audio chunks and transcription results +- **Background Tasks**: Async consumers that process streams independently +- **RQ Job Queue**: Orchestrates session-level and conversation-level workflows + +**Key Insight**: Multiple workers can independently consume the **same audio stream** using Redis Consumer Groups, enabling parallel processing paths (transcription + disk persistence) without duplication. + +## Architecture Diagram + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ AUDIO INPUT │ +│ WebSocket (/ws) │ File Upload (/audio/upload) │ Google Drive │ +└────────────────────────────────┬────────────────────────────────┘ + ↓ + ┌────────────────────────┐ + │ AudioStreamProducer │ + │ - Chunk audio (0.25s) │ + │ - Session metadata │ + └────────────┬───────────┘ + ↓ + ┌────────────────────────────────┐ + │ Redis Stream (Per Client) │ + │ audio:stream:{client_id} │ + └─────┬──────────────────┬───────┘ + ↓ ↓ + ┌───────────────────────┐ ┌──────────────────────┐ + │ Transcription Consumer│ │ Audio Persistence │ + │ Group (streaming/batch)│ │ Consumer Group │ + │ │ │ │ + │ → Deepgram WebSocket │ │ → Writes WAV files │ + │ → Batch buffering │ │ → Monitors rotation │ + │ → Publish results │ │ → Stores file paths │ + └───────────┬───────────┘ └──────────┬───────────┘ + ↓ ↓ + ┌───────────────────────┐ ┌──────────────────────┐ + │ transcription:results │ │ Disk Storage │ + │ :{session_id} │ │ data/chunks/*.wav │ + └───────────┬───────────┘ └──────────────────────┘ + ↓ + ┌───────────────────────┐ + │ TranscriptionResults │ + │ Aggregator │ + │ - Combines chunks │ + │ - Merges timestamps │ + └───────────┬───────────┘ + ↓ + ┌───────────────────────┐ + │ RQ Job Pipeline │ + ├───────────────────────┤ + │ speech_detection_job │ ← Session-level + │ ↓ │ + │ open_conversation_job │ ← Conversation-level + │ ↓ │ + │ Post-Conversation: │ + │ • transcribe_full │ + │ • speaker_recognition │ + │ • memory_extraction │ + │ • title_generation │ + └───────────┬───────────┘ + ↓ + ┌───────────────────────┐ + │ Final Storage │ + ├───────────────────────┤ + │ MongoDB: conversations│ + │ Disk: WAV files │ + │ Qdrant: Memories │ + └───────────────────────┘ +``` + +## Data Sources + +### 1. WebSocket Streaming (`/ws`) + +**Endpoint**: `/ws?codec=pcm|opus&token=xxx&device_name=xxx` + +**Handlers**: +- `handle_pcm_websocket()` - Raw PCM audio +- `handle_omi_websocket()` - Opus-encoded audio (compressed, used by OMI devices) + +**Protocol**: Wyoming Protocol (JSON lines + binary frames) + +**Authentication**: JWT token required + +**Location**: `backends/advanced/src/advanced_omi_backend/routers/websocket_routes.py` + +**Container**: `chronicle-backend` + +### 2. File Upload (`/audio/upload`) + +**Endpoint**: `POST /api/audio/upload` + +**Accepts**: Multiple WAV files (multipart form data) + +**Authentication**: Admin only + +**Device ID**: Auto-generated as `{user_id_suffix}-upload` or custom `device_name` + +**Location**: `backends/advanced/src/advanced_omi_backend/routers/api_router.py` + +**Container**: `chronicle-backend` + +### 3. Google Drive Upload + +**Endpoint**: `POST /api/audio/upload_audio_from_gdrive` + +**Source**: Google Drive folder ID + +**Processing**: Downloads files and enqueues for processing + +**Container**: `chronicle-backend` + +## Redis Streams: The Central Pipeline + +### Stream Naming Convention + +``` +audio:stream:{client_id} +``` + +**Examples**: +- `audio:stream:user01-phone` +- `audio:stream:user01-omi-device` +- `audio:stream:user01-upload` + +**Characteristics**: +- **Client-specific isolation**: Each device has its own stream +- **Fan-out pattern**: Multiple consumer groups read the same stream +- **MAXLEN constraint**: Keeps last 25,000 entries (auto-trimming) +- **No TTL**: Streams persist until manually deleted +- **Container**: `redis` service + +### Session Metadata Storage + +``` +audio:session:{session_id} +``` + +**Type**: Redis Hash + +**Fields**: +- `user_id`: MongoDB ObjectId +- `client_id`: Device identifier +- `connection_id`: WebSocket connection ID +- `stream_name`: `audio:stream:{client_id}` +- `status`: `"active"` → `"finalizing"` → `"complete"` +- `chunks_published`: Integer count +- `speech_detection_job_id`: RQ job ID +- `audio_persistence_job_id`: RQ job ID +- `websocket_connected`: `true|false` +- `transcription_error`: Error message (if any) + +**TTL**: 1 hour + +**Container**: `redis` + +### Transcription Results Stream + +``` +transcription:results:{session_id} +``` + +**Type**: Redis Stream + +**Written by**: Transcription consumers (streaming or batch) + +**Read by**: `TranscriptionResultsAggregator` + +**Message Fields**: +- `text`: Transcribed text for this chunk +- `chunk_id`: Redis message ID from audio stream +- `provider`: `"deepgram"` or `"parakeet"` +- `confidence`: Float (0.0-1.0) +- `words`: JSON array of word-level timestamps +- `segments`: JSON array of speaker segments + +**Lifecycle**: Deleted when conversation completes + +**Container**: `redis` + +### Conversation Tracking + +``` +conversation:current:{session_id} +``` + +**Type**: Redis String + +**Value**: Current `conversation_id` (UUID) + +**Purpose**: Signals audio persistence job to rotate WAV file + +**TTL**: 24 hours + +**Container**: `redis` + +### Audio File Path Mapping + +``` +audio:file:{conversation_id} +``` + +**Type**: Redis String + +**Value**: File path (e.g., `1704067200000_user01-phone_convid.wav`) + +**Purpose**: Links conversation to its audio file on disk + +**TTL**: 24 hours + +**Container**: `redis` + +## Producer: AudioStreamProducer + +**File**: `backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py` + +**Container**: `chronicle-backend` (in-memory, no persistence) + +### Responsibilities + +#### 1. Session Initialization + +```python +async def init_session( + session_id: str, + user_id: str, + client_id: str, + provider: str, + mode: str +) -> None +``` + +**Actions**: +- Creates `audio:session:{session_id}` hash in Redis +- Initializes in-memory buffer for chunking +- Stores session metadata (user, client, provider) + +#### 2. Audio Chunking + +```python +async def add_audio_chunk( + session_id: str, + audio_data: bytes +) -> list[str] +``` + +**Process**: +1. Buffers incoming audio (arbitrary size from WebSocket) +2. Creates **fixed-size chunks**: 0.25 seconds = 8,000 bytes + - Assumes: 16kHz sample rate, 16-bit mono PCM +3. Prevents cutting audio mid-word (aligned chunks) +4. Publishes each chunk to `audio:stream:{client_id}` via `XADD` +5. Returns Redis message IDs for tracking + +**In-Memory Storage**: Session buffers stored in `AudioStreamProducer._session_buffers` dict + +#### 3. Session End Signal + +```python +async def send_session_end_signal(session_id: str) -> None +``` + +**Actions**: +- Publishes special `{"type": "END"}` message to stream +- Signals all consumers to flush buffers and finalize +- Updates session status to `"finalizing"` + +### Data Location + +**Memory**: `chronicle-backend` container (in-memory buffers) + +**Redis**: Published chunks in `audio:stream:{client_id}` (redis container) + +## Dual-Consumer Architecture + +Chronicle uses **Redis Consumer Groups** to enable multiple independent consumers to read the **same audio stream** without message duplication. + +### Consumer Group 1: Transcription + +Two implementations available: + +#### A. Streaming Transcription Consumer + +**File**: `backends/advanced/src/advanced_omi_backend/services/transcription/streaming_consumer.py` + +**Class**: `StreamingTranscriptionConsumer` + +**Consumer Group**: `streaming-transcription` + +**Provider**: Deepgram (WebSocket-based) + +**Process**: +1. Discovers `audio:stream:*` streams dynamically using `SCAN` +2. Opens persistent WebSocket connection to Deepgram per stream +3. Sends audio chunks **immediately** (no buffering) +4. Publishes **interim results** to `transcription:interim:{session_id}` (Redis Pub/Sub) +5. Publishes **final results** to `transcription:results:{session_id}` (Redis Stream) +6. Triggers plugins on final results only +7. ACKs messages with `XACK` to prevent reprocessing +8. Handles END signal: closes WebSocket, cleans up + +**Container**: `chronicle-backend` (Background Task via `BackgroundTaskManager`) + +**Real-time Updates**: Interim results pushed to WebSocket clients via Pub/Sub + +#### B. Batch Transcription Consumer + +**File**: `backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py` + +**Class**: `BaseAudioStreamConsumer` + +**Consumer Group**: `{provider_name}_workers` (e.g., `deepgram_workers`, `parakeet_workers`) + +**Providers**: Deepgram (batch), Parakeet ASR (offline) + +**Process**: +1. Reads from `audio:stream:{client_id}` using `XREADGROUP` +2. Buffers chunks per session (default: 30 chunks = ~7.5 seconds) +3. When buffer full: + - Combines chunks into single audio buffer + - Transcribes using provider API + - Adjusts word/segment timestamps relative to session start + - Publishes result to `transcription:results:{session_id}` +4. Flushes remaining buffer on END signal +5. ACKs all buffered messages with `XACK` +6. Trims stream to keep only last 1,000 entries (`XTRIM MAXLEN`) + +**Container**: `chronicle-backend` (Background Task) + +**Batching Benefits**: Reduces API calls, improves transcription accuracy (more context) + +### Consumer Group 2: Audio Persistence + +**File**: `backends/advanced/src/advanced_omi_backend/workers/audio_jobs.py` + +**Function**: `audio_streaming_persistence_job()` + +**Consumer Group**: `audio_persistence` + +**Consumer Name**: `persistence-worker-{session_id}` + +**Process**: +1. Reads audio chunks from `audio:stream:{client_id}` using `XREADGROUP` +2. Monitors `conversation:current:{session_id}` for rotation signals +3. On conversation rotation: + - Closes current WAV file + - Opens new WAV file with new conversation ID +4. Writes chunks immediately to disk (real-time persistence) +5. Stores file path in `audio:file:{conversation_id}` (Redis) +6. Handles END signal: closes file, returns statistics +7. ACKs messages after writing to disk + +**Container**: `chronicle-backend` (RQ Worker) + +**Output Location**: `backends/advanced/data/chunks/` (volume-mounted) + +**File Format**: `{timestamp_ms}_{client_id}_{conversation_id}.wav` + +### Fan-Out Pattern Visualization + +``` +audio:stream:user01-phone + ↓ + ├─ Consumer Group: "streaming-transcription" + │ └─ Worker: streaming-worker-12345 + │ → Reads: chunks → Deepgram WS → Results stream + │ + ├─ Consumer Group: "deepgram_workers" + │ ├─ Worker: deepgram-worker-67890 + │ ├─ Worker: deepgram-worker-67891 + │ └─ Reads: chunks → Buffer (30) → Batch API → Results stream + │ + └─ Consumer Group: "audio_persistence" + └─ Worker: persistence-worker-sessionXYZ + → Reads: chunks → WAV file (disk) +``` + +**Key Benefits**: +- **Horizontal scaling**: Multiple workers per group +- **Independent processing**: Each group processes all messages +- **No message loss**: Messages ACKed only after processing +- **Decoupled**: Producer doesn't know about consumers + +## Transcription Results Aggregator + +**File**: `backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py` + +**Class**: `TranscriptionResultsAggregator` + +**Container**: `chronicle-backend` (in-memory, stateless) + +### Methods + +#### Get Combined Results + +```python +async def get_combined_results(session_id: str) -> dict +``` + +**Returns**: +```python +{ + "text": "Full transcript...", + "segments": [SpeakerSegment, ...], + "words": [Word, ...], + "provider": "deepgram", + "chunk_count": 42 +} +``` + +**Process**: +- Reads all entries from `transcription:results:{session_id}` +- For **streaming mode**: Uses latest final result only (supersedes interim) +- For **batch mode**: Combines all chunks sequentially +- Adjusts timestamps across chunks (adds audio offset) +- Merges speaker segments, words + +#### Get Session Results (Raw) + +```python +async def get_session_results(session_id: str) -> list[dict] +``` + +**Returns**: Raw list of transcription result messages + +#### Get Real-time Results + +```python +async def get_realtime_results( + session_id: str, + last_id: str = "0-0" +) -> tuple[list[dict], str] +``` + +**Returns**: `(new_results, new_last_id)` + +**Purpose**: Incremental polling for live UI updates + +### Data Location + +**Input**: `transcription:results:{session_id}` stream (redis container) + +**Processing**: In-memory (chronicle-backend container) + +**Output**: Returned to caller (no persistence) + +## Job Queue Orchestration (RQ) + +**Library**: Python RQ (Redis Queue) + +**File**: `backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py` + +**Containers**: +- `chronicle-backend` (enqueues jobs) +- `rq-worker` (executes jobs) + +### Job Pipeline + +``` +Session Starts + ↓ +┌─────────────────────────────────┐ +│ stream_speech_detection_job │ ← Session-level (long-running) +│ - Polls transcription results │ +│ - Analyzes speech content │ +│ - Checks speaker filters │ +└─────────────┬───────────────────┘ + ↓ (when speech detected) +┌─────────────────────────────────┐ +│ open_conversation_job │ ← Conversation-level (long-running) +│ - Creates conversation │ +│ - Signals file rotation │ +│ - Monitors activity │ +│ - Detects end conditions │ +└─────────────┬───────────────────┘ + ↓ (when conversation ends) +┌─────────────────────────────────┐ +│ Post-Conversation Pipeline │ ← Parallel batch jobs +├─────────────────────────────────┤ +│ • transcribe_full_audio_job │ +│ • recognize_speakers_job │ +│ • memory_extraction_job │ +│ • generate_title_summary_job │ +│ • dispatch_conversation_complete│ +└─────────────────────────────────┘ +``` + +### Session-Level Jobs + +#### Speech Detection Job + +**File**: `backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py` + +**Function**: `stream_speech_detection_job()` + +**Scope**: Entire session (can handle multiple conversations) + +**Max Duration**: 24 hours + +**Process**: +1. Polls `TranscriptionResultsAggregator.get_combined_results()` (1-second intervals) +2. Analyzes speech content: + - Word count > 10 + - Duration > 5 seconds + - Confidence > threshold +3. If speaker filter enabled: checks for enrolled speakers +4. When speech detected: + - Creates conversation in MongoDB + - Enqueues `open_conversation_job` + - **Exits** (restarts when conversation completes) +5. Handles transcription errors (marks session with error flag) + +**RQ Queue**: `speech_detection_queue` (dedicated queue) + +**Container**: `rq-worker` + +### Conversation-Level Jobs + +#### Open Conversation Job + +**File**: `backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py` + +**Function**: `open_conversation_job()` + +**Scope**: Single conversation + +**Max Duration**: 3 hours + +**Process**: +1. Creates conversation document in MongoDB `conversations` collection +2. Sets `conversation:current:{session_id}` = `conversation_id` (Redis) + - **Triggers audio persistence job to rotate WAV file** +3. Polls for transcription updates (1-second intervals) +4. Tracks speech activity (inactivity timeout = 60 seconds default) +5. Detects end conditions: + - WebSocket disconnect + - User manual stop + - Inactivity timeout +6. Waits for audio file path from persistence job +7. Saves `audio_path` to conversation document +8. Triggers conversation-level plugins +9. Enqueues post-conversation jobs +10. Calls `handle_end_of_conversation()` for cleanup + restart + +**RQ Queue**: `default` + +**Container**: `rq-worker` + +#### Audio Persistence Job + +**File**: `backends/advanced/src/advanced_omi_backend/workers/audio_jobs.py` + +**Function**: `audio_streaming_persistence_job()` + +**Scope**: Entire session (parallel with open_conversation_job) + +**Max Duration**: 24 hours + +**Process**: +1. Monitors `conversation:current:{session_id}` for rotation signals +2. For each conversation: + - Opens new WAV file: `{timestamp}_{client_id}_{conversation_id}.wav` + - Writes chunks immediately as they arrive from stream + - Stores file path in `audio:file:{conversation_id}` +3. On rotation signal: + - Closes current file + - Opens new file for next conversation +4. On END signal: + - Closes file + - Returns statistics (chunk count, bytes, duration) + +**Output**: WAV files in `backends/advanced/data/chunks/` + +**Container**: `rq-worker` + +### Post-Conversation Pipeline + +All jobs run **in parallel** after conversation completes: + +#### 1. Transcribe Full Audio Job + +**File**: `backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py` + +**Function**: `transcribe_full_audio_job()` + +**Input**: Audio file from disk (`data/chunks/*.wav`) + +**Process**: +- Batch transcribes entire conversation audio +- Validates meaningful speech +- Marks conversation `deleted` if no speech detected +- Stores transcript, segments, words in MongoDB + +**Container**: `rq-worker` + +#### 2. Recognize Speakers Job + +**File**: `backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py` + +**Function**: `recognize_speakers_job()` + +**Prerequisite**: `transcribe_full_audio_job` completes + +**Process**: +- Sends audio + segments to speaker recognition service +- Identifies speakers using voice embeddings +- Updates segment speaker labels in MongoDB + +**Optional**: Only runs if `DISABLE_SPEAKER_RECOGNITION=false` + +**Container**: `rq-worker` + +**External Service**: `speaker-recognition` container (if enabled) + +#### 3. Memory Extraction Job + +**File**: `backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py` + +**Function**: `memory_extraction_job()` + +**Prerequisite**: `transcribe_full_audio_job` completes + +**Process**: +- Uses LLM (OpenAI/Ollama) to extract semantic facts +- Stores embeddings in vector database: + - **Chronicle provider**: Qdrant + - **OpenMemory MCP provider**: External OpenMemory server + +**Container**: `rq-worker` + +**External Services**: +- `ollama` or OpenAI API (LLM) +- `qdrant` or OpenMemory MCP (vector storage) + +#### 4. Generate Title Summary Job + +**File**: `backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py` + +**Function**: `generate_title_summary_job()` + +**Prerequisite**: `transcribe_full_audio_job` completes + +**Process**: +- Uses LLM to generate: + - Title (short summary) + - Summary (1-2 sentences) + - Detailed summary (paragraph) +- Updates conversation document in MongoDB + +**Container**: `rq-worker` + +#### 5. Dispatch Conversation Complete Event + +**File**: `backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py` + +**Function**: `dispatch_conversation_complete_event_job()` + +**Process**: +- Triggers `conversation.complete` plugin event +- Only runs for **file uploads** (not streaming sessions) + +**Container**: `rq-worker` + +### Session Restart + +**File**: `backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py` + +**Function**: `handle_end_of_conversation()` + +**Process**: +1. Deletes transcription results stream: `transcription:results:{session_id}` +2. Increments `session:conversation_count:{session_id}` +3. Checks if session still active (WebSocket connected) +4. If active: Re-enqueues `stream_speech_detection_job` for next conversation +5. Cleans up consumer groups and pending messages + +**Purpose**: Allows continuous recording with multiple conversations per session + +## Data Storage + +### MongoDB Collections + +**Database**: `chronicle` + +**Container**: `mongo` + +**Volume**: `mongodb_data` (persistent) + +#### `conversations` Collection + +**Schema**: +```python +{ + "_id": ObjectId, + "conversation_id": "uuid-string", + "audio_uuid": "session_id", + "user_id": ObjectId, + "client_id": "user01-phone", + + # Content + "title": "Meeting notes", + "summary": "Discussion about...", + "detailed_summary": "Longer summary...", + "transcript": "Full transcript text", + "audio_path": "1704067200000_user01-phone_convid.wav", + + # Versioned Transcripts + "active_transcript_version": "v1", + "transcript_versions": { + "v1": { + "text": "Full transcript", + "segments": [SpeakerSegment], + "words": [Word], + "provider": "deepgram", + "processing_time_seconds": 45.2, + "created_at": "2025-01-11T12:00:00Z" + } + }, + "segments": [SpeakerSegment], # From active version + + # Metadata + "created_at": "2025-01-11T12:00:00Z", + "completed_at": "2025-01-11T12:15:00Z", + "end_reason": "user_stopped|inactivity_timeout|websocket_disconnect", + "deleted": false +} +``` + +**Indexes**: +- `user_id` (for user-scoped queries) +- `client_id` (for device filtering) +- `conversation_id` (unique) + +#### `audio_chunks` Collection + +**Purpose**: Stores raw audio session data + +**Schema**: +```python +{ + "_id": ObjectId, + "audio_uuid": "session_id", + "user_id": ObjectId, + "client_id": "user01-phone", + "created_at": "2025-01-11T12:00:00Z", + "metadata": { ... } +} +``` + +**Use Case**: Speech-driven architecture (sessions without conversations) + +#### `users` Collection + +**Purpose**: User accounts, authentication, preferences + +**Schema**: +```python +{ + "_id": ObjectId, + "email": "user@example.com", + "hashed_password": "...", + "is_active": true, + "is_superuser": false, + "created_at": "2025-01-11T12:00:00Z" +} +``` + +### Disk Storage + +**Location**: `backends/advanced/data/chunks/` + +**Container**: `chronicle-backend` (volume-mounted) + +**Volume**: `./backends/advanced/data/chunks:/app/data/chunks` + +**File Format**: WAV files + +**Naming Convention**: `{timestamp_ms}_{client_id}_{conversation_id}.wav` + +**Example**: `1704067200000_user01-phone_550e8400-e29b-41d4-a716-446655440000.wav` + +**Created by**: `audio_streaming_persistence_job()` + +**Read by**: Post-conversation transcription jobs + +**Retention**: Manual cleanup (no automatic deletion) + +### Redis Storage + +**Container**: `redis` + +**Volume**: `redis_data` (persistent) + +| Key Pattern | Type | Purpose | TTL | Created By | +|-------------|------|---------|-----|------------| +| `audio:stream:{client_id}` | Stream | Audio chunks for transcription | None (MAXLEN=25k) | AudioStreamProducer | +| `audio:session:{session_id}` | Hash | Session metadata | 1 hour | AudioStreamProducer | +| `transcription:results:{session_id}` | Stream | Transcription results | Manual delete | Transcription consumers | +| `transcription:interim:{session_id}` | Pub/Sub | Real-time interim results | N/A (ephemeral) | Streaming consumer | +| `conversation:current:{session_id}` | String | Current conversation ID | 24 hours | open_conversation_job | +| `audio:file:{conversation_id}` | String | Audio file path | 24 hours | audio_persistence_job | +| `session:conversation_count:{session_id}` | Counter | Conversation count | 1 hour | handle_end_of_conversation | +| `speech_detection_job:{client_id}` | String | Job ID for cleanup | 1 hour | speech_detection_job | +| `rq:job:{job_id}` | Hash | RQ job metadata | 24 hours (default) | RQ | + +### Vector Storage (Memory) + +#### Option A: Qdrant (Chronicle Native Provider) + +**Container**: `qdrant` + +**Volume**: `qdrant_data` (persistent) + +**Ports**: 6333 (HTTP), 6334 (gRPC) + +**Collections**: User-specific collections for semantic embeddings + +**Written by**: `memory_extraction_job()` + +**Read by**: Memory search API (`/api/memories/search`) + +#### Option B: OpenMemory MCP + +**Container**: `openmemory-mcp` (external service) + +**Port**: 8765 + +**Protocol**: MCP (Model Context Protocol) + +**Collections**: Cross-client memory storage + +**Written by**: `memory_extraction_job()` (via MCP provider) + +**Read by**: Memory search API (via MCP provider) + +## Complete End-to-End Flow + +### Step-by-Step Data Journey + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ 1. AUDIO INPUT │ +└─────────────────────────────────────────────────────────────────┘ + WebSocket (/ws) or File Upload (/audio/upload) + ↓ + Container: chronicle-backend + ↓ + AudioStreamProducer.init_session() + - Creates: audio:session:{session_id} (Redis) + - Initializes: In-memory buffer (chronicle-backend container) + ↓ + AudioStreamProducer.add_audio_chunk() + - Buffers: In-memory (chronicle-backend) + - Chunks: Fixed 0.25s chunks (8,000 bytes) + - Publishes: audio:stream:{client_id} (Redis) + - Returns: Redis message IDs + +┌─────────────────────────────────────────────────────────────────┐ +│ 2. SESSION-LEVEL JOB (RQ) │ +└─────────────────────────────────────────────────────────────────┘ + stream_speech_detection_job + Container: rq-worker + ↓ + Polls: TranscriptionResultsAggregator.get_combined_results() + Reads: transcription:results:{session_id} (Redis) + ↓ + Analyzes: Word count, duration, confidence + ↓ + When speech detected: + - Creates: Conversation document (MongoDB) + - Enqueues: open_conversation_job (RQ) + - Exits (restarts when conversation ends) + +┌─────────────────────────────────────────────────────────────────┐ +│ 3a. TRANSCRIPTION CONSUMER (Background Task) │ +└─────────────────────────────────────────────────────────────────┘ + StreamingTranscriptionConsumer (or BaseAudioStreamConsumer) + Container: chronicle-backend (Background Task) + ↓ + Reads: audio:stream:{client_id} (Redis, via XREADGROUP) + Consumer Group: streaming-transcription (or batch provider) + ↓ + STREAMING PATH: + • Opens: WebSocket to Deepgram + • Sends: Chunks immediately (no buffering) + • Publishes Interim: transcription:interim:{session_id} (Redis Pub/Sub) + • Publishes Final: transcription:results:{session_id} (Redis Stream) + • Triggers: Plugins on final results + + BATCH PATH: + • Buffers: 30 chunks (~7.5s) in memory (chronicle-backend) + • Combines: All buffered chunks + • Transcribes: Via provider API (Deepgram/Parakeet) + • Adjusts: Timestamps relative to session start + • Publishes: transcription:results:{session_id} (Redis Stream) + +┌─────────────────────────────────────────────────────────────────┐ +│ 3b. AUDIO PERSISTENCE CONSUMER (RQ Job) │ +└─────────────────────────────────────────────────────────────────┘ + audio_streaming_persistence_job + Container: rq-worker + ↓ + Reads: audio:stream:{client_id} (Redis, via XREADGROUP) + Consumer Group: audio_persistence + ↓ + Monitors: conversation:current:{session_id} (Redis) + ↓ + For each conversation: + • Opens: New WAV file (data/chunks/, chronicle-backend volume) + • Writes: Chunks immediately (real-time) + • Stores: audio:file:{conversation_id} = path (Redis) + ↓ + On rotation signal: + • Closes: Current file + • Opens: New file for next conversation + ↓ + On END signal: + • Closes: File + • Returns: Statistics (chunks, bytes, duration) + +┌─────────────────────────────────────────────────────────────────┐ +│ 4. CONVERSATION-LEVEL JOB (RQ) │ +└─────────────────────────────────────────────────────────────────┘ + open_conversation_job + Container: rq-worker + ↓ + Creates: Conversation document (MongoDB conversations collection) + ↓ + Sets: conversation:current:{session_id} = conversation_id (Redis) + → Triggers audio persistence job to rotate WAV file + ↓ + Polls: TranscriptionResultsAggregator for updates (1s intervals) + Reads: transcription:results:{session_id} (Redis) + ↓ + Tracks: Speech activity (inactivity timeout = 60s) + ↓ + Detects End: + - Inactivity (60s) + - User manual stop + - WebSocket disconnect + ↓ + Waits: For audio file path from persistence job + Reads: audio:file:{conversation_id} (Redis) + ↓ + Saves: audio_path to conversation document (MongoDB) + ↓ + Enqueues: POST-CONVERSATION PIPELINE (RQ) + +┌─────────────────────────────────────────────────────────────────┐ +│ 5. POST-CONVERSATION PIPELINE (RQ - Parallel Jobs) │ +└─────────────────────────────────────────────────────────────────┘ + All jobs run in parallel + Container: rq-worker + ↓ + Reads: Audio file from disk (data/chunks/*.wav) + + ┌─ transcribe_full_audio_job + │ - Batch transcribes: Complete audio file + │ - Validates: Meaningful speech + │ - Marks deleted: If no speech + │ - Stores: MongoDB (transcript, segments, words) + │ + │ └─ recognize_speakers_job (if enabled) + │ - Sends: Audio + segments to speaker-recognition service + │ - Identifies: Speakers via voice embeddings + │ - Updates: MongoDB (segment speaker labels) + │ + │ └─ memory_extraction_job + │ - Uses: LLM (OpenAI/Ollama) to extract facts + │ - Stores: Qdrant (Chronicle) or OpenMemory MCP (vector DB) + │ + └─ generate_title_summary_job + - Uses: LLM (OpenAI/Ollama) + - Generates: Title, summary, detailed_summary + - Stores: MongoDB (conversation document) + + └─ dispatch_conversation_complete_event_job + - Triggers: conversation.complete plugins + - Only for: File uploads (not streaming) + + All results stored: MongoDB conversations collection + +┌─────────────────────────────────────────────────────────────────┐ +│ 6. SESSION RESTART │ +└─────────────────────────────────────────────────────────────────┘ + handle_end_of_conversation() + Container: chronicle-backend + ↓ + Deletes: transcription:results:{session_id} (Redis) + ↓ + Increments: session:conversation_count:{session_id} (Redis) + ↓ + Checks: Session still active? (WebSocket connected) + ↓ + If active: + - Re-enqueues: stream_speech_detection_job (RQ) + - Session remains: "active" for next conversation +``` + +### Data Locations Summary + +| Stage | Data Type | Location | Container | +|-------|-----------|----------|-----------| +| Input | Audio bytes | In-memory buffers | chronicle-backend | +| Producer | Fixed chunks | `audio:stream:{client_id}` | redis | +| Session metadata | Hash | `audio:session:{session_id}` | redis | +| Transcription consumer | Interim results | `transcription:interim:{session_id}` (Pub/Sub) | redis | +| Transcription consumer | Final results | `transcription:results:{session_id}` (Stream) | redis | +| Audio persistence | WAV files | `data/chunks/*.wav` (disk volume) | chronicle-backend (volume) | +| Audio persistence | File paths | `audio:file:{conversation_id}` | redis | +| Conversation job | Conversation doc | MongoDB `conversations` | mongo | +| Post-processing | Transcript | MongoDB `conversations` | mongo | +| Post-processing | Memories | Qdrant or OpenMemory MCP | qdrant / openmemory-mcp | +| Post-processing | Title/summary | MongoDB `conversations` | mongo | + +## Key Design Patterns + +### 1. Speech-Driven Architecture + +**Principle**: Conversations only created when speech is detected + +**Benefits**: +- Clean user experience (no noise-only sessions in UI) +- Reduced memory processing load +- Automatic quality filtering + +**Implementation**: +- `audio_chunks` collection: Always stores sessions +- `conversations` collection: Only created with speech +- Speech detection: Analyzes word count, duration, confidence + +### 2. Versioned Processing + +**Principle**: Store multiple versions of transcripts/memories + +**Benefits**: +- Reprocess without losing originals +- A/B testing different providers +- Rollback to previous versions + +**Implementation**: +- `transcript_versions` dict with version IDs (v1, v2, ...) +- `active_transcript_version` pointer +- `segments` field mirrors active version (quick access) + +### 3. Session-Level vs Conversation-Level + +**Session**: WebSocket connection lifetime (multiple conversations) +- Duration: Up to 24 hours +- Job: `stream_speech_detection_job` +- Purpose: Continuous monitoring for speech + +**Conversation**: Speech burst between silence periods +- Duration: Typically minutes +- Job: `open_conversation_job` +- Purpose: Process single meaningful exchange + +**Benefits**: +- Continuous recording without manual start/stop +- Automatic conversation segmentation +- Efficient resource usage (one session, many conversations) + +### 4. Job Metadata Cascading + +**Pattern**: Parent jobs link to child jobs + +**Example**: +``` +speech_detection_job + ↓ job_id stored in +audio:session:{session_id} + ↓ creates +open_conversation_job + ↓ job_id stored in +conversation document + ↓ creates +post-conversation jobs (parallel) +``` + +**Benefits**: +- Job grouping and cleanup +- Dependency tracking +- Debugging (trace job lineage) + +### 5. Real-Time + Batch Hybrid + +**Real-Time Path** (Streaming Consumer): +- Low latency (interim results in <1 second) +- WebSocket to Deepgram +- Publishes to Pub/Sub for live UI updates + +**Batch Path** (Batch Consumer): +- High accuracy (more context) +- Buffers 7.5 seconds +- API-based transcription + +**Both paths** write to same `transcription:results:{session_id}` stream + +**Benefits**: +- Live UI updates (interim results) +- Accurate final results (batch processing) +- Provider flexibility (switch between streaming/batch) + +### 6. Fan-Out via Redis Consumer Groups + +**Pattern**: Multiple consumer groups read same stream + +**Example**: `audio:stream:{client_id}` consumed by: +- Transcription consumer group +- Audio persistence consumer group + +**Benefits**: +- Parallel processing paths +- Horizontal scaling (multiple workers per group) +- No message duplication (each group processes independently) + +### 7. File Rotation via Redis Signals + +**Pattern**: Conversation job signals persistence job via Redis key + +**Implementation**: +```python +# Conversation job +redis.set(f"conversation:current:{session_id}", conversation_id) + +# Persistence job (monitors key) +current_conv = redis.get(f"conversation:current:{session_id}") +if current_conv != last_conv: + close_current_file() + open_new_file(current_conv) +``` + +**Benefits**: +- Decoupled jobs (no direct communication) +- Real-time file rotation +- Multiple files per session (one per conversation) + +## Failure Handling + +### Transcription Errors + +**Detection**: `stream_speech_detection_job` polls results + +**Action**: +- Sets `transcription_error` field in `audio:session:{session_id}` +- Logs error for debugging +- Session remains active (can recover) + +### No Meaningful Speech + +**Detection**: `transcribe_full_audio_job` validates transcript + +**Criteria**: +- Word count < 10 +- Duration < 5 seconds +- All words low confidence + +**Action**: +- Marks conversation `deleted=True` +- Sets `end_reason="no_meaningful_speech"` +- Conversation hidden from UI + +### Audio File Not Ready + +**Detection**: `open_conversation_job` waits for file path + +**Timeout**: 30 seconds (configurable) + +**Action**: +- Marks conversation `deleted=True` +- Sets `end_reason="audio_file_not_ready"` +- Logs error for debugging + +### Job Zombies (Stuck Jobs) + +**Detection**: `check_job_alive()` utility + +**Method**: Checks Redis for job existence + +**Action**: +- Returns `False` if job missing +- Caller can retry or fail gracefully + +### Dead Consumers + +**Detection**: Consumer group lag monitoring + +**Cleanup**: +- Removes idle consumers (>30 seconds) +- Claims pending messages from dead consumers +- Redistributes to active workers + +### Stream Trimming + +**Prevention**: Streams don't grow unbounded + +**Implementation**: +- `XTRIM MAXLEN 25000` on `audio:stream:{client_id}` +- Keeps last 25k messages (~104 minutes @ 0.25s chunks) +- Deletes `transcription:results:{session_id}` after conversation ends + +### Session Timeout + +**Max Duration**: 24 hours + +**Action**: +- Jobs exit gracefully +- Session marked `"complete"` +- Resources cleaned up (streams deleted, consumer groups removed) + +--- + +## Conclusion + +Chronicle's audio pipeline is designed for: +- **Real-time processing**: Low-latency transcription and live UI updates +- **Horizontal scalability**: Redis Consumer Groups enable multiple workers +- **Fault tolerance**: Decoupled components, job retries, graceful error handling +- **Resource efficiency**: Speech-driven architecture filters noise automatically +- **Flexibility**: Pluggable providers (Deepgram/Parakeet, OpenAI/Ollama, Qdrant/OpenMemory) + +All coordinated through **Redis Streams** for data flow and **RQ** for orchestration, with **MongoDB** for final storage and **disk** for audio archives. diff --git a/app/README.md b/app/README.md index d73dd748..e85e83e5 100644 --- a/app/README.md +++ b/app/README.md @@ -120,14 +120,14 @@ The app connects to any backend that accepts OPUS audio streams: 2. **Advanced Backend** (`backends/advanced/`) - Full transcription and memory features - Real-time processing with speaker recognition - - WebSocket endpoint: `/ws_pcm` + - WebSocket endpoint: `/ws?codec=pcm` ### Connection Setup #### Local Development ``` -Backend URL: ws://[machine-ip]:8000/ws_pcm -Example: ws://192.168.1.100:8000/ws_pcm +Backend URL: ws://[machine-ip]:8000/ws?codec=pcm +Example: ws://192.168.1.100:8000/ws?codec=pcm ``` #### Public Access (Production) @@ -138,7 +138,7 @@ Use ngrok or similar tunneling service: ngrok http 8000 # Use provided URL in app -Backend URL: wss://[ngrok-subdomain].ngrok.io/ws_pcm +Backend URL: wss://[ngrok-subdomain].ngrok.io/ws?codec=pcm ``` ### Configuration Steps @@ -147,8 +147,8 @@ Backend URL: wss://[ngrok-subdomain].ngrok.io/ws_pcm 2. **Open the mobile app** 3. **Navigate to Settings** 4. **Enter Backend URL**: - - Local: `ws://[your-ip]:8000/ws_pcm` - - Public: `wss://[your-domain]/ws_pcm` + - Local: `ws://[your-ip]:8000/ws?codec=pcm` + - Public: `wss://[your-domain]/ws?codec=pcm` 5. **Save configuration** ## Phone Audio Streaming (NEW) @@ -176,7 +176,7 @@ Stream audio directly from your phone's microphone to Chronicle backend, bypassi - **iOS**: iOS 13+ with microphone permissions - **Android**: Android API 21+ with microphone permissions - **Network**: Stable connection to Chronicle backend -- **Backend**: Advanced backend running with `/ws_pcm` endpoint +- **Backend**: Advanced backend running with `/ws?codec=pcm` endpoint #### Switching Audio Sources - **Mutual Exclusion**: Cannot use Bluetooth and phone audio simultaneously @@ -187,7 +187,7 @@ Stream audio directly from your phone's microphone to Chronicle backend, bypassi #### Audio Not Streaming - **Check Permissions**: Ensure microphone access granted -- **Verify Backend URL**: Confirm `ws://[ip]:8000/ws_pcm` format +- **Verify Backend URL**: Confirm `ws://[ip]:8000/ws?codec=pcm` format - **Network Connection**: Test backend connectivity - **Authentication**: Verify JWT token is valid @@ -292,7 +292,7 @@ curl -i -N -H "Connection: Upgrade" \ -H "Upgrade: websocket" \ -H "Sec-WebSocket-Key: test" \ -H "Sec-WebSocket-Version: 13" \ - http://[backend-ip]:8000/ws_pcm + http://[backend-ip]:8000/ws?codec=pcm ``` ## Development @@ -338,7 +338,7 @@ npx expo build:android ### WebSocket Communication ```javascript // Connect to backend -const ws = new WebSocket('ws://backend-url:8000/ws_pcm'); +const ws = new WebSocket('ws://backend-url:8000/ws?codec=pcm'); // Send audio data ws.send(audioBuffer); diff --git a/app/app/components/BackendStatus.tsx b/app/app/components/BackendStatus.tsx index 75fdd7a8..4f55d37f 100644 --- a/app/app/components/BackendStatus.tsx +++ b/app/app/components/BackendStatus.tsx @@ -208,9 +208,9 @@ export const BackendStatus: React.FC = ({ - Enter the WebSocket URL of your backend server. Simple backend: http://localhost:8000/ (no auth). + Enter the WebSocket URL of your backend server. Simple backend: http://localhost:8000/ (no auth). Advanced backend: http://localhost:8080/ (requires login). Status is automatically checked. - The websocket URL can be different or the same as the HTTP URL, with /ws_omi suffix + The websocket URL can be different or the same as the HTTP URL, with /ws endpoint and codec parameter (e.g., /ws?codec=pcm) ); diff --git a/app/app/index.tsx b/app/app/index.tsx index fc924d92..649a2e2b 100644 --- a/app/app/index.tsx +++ b/app/app/index.tsx @@ -322,10 +322,16 @@ export default function App() { // Convert HTTP/HTTPS to WS/WSS protocol finalWebSocketUrl = finalWebSocketUrl.replace(/^http:/, 'ws:').replace(/^https:/, 'wss:'); - // Ensure /ws_pcm endpoint is included - if (!finalWebSocketUrl.includes('/ws_pcm')) { - // Remove trailing slash if present, then add /ws_pcm - finalWebSocketUrl = finalWebSocketUrl.replace(/\/$/, '') + '/ws_pcm'; + // Ensure /ws endpoint is included + if (!finalWebSocketUrl.includes('/ws')) { + // Remove trailing slash if present, then add /ws + finalWebSocketUrl = finalWebSocketUrl.replace(/\/$/, '') + '/ws'; + } + + // Add codec parameter if not present + if (!finalWebSocketUrl.includes('codec=')) { + const separator = finalWebSocketUrl.includes('?') ? '&' : '?'; + finalWebSocketUrl = finalWebSocketUrl + separator + 'codec=pcm'; } // Check if this is the advanced backend (requires authentication) or simple backend diff --git a/backends/advanced/.dockerignore b/backends/advanced/.dockerignore index 2dd9b44f..f0f7f05c 100644 --- a/backends/advanced/.dockerignore +++ b/backends/advanced/.dockerignore @@ -17,5 +17,5 @@ !nginx.conf.template !start.sh !start-k8s.sh -!start-workers.sh +!worker_orchestrator.py !Caddyfile \ No newline at end of file diff --git a/backends/advanced/.env.template b/backends/advanced/.env.template index a63ab6f5..9c11af67 100644 --- a/backends/advanced/.env.template +++ b/backends/advanced/.env.template @@ -216,4 +216,41 @@ CORS_ORIGINS=http://localhost:5173,http://localhost:3000,http://127.0.0.1:5173,h LANGFUSE_PUBLIC_KEY="" LANGFUSE_SECRET_KEY="" LANGFUSE_HOST="http://x.x.x.x:3002" -LANGFUSE_ENABLE_TELEMETRY=False \ No newline at end of file +LANGFUSE_ENABLE_TELEMETRY=False + +# ======================================== +# TAILSCALE CONFIGURATION (Optional) +# ======================================== +# Required for accessing remote services on Tailscale network (e.g., Home Assistant plugin) +# +# To enable Tailscale Docker integration: +# 1. Get auth key from: https://login.tailscale.com/admin/settings/keys +# 2. Set TS_AUTHKEY below +# 3. Start Tailscale: docker compose --profile tailscale up -d +# +# The Tailscale container provides proxy access to remote services at: +# http://host.docker.internal:18123 (proxies to Home Assistant on Tailscale) +# +TS_AUTHKEY=your-tailscale-auth-key-here + +# ======================================== +# HOME ASSISTANT PLUGIN (Optional) +# ======================================== +# Required for Home Assistant voice control via wake word (e.g., "Hey Vivi, turn off the lights") +# +# To get a long-lived access token: +# 1. Go to Home Assistant → Profile → Security tab +# 2. Scroll to "Long-lived access tokens" +# 3. Click "Create Token" +# 4. Copy the token and paste it below +# +# Configuration in config/plugins.yml: +# - Enable the homeassistant plugin +# - Set ha_url to your Home Assistant URL +# - Set ha_token to ${HA_TOKEN} (reads from this variable) +# +# SECURITY: This token grants full access to your Home Assistant. +# - Never commit .env or config/plugins.yml to version control +# - Rotate the token if it's ever exposed +# +HA_TOKEN= \ No newline at end of file diff --git a/backends/advanced/Dockerfile b/backends/advanced/Dockerfile index 352bcfe9..886c1f32 100644 --- a/backends/advanced/Dockerfile +++ b/backends/advanced/Dockerfile @@ -1,6 +1,9 @@ -FROM python:3.12-slim-bookworm AS builder +# ============================================ +# Base stage - common setup +# ============================================ +FROM python:3.12-slim-bookworm AS base -# Install system dependencies for building +# Install system dependencies RUN apt-get update && \ apt-get install -y --no-install-recommends \ build-essential \ @@ -9,40 +12,59 @@ RUN apt-get update && \ curl \ ffmpeg \ && rm -rf /var/lib/apt/lists/* - # portaudio19-dev \ # Install uv COPY --from=ghcr.io/astral-sh/uv:0.6.10 /uv /uvx /bin/ -# Set up the working directory +# Set up working directory WORKDIR /app -# Copy package structure and dependency files first +# Copy package structure and dependency files COPY pyproject.toml README.md ./ COPY uv.lock . RUN mkdir -p src/advanced_omi_backend COPY src/advanced_omi_backend/__init__.py src/advanced_omi_backend/ -# Install dependencies using uv with deepgram extra -# Use cache mount for BuildKit, fallback for legacy builds -# RUN --mount=type=cache,target=/root/.cache/uv \ -# uv sync --extra deepgram -# Fallback for legacy Docker builds (CI compatibility) + +# ============================================ +# Production stage - production dependencies only +# ============================================ +FROM base AS prod + +# Install production dependencies only RUN uv sync --extra deepgram # Copy all application code COPY . . -# Copy configuration files if they exist, otherwise they will be created from templates at runtime -# The files are expected to exist, but we handle the case where they don't gracefully - +# Copy configuration files if they exist COPY diarization_config.json* ./ +# Copy and make startup script executable +COPY start.sh ./ +RUN chmod +x start.sh + +# Run the application +CMD ["./start.sh"] + + +# ============================================ +# Dev/Test stage - includes test dependencies +# ============================================ +FROM base AS dev + +# Install production + test dependencies +RUN uv sync --extra deepgram --group test + +# Copy all application code +COPY . . + +# Copy configuration files if they exist +COPY diarization_config.json* ./ -# Copy and make startup scripts executable +# Copy and make startup script executable COPY start.sh ./ -COPY start-workers.sh ./ -RUN chmod +x start.sh start-workers.sh +RUN chmod +x start.sh -# Run the application with workers +# Run the application CMD ["./start.sh"] diff --git a/backends/advanced/Dockerfile.k8s b/backends/advanced/Dockerfile.k8s index b746752a..6500ccf5 100644 --- a/backends/advanced/Dockerfile.k8s +++ b/backends/advanced/Dockerfile.k8s @@ -36,9 +36,9 @@ COPY . . # Copy memory config (created by init.sh from template) -# Copy and make K8s startup scripts executable -COPY start-k8s.sh start-workers.sh ./ -RUN chmod +x start-k8s.sh start-workers.sh +# Copy and make K8s startup script executable +COPY start-k8s.sh ./ +RUN chmod +x start-k8s.sh # Activate virtual environment in PATH ENV PATH="/app/.venv/bin:$PATH" diff --git a/backends/advanced/Docs/architecture.md b/backends/advanced/Docs/architecture.md index 7c6427bb..739f0ed7 100644 --- a/backends/advanced/Docs/architecture.md +++ b/backends/advanced/Docs/architecture.md @@ -22,7 +22,7 @@ graph TB %% Main WebSocket Server subgraph "WebSocket Server" - WS["/ws_pcm endpoint"] + WS["/ws?codec=pcm endpoint"] AUTH[JWT Auth] end @@ -237,13 +237,13 @@ Wyoming is a peer-to-peer protocol for voice assistants that combines JSONL (JSO #### Backend Implementation -**Advanced Backend (`/ws_pcm`)**: +**Advanced Backend (`/ws?codec=pcm`)**: - **Full Wyoming Protocol Support**: Parses all Wyoming events for comprehensive session management - **Session State Tracking**: Only processes audio chunks when session is active (after receiving audio-start) - **Conversation Boundaries**: Uses Wyoming audio-start/stop events to define precise conversation segments - **PCM Audio Processing**: Direct processing of PCM audio data from all apps -**Advanced Backend (`/ws_omi`)**: +**Advanced Backend (`/ws?codec=opus`)**: - **Wyoming Protocol + Opus Decoding**: Combines Wyoming session management with OMI Opus decoding - **Continuous Streaming**: OMI devices stream continuously, audio-start/stop events are optional - **Timestamp Preservation**: Uses timestamps from Wyoming headers when provided @@ -1006,8 +1006,8 @@ src/advanced_omi_backend/ - `POST /api/conversations/{conversation_id}/activate-transcript` - Switch transcript version - `POST /api/conversations/{conversation_id}/activate-memory` - Switch memory version - `POST /api/audio/upload` - Batch audio file upload and processing -- WebSocket `/ws_omi` - Real-time Opus audio streaming with Wyoming protocol (OMI devices) -- WebSocket `/ws_pcm` - Real-time PCM audio streaming with Wyoming protocol (all apps) +- WebSocket `/ws?codec=opus` - Real-time Opus audio streaming with Wyoming protocol (OMI devices) +- WebSocket `/ws?codec=pcm` - Real-time PCM audio streaming with Wyoming protocol (all apps) ### Authentication & Authorization - **JWT Tokens**: All API endpoints require valid JWT authentication diff --git a/backends/advanced/Docs/auth.md b/backends/advanced/Docs/auth.md index acbf8df4..7998750e 100644 --- a/backends/advanced/Docs/auth.md +++ b/backends/advanced/Docs/auth.md @@ -100,13 +100,13 @@ curl -X POST "http://localhost:8000/auth/jwt/login" \ #### Token-based (Recommended) ```javascript -const ws = new WebSocket('ws://localhost:8000/ws_pcm?token=JWT_TOKEN&device_name=phone'); +const ws = new WebSocket('ws://localhost:8000/ws?codec=pcm?token=JWT_TOKEN&device_name=phone'); ``` #### Cookie-based ```javascript // Requires existing cookie from web login -const ws = new WebSocket('ws://localhost:8000/ws_pcm?device_name=phone'); +const ws = new WebSocket('ws://localhost:8000/ws?codec=pcm?device_name=phone'); ``` ## Client ID Management @@ -183,8 +183,8 @@ COOKIE_SECURE=false - `PATCH /api/users/me` - Update user profile ### WebSocket Endpoints -- `ws://host/ws` - Opus audio stream with auth -- `ws://host/ws_pcm` - PCM audio stream with auth +- `ws://host/ws?codec=opus` - Opus audio stream with auth +- `ws://host/ws?codec=pcm` - PCM audio stream with auth (default) ## Error Handling diff --git a/backends/advanced/Docs/memory-configuration-guide.md b/backends/advanced/Docs/memory-configuration-guide.md index 12796e13..66244003 100644 --- a/backends/advanced/Docs/memory-configuration-guide.md +++ b/backends/advanced/Docs/memory-configuration-guide.md @@ -65,7 +65,7 @@ memory: - **Embeddings**: `text-embedding-3-small`, `text-embedding-3-large` #### Ollama Models (Local) -- **LLM**: `llama3`, `mistral`, `qwen2.5` +- **LLM**: `llama3`, `qwen2.5` - **Embeddings**: `nomic-embed-text`, `all-minilm` ## Hot Reload diff --git a/backends/advanced/README.md b/backends/advanced/README.md index 0f5a4490..7f3d5a24 100644 --- a/backends/advanced/README.md +++ b/backends/advanced/README.md @@ -31,7 +31,7 @@ Modern React-based web dashboard located in `./webui/` with: **The setup wizard guides you through:** - **Authentication**: Admin email/password setup with secure keys -- **Transcription Provider**: Choose between Deepgram, Mistral, or Offline (Parakeet) +- **Transcription Provider**: Choose between Deepgram or Offline (Parakeet) - **LLM Provider**: Choose between OpenAI (recommended) or Ollama for memory extraction - **Memory Provider**: Choose between Friend-Lite Native or OpenMemory MCP - **HTTPS Configuration**: Optional SSL setup for microphone access (uses Caddy) diff --git a/backends/advanced/docker-compose-test.yml b/backends/advanced/docker-compose-test.yml index e4203f91..e01a75f6 100644 --- a/backends/advanced/docker-compose-test.yml +++ b/backends/advanced/docker-compose-test.yml @@ -7,21 +7,24 @@ services: build: context: . dockerfile: Dockerfile + target: dev # Use dev stage with test dependencies + command: ["./start.sh", "--test"] ports: - "8001:8000" # Avoid conflict with dev on 8000 volumes: - ./src:/app/src # Mount source code for easier development - ./data/test_audio_chunks:/app/audio_chunks - - ./data/test_debug_dir:/app/debug_dir + - ./data/test_debug_dir:/app/debug # Fixed: mount to /app/debug for plugin database - ./data/test_data:/app/data - ${CONFIG_FILE:-../../config/config.yml}:/app/config.yml # Mount config.yml for model registry and memory settings (writable for admin config updates) + - ${PLUGINS_CONFIG:-../../tests/config/plugins.test.yml}:/app/plugins.yml # Mount test plugins config environment: # Override with test-specific settings - MONGODB_URI=mongodb://mongo-test:27017/test_db - QDRANT_BASE_URL=qdrant-test - QDRANT_PORT=6333 - REDIS_URL=redis://redis-test:6379/0 - - DEBUG_DIR=/app/debug_dir + - DEBUG_DIR=/app/debug # Fixed: match plugin database mount path # Import API keys from environment - DEEPGRAM_API_KEY=${DEEPGRAM_API_KEY} - OPENAI_API_KEY=${OPENAI_API_KEY} @@ -44,6 +47,9 @@ services: - CORS_ORIGINS=http://localhost:3001,http://localhost:8001,https://localhost:3001,https://localhost:8001 # Set low inactivity timeout for tests (2 seconds instead of 60) - SPEECH_INACTIVITY_THRESHOLD_SECONDS=2 + # Set low speech detection thresholds for tests + - SPEECH_DETECTION_MIN_DURATION=2.0 # 2 seconds instead of 10 + - SPEECH_DETECTION_MIN_WORDS=5 # 5 words instead of 10 # Wait for audio queue to drain before timing out (test mode) - WAIT_FOR_AUDIO_QUEUE_DRAIN=true depends_on: @@ -53,8 +59,6 @@ services: condition: service_healthy redis-test: condition: service_started - speaker-service-test: - condition: service_healthy healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8000/readiness"] interval: 10s @@ -154,20 +158,23 @@ services: build: context: . dockerfile: Dockerfile - command: ./start-workers.sh + target: dev # Use dev stage with test dependencies + command: ["uv", "run", "--group", "test", "python", "worker_orchestrator.py"] volumes: - ./src:/app/src + - ./worker_orchestrator.py:/app/worker_orchestrator.py - ./data/test_audio_chunks:/app/audio_chunks - - ./data/test_debug_dir:/app/debug_dir + - ./data/test_debug_dir:/app/debug # Fixed: mount to /app/debug for plugin database - ./data/test_data:/app/data - ${CONFIG_FILE:-../../config/config.yml}:/app/config.yml # Mount config.yml for model registry and memory settings (writable for admin config updates) + - ${PLUGINS_CONFIG:-../../tests/config/plugins.test.yml}:/app/plugins.yml # Mount test plugins config environment: # Same environment as backend - MONGODB_URI=mongodb://mongo-test:27017/test_db - QDRANT_BASE_URL=qdrant-test - QDRANT_PORT=6333 - REDIS_URL=redis://redis-test:6379/0 - - DEBUG_DIR=/app/debug_dir + - DEBUG_DIR=/app/debug # Fixed: match plugin database mount path - DEEPGRAM_API_KEY=${DEEPGRAM_API_KEY} - OPENAI_API_KEY=${OPENAI_API_KEY} - GROQ_API_KEY=${GROQ_API_KEY} @@ -185,6 +192,9 @@ services: - SPEAKER_SERVICE_URL=http://speaker-service-test:8085 # Set low inactivity timeout for tests (2 seconds instead of 60) - SPEECH_INACTIVITY_THRESHOLD_SECONDS=2 + # Set low speech detection thresholds for tests + - SPEECH_DETECTION_MIN_DURATION=2.0 # 2 seconds instead of 10 + - SPEECH_DETECTION_MIN_WORDS=5 # 5 words instead of 10 # Wait for audio queue to drain before timing out (test mode) - WAIT_FOR_AUDIO_QUEUE_DRAIN=true depends_on: @@ -196,8 +206,6 @@ services: condition: service_started qdrant-test: condition: service_started - speaker-service-test: - condition: service_healthy restart: unless-stopped # Mycelia - AI memory and timeline service (test environment) diff --git a/backends/advanced/docker-compose.yml b/backends/advanced/docker-compose.yml index f46a23fa..ceaaf6a8 100644 --- a/backends/advanced/docker-compose.yml +++ b/backends/advanced/docker-compose.yml @@ -1,8 +1,35 @@ services: + tailscale: + image: tailscale/tailscale:latest + container_name: advanced-tailscale + hostname: chronicle-tailscale + environment: + - TS_AUTHKEY=${TS_AUTHKEY} + - TS_STATE_DIR=/var/lib/tailscale + - TS_USERSPACE=false + - TS_ACCEPT_DNS=true + volumes: + - tailscale-state:/var/lib/tailscale + devices: + - /dev/net/tun:/dev/net/tun + cap_add: + - NET_ADMIN + restart: unless-stopped + profiles: + - tailscale # Optional profile + ports: + - "18123:18123" # HA proxy port + command: > + sh -c "tailscaled & + tailscale up --authkey=$${TS_AUTHKEY} --accept-dns=true && + apk add --no-cache socat 2>/dev/null || true && + socat TCP-LISTEN:18123,fork,reuseaddr TCP:100.99.62.5:8123" + chronicle-backend: build: context: . dockerfile: Dockerfile + target: prod # Use prod stage without test dependencies ports: - "8000:8000" env_file: @@ -12,7 +39,8 @@ services: - ./data/audio_chunks:/app/audio_chunks - ./data/debug_dir:/app/debug_dir - ./data:/app/data - - ../../config/config.yml:/app/config.yml # Removed :ro to allow UI config saving + - ../../config/config.yml:/app/config.yml # Main config file + - ../../config/plugins.yml:/app/plugins.yml # Plugin configuration environment: - DEEPGRAM_API_KEY=${DEEPGRAM_API_KEY} - PARAKEET_ASR_URL=${PARAKEET_ASR_URL} @@ -26,6 +54,7 @@ services: - NEO4J_HOST=${NEO4J_HOST} - NEO4J_USER=${NEO4J_USER} - NEO4J_PASSWORD=${NEO4J_PASSWORD} + - HA_TOKEN=${HA_TOKEN} - CORS_ORIGINS=http://localhost:3010,http://localhost:8000,http://192.168.1.153:3010,http://192.168.1.153:8000,https://localhost:3010,https://localhost:8000,https://100.105.225.45,https://localhost - REDIS_URL=redis://redis:6379/0 depends_on: @@ -35,6 +64,8 @@ services: condition: service_healthy redis: condition: service_healthy + extra_hosts: + - "host.docker.internal:host-gateway" # Access host's Tailscale network healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8000/readiness"] interval: 30s @@ -46,27 +77,37 @@ services: # Unified Worker Container # No CUDA needed for chronicle-backend and workers, workers only orchestrate jobs and call external services # Runs all workers in a single container for efficiency: - # - 3 RQ workers (transcription, memory, default queues) - # - 1 Audio stream worker (Redis Streams consumer - must be single to maintain sequential chunks) + # - 6 RQ workers (transcription, memory, default queues) + # - 1 Audio persistence worker (audio queue) + # - 1+ Stream workers (conditional based on config.yml - Deepgram/Parakeet) + # Uses Python orchestrator for process management, health monitoring, and self-healing workers: build: context: . dockerfile: Dockerfile - command: ["./start-workers.sh"] + target: prod # Use prod stage without test dependencies + command: ["uv", "run", "python", "worker_orchestrator.py"] env_file: - .env volumes: - ./src:/app/src - - ./start-workers.sh:/app/start-workers.sh + - ./worker_orchestrator.py:/app/worker_orchestrator.py - ./data/audio_chunks:/app/audio_chunks - ./data:/app/data - - ../../config/config.yml:/app/config.yml # Removed :ro for consistency + - ../../config/config.yml:/app/config.yml + - ../../config/plugins.yml:/app/plugins.yml environment: - DEEPGRAM_API_KEY=${DEEPGRAM_API_KEY} - PARAKEET_ASR_URL=${PARAKEET_ASR_URL} - OPENAI_API_KEY=${OPENAI_API_KEY} - GROQ_API_KEY=${GROQ_API_KEY} + - HA_TOKEN=${HA_TOKEN} - REDIS_URL=redis://redis:6379/0 + # Worker orchestrator configuration (optional - defaults shown) + - WORKER_CHECK_INTERVAL=${WORKER_CHECK_INTERVAL:-10} + - MIN_RQ_WORKERS=${MIN_RQ_WORKERS:-6} + - WORKER_STARTUP_GRACE_PERIOD=${WORKER_STARTUP_GRACE_PERIOD:-30} + - WORKER_SHUTDOWN_TIMEOUT=${WORKER_SHUTDOWN_TIMEOUT:-30} depends_on: redis: condition: service_healthy @@ -226,3 +267,5 @@ volumes: driver: local neo4j_logs: driver: local + tailscale-state: + driver: local diff --git a/backends/advanced/init.py b/backends/advanced/init.py index dddbfdcb..e566cc72 100644 --- a/backends/advanced/init.py +++ b/backends/advanced/init.py @@ -49,6 +49,9 @@ def __init__(self, args=None): self.console.print("[red][ERROR][/red] Run wizard.py from project root to create config.yml") sys.exit(1) + # Ensure plugins.yml exists (copy from template if missing) + self._ensure_plugins_yml_exists() + def print_header(self, title: str): """Print a colorful header""" self.console.print() @@ -107,6 +110,26 @@ def prompt_choice(self, prompt: str, choices: Dict[str, str], default: str = "1" self.console.print(f"Using default choice: {default}") return default + def _ensure_plugins_yml_exists(self): + """Ensure plugins.yml exists by copying from template if missing.""" + plugins_yml = Path("../../config/plugins.yml") + plugins_template = Path("../../config/plugins.yml.template") + + if not plugins_yml.exists(): + if plugins_template.exists(): + self.console.print("[blue][INFO][/blue] plugins.yml not found, creating from template...") + shutil.copy2(plugins_template, plugins_yml) + self.console.print(f"[green]✅[/green] Created {plugins_yml} from template") + self.console.print("[yellow][NOTE][/yellow] Edit config/plugins.yml to configure plugins") + self.console.print("[yellow][NOTE][/yellow] Set HA_TOKEN in .env for Home Assistant integration") + else: + raise RuntimeError( + f"Template file not found: {plugins_template}\n" + f"The repository structure is incomplete. Please ensure config/plugins.yml.template exists." + ) + else: + self.console.print(f"[blue][INFO][/blue] Found existing {plugins_yml}") + def backup_existing_env(self): """Backup existing .env file""" env_path = Path(".env") @@ -136,6 +159,41 @@ def mask_api_key(self, key: str, show_chars: int = 5) -> str: return f"{key_clean[:show_chars]}{'*' * min(15, len(key_clean) - show_chars * 2)}{key_clean[-show_chars:]}" + def prompt_with_existing_masked(self, prompt_text: str, env_key: str, placeholders: list, + is_password: bool = False, default: str = "") -> str: + """ + Prompt for a value, showing masked existing value from .env if present. + + Args: + prompt_text: The prompt to display + env_key: The .env key to check for existing value + placeholders: List of placeholder values to treat as "not set" + is_password: Whether to mask the value (for passwords/tokens) + default: Default value if no existing value + + Returns: + User input value, existing value if reused, or default + """ + existing_value = self.read_existing_env_value(env_key) + + # Check if existing value is valid (not empty and not a placeholder) + has_valid_existing = existing_value and existing_value not in placeholders + + if has_valid_existing: + # Show masked value with option to reuse + if is_password: + masked = self.mask_api_key(existing_value) + display_prompt = f"{prompt_text} ({masked}) [press Enter to reuse, or enter new]" + else: + display_prompt = f"{prompt_text} ({existing_value}) [press Enter to reuse, or enter new]" + + user_input = self.prompt_value(display_prompt, "") + # If user pressed Enter (empty input), reuse existing value + return user_input if user_input else existing_value + else: + # No existing value, prompt normally + return self.prompt_value(prompt_text, default) + def setup_authentication(self): """Configure authentication settings""" @@ -192,15 +250,14 @@ def setup_transcription(self): self.console.print("[blue][INFO][/blue] Deepgram selected") self.console.print("Get your API key from: https://console.deepgram.com/") - # Check for existing API key - existing_key = self.read_existing_env_value("DEEPGRAM_API_KEY") - if existing_key and existing_key not in ['your_deepgram_api_key_here', 'your-deepgram-key-here']: - masked_key = self.mask_api_key(existing_key) - prompt_text = f"Deepgram API key ({masked_key}) [press Enter to reuse, or enter new]" - api_key_input = self.prompt_value(prompt_text, "") - api_key = api_key_input if api_key_input else existing_key - else: - api_key = self.prompt_value("Deepgram API key (leave empty to skip)", "") + # Use the new masked prompt function + api_key = self.prompt_with_existing_masked( + prompt_text="Deepgram API key (leave empty to skip)", + env_key="DEEPGRAM_API_KEY", + placeholders=['your_deepgram_api_key_here', 'your-deepgram-key-here'], + is_password=True, + default="" + ) if api_key: # Write API key to .env @@ -250,15 +307,14 @@ def setup_llm(self): self.console.print("[blue][INFO][/blue] OpenAI selected") self.console.print("Get your API key from: https://platform.openai.com/api-keys") - # Check for existing API key - existing_key = self.read_existing_env_value("OPENAI_API_KEY") - if existing_key and existing_key not in ['your_openai_api_key_here', 'your-openai-key-here']: - masked_key = self.mask_api_key(existing_key) - prompt_text = f"OpenAI API key ({masked_key}) [press Enter to reuse, or enter new]" - api_key_input = self.prompt_value(prompt_text, "") - api_key = api_key_input if api_key_input else existing_key - else: - api_key = self.prompt_value("OpenAI API key (leave empty to skip)", "") + # Use the new masked prompt function + api_key = self.prompt_with_existing_masked( + prompt_text="OpenAI API key (leave empty to skip)", + env_key="OPENAI_API_KEY", + placeholders=['your_openai_api_key_here', 'your-openai-key-here'], + is_password=True, + default="" + ) if api_key: self.config["OPENAI_API_KEY"] = api_key @@ -370,6 +426,11 @@ def setup_optional_services(self): self.config["PARAKEET_ASR_URL"] = self.args.parakeet_asr_url self.console.print(f"[green][SUCCESS][/green] Parakeet ASR configured via args: {self.args.parakeet_asr_url}") + # Check if Tailscale auth key provided via args + if hasattr(self.args, 'ts_authkey') and self.args.ts_authkey: + self.config["TS_AUTHKEY"] = self.args.ts_authkey + self.console.print(f"[green][SUCCESS][/green] Tailscale auth key configured (Docker integration enabled)") + def setup_obsidian(self): """Configure Obsidian/Neo4j integration""" # Check if enabled via command line @@ -443,14 +504,14 @@ def setup_https(self): self.console.print("[blue][INFO][/blue] For distributed deployments, use your Tailscale IP (e.g., 100.64.1.2)") self.console.print("[blue][INFO][/blue] For local-only access, use 'localhost'") - # Check for existing SERVER_IP - existing_ip = self.read_existing_env_value("SERVER_IP") - if existing_ip and existing_ip not in ['localhost', 'your-server-ip-here']: - prompt_text = f"Server IP/Domain for SSL certificate ({existing_ip}) [press Enter to reuse, or enter new]" - server_ip_input = self.prompt_value(prompt_text, "") - server_ip = server_ip_input if server_ip_input else existing_ip - else: - server_ip = self.prompt_value("Server IP/Domain for SSL certificate (Tailscale IP or localhost)", "localhost") + # Use the new masked prompt function (not masked for IP, but shows existing) + server_ip = self.prompt_with_existing_masked( + prompt_text="Server IP/Domain for SSL certificate (Tailscale IP or localhost)", + env_key="SERVER_IP", + placeholders=['localhost', 'your-server-ip-here'], + is_password=False, + default="localhost" + ) if enable_https: @@ -707,6 +768,8 @@ def main(): help="Enable Obsidian/Neo4j integration (default: prompt user)") parser.add_argument("--neo4j-password", help="Neo4j password (default: prompt user)") + parser.add_argument("--ts-authkey", + help="Tailscale auth key for Docker integration (default: prompt user)") args = parser.parse_args() diff --git a/backends/advanced/pyproject.toml b/backends/advanced/pyproject.toml index e7bcb50a..aa26a9b2 100644 --- a/backends/advanced/pyproject.toml +++ b/backends/advanced/pyproject.toml @@ -114,4 +114,5 @@ test = [ "requests-mock>=1.12.1", "pytest-json-report>=1.5.0", "pytest-html>=4.0.0", + "aiosqlite>=0.20.0", # For test plugin event storage ] diff --git a/backends/advanced/run-test.sh b/backends/advanced/run-test.sh index 01204be6..c68a30ea 100755 --- a/backends/advanced/run-test.sh +++ b/backends/advanced/run-test.sh @@ -91,6 +91,29 @@ if [ -n "$_CONFIG_FILE_OVERRIDE" ]; then print_info "Using command-line override: CONFIG_FILE=$CONFIG_FILE" fi +# Load HF_TOKEN from speaker-recognition/.env (proper location for this credential) +SPEAKER_ENV="../../extras/speaker-recognition/.env" +if [ -f "$SPEAKER_ENV" ] && [ -z "$HF_TOKEN" ]; then + print_info "Loading HF_TOKEN from speaker-recognition service..." + set -a + source "$SPEAKER_ENV" + set +a +fi + +# Display HF_TOKEN status with masking +if [ -n "$HF_TOKEN" ]; then + if [ ${#HF_TOKEN} -gt 15 ]; then + MASKED_TOKEN="${HF_TOKEN:0:5}***************${HF_TOKEN: -5}" + else + MASKED_TOKEN="***************" + fi + print_info "HF_TOKEN configured: $MASKED_TOKEN" + export HF_TOKEN +else + print_warning "HF_TOKEN not found - speaker recognition tests may fail" + print_info "Configure via wizard: uv run --with-requirements ../../setup-requirements.txt python ../../wizard.py" +fi + # Set default CONFIG_FILE if not provided # This allows testing with different provider combinations # Usage: CONFIG_FILE=../../tests/configs/parakeet-ollama.yml ./run-test.sh @@ -166,6 +189,18 @@ if [ ! -f "diarization_config.json" ] && [ -f "diarization_config.json.template" print_success "diarization_config.json created" fi +# Ensure plugins.yml exists (required for Docker volume mount) +if [ ! -f "../../config/plugins.yml" ]; then + if [ -f "../../config/plugins.yml.template" ]; then + print_info "Creating config/plugins.yml from template..." + cp ../../config/plugins.yml.template ../../config/plugins.yml + print_success "config/plugins.yml created" + else + print_error "config/plugins.yml.template not found - repository structure incomplete" + exit 1 + fi +fi + # Note: Robot Framework dependencies are managed via tests/test-requirements.txt # The integration tests use Docker containers for service dependencies @@ -176,15 +211,25 @@ print_info "Using environment variables from .env file for test configuration" # Clean test environment print_info "Cleaning test environment..." -sudo rm -rf ./test_audio_chunks/ ./test_data/ ./test_debug_dir/ ./mongo_data_test/ ./qdrant_data_test/ ./test_neo4j/ || true +rm -rf ./test_audio_chunks/ ./test_data/ ./test_debug_dir/ ./mongo_data_test/ ./qdrant_data_test/ ./test_neo4j/ 2>/dev/null || true + +# If cleanup fails due to permissions, try with docker +if [ -d "./data/test_audio_chunks/" ] || [ -d "./data/test_data/" ] || [ -d "./data/test_debug_dir/" ]; then + print_warning "Permission denied, using docker to clean test directories..." + docker run --rm -v "$(pwd)/data:/data" alpine sh -c 'rm -rf /data/test_*' 2>/dev/null || true +fi # Use unique project name to avoid conflicts with development environment export COMPOSE_PROJECT_NAME="advanced-backend-test" # Stop any existing test containers print_info "Stopping existing test containers..." +# Try cleanup with current project name docker compose -f docker-compose-test.yml down -v || true +# Also try cleanup with default project name (in case containers were started without COMPOSE_PROJECT_NAME) +COMPOSE_PROJECT_NAME=advanced docker compose -f docker-compose-test.yml down -v 2>/dev/null || true + # Run integration tests print_info "Running integration tests..." print_info "Using fresh mode (CACHED_MODE=False) for clean testing" @@ -211,8 +256,9 @@ export TEST_MODE=dev # Run the Robot Framework integration tests with extended timeout (mem0 needs time for comprehensive extraction) # IMPORTANT: Robot tests must be run from the repository root where backends/ and tests/ are siblings +# Run full test suite from tests/integration/ directory (includes all test files) print_info "Starting Robot Framework integration tests (timeout: 15 minutes)..." -if (cd ../.. && timeout 900 robot --outputdir test-results --loglevel INFO tests/integration/integration_test.robot); then +if (cd ../.. && timeout 900 uv run --with-requirements tests/test-requirements.txt robot --outputdir test-results --loglevel INFO tests/integration/); then print_success "Integration tests completed successfully!" else TEST_EXIT_CODE=$? @@ -222,6 +268,8 @@ else if [ "${CLEANUP_CONTAINERS:-true}" != "false" ]; then print_info "Cleaning up test containers after failure..." docker compose -f docker-compose-test.yml down -v || true + # Also cleanup with default project name + COMPOSE_PROJECT_NAME=advanced docker compose -f docker-compose-test.yml down -v 2>/dev/null || true docker system prune -f || true else print_warning "Skipping cleanup (CLEANUP_CONTAINERS=false) - containers left running for debugging" @@ -234,6 +282,8 @@ fi if [ "${CLEANUP_CONTAINERS:-true}" != "false" ]; then print_info "Cleaning up test containers..." docker compose -f docker-compose-test.yml down -v || true + # Also cleanup with default project name + COMPOSE_PROJECT_NAME=advanced docker compose -f docker-compose-test.yml down -v 2>/dev/null || true docker system prune -f || true else print_warning "Skipping cleanup (CLEANUP_CONTAINERS=false) - containers left running" diff --git a/backends/advanced/scripts/laptop_client.py b/backends/advanced/scripts/laptop_client.py index 385a4a1b..a0047f3b 100644 --- a/backends/advanced/scripts/laptop_client.py +++ b/backends/advanced/scripts/laptop_client.py @@ -15,7 +15,7 @@ # Default WebSocket settings DEFAULT_HOST = "localhost" DEFAULT_PORT = 8000 -DEFAULT_ENDPOINT = "/ws_pcm" +DEFAULT_ENDPOINT = "/ws?codec=pcm" # Audio format will be determined from the InputMicStream instance diff --git a/backends/advanced/src/advanced_omi_backend/app_config.py b/backends/advanced/src/advanced_omi_backend/app_config.py index 1e24fb54..15e825ec 100644 --- a/backends/advanced/src/advanced_omi_backend/app_config.py +++ b/backends/advanced/src/advanced_omi_backend/app_config.py @@ -47,11 +47,6 @@ def __init__(self): os.getenv("NEW_CONVERSATION_TIMEOUT_MINUTES", "1.5") ) - # Audio cropping configuration - self.audio_cropping_enabled = os.getenv("AUDIO_CROPPING_ENABLED", "true").lower() == "true" - self.min_speech_segment_duration = float(os.getenv("MIN_SPEECH_SEGMENT_DURATION", "1.0")) - self.cropping_context_padding = float(os.getenv("CROPPING_CONTEXT_PADDING", "0.1")) - # Transcription Configuration (registry-based) self.transcription_provider = get_transcription_provider(None) if self.transcription_provider: diff --git a/backends/advanced/src/advanced_omi_backend/app_factory.py b/backends/advanced/src/advanced_omi_backend/app_factory.py index 7ccda184..8a162cec 100644 --- a/backends/advanced/src/advanced_omi_backend/app_factory.py +++ b/backends/advanced/src/advanced_omi_backend/app_factory.py @@ -111,6 +111,11 @@ async def lifespan(app: FastAPI): from advanced_omi_backend.services.audio_stream import AudioStreamProducer app.state.audio_stream_producer = AudioStreamProducer(app.state.redis_audio_stream) application_logger.info("✅ Redis client for audio streaming producer initialized") + + # Initialize ClientManager Redis for cross-container client→user mapping + from advanced_omi_backend.client_manager import initialize_redis_for_client_manager + initialize_redis_for_client_manager(config.redis_url) + except Exception as e: application_logger.error(f"Failed to initialize Redis client for audio streaming: {e}", exc_info=True) application_logger.warning("Audio streaming producer will not be available") @@ -122,6 +127,36 @@ async def lifespan(app: FastAPI): # SystemTracker is used for monitoring and debugging application_logger.info("Using SystemTracker for monitoring and debugging") + # Initialize plugins using plugin service + try: + from advanced_omi_backend.services.plugin_service import init_plugin_router, set_plugin_router + + plugin_router = init_plugin_router() + + if plugin_router: + # Initialize async resources for each enabled plugin + for plugin_id, plugin in plugin_router.plugins.items(): + if plugin.enabled: + try: + await plugin.initialize() + application_logger.info(f"✅ Plugin '{plugin_id}' initialized") + except Exception as e: + application_logger.error(f"Failed to initialize plugin '{plugin_id}': {e}", exc_info=True) + + application_logger.info(f"Plugins initialized: {len(plugin_router.plugins)} active") + + # Store in app state for API access + app.state.plugin_router = plugin_router + # Register with plugin service for worker access + set_plugin_router(plugin_router) + else: + application_logger.info("No plugins configured") + app.state.plugin_router = None + + except Exception as e: + application_logger.error(f"Failed to initialize plugin system: {e}", exc_info=True) + app.state.plugin_router = None + application_logger.info("Application ready - using application-level processing architecture.") logger.info("App ready") @@ -162,6 +197,14 @@ async def lifespan(app: FastAPI): # Stop metrics collection and save final report application_logger.info("Metrics collection stopped") + # Shutdown plugins + try: + from advanced_omi_backend.services.plugin_service import cleanup_plugin_router + await cleanup_plugin_router() + application_logger.info("Plugins shut down") + except Exception as e: + application_logger.error(f"Error shutting down plugins: {e}") + # Shutdown memory service and speaker service shutdown_memory_service() application_logger.info("Memory and speaker services shut down.") diff --git a/backends/advanced/src/advanced_omi_backend/client_manager.py b/backends/advanced/src/advanced_omi_backend/client_manager.py index 5a3131b5..e55b3502 100644 --- a/backends/advanced/src/advanced_omi_backend/client_manager.py +++ b/backends/advanced/src/advanced_omi_backend/client_manager.py @@ -9,6 +9,7 @@ import logging import uuid from typing import TYPE_CHECKING, Dict, Optional +import redis.asyncio as redis if TYPE_CHECKING: from advanced_omi_backend.client import ClientState @@ -21,6 +22,9 @@ _client_to_user_mapping: Dict[str, str] = {} # Active clients only _all_client_user_mappings: Dict[str, str] = {} # All clients including disconnected +# Redis client for cross-container client→user mapping +_redis_client: Optional[redis.Redis] = None + class ClientManager: """ @@ -372,9 +376,33 @@ def unregister_client_user_mapping(client_id: str): logger.warning(f"⚠️ Attempted to unregister non-existent client {client_id}") +async def track_client_user_relationship_async(client_id: str, user_id: str, ttl: int = 86400): + """ + Track that a client belongs to a user (async, writes to Redis for cross-container support). + + Args: + client_id: The client ID + user_id: The user ID that owns this client + ttl: Time-to-live in seconds (default 24 hours) + """ + _all_client_user_mappings[client_id] = user_id # In-memory fallback + + if _redis_client: + try: + await _redis_client.setex(f"client:owner:{client_id}", ttl, user_id) + logger.debug(f"✅ Tracked client {client_id} → user {user_id} in Redis (TTL: {ttl}s)") + except Exception as e: + logger.warning(f"Failed to track client in Redis: {e}") + else: + logger.debug(f"Tracked client {client_id} relationship to user {user_id} (in-memory only)") + + def track_client_user_relationship(client_id: str, user_id: str): """ - Track that a client belongs to a user (persists after disconnection for database queries). + Track that a client belongs to a user (sync version for backward compatibility). + + WARNING: This is synchronous and cannot use Redis. Use track_client_user_relationship_async() + instead in async contexts for cross-container support. Args: client_id: The client ID @@ -444,9 +472,45 @@ def get_user_clients_active(user_id: str) -> list[str]: return user_clients +def initialize_redis_for_client_manager(redis_url: str): + """ + Initialize Redis client for cross-container client→user mapping. + + Args: + redis_url: Redis connection URL + """ + global _redis_client + _redis_client = redis.from_url(redis_url, decode_responses=True) + logger.info(f"✅ ClientManager Redis initialized: {redis_url}") + + +async def get_client_owner_async(client_id: str) -> Optional[str]: + """ + Get the user ID that owns a specific client (async Redis lookup). + + Args: + client_id: The client ID to look up + + Returns: + User ID if found, None otherwise + """ + if _redis_client: + try: + user_id = await _redis_client.get(f"client:owner:{client_id}") + return user_id + except Exception as e: + logger.warning(f"Redis lookup failed for client {client_id}: {e}") + + # Fallback to in-memory mapping + return _all_client_user_mappings.get(client_id) + + def get_client_owner(client_id: str) -> Optional[str]: """ - Get the user ID that owns a specific client. + Get the user ID that owns a specific client (sync version for backward compatibility). + + WARNING: This is synchronous and cannot use Redis. Use get_client_owner_async() instead + in async contexts for cross-container support. Args: client_id: The client ID to look up diff --git a/backends/advanced/src/advanced_omi_backend/clients/audio_stream_client.py b/backends/advanced/src/advanced_omi_backend/clients/audio_stream_client.py index af89fd51..edddd914 100644 --- a/backends/advanced/src/advanced_omi_backend/clients/audio_stream_client.py +++ b/backends/advanced/src/advanced_omi_backend/clients/audio_stream_client.py @@ -65,7 +65,7 @@ def __init__( base_url: str, token: str, device_name: str = "python-client", - endpoint: str = "ws_pcm", + endpoint: str = "ws?codec=pcm", ): """Initialize the audio stream client. @@ -73,7 +73,7 @@ def __init__( base_url: Base URL of the backend (e.g., "http://localhost:8000") token: JWT authentication token device_name: Device name for client identification - endpoint: WebSocket endpoint ("ws_pcm" or "ws_omi") + endpoint: WebSocket endpoint ("ws?codec=pcm" or "ws?codec=opus") """ self.base_url = base_url self.token = token @@ -87,7 +87,9 @@ def __init__( def ws_url(self) -> str: """Build WebSocket URL from base URL.""" url = self.base_url.replace("http://", "ws://").replace("https://", "wss://") - return f"{url}/{self.endpoint}?token={self.token}&device_name={self.device_name}" + # Check if endpoint already has query params + separator = "&" if "?" in self.endpoint else "?" + return f"{url}/{self.endpoint}{separator}token={self.token}&device_name={self.device_name}" async def connect(self, wait_for_ready: bool = True) -> WebSocketClientProtocol: """Connect to the WebSocket endpoint. @@ -105,8 +107,8 @@ async def connect(self, wait_for_ready: bool = True) -> WebSocketClientProtocol: self.ws = await websockets.connect(self.ws_url) logger.info("WebSocket connected") - if wait_for_ready and self.endpoint == "ws_pcm": - # PCM endpoint sends "ready" message after auth (line 261-268 in websocket_controller.py) + if wait_for_ready and "codec=pcm" in self.endpoint: + # PCM codec sends "ready" message after auth (line 261-268 in websocket_controller.py) ready_msg = await self.ws.recv() ready = json.loads(ready_msg.strip() if isinstance(ready_msg, str) else ready_msg.decode().strip()) if ready.get("type") != "ready": diff --git a/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py index 4810810d..e63dd883 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py @@ -219,14 +219,13 @@ async def upload_and_process_audio_files( ) -async def get_conversation_audio_path(conversation_id: str, user: User, cropped: bool = False) -> Path: +async def get_conversation_audio_path(conversation_id: str, user: User) -> Path: """ Get the file path for a conversation's audio file. Args: conversation_id: The conversation ID user: The authenticated user - cropped: If True, return cropped audio path; if False, return original audio path Returns: Path object for the audio file @@ -244,12 +243,11 @@ async def get_conversation_audio_path(conversation_id: str, user: User, cropped: if not user.is_superuser and conversation.user_id != str(user.user_id): raise ValueError("Access denied") - # Get the appropriate audio path - audio_path = conversation.cropped_audio_path if cropped else conversation.audio_path + # Get the audio path + audio_path = conversation.audio_path if not audio_path: - audio_type = "cropped" if cropped else "original" - raise ValueError(f"No {audio_type} audio file available for this conversation") + raise ValueError(f"No audio file available for this conversation") # Build full file path from advanced_omi_backend.app_config import get_audio_chunk_dir @@ -261,39 +259,3 @@ async def get_conversation_audio_path(conversation_id: str, user: User, cropped: raise ValueError("Audio file not found on disk") return file_path - - -async def get_cropped_audio_info(audio_uuid: str, user: User): - """ - Get audio cropping metadata from the conversation. - - This is an audio service operation that retrieves cropping-related metadata - such as speech segments, cropped audio path, and cropping timestamps. - - Used for: Checking cropping status and retrieving audio processing details. - Works with: Conversation model. - """ - try: - # Find the conversation - conversation = await Conversation.find_one(Conversation.audio_uuid == audio_uuid) - if not conversation: - return JSONResponse(status_code=404, content={"error": "Conversation not found"}) - - # Check ownership for non-admin users - if not user.is_superuser: - if conversation.user_id != str(user.user_id): - return JSONResponse(status_code=404, content={"error": "Conversation not found"}) - - return { - "audio_uuid": audio_uuid, - "cropped_audio_path": conversation.cropped_audio_path, - "speech_segments": conversation.speech_segments if hasattr(conversation, 'speech_segments') else [], - "cropped_duration": conversation.cropped_duration if hasattr(conversation, 'cropped_duration') else None, - "cropped_at": conversation.cropped_at if hasattr(conversation, 'cropped_at') else None, - "original_audio_path": conversation.audio_path, - } - - except Exception as e: - # Database or unexpected errors when fetching audio metadata - audio_logger.exception("Error fetching cropped audio info") - return JSONResponse(status_code=500, content={"error": "Error fetching cropped audio info"}) diff --git a/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py index b9533391..943d86bd 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py @@ -103,7 +103,6 @@ async def get_conversation(conversation_id: str, user: User): "user_id": conversation.user_id, "client_id": conversation.client_id, "audio_path": conversation.audio_path, - "cropped_audio_path": conversation.cropped_audio_path, "created_at": conversation.created_at.isoformat() if conversation.created_at else None, "deleted": conversation.deleted, "deletion_reason": conversation.deletion_reason, @@ -154,7 +153,6 @@ async def get_conversations(user: User): "user_id": conv.user_id, "client_id": conv.client_id, "audio_path": conv.audio_path, - "cropped_audio_path": conv.cropped_audio_path, "created_at": conv.created_at.isoformat() if conv.created_at else None, "deleted": conv.deleted, "deletion_reason": conv.deletion_reason, @@ -210,7 +208,6 @@ async def delete_conversation(conversation_id: str, user: User): # Get file paths before deletion audio_path = conversation.audio_path - cropped_audio_path = conversation.cropped_audio_path audio_uuid = conversation.audio_uuid client_id = conversation.client_id @@ -237,17 +234,6 @@ async def delete_conversation(conversation_id: str, user: User): except Exception as e: logger.warning(f"Failed to delete audio file {audio_path}: {e}") - if cropped_audio_path: - try: - # Construct full path to cropped audio file - full_cropped_path = Path("/app/audio_chunks") / cropped_audio_path - if full_cropped_path.exists(): - full_cropped_path.unlink() - deleted_files.append(str(full_cropped_path)) - logger.info(f"Deleted cropped audio file: {full_cropped_path}") - except Exception as e: - logger.warning(f"Failed to delete cropped audio file {cropped_audio_path}: {e}") - logger.info(f"Successfully deleted conversation {conversation_id} for user {user.user_id}") # Prepare response message @@ -321,10 +307,9 @@ async def reprocess_transcript(conversation_id: str, user: User): import uuid version_id = str(uuid.uuid4()) - # Enqueue job chain with RQ (transcription -> speaker recognition -> cropping -> memory) + # Enqueue job chain with RQ (transcription -> speaker recognition -> memory) from advanced_omi_backend.workers.transcription_jobs import transcribe_full_audio_job from advanced_omi_backend.workers.speaker_jobs import recognise_speakers_job - from advanced_omi_backend.workers.audio_jobs import process_cropping_job from advanced_omi_backend.workers.memory_jobs import process_memory_job from advanced_omi_backend.controllers.queue_controller import transcription_queue, memory_queue, default_queue, JOB_RESULT_TTL @@ -361,33 +346,19 @@ async def reprocess_transcript(conversation_id: str, user: User): ) logger.info(f"📥 RQ: Enqueued speaker recognition job {speaker_job.id} (depends on {transcript_job.id})") - # Job 3: Audio cropping (depends on speaker recognition) - cropping_job = default_queue.enqueue( - process_cropping_job, - conversation_id, - str(full_audio_path), - depends_on=speaker_job, - job_timeout=300, - result_ttl=JOB_RESULT_TTL, - job_id=f"crop_{conversation_id[:8]}", - description=f"Crop audio for {conversation_id[:8]}", - meta={'audio_uuid': audio_uuid, 'conversation_id': conversation_id} - ) - logger.info(f"📥 RQ: Enqueued audio cropping job {cropping_job.id} (depends on {speaker_job.id})") - - # Job 4: Extract memories (depends on cropping) + # Job 3: Extract memories (depends on speaker recognition) # Note: redis_client is injected by @async_job decorator, don't pass it directly memory_job = memory_queue.enqueue( process_memory_job, conversation_id, - depends_on=cropping_job, + depends_on=speaker_job, job_timeout=1800, result_ttl=JOB_RESULT_TTL, job_id=f"memory_{conversation_id[:8]}", description=f"Extract memories for {conversation_id[:8]}", meta={'audio_uuid': audio_uuid, 'conversation_id': conversation_id} ) - logger.info(f"📥 RQ: Enqueued memory job {memory_job.id} (depends on {cropping_job.id})") + logger.info(f"📥 RQ: Enqueued memory job {memory_job.id} (depends on {speaker_job.id})") job = transcript_job # For backward compatibility with return value logger.info(f"Created transcript reprocessing job {job.id} (version: {version_id}) for conversation {conversation_id}") diff --git a/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py index 91773756..cd4f7455 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py @@ -113,9 +113,12 @@ def get_jobs( Returns: Dict with jobs list and pagination metadata matching frontend expectations """ + logger.info(f"🔍 DEBUG get_jobs: Filtering - queue_name={queue_name}, job_type={job_type}, client_id={client_id}") all_jobs = [] + seen_job_ids = set() # Track which job IDs we've already processed to avoid duplicates queues_to_check = [queue_name] if queue_name else QUEUE_NAMES + logger.info(f"🔍 DEBUG get_jobs: Checking queues: {queues_to_check}") for qname in queues_to_check: queue = get_queue(qname) @@ -131,6 +134,11 @@ def get_jobs( for job_ids, status in registries: for job_id in job_ids: + # Skip if we've already processed this job_id (prevents duplicates across registries) + if job_id in seen_job_ids: + continue + seen_job_ids.add(job_id) + try: job = Job.fetch(job_id, connection=redis_conn) @@ -140,16 +148,23 @@ def get_jobs( # Extract just the function name (e.g., "listen_for_speech_job" from "module.listen_for_speech_job") func_name = job.func_name.split('.')[-1] if job.func_name else "unknown" + # Debug: Log job details before filtering + logger.debug(f"🔍 DEBUG get_jobs: Job {job_id} - func_name={func_name}, full_func_name={job.func_name}, meta_client_id={job.meta.get('client_id', '') if job.meta else ''}, status={status}") + # Apply job_type filter if job_type and job_type not in func_name: + logger.debug(f"🔍 DEBUG get_jobs: Filtered out {job_id} - job_type '{job_type}' not in func_name '{func_name}'") continue # Apply client_id filter (partial match in meta) if client_id: job_client_id = job.meta.get("client_id", "") if job.meta else "" if client_id not in job_client_id: + logger.debug(f"🔍 DEBUG get_jobs: Filtered out {job_id} - client_id '{client_id}' not in job_client_id '{job_client_id}'") continue + logger.debug(f"🔍 DEBUG get_jobs: Including job {job_id} in results") + all_jobs.append({ "job_id": job.id, "job_type": func_name, @@ -182,6 +197,8 @@ def get_jobs( paginated_jobs = all_jobs[offset:offset + limit] has_more = (offset + limit) < total_jobs + logger.info(f"🔍 DEBUG get_jobs: Found {total_jobs} matching jobs (returning {len(paginated_jobs)} after pagination)") + return { "jobs": paginated_jobs, "pagination": { @@ -290,12 +307,22 @@ def start_streaming_jobs( user_id, client_id, job_timeout=86400, # 24 hours for all-day sessions - result_ttl=JOB_RESULT_TTL, + ttl=None, # No pre-run expiry (job can wait indefinitely in queue) + result_ttl=JOB_RESULT_TTL, # Cleanup AFTER completion + failure_ttl=86400, # Cleanup failed jobs after 24h job_id=f"speech-detect_{session_id[:12]}", description=f"Listening for speech...", meta={'audio_uuid': session_id, 'client_id': client_id, 'session_level': True} ) + # Log job enqueue with TTL information for debugging + actual_ttl = redis_conn.ttl(f"rq:job:{speech_job.id}") logger.info(f"📥 RQ: Enqueued speech detection job {speech_job.id}") + logger.info( + f"🔍 Job enqueue details: ID={speech_job.id}, " + f"job_timeout={speech_job.timeout}, result_ttl={speech_job.result_ttl}, " + f"failure_ttl={speech_job.failure_ttl}, redis_key_ttl={actual_ttl}, " + f"queue_length={transcription_queue.count}, client_id={client_id}" + ) # Store job ID for cleanup (keyed by client_id for easy WebSocket cleanup) try: @@ -313,12 +340,22 @@ def start_streaming_jobs( user_id, client_id, job_timeout=86400, # 24 hours for all-day sessions - result_ttl=JOB_RESULT_TTL, + ttl=None, # No pre-run expiry (job can wait indefinitely in queue) + result_ttl=JOB_RESULT_TTL, # Cleanup AFTER completion + failure_ttl=86400, # Cleanup failed jobs after 24h job_id=f"audio-persist_{session_id[:12]}", description=f"Audio persistence for session {session_id[:12]}", meta={'audio_uuid': session_id, 'session_level': True} # Mark as session-level job ) + # Log job enqueue with TTL information for debugging + actual_ttl = redis_conn.ttl(f"rq:job:{audio_job.id}") logger.info(f"📥 RQ: Enqueued audio persistence job {audio_job.id} on audio queue") + logger.info( + f"🔍 Job enqueue details: ID={audio_job.id}, " + f"job_timeout={audio_job.timeout}, result_ttl={audio_job.result_ttl}, " + f"failure_ttl={audio_job.failure_ttl}, redis_key_ttl={actual_ttl}, " + f"queue_length={audio_queue.count}, client_id={client_id}" + ) return { 'speech_detection': speech_job.id, @@ -341,10 +378,9 @@ def start_post_conversation_jobs( This creates the standard processing chain after a conversation is created: 1. [Optional] Transcription job - Batch transcription (if post_transcription=True) - 2. Audio cropping job - Removes silence from audio - 3. Speaker recognition job - Identifies speakers in audio - 4. Memory extraction job - Extracts memories from conversation (parallel) - 5. Title/summary generation job - Generates title and summary (parallel) + 2. Speaker recognition job - Identifies speakers in audio + 3. Memory extraction job - Extracts memories from conversation (parallel) + 4. Title/summary generation job - Generates title and summary (parallel) Args: conversation_id: Conversation identifier @@ -354,16 +390,15 @@ def start_post_conversation_jobs( post_transcription: If True, run batch transcription step (for uploads) If False, skip transcription (streaming already has it) transcript_version_id: Transcript version ID (auto-generated if None) - depends_on_job: Optional job dependency for cropping job + depends_on_job: Optional job dependency for first job Returns: Dict with job IDs (transcription will be None if post_transcription=False) """ from advanced_omi_backend.workers.transcription_jobs import transcribe_full_audio_job from advanced_omi_backend.workers.speaker_jobs import recognise_speakers_job - from advanced_omi_backend.workers.audio_jobs import process_cropping_job from advanced_omi_backend.workers.memory_jobs import process_memory_job - from advanced_omi_backend.workers.conversation_jobs import generate_title_summary_job + from advanced_omi_backend.workers.conversation_jobs import generate_title_summary_job, dispatch_conversation_complete_event_job version_id = transcript_version_id or str(uuid.uuid4()) @@ -392,29 +427,11 @@ def start_post_conversation_jobs( meta=job_meta ) logger.info(f"📥 RQ: Enqueued transcription job {transcription_job.id}, meta={transcription_job.meta}") - crop_depends_on = transcription_job - # Step 2: Audio cropping job (depends on transcription if it ran, otherwise depends_on_job) - crop_job_id = f"crop_{conversation_id[:12]}" - logger.info(f"🔍 DEBUG: Creating crop job with job_id={crop_job_id}, conversation_id={conversation_id[:12]}, audio_uuid={audio_uuid[:12]}") + # Speaker recognition depends on transcription (no cropping step) + speaker_depends_on = transcription_job - cropping_job = default_queue.enqueue( - process_cropping_job, - conversation_id, - audio_file_path, - job_timeout=300, # 5 minutes - result_ttl=JOB_RESULT_TTL, - depends_on=crop_depends_on, - job_id=crop_job_id, - description=f"Crop audio for conversation {conversation_id[:8]}", - meta=job_meta - ) - logger.info(f"📥 RQ: Enqueued cropping job {cropping_job.id}, meta={cropping_job.meta}") - - # Speaker recognition depends on cropping - speaker_depends_on = cropping_job - - # Step 3: Speaker recognition job + # Step 2: Speaker recognition job speaker_job_id = f"speaker_{conversation_id[:12]}" logger.info(f"🔍 DEBUG: Creating speaker job with job_id={speaker_job_id}, conversation_id={conversation_id[:12]}, audio_uuid={audio_uuid[:12]}") @@ -434,7 +451,7 @@ def start_post_conversation_jobs( ) logger.info(f"📥 RQ: Enqueued speaker recognition job {speaker_job.id}, meta={speaker_job.meta} (depends on {speaker_depends_on.id})") - # Step 4: Memory extraction job (parallel with title/summary) + # Step 3: Memory extraction job (parallel with title/summary) memory_job_id = f"memory_{conversation_id[:12]}" logger.info(f"🔍 DEBUG: Creating memory job with job_id={memory_job_id}, conversation_id={conversation_id[:12]}, audio_uuid={audio_uuid[:12]}") @@ -450,7 +467,7 @@ def start_post_conversation_jobs( ) logger.info(f"📥 RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (depends on {speaker_job.id})") - # Step 5: Title/summary generation job (parallel with memory, independent) + # Step 4: Title/summary generation job (parallel with memory, independent) # This ensures conversations always get titles/summaries even if memory job fails title_job_id = f"title_summary_{conversation_id[:12]}" logger.info(f"🔍 DEBUG: Creating title/summary job with job_id={title_job_id}, conversation_id={conversation_id[:12]}, audio_uuid={audio_uuid[:12]}") @@ -467,12 +484,34 @@ def start_post_conversation_jobs( ) logger.info(f"📥 RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on {speaker_job.id})") + # Step 5: Dispatch conversation.complete event (runs after both memory and title/summary complete) + # This ensures plugins receive the event after all processing is done + event_job_id = f"event_complete_{conversation_id[:12]}" + logger.info(f"🔍 DEBUG: Creating conversation complete event job with job_id={event_job_id}, conversation_id={conversation_id[:12]}, audio_uuid={audio_uuid[:12]}") + + # Event job depends on both memory and title/summary jobs completing + # Use RQ's depends_on list to wait for both + event_dispatch_job = default_queue.enqueue( + dispatch_conversation_complete_event_job, + conversation_id, + audio_uuid, + client_id or "", + user_id, + job_timeout=120, # 2 minutes + result_ttl=JOB_RESULT_TTL, + depends_on=[memory_job, title_summary_job], # Wait for both parallel jobs + job_id=event_job_id, + description=f"Dispatch conversation complete event for {conversation_id[:8]}", + meta=job_meta + ) + logger.info(f"📥 RQ: Enqueued conversation complete event job {event_dispatch_job.id}, meta={event_dispatch_job.meta} (depends on {memory_job.id} and {title_summary_job.id})") + return { - 'cropping': cropping_job.id, 'transcription': transcription_job.id if transcription_job else None, 'speaker_recognition': speaker_job.id, 'memory': memory_job.id, - 'title_summary': title_summary_job.id + 'title_summary': title_summary_job.id, + 'event_dispatch': event_dispatch_job.id } diff --git a/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py index a3836898..d1a22695 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py @@ -9,13 +9,61 @@ import logging import time -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Literal from fastapi.responses import JSONResponse logger = logging.getLogger(__name__) +async def mark_session_complete( + redis_client, + session_id: str, + reason: Literal[ + "websocket_disconnect", + "user_stopped", + "inactivity_timeout", + "max_duration", + "all_jobs_complete" + ], +) -> None: + """ + Single source of truth for marking sessions as complete. + + This function ensures that both 'status' and 'completion_reason' are ALWAYS + set together atomically, preventing race conditions where workers check status + before completion_reason is set. + + Args: + redis_client: Redis async client + session_id: Session UUID + reason: Why the session is completing (enforced by type system) + + Usage: + # WebSocket disconnect + await mark_session_complete(redis, session_id, "websocket_disconnect") + + # User manually stopped + await mark_session_complete(redis, session_id, "user_stopped") + + # Inactivity timeout + await mark_session_complete(redis, session_id, "inactivity_timeout") + + # Max duration reached + await mark_session_complete(redis, session_id, "max_duration") + + # All jobs finished + await mark_session_complete(redis, session_id, "all_jobs_complete") + """ + session_key = f"audio:session:{session_id}" + await redis_client.hset(session_key, mapping={ + "status": "complete", + "completed_at": str(time.time()), + "completion_reason": reason + }) + logger.info(f"✅ Session {session_id[:12]} marked complete: {reason}") + + async def get_session_info(redis_client, session_id: str) -> Optional[Dict]: """ Get detailed information about a specific session. @@ -192,8 +240,7 @@ async def get_streaming_status(request): # All jobs complete - this is truly a completed session # Update Redis status if it wasn't already marked complete if status not in ["complete", "completed", "finalized"]: - await redis_client.hset(key, "status", "complete") - logger.info(f"✅ Marked session {session_id} as complete (all jobs terminal)") + await mark_session_complete(redis_client, session_id, "all_jobs_complete") # Get additional session data for completed sessions session_key = f"audio:session:{session_id}" diff --git a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py index aced763f..f5ff3275 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py @@ -7,6 +7,7 @@ import shutil import time from datetime import UTC, datetime +from pathlib import Path import yaml from fastapi import HTTPException @@ -555,3 +556,139 @@ async def validate_chat_config_yaml(prompt_text: str) -> dict: except Exception as e: logger.error(f"Error validating chat config: {e}") return {"valid": False, "error": f"Validation error: {str(e)}"} + + +# Plugin Configuration Management Functions + +async def get_plugins_config_yaml() -> str: + """Get plugins configuration as YAML text.""" + try: + plugins_yml_path = Path("/app/plugins.yml") + + # Default empty plugins config + default_config = """plugins: + # No plugins configured yet + # Example plugin configuration: + # homeassistant: + # enabled: true + # access_level: transcript + # trigger: + # type: wake_word + # wake_word: vivi + # ha_url: http://localhost:8123 + # ha_token: YOUR_TOKEN_HERE +""" + + if not plugins_yml_path.exists(): + return default_config + + with open(plugins_yml_path, 'r') as f: + yaml_content = f.read() + + return yaml_content + + except Exception as e: + logger.error(f"Error loading plugins config: {e}") + raise + + +async def save_plugins_config_yaml(yaml_content: str) -> dict: + """Save plugins configuration from YAML text.""" + try: + plugins_yml_path = Path("/app/plugins.yml") + + # Validate YAML can be parsed + try: + parsed_config = yaml.safe_load(yaml_content) + if not isinstance(parsed_config, dict): + raise ValueError("Configuration must be a YAML dictionary") + + # Validate has 'plugins' key + if 'plugins' not in parsed_config: + raise ValueError("Configuration must contain 'plugins' key") + + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML syntax: {e}") + + # Create config directory if it doesn't exist + plugins_yml_path.parent.mkdir(parents=True, exist_ok=True) + + # Backup existing config + if plugins_yml_path.exists(): + backup_path = str(plugins_yml_path) + '.backup' + shutil.copy2(plugins_yml_path, backup_path) + logger.info(f"Created plugins config backup at {backup_path}") + + # Save new config + with open(plugins_yml_path, 'w') as f: + f.write(yaml_content) + + # Hot-reload plugins (optional - may require restart) + try: + from advanced_omi_backend.services.plugin_service import get_plugin_router + plugin_router = get_plugin_router() + if plugin_router: + logger.info("Plugin configuration updated - restart backend for changes to take effect") + except Exception as reload_err: + logger.warning(f"Could not reload plugins: {reload_err}") + + logger.info("Plugins configuration updated successfully") + + return { + "success": True, + "message": "Plugins configuration updated successfully. Restart backend for changes to take effect." + } + + except Exception as e: + logger.error(f"Error saving plugins config: {e}") + raise + + +async def validate_plugins_config_yaml(yaml_content: str) -> dict: + """Validate plugins configuration YAML.""" + try: + # Parse YAML + try: + parsed_config = yaml.safe_load(yaml_content) + except yaml.YAMLError as e: + return {"valid": False, "error": f"Invalid YAML syntax: {e}"} + + # Check structure + if not isinstance(parsed_config, dict): + return {"valid": False, "error": "Configuration must be a YAML dictionary"} + + if 'plugins' not in parsed_config: + return {"valid": False, "error": "Configuration must contain 'plugins' key"} + + plugins = parsed_config['plugins'] + if not isinstance(plugins, dict): + return {"valid": False, "error": "'plugins' must be a dictionary"} + + # Validate each plugin + valid_access_levels = ['transcript', 'conversation', 'memory'] + valid_trigger_types = ['wake_word', 'always', 'conditional'] + + for plugin_id, plugin_config in plugins.items(): + if not isinstance(plugin_config, dict): + return {"valid": False, "error": f"Plugin '{plugin_id}' config must be a dictionary"} + + # Check required fields + if 'enabled' in plugin_config and not isinstance(plugin_config['enabled'], bool): + return {"valid": False, "error": f"Plugin '{plugin_id}': 'enabled' must be boolean"} + + if 'access_level' in plugin_config and plugin_config['access_level'] not in valid_access_levels: + return {"valid": False, "error": f"Plugin '{plugin_id}': invalid access_level (must be one of {valid_access_levels})"} + + if 'trigger' in plugin_config: + trigger = plugin_config['trigger'] + if not isinstance(trigger, dict): + return {"valid": False, "error": f"Plugin '{plugin_id}': 'trigger' must be a dictionary"} + + if 'type' in trigger and trigger['type'] not in valid_trigger_types: + return {"valid": False, "error": f"Plugin '{plugin_id}': invalid trigger type (must be one of {valid_trigger_types})"} + + return {"valid": True, "message": "Configuration is valid"} + + except Exception as e: + logger.error(f"Error validating plugins config: {e}") + return {"valid": False, "error": f"Validation error: {str(e)}"} diff --git a/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py index 50ffc77f..28e9924f 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py @@ -17,10 +17,12 @@ from fastapi import WebSocket, WebSocketDisconnect, Query from friend_lite.decoder import OmiOpusDecoder +import redis.asyncio as redis from advanced_omi_backend.auth import websocket_auth from advanced_omi_backend.client_manager import generate_client_id, get_client_manager from advanced_omi_backend.constants import OMI_CHANNELS, OMI_SAMPLE_RATE, OMI_SAMPLE_WIDTH +from advanced_omi_backend.controllers.session_controller import mark_session_complete from advanced_omi_backend.utils.audio_utils import process_audio_chunk from advanced_omi_backend.services.audio_stream import AudioStreamProducer from advanced_omi_backend.services.audio_stream.producer import get_audio_stream_producer @@ -39,6 +41,89 @@ pending_connections: set[str] = set() +async def subscribe_to_interim_results(websocket: WebSocket, session_id: str) -> None: + """ + Subscribe to interim transcription results from Redis Pub/Sub and forward to client WebSocket. + + Runs as background task during WebSocket connection. Listens for interim and final + transcription results published by the Deepgram streaming consumer and forwards them + to the connected client for real-time transcript display. + + Args: + websocket: Connected WebSocket client + session_id: Session ID (client_id) to subscribe to + + Note: + This task runs continuously until the WebSocket disconnects or the task is cancelled. + Results are published to Redis Pub/Sub channel: transcription:interim:{session_id} + """ + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") + + try: + # Create Redis client for Pub/Sub + redis_client = await redis.from_url(redis_url, decode_responses=True) + + # Create Pub/Sub instance + pubsub = redis_client.pubsub() + + # Subscribe to interim results channel for this session + channel = f"transcription:interim:{session_id}" + await pubsub.subscribe(channel) + + logger.info(f"📢 Subscribed to interim results channel: {channel}") + + # Listen for messages + while True: + try: + message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0) + + if message and message['type'] == 'message': + # Parse result data + try: + result_data = json.loads(message['data']) + + # Forward to client WebSocket + await websocket.send_json({ + "type": "interim_transcript", + "data": result_data + }) + + # Log for debugging + is_final = result_data.get("is_final", False) + text_preview = result_data.get("text", "")[:50] + result_type = "FINAL" if is_final else "interim" + logger.debug(f"✉️ Forwarded {result_type} result to client {session_id}: {text_preview}...") + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse interim result JSON: {e}") + except Exception as send_error: + logger.error(f"Failed to send interim result to client {session_id}: {send_error}") + # WebSocket might be closed, exit loop + break + + except asyncio.TimeoutError: + # No message received, continue waiting + continue + except asyncio.CancelledError: + logger.info(f"Interim results subscriber cancelled for session {session_id}") + break + except Exception as e: + logger.error(f"Error in interim results subscriber for {session_id}: {e}", exc_info=True) + break + + except Exception as e: + logger.error(f"Failed to initialize interim results subscriber for {session_id}: {e}", exc_info=True) + finally: + try: + # Unsubscribe and close connections + await pubsub.unsubscribe(channel) + await pubsub.close() + await redis_client.aclose() + logger.info(f"🔕 Unsubscribed from interim results channel: {channel}") + except Exception as cleanup_error: + logger.error(f"Error cleaning up interim results subscriber: {cleanup_error}") + + async def parse_wyoming_protocol(ws: WebSocket) -> tuple[dict, Optional[bytes]]: """Parse Wyoming protocol: JSON header line followed by optional binary payload. @@ -105,9 +190,9 @@ async def create_client_state(client_id: str, user, device_name: Optional[str] = client_id, CHUNK_DIR, user.user_id, user.email ) - # Also track in persistent mapping (for database queries) - from advanced_omi_backend.client_manager import track_client_user_relationship - track_client_user_relationship(client_id, user.user_id) + # Also track in persistent mapping (for database queries + cross-container Redis) + from advanced_omi_backend.client_manager import track_client_user_relationship_async + await track_client_user_relationship_async(client_id, user.user_id) # Register client in user model (persistent) from advanced_omi_backend.users import register_client_to_user @@ -166,13 +251,8 @@ async def cleanup_client_state(client_id: str): client_id_bytes = await async_redis.hget(key, "client_id") if client_id_bytes and client_id_bytes.decode() == client_id: # Mark session as complete (WebSocket disconnected) - await async_redis.hset(key, mapping={ - "status": "complete", - "completed_at": str(time.time()), - "completion_reason": "websocket_disconnect" - }) session_id = key.decode().replace("audio:session:", "") - logger.info(f"📊 Marked session {session_id[:12]} as complete (WebSocket disconnect)") + await mark_session_complete(async_redis, session_id, "websocket_disconnect") sessions_closed += 1 if cursor == 0: @@ -181,12 +261,12 @@ async def cleanup_client_state(client_id: str): if sessions_closed > 0: logger.info(f"✅ Closed {sessions_closed} active session(s) for client {client_id}") - # Delete Redis Streams for this client + # Set TTL on Redis Streams for this client (allows consumer groups to finish processing) stream_pattern = f"audio:stream:{client_id}" stream_key = await async_redis.exists(stream_pattern) if stream_key: - await async_redis.delete(stream_pattern) - logger.info(f"🧹 Deleted Redis stream: {stream_pattern}") + await async_redis.expire(stream_pattern, 60) # 60 second TTL for consumer group fan-out + logger.info(f"⏰ Set 60s TTL on Redis stream: {stream_pattern}") else: logger.debug(f"No Redis stream found for client {client_id}") @@ -279,8 +359,9 @@ async def _initialize_streaming_session( user_id: str, user_email: str, client_id: str, - audio_format: dict -) -> None: + audio_format: dict, + websocket: Optional[WebSocket] = None +) -> Optional[asyncio.Task]: """ Initialize streaming session with Redis and enqueue processing jobs. @@ -291,15 +372,18 @@ async def _initialize_streaming_session( user_email: User email client_id: Client ID audio_format: Audio format dict from audio-start event + websocket: Optional WebSocket connection to launch interim results subscriber + + Returns: + Interim results subscriber task if websocket provided and session initialized, None otherwise """ if hasattr(client_state, 'stream_session_id'): application_logger.debug(f"Session already initialized for {client_id}") - return + return None - # Initialize stream session - client_state.stream_session_id = str(uuid.uuid4()) - client_state.stream_chunk_count = 0 - client_state.stream_audio_format = audio_format + # Initialize stream session - use client_id as session_id for predictable lookup + # All other session metadata goes to Redis (single source of truth) + client_state.stream_session_id = client_state.client_id application_logger.info(f"🆔 Created stream session: {client_state.stream_session_id}") # Determine transcription provider from config.yml @@ -313,21 +397,31 @@ async def _initialize_streaming_session( if not stt_model: raise ValueError("No default STT model configured in config.yml (defaults.stt)") - provider = stt_model.model_provider.lower() - if provider not in ["deepgram", "parakeet"]: - raise ValueError(f"Unsupported STT provider: {provider}. Expected: deepgram or parakeet") + # Use model_provider for session tracking (generic, not validated against hardcoded list) + provider = stt_model.model_provider.lower() if stt_model.model_provider else stt_model.name application_logger.info(f"📋 Using STT provider: {provider} (model: {stt_model.name})") - - # Initialize session tracking in Redis + + # Initialize session tracking in Redis (SINGLE SOURCE OF TRUTH for session metadata) + # This includes user_email, connection info, audio format, chunk counters, job IDs, etc. + connection_id = f"ws_{client_id}_{int(time.time())}" await audio_stream_producer.init_session( session_id=client_state.stream_session_id, user_id=user_id, client_id=client_id, + user_email=user_email, + connection_id=connection_id, mode="streaming", provider=provider ) + # Store audio format in Redis session (not in ClientState) + from advanced_omi_backend.services.audio_stream.producer import get_audio_stream_producer + import json + session_key = f"audio:session:{client_state.stream_session_id}" + redis_client = audio_stream_producer.redis_client + await redis_client.hset(session_key, "audio_format", json.dumps(audio_format)) + # Enqueue streaming jobs (speech detection + audio persistence) from advanced_omi_backend.controllers.queue_controller import start_streaming_jobs @@ -337,8 +431,22 @@ async def _initialize_streaming_session( client_id=client_id ) - client_state.speech_detection_job_id = job_ids['speech_detection'] - client_state.audio_persistence_job_id = job_ids['audio_persistence'] + # Store job IDs in Redis session (not in ClientState) + await audio_stream_producer.update_session_job_ids( + session_id=client_state.stream_session_id, + speech_detection_job_id=job_ids['speech_detection'], + audio_persistence_job_id=job_ids['audio_persistence'] + ) + + # Launch interim results subscriber if WebSocket provided + subscriber_task = None + if websocket: + subscriber_task = asyncio.create_task( + subscribe_to_interim_results(websocket, client_state.stream_session_id) + ) + application_logger.info(f"📡 Launched interim results subscriber for session {client_state.stream_session_id}") + + return subscriber_task async def _finalize_streaming_session( @@ -399,11 +507,10 @@ async def _finalize_streaming_session( f"✅ Session {session_id[:12]} marked as finalizing - open_conversation_job will handle cleanup" ) - # Clear session state - for attr in ['stream_session_id', 'stream_chunk_count', 'stream_audio_format', - 'speech_detection_job_id', 'audio_persistence_job_id']: - if hasattr(client_state, attr): - delattr(client_state, attr) + # Clear session state from ClientState (only stream_session_id is stored there now) + # All other session metadata lives in Redis (single source of truth) + if hasattr(client_state, 'stream_session_id'): + delattr(client_state, 'stream_session_id') except Exception as finalize_error: application_logger.error( @@ -439,14 +546,18 @@ async def _publish_audio_to_stream( application_logger.warning(f"⚠️ Received audio chunk before session initialized for {client_id}") return - # Increment chunk count and format chunk ID - client_state.stream_chunk_count += 1 - chunk_id = f"{client_state.stream_chunk_count:05d}" + session_id = client_state.stream_session_id + + # Increment chunk count in Redis (single source of truth) and format chunk ID + session_key = f"audio:session:{session_id}" + redis_client = audio_stream_producer.redis_client + chunk_count = await redis_client.hincrby(session_key, "chunks_published", 1) + chunk_id = f"{chunk_count:05d}" # Publish to Redis Stream using producer await audio_stream_producer.add_audio_chunk( audio_data=audio_data, - session_id=client_state.stream_session_id, + session_id=session_id, chunk_id=chunk_id, user_id=user_id, client_id=client_id, @@ -516,8 +627,9 @@ async def _handle_streaming_mode_audio( audio_format: dict, user_id: str, user_email: str, - client_id: str -) -> None: + client_id: str, + websocket: Optional[WebSocket] = None +) -> Optional[asyncio.Task]: """ Handle audio chunk in streaming mode. @@ -529,16 +641,22 @@ async def _handle_streaming_mode_audio( user_id: User ID user_email: User email client_id: Client ID + websocket: Optional WebSocket connection to launch interim results subscriber + + Returns: + Interim results subscriber task if websocket provided and session initialized, None otherwise """ # Initialize session if needed + subscriber_task = None if not hasattr(client_state, 'stream_session_id'): - await _initialize_streaming_session( + subscriber_task = await _initialize_streaming_session( client_state, audio_stream_producer, user_id, user_email, client_id, - audio_format + audio_format, + websocket=websocket # Pass WebSocket to launch interim results subscriber ) # Publish to Redis Stream @@ -553,6 +671,8 @@ async def _handle_streaming_mode_audio( audio_format.get("width", 2) ) + return subscriber_task + async def _handle_batch_mode_audio( client_state, @@ -589,8 +709,9 @@ async def _handle_audio_chunk( audio_format: dict, user_id: str, user_email: str, - client_id: str -) -> None: + client_id: str, + websocket: Optional[WebSocket] = None +) -> Optional[asyncio.Task]: """ Route audio chunk to appropriate mode handler (streaming or batch). @@ -602,18 +723,24 @@ async def _handle_audio_chunk( user_id: User ID user_email: User email client_id: Client ID + websocket: Optional WebSocket connection to launch interim results subscriber + + Returns: + Interim results subscriber task if websocket provided and streaming mode, None otherwise """ recording_mode = getattr(client_state, 'recording_mode', 'batch') if recording_mode == "streaming": - await _handle_streaming_mode_audio( + return await _handle_streaming_mode_audio( client_state, audio_stream_producer, audio_data, - audio_format, user_id, user_email, client_id + audio_format, user_id, user_email, client_id, + websocket=websocket ) else: await _handle_batch_mode_audio( client_state, audio_data, audio_format, client_id ) + return None async def _handle_audio_session_start( @@ -788,6 +915,7 @@ async def handle_omi_websocket( client_id = None client_state = None + interim_subscriber_task = None try: # Setup connection (accept, auth, create client state) @@ -814,13 +942,14 @@ async def handle_omi_websocket( if header["type"] == "audio-start": # Handle audio session start application_logger.info(f"🎙️ OMI audio session started for {client_id}") - await _initialize_streaming_session( + interim_subscriber_task = await _initialize_streaming_session( client_state, audio_stream_producer, user.user_id, user.email, client_id, - header.get("data", {"rate": OMI_SAMPLE_RATE, "width": OMI_SAMPLE_WIDTH, "channels": OMI_CHANNELS}) + header.get("data", {"rate": OMI_SAMPLE_RATE, "width": OMI_SAMPLE_WIDTH, "channels": OMI_CHANNELS}), + websocket=ws # Pass WebSocket to launch interim results subscriber ) elif header["type"] == "audio-chunk" and payload: @@ -883,6 +1012,16 @@ async def handle_omi_websocket( except Exception as e: application_logger.error(f"❌ WebSocket error for client {client_id}: {e}", exc_info=True) finally: + # Cancel interim results subscriber task if running + if interim_subscriber_task and not interim_subscriber_task.done(): + interim_subscriber_task.cancel() + try: + await interim_subscriber_task + except asyncio.CancelledError: + application_logger.info(f"Interim subscriber task cancelled for {client_id}") + except Exception as task_error: + application_logger.error(f"Error cancelling interim subscriber task: {task_error}") + # Clean up pending connection tracking pending_connections.discard(pending_client_id) @@ -909,6 +1048,7 @@ async def handle_pcm_websocket( client_id = None client_state = None + interim_subscriber_task = None try: # Setup connection (accept, auth, create client state) @@ -1011,15 +1151,19 @@ async def handle_pcm_websocket( # Route to appropriate mode handler audio_format = control_header.get("data", {}) - await _handle_audio_chunk( + task = await _handle_audio_chunk( client_state, audio_stream_producer, audio_data, audio_format, user.user_id, user.email, - client_id + client_id, + websocket=ws ) + # Store subscriber task if it was created (first streaming chunk) + if task and not interim_subscriber_task: + interim_subscriber_task = task else: application_logger.warning(f"Expected binary payload for audio-chunk, got: {payload_msg.keys()}") else: @@ -1044,15 +1188,19 @@ async def handle_pcm_websocket( # Route to appropriate mode handler with default format default_format = {"rate": 16000, "width": 2, "channels": 1} - await _handle_audio_chunk( + task = await _handle_audio_chunk( client_state, audio_stream_producer, audio_data, default_format, user.user_id, user.email, - client_id + client_id, + websocket=ws ) + # Store subscriber task if it was created (first streaming chunk) + if task and not interim_subscriber_task: + interim_subscriber_task = task else: application_logger.warning(f"Unexpected message format in streaming mode: {message.keys()}") @@ -1115,6 +1263,16 @@ async def handle_pcm_websocket( f"❌ PCM WebSocket error for client {client_id}: {e}", exc_info=True ) finally: + # Cancel interim results subscriber task if running + if interim_subscriber_task and not interim_subscriber_task.done(): + interim_subscriber_task.cancel() + try: + await interim_subscriber_task + except asyncio.CancelledError: + application_logger.info(f"Interim subscriber task cancelled for {client_id}") + except Exception as task_error: + application_logger.error(f"Error cancelling interim subscriber task: {task_error}") + # Clean up pending connection tracking pending_connections.discard(pending_client_id) diff --git a/backends/advanced/src/advanced_omi_backend/main.py b/backends/advanced/src/advanced_omi_backend/main.py index df51e1cc..5160c230 100644 --- a/backends/advanced/src/advanced_omi_backend/main.py +++ b/backends/advanced/src/advanced_omi_backend/main.py @@ -2,7 +2,7 @@ """ Unified Omi-audio service - * Accepts Opus packets over a WebSocket (`/ws`) or PCM over a WebSocket (`/ws_pcm`). + * Accepts audio over a unified WebSocket endpoint (`/ws`) with codec parameter (pcm or opus). * Uses a central queue to decouple audio ingestion from processing. * A saver consumer buffers PCM and writes 30-second WAV chunks to `./data/audio_chunks/`. * A transcription consumer sends each chunk to a Wyoming ASR service. diff --git a/backends/advanced/src/advanced_omi_backend/middleware/app_middleware.py b/backends/advanced/src/advanced_omi_backend/middleware/app_middleware.py index eafeffec..4cff21eb 100644 --- a/backends/advanced/src/advanced_omi_backend/middleware/app_middleware.py +++ b/backends/advanced/src/advanced_omi_backend/middleware/app_middleware.py @@ -56,8 +56,6 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): "/auth/jwt/logout", "/auth/cookie/logout", "/ws", - "/ws_omi", - "/ws_pcm", "/mcp", "/health", "/auth/health", diff --git a/backends/advanced/src/advanced_omi_backend/models/audio_file.py b/backends/advanced/src/advanced_omi_backend/models/audio_file.py index e1e2c09a..ca154500 100644 --- a/backends/advanced/src/advanced_omi_backend/models/audio_file.py +++ b/backends/advanced/src/advanced_omi_backend/models/audio_file.py @@ -41,9 +41,6 @@ class AudioFile(Document): user_id: Indexed(str) = Field(description="User who owns this audio") user_email: Optional[str] = Field(None, description="User email") - # Audio processing - cropped_audio_path: Optional[str] = Field(None, description="Path to cropped audio (speech only)") - # Speech-driven conversation linking conversation_id: Optional[str] = Field( None, diff --git a/backends/advanced/src/advanced_omi_backend/models/conversation.py b/backends/advanced/src/advanced_omi_backend/models/conversation.py index 01dd5d96..00178f10 100644 --- a/backends/advanced/src/advanced_omi_backend/models/conversation.py +++ b/backends/advanced/src/advanced_omi_backend/models/conversation.py @@ -19,12 +19,15 @@ class Conversation(Document): # Nested Enums class TranscriptProvider(str, Enum): - """Supported transcription providers.""" + """ + Transcription provider identifiers. + + Note: Actual providers are configured in config.yml. + Any provider name from config.yml is valid - this enum is for common values only. + """ DEEPGRAM = "deepgram" - MISTRAL = "mistral" - PARAKEET = "parakeet" - SPEECH_DETECTION = "speech_detection" # Legacy value - UNKNOWN = "unknown" # Fallback value + SPEECH_DETECTION = "speech_detection" + UNKNOWN = "unknown" class MemoryProvider(str, Enum): """Supported memory providers.""" @@ -63,7 +66,7 @@ class TranscriptVersion(BaseModel): transcript: Optional[str] = Field(None, description="Full transcript text") segments: List["Conversation.SpeakerSegment"] = Field(default_factory=list, description="Speaker segments") provider: Optional["Conversation.TranscriptProvider"] = Field(None, description="Transcription provider used") - model: Optional[str] = Field(None, description="Model used (e.g., nova-3, voxtral-mini-2507)") + model: Optional[str] = Field(None, description="Model used (e.g., nova-3, parakeet)") created_at: datetime = Field(description="When this version was created") processing_time_seconds: Optional[float] = Field(None, description="Time taken to process") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional provider-specific metadata") @@ -87,7 +90,6 @@ class MemoryVersion(BaseModel): # Audio file reference audio_path: Optional[str] = Field(None, description="Path to audio file (relative to CHUNK_DIR)") - cropped_audio_path: Optional[str] = Field(None, description="Path to cropped audio file (relative to CHUNK_DIR)") # Creation metadata created_at: Indexed(datetime) = Field(default_factory=datetime.utcnow, description="When the conversation was created") diff --git a/backends/advanced/src/advanced_omi_backend/plugins/__init__.py b/backends/advanced/src/advanced_omi_backend/plugins/__init__.py new file mode 100644 index 00000000..3ccea7dc --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/__init__.py @@ -0,0 +1,18 @@ +""" +Chronicle plugin system for multi-level pipeline extension. + +Plugins can hook into different stages of the processing pipeline: +- transcript: When new transcript segment arrives +- conversation: When conversation processing completes +- memory: After memory extraction finishes + +Trigger types control when plugins execute: +- wake_word: Only when transcript starts with specified wake word +- always: Execute on every invocation at access level +- conditional: Execute based on custom condition (future) +""" + +from .base import BasePlugin, PluginContext, PluginResult +from .router import PluginRouter + +__all__ = ['BasePlugin', 'PluginContext', 'PluginResult', 'PluginRouter'] diff --git a/backends/advanced/src/advanced_omi_backend/plugins/base.py b/backends/advanced/src/advanced_omi_backend/plugins/base.py new file mode 100644 index 00000000..e5dfcc36 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/base.py @@ -0,0 +1,131 @@ +""" +Base plugin classes for Chronicle multi-level plugin architecture. + +Provides: +- PluginContext: Context passed to plugin execution +- PluginResult: Result from plugin execution +- BasePlugin: Abstract base class for all plugins +""" +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any, List +from dataclasses import dataclass, field + + +@dataclass +class PluginContext: + """Context passed to plugin execution""" + user_id: str + event: str # Event name (e.g., "transcript.streaming", "conversation.complete") + data: Dict[str, Any] # Event-specific data + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class PluginResult: + """Result from plugin execution""" + success: bool + data: Optional[Dict[str, Any]] = None + message: Optional[str] = None + should_continue: bool = True # Whether to continue normal processing + + +class BasePlugin(ABC): + """ + Base class for all Chronicle plugins. + + Plugins can hook into different stages of the processing pipeline: + - transcript: When new transcript segment arrives + - conversation: When conversation processing completes + - memory: When memory extraction finishes + + Subclasses should: + 1. Set SUPPORTED_ACCESS_LEVELS to list which levels they support + 2. Implement initialize() for plugin initialization + 3. Implement the appropriate callback methods (on_transcript, on_conversation_complete, on_memory_processed) + 4. Optionally implement cleanup() for resource cleanup + """ + + # Subclasses declare which access levels they support + SUPPORTED_ACCESS_LEVELS: List[str] = [] + + def __init__(self, config: Dict[str, Any]): + """ + Initialize plugin with configuration. + + Args: + config: Plugin configuration from config/plugins.yml + Contains: enabled, subscriptions, trigger, and plugin-specific config + """ + self.config = config + self.enabled = config.get('enabled', False) + self.subscriptions = config.get('subscriptions', []) + self.trigger = config.get('trigger', {'type': 'always'}) + + @abstractmethod + async def initialize(self): + """ + Initialize plugin resources (connect to services, etc.) + + Called during application startup after plugin registration. + Raise an exception if initialization fails. + """ + pass + + async def cleanup(self): + """ + Clean up plugin resources. + + Called during application shutdown. + Override if your plugin needs cleanup (closing connections, etc.) + """ + pass + + # Access-level specific methods (implement only what you need) + + async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]: + """ + Called when new transcript segment arrives. + + Context data contains: + - transcript: str - The transcript text + - segment_id: str - Unique segment identifier + - conversation_id: str - Current conversation ID + + For wake_word triggers, router adds: + - command: str - Command with wake word stripped + - original_transcript: str - Full transcript + + Returns: + PluginResult with success status, optional message, and should_continue flag + """ + pass + + async def on_conversation_complete(self, context: PluginContext) -> Optional[PluginResult]: + """ + Called when conversation processing completes. + + Context data contains: + - conversation: dict - Full conversation data + - transcript: str - Complete transcript + - duration: float - Conversation duration + - conversation_id: str - Conversation identifier + + Returns: + PluginResult with success status, optional message, and should_continue flag + """ + pass + + async def on_memory_processed(self, context: PluginContext) -> Optional[PluginResult]: + """ + Called after memory extraction finishes. + + Context data contains: + - memories: list - Extracted memories + - conversation: dict - Source conversation + - memory_count: int - Number of memories created + - conversation_id: str - Conversation identifier + + Returns: + PluginResult with success status, optional message, and should_continue flag + """ + pass diff --git a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/__init__.py b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/__init__.py new file mode 100644 index 00000000..11b831e9 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/__init__.py @@ -0,0 +1,9 @@ +""" +Home Assistant plugin for Chronicle. + +Allows control of Home Assistant devices via natural language wake word commands. +""" + +from .plugin import HomeAssistantPlugin + +__all__ = ['HomeAssistantPlugin'] diff --git a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/command_parser.py b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/command_parser.py new file mode 100644 index 00000000..cc73626d --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/command_parser.py @@ -0,0 +1,97 @@ +""" +LLM-based command parser for Home Assistant integration. + +This module provides structured command parsing using LLM to extract +intent, target entities/areas, and parameters from natural language. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + + +@dataclass +class ParsedCommand: + """Structured representation of a parsed Home Assistant command.""" + + action: str + """Action to perform (e.g., turn_on, turn_off, set_brightness, toggle)""" + + target_type: str + """Type of target (area, entity, all_in_area)""" + + target: str + """Target identifier (area name or entity name)""" + + entity_type: Optional[str] = None + """Entity domain filter (e.g., light, switch, fan) - None means all types""" + + parameters: Dict[str, Any] = field(default_factory=dict) + """Additional parameters (e.g., brightness_pct=50, color='red')""" + + +# LLM System Prompt for Command Parsing +COMMAND_PARSER_SYSTEM_PROMPT = """You are a smart home command parser for Home Assistant. + +Extract structured information from natural language commands. +Return ONLY valid JSON in this exact format (no markdown, no code blocks, no explanation): + +{ + "action": "turn_off", + "target_type": "area", + "target": "study", + "entity_type": "light", + "parameters": {} +} + +ACTIONS (choose one): +- turn_on: Turn on entities +- turn_off: Turn off entities +- toggle: Toggle entity state +- set_brightness: Set brightness level +- set_color: Set color + +TARGET_TYPE (choose one): +- area: Targeting all entities of a type in an area (e.g., "study lights") +- all_in_area: Targeting ALL entities in an area (e.g., "everything in study") +- entity: Targeting a specific entity by name (e.g., "desk lamp") + +ENTITY_TYPE (optional, use null if not specified): +- light: Light entities +- switch: Switch entities +- fan: Fan entities +- cover: Covers/blinds +- null: All entity types (when target_type is "all_in_area") + +PARAMETERS (optional, empty dict if none): +- brightness_pct: Brightness percentage (0-100) +- color: Color name (e.g., "red", "blue", "warm white") + +EXAMPLES: + +Command: "turn off study lights" +Response: {"action": "turn_off", "target_type": "area", "target": "study", "entity_type": "light", "parameters": {}} + +Command: "turn off everything in study" +Response: {"action": "turn_off", "target_type": "all_in_area", "target": "study", "entity_type": null, "parameters": {}} + +Command: "turn on desk lamp" +Response: {"action": "turn_on", "target_type": "entity", "target": "desk lamp", "entity_type": null, "parameters": {}} + +Command: "set study lights to 50%" +Response: {"action": "set_brightness", "target_type": "area", "target": "study", "entity_type": "light", "parameters": {"brightness_pct": 50}} + +Command: "turn on living room fan" +Response: {"action": "turn_on", "target_type": "area", "target": "living room", "entity_type": "fan", "parameters": {}} + +Command: "turn off all lights" +Response: {"action": "turn_off", "target_type": "entity", "target": "all", "entity_type": "light", "parameters": {}} + +Command: "toggle hallway light" +Response: {"action": "toggle", "target_type": "entity", "target": "hallway light", "entity_type": null, "parameters": {}} + +Remember: +1. Return ONLY the JSON object, no markdown formatting +2. Use lowercase for action, target_type, target, entity_type +3. Use null (not "null" string) for missing entity_type +4. Always include all 5 fields: action, target_type, target, entity_type, parameters +""" diff --git a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/entity_cache.py b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/entity_cache.py new file mode 100644 index 00000000..e8624f1b --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/entity_cache.py @@ -0,0 +1,133 @@ +""" +Entity cache for Home Assistant integration. + +This module provides caching and lookup functionality for Home Assistant areas and entities. +""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, List, Optional +import logging + +logger = logging.getLogger(__name__) + + +@dataclass +class EntityCache: + """Cache for Home Assistant areas and entities.""" + + areas: List[str] = field(default_factory=list) + """List of area names (e.g., ["study", "living_room"])""" + + area_entities: Dict[str, List[str]] = field(default_factory=dict) + """Map of area names to entity IDs (e.g., {"study": ["light.tubelight_3"]})""" + + entity_details: Dict[str, Dict] = field(default_factory=dict) + """Full entity state data keyed by entity_id""" + + last_refresh: datetime = field(default_factory=datetime.now) + """Timestamp of last cache refresh""" + + def find_entity_by_name(self, name: str) -> Optional[str]: + """ + Find entity ID by fuzzy name matching. + + Matching priority: + 1. Exact friendly_name match (case-insensitive) + 2. Partial friendly_name match (case-insensitive) + 3. Entity ID match (e.g., "tubelight_3" → "light.tubelight_3") + + Args: + name: Entity name to search for + + Returns: + Entity ID if found, None otherwise + """ + name_lower = name.lower().strip() + + # Step 1: Exact friendly_name match + for entity_id, details in self.entity_details.items(): + friendly_name = details.get('attributes', {}).get('friendly_name', '') + if friendly_name.lower() == name_lower: + logger.debug(f"Exact match: {name} → {entity_id} (friendly_name: {friendly_name})") + return entity_id + + # Step 2: Partial friendly_name match + for entity_id, details in self.entity_details.items(): + friendly_name = details.get('attributes', {}).get('friendly_name', '') + if name_lower in friendly_name.lower(): + logger.debug(f"Partial match: {name} → {entity_id} (friendly_name: {friendly_name})") + return entity_id + + # Step 3: Entity ID match (try adding common domains) + common_domains = ['light', 'switch', 'fan', 'cover'] + for domain in common_domains: + candidate_id = f"{domain}.{name_lower.replace(' ', '_')}" + if candidate_id in self.entity_details: + logger.debug(f"Entity ID match: {name} → {candidate_id}") + return candidate_id + + logger.warning(f"No entity found matching: {name}") + return None + + def get_entities_in_area( + self, + area: str, + entity_type: Optional[str] = None + ) -> List[str]: + """ + Get all entities in an area, optionally filtered by domain. + + Args: + area: Area name (case-insensitive) + entity_type: Entity domain filter (e.g., "light", "switch") + + Returns: + List of entity IDs in the area + """ + area_lower = area.lower().strip() + + # Find matching area (case-insensitive) + matching_area = None + for area_name in self.areas: + if area_name.lower() == area_lower: + matching_area = area_name + break + + if not matching_area: + logger.warning(f"Area not found: {area}") + return [] + + # Get entities in area + entities = self.area_entities.get(matching_area, []) + + # Filter by entity type if specified + if entity_type: + entity_type_lower = entity_type.lower() + entities = [ + e for e in entities + if e.split('.')[0] == entity_type_lower + ] + + logger.debug( + f"Found {len(entities)} entities in area '{matching_area}'" + + (f" (type: {entity_type})" if entity_type else "") + ) + + return entities + + def get_cache_age_seconds(self) -> float: + """Get cache age in seconds.""" + return (datetime.now() - self.last_refresh).total_seconds() + + def is_stale(self, max_age_seconds: int = 3600) -> bool: + """ + Check if cache is stale. + + Args: + max_age_seconds: Maximum cache age before considering stale (default: 1 hour) + + Returns: + True if cache is older than max_age_seconds + """ + return self.get_cache_age_seconds() > max_age_seconds diff --git a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/mcp_client.py b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/mcp_client.py new file mode 100644 index 00000000..42ede8dc --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/mcp_client.py @@ -0,0 +1,421 @@ +""" +MCP client for communicating with Home Assistant's MCP Server. + +Home Assistant exposes an MCP server at /api/mcp that provides tools +for controlling smart home devices. +""" + +import json +import logging +from typing import Any, Dict, List, Optional + +import httpx + +logger = logging.getLogger(__name__) + + +class MCPError(Exception): + """MCP protocol error""" + pass + + +class HAMCPClient: + """ + MCP Client for Home Assistant's /api/mcp endpoint. + + Implements the Model Context Protocol for communicating with + Home Assistant's built-in MCP server. + """ + + def __init__(self, base_url: str, token: str, timeout: int = 30): + """ + Initialize the MCP client. + + Args: + base_url: Base URL of Home Assistant (e.g., http://localhost:8123) + token: Long-lived access token for authentication + timeout: Request timeout in seconds + + """ + self.base_url = base_url.rstrip('/') + self.mcp_url = f"{self.base_url}/api/mcp" + self.token = token + self.timeout = timeout + self.client = httpx.AsyncClient(timeout=timeout) + self._request_id = 0 + + async def close(self): + """Close the HTTP client""" + await self.client.aclose() + + def _next_request_id(self) -> int: + """Generate next request ID""" + self._request_id += 1 + return self._request_id + + async def _send_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict[str, Any]: + """ + Send MCP protocol request to Home Assistant. + + Args: + method: MCP method name (e.g., "tools/list", "tools/call") + params: Optional method parameters + + Returns: + Response data from MCP server + + Raises: + MCPError: If request fails or returns an error + """ + payload = { + "jsonrpc": "2.0", + "id": self._next_request_id(), + "method": method + } + + if params: + payload["params"] = params + + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + + try: + logger.debug(f"MCP Request: {method} with params: {params}") + response = await self.client.post( + self.mcp_url, + json=payload, + headers=headers + ) + response.raise_for_status() + + data = response.json() + + # Check for JSON-RPC error + if "error" in data: + error = data["error"] + raise MCPError(f"MCP Error {error.get('code')}: {error.get('message')}") + + return data.get("result", {}) + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error calling MCP endpoint: {e.response.status_code}") + raise MCPError(f"HTTP {e.response.status_code}: {e.response.text}") + except httpx.RequestError as e: + logger.error(f"Request error calling MCP endpoint: {e}") + raise MCPError(f"Request failed: {e}") + except Exception as e: + logger.error(f"Unexpected error calling MCP endpoint: {e}") + raise MCPError(f"Unexpected error: {e}") + + async def list_tools(self) -> List[Dict[str, Any]]: + """ + Get list of available MCP tools from Home Assistant. + + Returns: + List of tool definitions with schema + + Example tool: + { + "name": "turn_on", + "description": "Turn on a light or switch", + "inputSchema": { + "type": "object", + "properties": { + "entity_id": {"type": "string"} + } + } + } + """ + result = await self._send_mcp_request("tools/list") + tools = result.get("tools", []) + logger.info(f"Retrieved {len(tools)} tools from Home Assistant MCP") + return tools + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute a tool via MCP. + + Args: + tool_name: Name of the tool to call (e.g., "turn_on", "turn_off") + arguments: Tool arguments (e.g., {"entity_id": "light.hall_light"}) + + Returns: + Tool execution result + + Raises: + MCPError: If tool execution fails + + Example: + >>> await client.call_tool("turn_off", {"entity_id": "light.hall_light"}) + {"success": True} + """ + params = { + "name": tool_name, + "arguments": arguments + } + + logger.info(f"Calling MCP tool '{tool_name}' with args: {arguments}") + result = await self._send_mcp_request("tools/call", params) + + # MCP tool results are wrapped in content blocks + content = result.get("content", []) + if content and isinstance(content, list): + # Extract text content from first block + first_block = content[0] + if isinstance(first_block, dict) and first_block.get("type") == "text": + return {"result": first_block.get("text"), "success": True} + + return result + + async def test_connection(self) -> bool: + """ + Test connection to Home Assistant MCP server. + + Returns: + True if connection successful, False otherwise + """ + try: + tools = await self.list_tools() + logger.info(f"MCP connection test successful ({len(tools)} tools available)") + return True + except Exception as e: + logger.error(f"MCP connection test failed: {e}") + return False + + async def _render_template(self, template: str) -> Any: + """ + Render a Home Assistant template using the Template API. + + Args: + template: Jinja2 template string (e.g., "{{ areas() }}") + + Returns: + Rendered template result (parsed as JSON if possible) + + Raises: + MCPError: If template rendering fails + + Example: + >>> await client._render_template("{{ areas() }}") + ["study", "living_room", "bedroom"] + """ + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + + payload = {"template": template} + + try: + logger.debug(f"Rendering template: {template}") + response = await self.client.post( + f"{self.base_url}/api/template", + json=payload, + headers=headers + ) + response.raise_for_status() + + result = response.text.strip() + + # Try to parse as JSON (for lists, dicts) + if result.startswith('[') or result.startswith('{'): + try: + return json.loads(result) + except json.JSONDecodeError: + logger.warning(f"Failed to parse template result as JSON: {result}") + return result + + return result + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error rendering template: {e.response.status_code}") + raise MCPError(f"HTTP {e.response.status_code}: {e.response.text}") + except httpx.RequestError as e: + logger.error(f"Request error rendering template: {e}") + raise MCPError(f"Request failed: {e}") + + async def fetch_areas(self) -> List[str]: + """ + Fetch all areas from Home Assistant using Template API. + + Returns: + List of area names + + Example: + >>> await client.fetch_areas() + ["study", "living_room", "bedroom"] + """ + template = "{{ areas() | to_json }}" + areas = await self._render_template(template) + + if isinstance(areas, list): + logger.info(f"Fetched {len(areas)} areas from Home Assistant") + return areas + else: + logger.warning(f"Unexpected areas format: {type(areas)}") + return [] + + async def fetch_area_entities(self, area_name: str) -> List[str]: + """ + Fetch all entity IDs in a specific area. + + Args: + area_name: Name of the area + + Returns: + List of entity IDs in the area + + Example: + >>> await client.fetch_area_entities("study") + ["light.tubelight_3", "switch.desk_fan"] + """ + template = f"{{{{ area_entities('{area_name}') | to_json }}}}" + entities = await self._render_template(template) + + if isinstance(entities, list): + logger.info(f"Fetched {len(entities)} entities from area '{area_name}'") + return entities + else: + logger.warning(f"Unexpected entities format for area '{area_name}': {type(entities)}") + return [] + + async def fetch_entity_states(self) -> Dict[str, Dict]: + """ + Fetch all entity states from Home Assistant. + + Returns: + Dict mapping entity_id to state data (includes attributes, area_id) + + Example: + >>> await client.fetch_entity_states() + { + "light.tubelight_3": { + "state": "on", + "attributes": {"friendly_name": "Study Light", ...}, + "area_id": "study" + } + } + """ + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + + try: + logger.debug("Fetching all entity states") + response = await self.client.get( + f"{self.base_url}/api/states", + headers=headers + ) + response.raise_for_status() + + states = response.json() + entity_details = {} + + # Enrich with area information + for state in states: + entity_id = state.get('entity_id') + if entity_id: + # Get area_id using Template API + try: + area_template = f"{{{{ area_id('{entity_id}') }}}}" + area_id = await self._render_template(area_template) + state['area_id'] = area_id if area_id else None + except Exception as e: + logger.debug(f"Failed to get area for {entity_id}: {e}") + state['area_id'] = None + + entity_details[entity_id] = state + + logger.info(f"Fetched {len(entity_details)} entity states") + return entity_details + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error fetching states: {e.response.status_code}") + raise MCPError(f"HTTP {e.response.status_code}: {e.response.text}") + except httpx.RequestError as e: + logger.error(f"Request error fetching states: {e}") + raise MCPError(f"Request failed: {e}") + + async def call_service( + self, + domain: str, + service: str, + entity_ids: List[str], + **parameters + ) -> Dict[str, Any]: + """ + Call a Home Assistant service directly via REST API. + + Args: + domain: Service domain (e.g., "light", "switch") + service: Service name (e.g., "turn_on", "turn_off") + entity_ids: List of entity IDs to target + **parameters: Additional service parameters (e.g., brightness_pct=50) + + Returns: + Service call response + + Example: + >>> await client.call_service("light", "turn_on", ["light.study"], brightness_pct=50) + [{"entity_id": "light.study", "state": "on"}] + """ + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + + payload = { + "entity_id": entity_ids, + **parameters + } + + service_url = f"{self.base_url}/api/services/{domain}/{service}" + + try: + logger.info(f"Calling service {domain}.{service} for {len(entity_ids)} entities") + logger.debug(f"Service payload: {payload}") + + response = await self.client.post( + service_url, + json=payload, + headers=headers + ) + response.raise_for_status() + + result = response.json() + logger.info(f"Service call successful: {domain}.{service}") + return result + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error calling service: {e.response.status_code}") + raise MCPError(f"HTTP {e.response.status_code}: {e.response.text}") + except httpx.RequestError as e: + logger.error(f"Request error calling service: {e}") + raise MCPError(f"Request failed: {e}") + + async def discover_entities(self) -> Dict[str, Dict]: + """ + Discover available entities from MCP tools. + + Parses the available tools to build an index of entities + that can be controlled. + + Returns: + Dict mapping entity_id to metadata + """ + tools = await self.list_tools() + entities = {} + + for tool in tools: + # Extract entity information from tool schemas + # This will depend on how HA MCP structures its tools + # For now, we'll just log what we find + logger.debug(f"Tool: {tool.get('name')} - {tool.get('description')}") + + # TODO: Parse tool schemas to extract entity_id information + # For now, return empty dict - will be populated based on actual HA MCP response + + return entities diff --git a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py new file mode 100644 index 00000000..931dd813 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py @@ -0,0 +1,598 @@ +""" +Home Assistant plugin for Chronicle. + +Enables control of Home Assistant devices through natural language commands +triggered by a wake word. +""" + +import json +import logging +from typing import Any, Dict, List, Optional + +from ..base import BasePlugin, PluginContext, PluginResult +from .entity_cache import EntityCache +from .mcp_client import HAMCPClient, MCPError + +logger = logging.getLogger(__name__) + + +class HomeAssistantPlugin(BasePlugin): + """ + Plugin for controlling Home Assistant devices via wake word commands. + + Example: + User says: "Vivi, turn off the hall lights" + -> Wake word "vivi" detected by router + -> Command "turn off the hall lights" passed to on_transcript() + -> Plugin parses command and calls HA MCP to execute + -> Returns: PluginResult with "I've turned off the hall light" + """ + + SUPPORTED_ACCESS_LEVELS: List[str] = ['transcript'] + + def __init__(self, config: Dict[str, Any]): + """ + Initialize Home Assistant plugin. + + Args: + config: Plugin configuration with keys: + - ha_url: Home Assistant URL + - ha_token: Long-lived access token + - wake_word: Wake word for triggering commands (handled by router) + - enabled: Whether plugin is enabled + - access_level: Should be 'transcript' + - trigger: Should be {'type': 'wake_word', 'wake_word': '...'} + """ + super().__init__(config) + self.mcp_client: Optional[HAMCPClient] = None + self.available_tools: List[Dict] = [] + self.entities: Dict[str, Dict] = {} + + # Entity cache for area-based commands + self.entity_cache: Optional[EntityCache] = None + self.cache_initialized = False + + # Configuration + self.ha_url = config.get('ha_url', 'http://localhost:8123') + self.ha_token = config.get('ha_token', '') + self.wake_word = config.get('wake_word', 'vivi') + self.timeout = config.get('timeout', 30) + + async def initialize(self): + """ + Initialize the Home Assistant plugin. + + Connects to Home Assistant MCP server and discovers available tools. + + Raises: + MCPError: If connection or discovery fails + """ + if not self.enabled: + logger.info("Home Assistant plugin is disabled, skipping initialization") + return + + if not self.ha_token: + raise ValueError("Home Assistant token is required") + + logger.info(f"Initializing Home Assistant plugin (URL: {self.ha_url})") + + # Create MCP client (used for REST API calls, not MCP protocol) + self.mcp_client = HAMCPClient( + base_url=self.ha_url, + token=self.ha_token, + timeout=self.timeout + ) + + # Test basic API connectivity with Template API + try: + logger.info("Testing Home Assistant API connectivity...") + test_result = await self.mcp_client._render_template("{{ 1 + 1 }}") + if str(test_result).strip() != "2": + raise ValueError(f"Unexpected template result: {test_result}") + logger.info("Home Assistant API connection successful") + except Exception as e: + raise MCPError(f"Failed to connect to Home Assistant API: {e}") + + logger.info("Home Assistant plugin initialized successfully") + + async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]: + """ + Execute Home Assistant command from wake word transcript. + + Called by the router when a wake word is detected in the transcript. + The router has already stripped the wake word and extracted the command. + + Args: + context: PluginContext containing: + - user_id: User ID who issued the command + - access_level: 'transcript' + - data: Dict with: + - command: str - Command with wake word already stripped + - original_transcript: str - Full transcript with wake word + - transcript: str - Original transcript + - segment_id: str - Unique segment identifier + - conversation_id: str - Current conversation ID + - metadata: Optional additional metadata + + Returns: + PluginResult with: + - success: True if command executed + - message: User-friendly response + - data: Dict with action details + - should_continue: False to stop normal processing + + Example: + Context data: + { + 'command': 'turn off study lights', + 'original_transcript': 'vivi turn off study lights', + 'conversation_id': 'conv_123' + } + + Returns: + PluginResult( + success=True, + message="I've turned off 1 light in study", + data={'action': 'turn_off', 'entity_ids': ['light.tubelight_3']}, + should_continue=False + ) + """ + command = context.data.get('command', '') + + if not command: + return PluginResult( + success=False, + message="No command provided", + should_continue=True + ) + + if not self.mcp_client: + logger.error("MCP client not initialized") + return PluginResult( + success=False, + message="Sorry, Home Assistant is not connected", + should_continue=True + ) + + try: + # Step 1: Parse command using hybrid LLM + fallback parsing + logger.info(f"Processing HA command: '{command}'") + parsed = await self._parse_command_hybrid(command) + + if not parsed: + return PluginResult( + success=False, + message="Sorry, I couldn't understand that command", + should_continue=True + ) + + # Step 2: Resolve entities from parsed command + try: + entity_ids = await self._resolve_entities(parsed) + except ValueError as e: + logger.warning(f"Entity resolution failed: {e}") + return PluginResult( + success=False, + message=str(e), + should_continue=True + ) + + # Step 3: Determine service and domain + # Extract domain from first entity (all should have same domain for area-based) + domain = entity_ids[0].split('.')[0] if entity_ids else 'light' + + # Map action to service name + service_map = { + 'turn_on': 'turn_on', + 'turn_off': 'turn_off', + 'toggle': 'toggle', + 'set_brightness': 'turn_on', # brightness uses turn_on with params + 'set_color': 'turn_on' # color uses turn_on with params + } + service = service_map.get(parsed.action, 'turn_on') + + # Step 4: Call Home Assistant service + logger.info( + f"Calling {domain}.{service} for {len(entity_ids)} entities: {entity_ids}" + ) + + result = await self.mcp_client.call_service( + domain=domain, + service=service, + entity_ids=entity_ids, + **parsed.parameters + ) + + # Step 5: Format user-friendly response + entity_type_name = parsed.entity_type or domain + if parsed.target_type == 'area': + message = ( + f"I've {parsed.action.replace('_', ' ')} {len(entity_ids)} " + f"{entity_type_name}{'s' if len(entity_ids) != 1 else ''} " + f"in {parsed.target}" + ) + elif parsed.target_type == 'all_in_area': + message = ( + f"I've {parsed.action.replace('_', ' ')} {len(entity_ids)} " + f"entities in {parsed.target}" + ) + else: + message = f"I've {parsed.action.replace('_', ' ')} {parsed.target}" + + logger.info(f"HA command executed successfully: {message}") + + return PluginResult( + success=True, + data={ + 'action': parsed.action, + 'entity_ids': entity_ids, + 'target_type': parsed.target_type, + 'target': parsed.target, + 'ha_result': result + }, + message=message, + should_continue=False # Stop normal processing - HA command handled + ) + + except MCPError as e: + logger.error(f"Home Assistant API error: {e}", exc_info=True) + return PluginResult( + success=False, + message=f"Sorry, Home Assistant couldn't execute that: {e}", + should_continue=True + ) + except Exception as e: + logger.error(f"Command execution failed: {e}", exc_info=True) + return PluginResult( + success=False, + message="Sorry, something went wrong while executing that command", + should_continue=True + ) + + async def cleanup(self): + """Clean up resources""" + if self.mcp_client: + await self.mcp_client.close() + logger.info("Closed Home Assistant MCP client") + + async def _ensure_cache_initialized(self): + """Ensure entity cache is initialized. Lazy-load on first use.""" + if not self.cache_initialized: + logger.info("Entity cache not initialized, refreshing...") + await self._refresh_cache() + self.cache_initialized = True + + async def _refresh_cache(self): + """ + Refresh the entity cache from Home Assistant. + + Fetches: + - All areas + - Entities in each area + - Entity state details + """ + if not self.mcp_client: + logger.error("Cannot refresh cache: MCP client not initialized") + return + + try: + logger.info("Refreshing entity cache from Home Assistant...") + + # Fetch all areas + areas = await self.mcp_client.fetch_areas() + logger.debug(f"Fetched {len(areas)} areas: {areas}") + + # Fetch entities for each area + area_entities = {} + for area in areas: + entities = await self.mcp_client.fetch_area_entities(area) + area_entities[area] = entities + logger.debug(f"Area '{area}': {len(entities)} entities") + + # Fetch all entity states + entity_details = await self.mcp_client.fetch_entity_states() + logger.debug(f"Fetched {len(entity_details)} entity states") + + # Create cache + from datetime import datetime + self.entity_cache = EntityCache( + areas=areas, + area_entities=area_entities, + entity_details=entity_details, + last_refresh=datetime.now() + ) + + logger.info( + f"Entity cache refreshed: {len(areas)} areas, " + f"{len(entity_details)} entities" + ) + + except Exception as e: + logger.error(f"Failed to refresh entity cache: {e}", exc_info=True) + raise + + async def _parse_command_with_llm(self, command: str) -> Optional['ParsedCommand']: + """ + Parse command using LLM with structured system prompt. + + Args: + command: Natural language command (wake word already stripped) + + Returns: + ParsedCommand if parsing succeeds, None otherwise + + Example: + >>> await self._parse_command_with_llm("turn off study lights") + ParsedCommand( + action="turn_off", + target_type="area", + target="study", + entity_type="light", + parameters={} + ) + """ + try: + from advanced_omi_backend.llm_client import get_llm_client + from .command_parser import COMMAND_PARSER_SYSTEM_PROMPT, ParsedCommand + + llm_client = get_llm_client() + + logger.debug(f"Parsing command with LLM: '{command}'") + + # Use OpenAI chat format with system + user messages + response = llm_client.client.chat.completions.create( + model=llm_client.model, + messages=[ + {"role": "system", "content": COMMAND_PARSER_SYSTEM_PROMPT}, + {"role": "user", "content": f'Command: "{command}"\n\nReturn JSON only.'} + ], + temperature=0.1, + max_tokens=150 + ) + + result_text = response.choices[0].message.content.strip() + logger.debug(f"LLM response: {result_text}") + + # Remove markdown code blocks if present + if result_text.startswith('```'): + lines = result_text.split('\n') + result_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else result_text + result_text = result_text.strip() + + # Parse JSON response + result_json = json.loads(result_text) + + # Validate required fields + required_fields = ['action', 'target_type', 'target'] + if not all(field in result_json for field in required_fields): + logger.warning(f"LLM response missing required fields: {result_json}") + return None + + parsed = ParsedCommand( + action=result_json['action'], + target_type=result_json['target_type'], + target=result_json['target'], + entity_type=result_json.get('entity_type'), + parameters=result_json.get('parameters', {}) + ) + + logger.info( + f"LLM parsed command: action={parsed.action}, " + f"target_type={parsed.target_type}, target={parsed.target}, " + f"entity_type={parsed.entity_type}" + ) + + return parsed + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse LLM JSON response: {e}\nResponse: {result_text}") + return None + except Exception as e: + logger.error(f"LLM command parsing failed: {e}", exc_info=True) + return None + + async def _resolve_entities(self, parsed: 'ParsedCommand') -> List[str]: + """ + Resolve ParsedCommand to actual Home Assistant entity IDs. + + Args: + parsed: ParsedCommand from LLM parsing + + Returns: + List of entity IDs to target + + Raises: + ValueError: If target not found or ambiguous + + Example: + >>> await self._resolve_entities(ParsedCommand( + ... action="turn_off", + ... target_type="area", + ... target="study", + ... entity_type="light" + ... )) + ["light.tubelight_3"] + """ + from .command_parser import ParsedCommand + + # Ensure cache is ready + await self._ensure_cache_initialized() + + if not self.entity_cache: + raise ValueError("Entity cache not initialized") + + if parsed.target_type == 'area': + # Get entities in area, filtered by type + entities = self.entity_cache.get_entities_in_area( + area=parsed.target, + entity_type=parsed.entity_type + ) + + if not entities: + entity_desc = f"{parsed.entity_type}s" if parsed.entity_type else "entities" + raise ValueError( + f"No {entity_desc} found in area '{parsed.target}'. " + f"Available areas: {', '.join(self.entity_cache.areas)}" + ) + + logger.info( + f"Resolved area '{parsed.target}' to {len(entities)} " + f"{parsed.entity_type or 'entity'}(s)" + ) + return entities + + elif parsed.target_type == 'all_in_area': + # Get ALL entities in area (no filter) + entities = self.entity_cache.get_entities_in_area( + area=parsed.target, + entity_type=None + ) + + if not entities: + raise ValueError( + f"No entities found in area '{parsed.target}'. " + f"Available areas: {', '.join(self.entity_cache.areas)}" + ) + + logger.info(f"Resolved 'all in {parsed.target}' to {len(entities)} entities") + return entities + + elif parsed.target_type == 'entity': + # Fuzzy match entity by name + entity_id = self.entity_cache.find_entity_by_name(parsed.target) + + if not entity_id: + raise ValueError( + f"Entity '{parsed.target}' not found. " + f"Try being more specific or check the entity name." + ) + + logger.info(f"Resolved entity '{parsed.target}' to {entity_id}") + return [entity_id] + + else: + raise ValueError(f"Unknown target type: {parsed.target_type}") + + async def _parse_command_fallback(self, command: str) -> Optional[Dict[str, Any]]: + """ + Fallback keyword-based command parser (used when LLM fails). + + Args: + command: Natural language command + + Returns: + Dict with 'tool', 'arguments', and optional metadata + None if parsing fails + + Example: + Input: "turn off the hall lights" + Output: { + "tool": "turn_off", + "arguments": {"entity_id": "light.hall_light"}, + "friendly_name": "Hall Light", + "action": "turn_off" + } + """ + logger.debug("Using fallback keyword-based parsing") + command_lower = command.lower().strip() + + # Determine action + tool = None + if any(word in command_lower for word in ['turn off', 'off', 'disable']): + tool = 'turn_off' + action_desc = 'turned off' + elif any(word in command_lower for word in ['turn on', 'on', 'enable']): + tool = 'turn_on' + action_desc = 'turned on' + elif 'toggle' in command_lower: + tool = 'toggle' + action_desc = 'toggled' + else: + logger.warning(f"Unknown action in command: {command}") + return None + + # Extract entity name from command + entity_query = command_lower + for action_word in ['turn off', 'turn on', 'toggle', 'off', 'on', 'the']: + entity_query = entity_query.replace(action_word, '').strip() + + logger.info(f"Searching for entity: '{entity_query}'") + + # Return placeholder (this will work if entity ID matches pattern) + return { + "tool": tool, + "arguments": { + "entity_id": f"light.{entity_query.replace(' ', '_')}" + }, + "friendly_name": entity_query.title(), + "action_desc": action_desc + } + + async def _parse_command_hybrid(self, command: str) -> Optional['ParsedCommand']: + """ + Hybrid command parser: Try LLM first, fallback to keywords. + + This provides the best of both worlds: + - LLM parsing for complex area-based and natural commands + - Keyword fallback for reliability when LLM fails or times out + + Args: + command: Natural language command + + Returns: + ParsedCommand if successful, None otherwise + + Example: + >>> await self._parse_command_hybrid("turn off study lights") + ParsedCommand(action="turn_off", target_type="area", target="study", ...) + """ + import asyncio + from .command_parser import ParsedCommand + + # Try LLM parsing with timeout + try: + logger.debug("Attempting LLM-based command parsing...") + parsed = await asyncio.wait_for( + self._parse_command_with_llm(command), + timeout=5.0 + ) + + if parsed: + logger.info("LLM parsing succeeded") + return parsed + else: + logger.warning("LLM parsing returned None, falling back to keywords") + + except asyncio.TimeoutError: + logger.warning("LLM parsing timed out (>5s), falling back to keywords") + except Exception as e: + logger.warning(f"LLM parsing failed: {e}, falling back to keywords") + + # Fallback to keyword-based parsing + try: + logger.debug("Using fallback keyword parsing...") + fallback_result = await self._parse_command_fallback(command) + + if not fallback_result: + return None + + # Convert fallback format to ParsedCommand + # Extract entity_id from arguments + entity_id = fallback_result['arguments'].get('entity_id', '') + entity_name = entity_id.split('.', 1)[1] if '.' in entity_id else entity_id + + # Simple heuristic: assume it's targeting a single entity + parsed = ParsedCommand( + action=fallback_result['tool'], + target_type='entity', + target=entity_name.replace('_', ' '), + entity_type=None, + parameters={} + ) + + logger.info("Fallback parsing succeeded") + return parsed + + except Exception as e: + logger.error(f"Fallback parsing failed: {e}", exc_info=True) + return None diff --git a/backends/advanced/src/advanced_omi_backend/plugins/router.py b/backends/advanced/src/advanced_omi_backend/plugins/router.py new file mode 100644 index 00000000..21b82eb8 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/router.py @@ -0,0 +1,227 @@ +""" +Plugin routing system for multi-level plugin architecture. + +Routes pipeline events to appropriate plugins based on access level and triggers. +""" + +import logging +import re +import string +from typing import Dict, List, Optional + +from .base import BasePlugin, PluginContext, PluginResult + +logger = logging.getLogger(__name__) + + +def normalize_text_for_wake_word(text: str) -> str: + """ + Normalize text for wake word matching. + - Lowercase + - Replace punctuation with spaces + - Collapse multiple spaces to single space + - Strip leading/trailing whitespace + + Example: + "Hey, Vivi!" -> "hey vivi" + "HEY VIVI" -> "hey vivi" + "Hey-Vivi" -> "hey vivi" + """ + # Lowercase + text = text.lower() + # Replace punctuation with spaces (instead of removing, to preserve word boundaries) + text = text.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))) + # Normalize whitespace (collapse multiple spaces to single space) + text = re.sub(r'\s+', ' ', text) + # Strip leading/trailing whitespace + return text.strip() + + +def extract_command_after_wake_word(transcript: str, wake_word: str) -> str: + """ + Intelligently extract command after wake word in original transcript. + + Handles punctuation and spacing variations by creating a flexible regex pattern. + + Example: + transcript: "Hey, Vivi, turn off lights" + wake_word: "hey vivi" + -> extracts: "turn off lights" + + Args: + transcript: Original transcript text with punctuation + wake_word: Configured wake word (will be normalized) + + Returns: + Command text after wake word, or full transcript if wake word boundary not found + """ + # Split wake word into parts (normalized) + wake_word_parts = normalize_text_for_wake_word(wake_word).split() + + if not wake_word_parts: + return transcript.strip() + + # Create regex pattern that allows punctuation/whitespace between parts + # Example: "hey" + "vivi" -> r"hey[\s,.\-!?]*vivi[\s,.\-!?]*" + # The pattern matches the wake word parts with optional punctuation/whitespace between and after + pattern_parts = [re.escape(part) for part in wake_word_parts] + # Allow optional punctuation/whitespace between parts + pattern = r'[\s,.\-!?;:]*'.join(pattern_parts) + # Add trailing punctuation/whitespace consumption after last wake word part + pattern = '^' + pattern + r'[\s,.\-!?;:]*' + + # Try to match wake word at start of transcript (case-insensitive) + match = re.match(pattern, transcript, re.IGNORECASE) + + if match: + # Extract everything after the matched wake word (including trailing punctuation) + command = transcript[match.end():].strip() + return command + else: + # Fallback: couldn't find wake word boundary, return full transcript + logger.warning(f"Could not find wake word boundary for '{wake_word}' in '{transcript}', using full transcript") + return transcript.strip() + + +class PluginRouter: + """Routes pipeline events to appropriate plugins based on event subscriptions""" + + def __init__(self): + self.plugins: Dict[str, BasePlugin] = {} + # Index plugins by event subscription for fast lookup + self._plugins_by_event: Dict[str, List[str]] = {} + + def register_plugin(self, plugin_id: str, plugin: BasePlugin): + """Register a plugin with the router""" + self.plugins[plugin_id] = plugin + + # Index by each event subscription + for event in plugin.subscriptions: + if event not in self._plugins_by_event: + self._plugins_by_event[event] = [] + self._plugins_by_event[event].append(plugin_id) + + logger.info(f"Registered plugin '{plugin_id}' for events: {plugin.subscriptions}") + + async def dispatch_event( + self, + event: str, + user_id: str, + data: Dict, + metadata: Optional[Dict] = None + ) -> List[PluginResult]: + """ + Dispatch event to all subscribed plugins. + + Args: + event: Event name (e.g., 'transcript.streaming', 'conversation.complete') + user_id: User ID for context + data: Event-specific data + metadata: Optional metadata + + Returns: + List of plugin results + """ + results = [] + + # Get plugins subscribed to this event + plugin_ids = self._plugins_by_event.get(event, []) + + for plugin_id in plugin_ids: + plugin = self.plugins[plugin_id] + + if not plugin.enabled: + continue + + # Check trigger condition (wake_word, etc.) + if not await self._should_trigger(plugin, data): + continue + + # Execute plugin + try: + context = PluginContext( + user_id=user_id, + event=event, + data=data, + metadata=metadata or {} + ) + + result = await self._execute_plugin(plugin, event, context) + + if result: + results.append(result) + + # If plugin says stop processing, break + if not result.should_continue: + logger.info(f"Plugin '{plugin_id}' stopped further processing") + break + + except Exception as e: + logger.error(f"Error executing plugin '{plugin_id}': {e}", exc_info=True) + + return results + + async def _should_trigger(self, plugin: BasePlugin, data: Dict) -> bool: + """Check if plugin should be triggered based on trigger configuration""" + trigger_type = plugin.trigger.get('type', 'always') + + if trigger_type == 'always': + return True + + elif trigger_type == 'wake_word': + # Normalize transcript for matching (handles punctuation and spacing) + transcript = data.get('transcript', '') + normalized_transcript = normalize_text_for_wake_word(transcript) + + # Support both singular 'wake_word' and plural 'wake_words' (list) + wake_words = plugin.trigger.get('wake_words', []) + if not wake_words: + # Fallback to singular wake_word for backward compatibility + wake_word = plugin.trigger.get('wake_word', '') + if wake_word: + wake_words = [wake_word] + + # Check if transcript starts with any wake word (after normalization) + for wake_word in wake_words: + normalized_wake_word = normalize_text_for_wake_word(wake_word) + if normalized_wake_word and normalized_transcript.startswith(normalized_wake_word): + # Smart extraction: find where wake word actually ends in original text + command = extract_command_after_wake_word(transcript, wake_word) + data['command'] = command + data['original_transcript'] = transcript + logger.debug(f"Wake word '{wake_word}' detected. Original: '{transcript}', Command: '{command}'") + return True + + return False + + elif trigger_type == 'conditional': + # Future: Custom condition checking + return True + + return False + + async def _execute_plugin( + self, + plugin: BasePlugin, + event: str, + context: PluginContext + ) -> Optional[PluginResult]: + """Execute plugin method for specified event""" + # Map events to plugin callback methods + if event.startswith('transcript.'): + return await plugin.on_transcript(context) + elif event.startswith('conversation.'): + return await plugin.on_conversation_complete(context) + elif event.startswith('memory.'): + return await plugin.on_memory_processed(context) + + return None + + async def cleanup_all(self): + """Clean up all registered plugins""" + for plugin_id, plugin in self.plugins.items(): + try: + await plugin.cleanup() + logger.info(f"Cleaned up plugin '{plugin_id}'") + except Exception as e: + logger.error(f"Error cleaning up plugin '{plugin_id}': {e}") diff --git a/backends/advanced/src/advanced_omi_backend/plugins/test_event/__init__.py b/backends/advanced/src/advanced_omi_backend/plugins/test_event/__init__.py new file mode 100644 index 00000000..5f3f2ecf --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/test_event/__init__.py @@ -0,0 +1,5 @@ +"""Test Event Plugin for integration testing""" + +from .plugin import TestEventPlugin + +__all__ = ['TestEventPlugin'] diff --git a/backends/advanced/src/advanced_omi_backend/plugins/test_event/event_storage.py b/backends/advanced/src/advanced_omi_backend/plugins/test_event/event_storage.py new file mode 100644 index 00000000..00bc674d --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/test_event/event_storage.py @@ -0,0 +1,297 @@ +""" +Event storage module for test plugin using SQLite. + +Provides async SQLite operations for logging and querying plugin events. +""" +import json +import logging +import os +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + +import aiosqlite + +logger = logging.getLogger(__name__) + + +class EventStorage: + """SQLite-based event storage for test plugin""" + + def __init__(self, db_path: str = "/app/debug/test_plugin_events.db"): + self.db_path = db_path + self.db: Optional[aiosqlite.Connection] = None + + async def initialize(self): + """Initialize database and create tables""" + # Ensure directory exists + logger.info(f"🔍 DEBUG: Initializing event storage with db_path={self.db_path}") + + db_dir = Path(self.db_path).parent + logger.info(f"🔍 DEBUG: Database directory: {db_dir}") + logger.info(f"🔍 DEBUG: Directory exists before mkdir: {db_dir.exists()}") + + try: + db_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"🔍 DEBUG: Directory created/verified: {db_dir}") + logger.info(f"🔍 DEBUG: Directory permissions: {oct(db_dir.stat().st_mode)}") + except Exception as e: + logger.error(f"🔍 DEBUG: Failed to create directory: {e}") + raise + + logger.info(f"🔍 DEBUG: Attempting to connect to SQLite database...") + try: + self.db = await aiosqlite.connect(self.db_path) + logger.info(f"🔍 DEBUG: Successfully connected to database") + + # Enable WAL mode for better concurrent access (allows concurrent reads/writes) + # This fixes the "readonly database" error when Robot tests access from host + await self.db.execute("PRAGMA journal_mode=WAL") + await self.db.execute("PRAGMA busy_timeout=5000") # Wait up to 5s for locks + logger.info(f"✓ Enabled WAL mode for concurrent access") + + # Set file permissions to 666 so host user can write (container runs as root) + # Robot tests run as host user and need write access to the database + try: + os.chmod(self.db_path, 0o666) + # Also set permissions on WAL and SHM files if they exist + wal_file = f"{self.db_path}-wal" + shm_file = f"{self.db_path}-shm" + if os.path.exists(wal_file): + os.chmod(wal_file, 0o666) + if os.path.exists(shm_file): + os.chmod(shm_file, 0o666) + logger.info(f"✓ Set database file permissions to 666 for host access") + except Exception as perm_error: + logger.warning(f"Could not set database permissions: {perm_error}") + + except Exception as e: + logger.error(f"🔍 DEBUG: Failed to connect to database: {e}") + logger.error(f"🔍 DEBUG: Database file exists: {Path(self.db_path).exists()}") + if Path(self.db_path).exists(): + logger.error(f"🔍 DEBUG: Database file permissions: {oct(Path(self.db_path).stat().st_mode)}") + raise + + # Create events table + await self.db.execute(""" + CREATE TABLE IF NOT EXISTS plugin_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp DATETIME NOT NULL, + event TEXT NOT NULL, + user_id TEXT NOT NULL, + data TEXT NOT NULL, + metadata TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create index for faster queries + await self.db.execute(""" + CREATE INDEX IF NOT EXISTS idx_event_type + ON plugin_events(event) + """) + + await self.db.execute(""" + CREATE INDEX IF NOT EXISTS idx_user_id + ON plugin_events(user_id) + """) + + await self.db.commit() + logger.info(f"Event storage initialized at {self.db_path}") + + async def log_event( + self, + event: str, + user_id: str, + data: Dict[str, Any], + metadata: Optional[Dict[str, Any]] = None + ) -> int: + """ + Log an event to the database. + + Args: + event: Event name (e.g., 'transcript.batch') + user_id: User ID from context + data: Event data dictionary + metadata: Optional metadata dictionary + + Returns: + Row ID of inserted event + """ + if not self.db: + raise RuntimeError("Event storage not initialized") + + timestamp = datetime.utcnow().isoformat() + data_json = json.dumps(data) + metadata_json = json.dumps(metadata) if metadata else None + + cursor = await self.db.execute( + """ + INSERT INTO plugin_events (timestamp, event, user_id, data, metadata) + VALUES (?, ?, ?, ?, ?) + """, + (timestamp, event, user_id, data_json, metadata_json) + ) + + await self.db.commit() + row_id = cursor.lastrowid + + logger.debug( + f"Logged event: {event} for user {user_id} (row_id={row_id})" + ) + + return row_id + + async def get_events_by_type(self, event: str) -> List[Dict[str, Any]]: + """ + Query events by event type. + + Args: + event: Event name to filter by + + Returns: + List of event dictionaries + """ + if not self.db: + raise RuntimeError("Event storage not initialized") + + cursor = await self.db.execute( + """ + SELECT id, timestamp, event, user_id, data, metadata, created_at + FROM plugin_events + WHERE event = ? + ORDER BY created_at DESC + """, + (event,) + ) + + rows = await cursor.fetchall() + return self._rows_to_dicts(rows) + + async def get_events_by_user(self, user_id: str) -> List[Dict[str, Any]]: + """ + Query events by user ID. + + Args: + user_id: User ID to filter by + + Returns: + List of event dictionaries + """ + if not self.db: + raise RuntimeError("Event storage not initialized") + + cursor = await self.db.execute( + """ + SELECT id, timestamp, event, user_id, data, metadata, created_at + FROM plugin_events + WHERE user_id = ? + ORDER BY created_at DESC + """, + (user_id,) + ) + + rows = await cursor.fetchall() + return self._rows_to_dicts(rows) + + async def get_all_events(self) -> List[Dict[str, Any]]: + """ + Get all logged events. + + Returns: + List of all event dictionaries + """ + if not self.db: + raise RuntimeError("Event storage not initialized") + + cursor = await self.db.execute( + """ + SELECT id, timestamp, event, user_id, data, metadata, created_at + FROM plugin_events + ORDER BY created_at DESC + """ + ) + + rows = await cursor.fetchall() + return self._rows_to_dicts(rows) + + async def clear_events(self) -> int: + """ + Clear all events from the database. + + Returns: + Number of rows deleted + """ + if not self.db: + raise RuntimeError("Event storage not initialized") + + cursor = await self.db.execute("DELETE FROM plugin_events") + await self.db.commit() + + deleted = cursor.rowcount + logger.info(f"Cleared {deleted} events from database") + + return deleted + + async def get_event_count(self, event: Optional[str] = None) -> int: + """ + Get count of events. + + Args: + event: Optional event type to filter by + + Returns: + Count of matching events + """ + if not self.db: + raise RuntimeError("Event storage not initialized") + + if event: + cursor = await self.db.execute( + "SELECT COUNT(*) FROM plugin_events WHERE event = ?", + (event,) + ) + else: + cursor = await self.db.execute( + "SELECT COUNT(*) FROM plugin_events" + ) + + row = await cursor.fetchone() + return row[0] if row else 0 + + def _rows_to_dicts(self, rows: List[tuple]) -> List[Dict[str, Any]]: + """ + Convert database rows to dictionaries. + + Args: + rows: List of database row tuples + + Returns: + List of event dictionaries + """ + events = [] + + for row in rows: + event_dict = { + 'id': row[0], + 'timestamp': row[1], + 'event': row[2], + 'user_id': row[3], + 'data': json.loads(row[4]) if row[4] else {}, + 'metadata': json.loads(row[5]) if row[5] else {}, + 'created_at': row[6] + } + + # Flatten data fields to top level for easier access in tests + if isinstance(event_dict['data'], dict): + event_dict.update(event_dict['data']) + + events.append(event_dict) + + return events + + async def cleanup(self): + """Close database connection""" + if self.db: + await self.db.close() + logger.info("Event storage connection closed") diff --git a/backends/advanced/src/advanced_omi_backend/plugins/test_event/plugin.py b/backends/advanced/src/advanced_omi_backend/plugins/test_event/plugin.py new file mode 100644 index 00000000..6b96e078 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/test_event/plugin.py @@ -0,0 +1,221 @@ +""" +Test Event Plugin + +Logs all plugin events to SQLite database for integration testing. +Subscribes to all event types to verify event dispatch system works correctly. +""" +import logging +from typing import Any, Dict, List, Optional + +from advanced_omi_backend.plugins.base import BasePlugin, PluginContext, PluginResult +from .event_storage import EventStorage + +logger = logging.getLogger(__name__) + + +class TestEventPlugin(BasePlugin): + """ + Test plugin that logs all events for verification. + + Subscribes to: + - transcript.streaming: Real-time WebSocket transcription + - transcript.batch: File upload batch transcription + - conversation.complete: Conversation processing complete + - memory.processed: Memory extraction complete + + All events are logged to SQLite database with full context for test verification. + """ + + SUPPORTED_ACCESS_LEVELS: List[str] = ['transcript', 'conversation', 'memory'] + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + self.storage = EventStorage( + db_path=config.get('db_path', '/app/debug/test_plugin_events.db') + ) + self.event_count = 0 + + async def initialize(self): + """Initialize the test plugin and event storage""" + try: + await self.storage.initialize() + logger.info("✅ Test Event Plugin initialized successfully") + except Exception as e: + logger.error(f"❌ Failed to initialize Test Event Plugin: {e}") + raise + + async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]: + """ + Log transcript events (streaming or batch). + + Context data contains: + - transcript: str - The transcript text + - conversation_id: str - Conversation ID + - For streaming: is_final, confidence, words, segments + - For batch: word_count, segments + + Args: + context: Plugin context with event data + + Returns: + PluginResult indicating success + """ + try: + # Determine which transcript event this is based on context.event + event_type = context.event # 'transcript.streaming' or 'transcript.batch' + + # Extract key data fields + transcript = context.data.get('transcript', '') + conversation_id = context.data.get('conversation_id', 'unknown') + + # Log to storage + row_id = await self.storage.log_event( + event=event_type, + user_id=context.user_id, + data=context.data, + metadata=context.metadata + ) + + self.event_count += 1 + + logger.info( + f"📝 Logged {event_type} event (row_id={row_id}): " + f"user={context.user_id}, " + f"conversation={conversation_id}, " + f"transcript='{transcript[:50]}...'" + ) + + return PluginResult( + success=True, + message=f"Transcript event logged (row_id={row_id})", + should_continue=True # Don't block normal processing + ) + + except Exception as e: + logger.error(f"Error logging transcript event: {e}", exc_info=True) + return PluginResult( + success=False, + message=f"Failed to log transcript event: {e}", + should_continue=True + ) + + async def on_conversation_complete(self, context: PluginContext) -> Optional[PluginResult]: + """ + Log conversation completion events. + + Context data contains: + - conversation: dict - Full conversation data + - transcript: str - Complete conversation transcript + - duration: float - Conversation duration + - conversation_id: str - Conversation identifier + + Args: + context: Plugin context with event data + + Returns: + PluginResult indicating success + """ + try: + conversation_id = context.data.get('conversation_id', 'unknown') + duration = context.data.get('duration', 0) + + # Log to storage + row_id = await self.storage.log_event( + event=context.event, # 'conversation.complete' + user_id=context.user_id, + data=context.data, + metadata=context.metadata + ) + + self.event_count += 1 + + logger.info( + f"📝 Logged conversation.complete event (row_id={row_id}): " + f"user={context.user_id}, " + f"conversation={conversation_id}, " + f"duration={duration:.2f}s" + ) + + return PluginResult( + success=True, + message=f"Conversation event logged (row_id={row_id})", + should_continue=True + ) + + except Exception as e: + logger.error(f"Error logging conversation event: {e}", exc_info=True) + return PluginResult( + success=False, + message=f"Failed to log conversation event: {e}", + should_continue=True + ) + + async def on_memory_processed(self, context: PluginContext) -> Optional[PluginResult]: + """ + Log memory processing events. + + Context data contains: + - memories: list - Extracted memories + - conversation: dict - Source conversation + - memory_count: int - Number of memories created + - conversation_id: str - Conversation identifier + + Metadata contains: + - processing_time: float - Time spent processing + - memory_provider: str - Provider name + + Args: + context: Plugin context with event data + + Returns: + PluginResult indicating success + """ + try: + conversation_id = context.data.get('conversation_id', 'unknown') + memory_count = context.data.get('memory_count', 0) + memory_provider = context.metadata.get('memory_provider', 'unknown') + processing_time = context.metadata.get('processing_time', 0) + + # Log to storage + row_id = await self.storage.log_event( + event=context.event, # 'memory.processed' + user_id=context.user_id, + data=context.data, + metadata=context.metadata + ) + + self.event_count += 1 + + logger.info( + f"📝 Logged memory.processed event (row_id={row_id}): " + f"user={context.user_id}, " + f"conversation={conversation_id}, " + f"memory_count={memory_count}, " + f"provider={memory_provider}, " + f"processing_time={processing_time:.2f}s" + ) + + return PluginResult( + success=True, + message=f"Memory event logged (row_id={row_id})", + should_continue=True + ) + + except Exception as e: + logger.error(f"Error logging memory event: {e}", exc_info=True) + return PluginResult( + success=False, + message=f"Failed to log memory event: {e}", + should_continue=True + ) + + async def cleanup(self): + """Clean up plugin resources""" + try: + logger.info( + f"🧹 Test Event Plugin shutting down. " + f"Logged {self.event_count} total events" + ) + await self.storage.cleanup() + except Exception as e: + logger.error(f"Error during test plugin cleanup: {e}") diff --git a/backends/advanced/src/advanced_omi_backend/routers/api_router.py b/backends/advanced/src/advanced_omi_backend/routers/api_router.py index 9e761f8e..80c03eae 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/api_router.py +++ b/backends/advanced/src/advanced_omi_backend/routers/api_router.py @@ -6,6 +6,7 @@ """ import logging +import os from fastapi import APIRouter @@ -40,5 +41,13 @@ router.include_router(queue_router) router.include_router(health_router) # Also include under /api for frontend compatibility +# Conditionally include test routes (only in test environments) +if os.getenv("DEBUG_DIR"): + try: + from .modules.test_routes import router as test_router + router.include_router(test_router) + logger.info("✅ Test routes loaded (test environment detected)") + except Exception as e: + logger.error(f"Error loading test routes: {e}", exc_info=True) logger.info("API router initialized with all sub-modules") diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/audio_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/audio_routes.py index 056e7667..58a33ff5 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/audio_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/audio_routes.py @@ -37,7 +37,6 @@ async def upload_audio_from_drive_folder( @router.get("/get_audio/{conversation_id}") async def get_conversation_audio( conversation_id: str, - cropped: bool = Query(default=False, description="Serve cropped (speech-only) audio instead of original"), token: Optional[str] = Query(default=None, description="JWT token for audio element access"), current_user: Optional[User] = Depends(current_active_user_optional), ): @@ -52,7 +51,6 @@ async def get_conversation_audio( Args: conversation_id: The conversation ID - cropped: If True, serve cropped audio; if False, serve original audio token: Optional JWT token as query param (for audio elements) current_user: Authenticated user (from header) @@ -75,8 +73,7 @@ async def get_conversation_audio( try: file_path = await audio_controller.get_conversation_audio_path( conversation_id=conversation_id, - user=current_user, - cropped=cropped + user=current_user ) except ValueError as e: # Map ValueError messages to appropriate HTTP status codes diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py index 8da0f5b0..2fc05425 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py @@ -42,14 +42,6 @@ async def get_conversation_detail( return await conversation_controller.get_conversation(conversation_id, current_user) -@router.get("/{audio_uuid}/cropped") -async def get_cropped_audio_info( - audio_uuid: str, current_user: User = Depends(current_active_user) -): - """Get cropped audio information for a conversation. Users can only access their own conversations.""" - return await audio_controller.get_cropped_audio_info(audio_uuid, current_user) - - # New reprocessing endpoints @router.post("/{conversation_id}/reprocess-transcript") async def reprocess_transcript( diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/health_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/health_routes.py index d7a62ba9..96ee72fe 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/health_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/health_routes.py @@ -139,7 +139,6 @@ async def health_check(): "chunk_dir": str(os.getenv("CHUNK_DIR", "./audio_chunks")), "active_clients": get_client_manager().get_client_count(), "new_conversation_timeout_minutes": float(os.getenv("NEW_CONVERSATION_TIMEOUT_MINUTES", "1.5")), - "audio_cropping_enabled": os.getenv("AUDIO_CROPPING_ENABLED", "true").lower() == "true", "llm_provider": (_llm_def.model_provider if _llm_def else None), "llm_model": (_llm_def.model_name if _llm_def else None), "llm_base_url": (_llm_def.model_url if _llm_def else None), diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py index 2da3767b..38bafa9a 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py @@ -635,6 +635,12 @@ async def flush_all_jobs( # Try to fetch the job job = Job.fetch(job_id, connection=redis_conn) + # Skip session-level jobs (e.g., speech_detection, audio_persistence) + # These run for the entire session and should not be killed by test cleanup + if job.meta and job.meta.get("session_level"): + logger.info(f"Skipping session-level job {job_id} ({job.description})") + continue + # Handle running jobs differently to avoid worker deadlock if job.is_started: # Send stop command to worker instead of canceling/deleting immediately diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py index 0c261675..93e94817 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py @@ -176,6 +176,53 @@ async def validate_chat_config( raise HTTPException(status_code=500, detail=str(e)) +# Plugin Configuration Management Endpoints + +@router.get("/admin/plugins/config", response_class=Response) +async def get_plugins_config(current_user: User = Depends(current_superuser)): + """Get plugins configuration as YAML. Admin only.""" + try: + yaml_content = await system_controller.get_plugins_config_yaml() + return Response(content=yaml_content, media_type="text/plain") + except Exception as e: + logger.error(f"Failed to get plugins config: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/admin/plugins/config") +async def save_plugins_config( + request: Request, + current_user: User = Depends(current_superuser) +): + """Save plugins configuration from YAML. Admin only.""" + try: + yaml_content = await request.body() + yaml_str = yaml_content.decode('utf-8') + result = await system_controller.save_plugins_config_yaml(yaml_str) + return JSONResponse(content=result) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Failed to save plugins config: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/admin/plugins/config/validate") +async def validate_plugins_config( + request: Request, + current_user: User = Depends(current_superuser) +): + """Validate plugins configuration YAML. Admin only.""" + try: + yaml_content = await request.body() + yaml_str = yaml_content.decode('utf-8') + result = await system_controller.validate_plugins_config_yaml(yaml_str) + return JSONResponse(content=result) + except Exception as e: + logger.error(f"Failed to validate plugins config: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/streaming/status") async def get_streaming_status(request: Request, current_user: User = Depends(current_superuser)): """Get status of active streaming sessions and Redis Streams health. Admin only.""" diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/test_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/test_routes.py new file mode 100644 index 00000000..6255b6d6 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/test_routes.py @@ -0,0 +1,121 @@ +""" +Test-only API routes for integration testing. + +These routes are ONLY loaded when DEBUG_DIR environment variable is set, +which happens in test environments. They should never be available in production. +""" + +import logging +from typing import Optional +from fastapi import APIRouter, HTTPException + +from advanced_omi_backend.services.plugin_service import get_plugin_router + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/test", tags=["testing"]) + + +@router.delete("/plugins/events") +async def clear_test_plugin_events(): + """ + Clear all test plugin events. + + This endpoint is ONLY available in test environments and provides a clean + way to reset plugin event state between tests without direct database access. + + Returns: + dict: Confirmation message with number of events cleared + """ + plugin_router = get_plugin_router() + + if not plugin_router: + return {"message": "No plugin router initialized", "events_cleared": 0} + + total_cleared = 0 + + # Clear events from all plugins that have storage + for plugin_id, plugin in plugin_router.plugins.items(): + if hasattr(plugin, 'storage') and plugin.storage: + try: + cleared = await plugin.storage.clear_events() + total_cleared += cleared + logger.info(f"Cleared {cleared} events from plugin '{plugin_id}'") + except Exception as e: + logger.error(f"Error clearing events from plugin '{plugin_id}': {e}") + + return { + "message": "Test plugin events cleared", + "events_cleared": total_cleared + } + + +@router.get("/plugins/events/count") +async def get_test_plugin_event_count(event_type: Optional[str] = None): + """ + Get count of test plugin events. + + Args: + event_type: Optional event type to filter by (e.g., 'transcript.batch') + + Returns: + dict: Event count and event type filter + """ + plugin_router = get_plugin_router() + + if not plugin_router: + return {"count": 0, "event_type": event_type, "message": "No plugin router initialized"} + + # Get count from first plugin with storage (usually test_event plugin) + for plugin_id, plugin in plugin_router.plugins.items(): + if hasattr(plugin, 'storage') and plugin.storage: + try: + count = await plugin.storage.get_event_count(event_type) + return { + "count": count, + "event_type": event_type, + "plugin_id": plugin_id + } + except Exception as e: + logger.error(f"Error getting event count from plugin '{plugin_id}': {e}") + raise HTTPException(status_code=500, detail=str(e)) + + return {"count": 0, "event_type": event_type, "message": "No plugin with storage found"} + + +@router.get("/plugins/events") +async def get_test_plugin_events(event_type: Optional[str] = None): + """ + Get test plugin events. + + Args: + event_type: Optional event type to filter by + + Returns: + dict: List of events + """ + plugin_router = get_plugin_router() + + if not plugin_router: + return {"events": [], "message": "No plugin router initialized"} + + # Get events from first plugin with storage + for plugin_id, plugin in plugin_router.plugins.items(): + if hasattr(plugin, 'storage') and plugin.storage: + try: + if event_type: + events = await plugin.storage.get_events_by_type(event_type) + else: + events = await plugin.storage.get_all_events() + + return { + "events": events, + "count": len(events), + "event_type": event_type, + "plugin_id": plugin_id + } + except Exception as e: + logger.error(f"Error getting events from plugin '{plugin_id}': {e}") + raise HTTPException(status_code=500, detail=str(e)) + + return {"events": [], "message": "No plugin with storage found"} diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/websocket_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/websocket_routes.py index d9754a87..2671d7f6 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/websocket_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/websocket_routes.py @@ -18,21 +18,34 @@ # Create router router = APIRouter(tags=["websocket"]) -@router.websocket("/ws_omi") -async def ws_endpoint_omi( +@router.websocket("/ws") +async def ws_endpoint( ws: WebSocket, + codec: str = Query("pcm"), token: Optional[str] = Query(None), device_name: Optional[str] = Query(None), ): - """Accepts WebSocket connections with Wyoming protocol, decodes OMI Opus audio, and processes per-client.""" - await handle_omi_websocket(ws, token, device_name) - - -@router.websocket("/ws_pcm") -async def ws_endpoint_pcm( - ws: WebSocket, - token: Optional[str] = Query(None), - device_name: Optional[str] = Query(None) -): - """Accepts WebSocket connections, processes PCM audio per-client.""" - await handle_pcm_websocket(ws, token, device_name) \ No newline at end of file + """ + WebSocket endpoint for audio streaming with multiple codec support. + + Args: + codec: Audio codec (pcm, opus). Default: pcm + token: JWT auth token + device_name: Device identifier + + Examples: + /ws?codec=pcm&token=xxx&device_name=laptop + /ws?codec=opus&token=xxx&device_name=omi-device + """ + # Validate and normalize codec + codec = codec.lower() + if codec not in ["pcm", "opus"]: + logger.warning(f"Unsupported codec requested: {codec}") + await ws.close(code=1008, reason=f"Unsupported codec: {codec}. Supported: pcm, opus") + return + + # Route to appropriate handler + if codec == "opus": + await handle_omi_websocket(ws, token, device_name) + else: + await handle_pcm_websocket(ws, token, device_name) \ No newline at end of file diff --git a/backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py b/backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py index 26b985ab..f31f7453 100644 --- a/backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py +++ b/backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py @@ -49,8 +49,8 @@ async def get_session_results(self, session_id: str) -> list[dict]: "text": fields[b"text"].decode(), "confidence": float(fields[b"confidence"].decode()), "provider": fields[b"provider"].decode(), - "chunk_id": fields[b"chunk_id"].decode(), - "processing_time": float(fields[b"processing_time"].decode()), + "chunk_id": fields.get(b"chunk_id", b"unknown").decode(), # Handle missing chunk_id gracefully + "processing_time": float(fields.get(b"processing_time", b"0.0").decode()), "timestamp": float(fields[b"timestamp"].decode()), } @@ -82,8 +82,6 @@ async def get_combined_results(self, session_id: str) -> dict: """ Get all transcription results combined into a single aggregated result. - This is what an aggregator should do - combine multiple chunks into one. - Args: session_id: Session identifier @@ -109,43 +107,24 @@ async def get_combined_results(self, session_id: str) -> dict: "provider": None } - # Combine text - full_text = " ".join([r.get("text", "") for r in results if r.get("text")]) - - # Combine words - all_words = [] - for r in results: - if "words" in r and r["words"]: - all_words.extend(r["words"]) - - # Combine segments - all_segments = [] - for r in results: - if "segments" in r and r["segments"]: - all_segments.extend(r["segments"]) - - # Sort segments by start time - all_segments.sort(key=lambda s: s.get("start", 0.0)) - - # Calculate average confidence - confidences = [r.get("confidence", 0.0) for r in results] - avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0 - - # Get provider (assume all chunks from same provider) - provider = results[0].get("provider") if results else None + # For streaming providers (Deepgram), use ONLY the latest final result + # Each is_final=true result supersedes interim results for the same speech segment + # The latest result contains the most accurate transcription with best timing/confidence + latest_result = results[-1] combined = { - "text": full_text, - "words": all_words, - "segments": all_segments, - "chunk_count": len(results), - "total_confidence": avg_confidence, - "provider": provider + "text": latest_result.get("text", ""), + "words": latest_result.get("words", []), + "segments": latest_result.get("segments", []), + "chunk_count": len(results), # Track how many results were received + "total_confidence": latest_result.get("confidence", 0.0), + "provider": latest_result.get("provider") } - logger.debug( - f"📦 Combined {len(results)} chunks for session {session_id}: " - f"{len(full_text)} chars, {len(all_words)} words, {len(all_segments)} segments" + logger.info( + f"🔤 TRANSCRIPT [AGGREGATOR] session={session_id}, " + f"total_results={len(results)}, words={len(combined['words'])}, " + f"text=\"{combined['text']}\"" ) return combined @@ -188,7 +167,7 @@ async def get_realtime_results( "text": fields[b"text"].decode(), "confidence": float(fields[b"confidence"].decode()), "provider": fields[b"provider"].decode(), - "chunk_id": fields[b"chunk_id"].decode(), + "chunk_id": fields.get(b"chunk_id", b"unknown").decode(), # Handle missing chunk_id gracefully } # Optional fields diff --git a/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py b/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py index 8ae0646b..aeb12e02 100644 --- a/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py +++ b/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py @@ -11,8 +11,6 @@ import redis.asyncio as redis from redis import exceptions as redis_exceptions -from redis.asyncio.lock import Lock - logger = logging.getLogger(__name__) @@ -28,8 +26,8 @@ def __init__(self, provider_name: str, redis_client: redis.Redis, buffer_chunks: """ Initialize consumer. - Dynamically discovers all audio:stream:* streams and claims them using Redis locks - to ensure exclusive processing (one consumer per stream). + Dynamically discovers all audio:stream:* streams and uses Redis consumer groups + for fan-out processing (multiple worker types can process the same stream). Args: provider_name: Provider name (e.g., "deepgram", "parakeet") @@ -47,9 +45,8 @@ def __init__(self, provider_name: str, redis_client: redis.Redis, buffer_chunks: self.running = False - # Dynamic stream discovery with exclusive locks + # Dynamic stream discovery - consumer groups handle fan-out self.active_streams = {} # {stream_name: True} - self.stream_locks = {} # {stream_name: Lock object} # Buffering: accumulate chunks per session self.session_buffers = {} # {session_id: {"chunks": [], "chunk_ids": [], "sample_rate": int}} @@ -73,59 +70,6 @@ async def discover_streams(self) -> list[str]: return streams - async def try_claim_stream(self, stream_name: str) -> bool: - """ - Try to claim exclusive ownership of a stream using Redis lock. - - Args: - stream_name: Stream to claim - - Returns: - True if lock acquired, False otherwise - """ - lock_key = f"consumer:lock:{stream_name}" - - # Create lock with 30 second timeout (will be renewed) - lock = Lock( - self.redis_client, - lock_key, - timeout=30, - blocking=False # Non-blocking - ) - - acquired = await lock.acquire(blocking=False) - - if acquired: - self.stream_locks[stream_name] = lock - logger.info(f"🔒 Claimed stream: {stream_name}") - return True - else: - logger.debug(f"⏭️ Stream already claimed by another consumer: {stream_name}") - return False - - async def release_stream(self, stream_name: str): - """Release lock on a stream.""" - if stream_name in self.stream_locks: - try: - await self.stream_locks[stream_name].release() - logger.info(f"🔓 Released stream: {stream_name}") - except Exception as e: - logger.warning(f"Failed to release lock for {stream_name}: {e}") - finally: - del self.stream_locks[stream_name] - - async def renew_stream_locks(self): - """Renew locks on all claimed streams.""" - for stream_name, lock in list(self.stream_locks.items()): - try: - await lock.reacquire() - except Exception as e: - logger.warning(f"Failed to renew lock for {stream_name}: {e}") - # Lock expired, remove from our list - del self.stream_locks[stream_name] - if stream_name in self.active_streams: - del self.active_streams[stream_name] - async def setup_consumer_group(self, stream_name: str): """Create consumer group if it doesn't exist.""" # Create consumer group (ignore error if already exists) @@ -257,14 +201,12 @@ async def transcribe_audio(self, audio_data: bytes, sample_rate: int) -> dict: pass async def start_consuming(self): - """Discover and consume from multiple streams with exclusive locking.""" + """Discover and consume from multiple streams using Redis consumer groups.""" self.running = True - logger.info(f"➡️ Starting dynamic stream consumer: {self.consumer_name}") + logger.info(f"➡️ Starting dynamic stream consumer: {self.consumer_name} (group: {self.group_name})") last_discovery = 0 - last_lock_renewal = 0 discovery_interval = 10 # Discover new streams every 10 seconds - lock_renewal_interval = 15 # Renew locks every 15 seconds while self.running: try: @@ -277,20 +219,13 @@ async def start_consuming(self): for stream_name in discovered: if stream_name not in self.active_streams: - # Try to claim this stream - if await self.try_claim_stream(stream_name): - # Setup consumer group for this stream - await self.setup_consumer_group(stream_name) - self.active_streams[stream_name] = True - logger.info(f"✅ Now consuming from {stream_name}") + # Setup consumer group for this stream (no manual lock needed) + await self.setup_consumer_group(stream_name) + self.active_streams[stream_name] = True + logger.info(f"✅ Now consuming from {stream_name} (group: {self.group_name})") last_discovery = current_time - # Periodically renew locks - if current_time - last_lock_renewal > lock_renewal_interval: - await self.renew_stream_locks() - last_lock_renewal = current_time - # Read from all active streams if not self.active_streams: # No streams claimed yet, wait and retry @@ -326,14 +261,6 @@ async def start_consuming(self): if stream_name in error_msg: logger.warning(f"➡️ [{self.consumer_name}] Stream {stream_name} was deleted, removing from active streams") - # Release the lock - lock_key = f"consumer:lock:{stream_name}" - try: - await self.redis_client.delete(lock_key) - logger.info(f"🔓 Released lock for deleted stream: {stream_name}") - except: - pass - # Remove from active streams del self.active_streams[stream_name] logger.info(f"➡️ [{self.consumer_name}] Removed {stream_name}, {len(self.active_streams)} streams remaining") @@ -419,9 +346,6 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st # Clean up session buffer del self.session_buffers[session_id] - # Release the consumer lock for this stream - await self.release_stream(stream_name) - # ACK the END message await self.redis_client.xack(stream_name, self.group_name, message_id) return diff --git a/backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py b/backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py index 66b0acf7..1fa06011 100644 --- a/backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py +++ b/backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py @@ -41,32 +41,57 @@ async def init_session( session_id: str, user_id: str, client_id: str, + user_email: str = "", + connection_id: str = "", mode: str = "streaming", provider: str = "deepgram" ): """ - Initialize session tracking metadata. + Initialize session tracking metadata in Redis. + + This is the SINGLE SOURCE OF TRUTH for session state. + All session metadata is stored here instead of in-memory ClientState. Args: - session_id: Session identifier - user_id: User identifier - client_id: Client identifier + session_id: Unique session identifier + user_id: User identifier (MongoDB ObjectId) + client_id: Client identifier (objectid_suffix-device_name) + user_email: User email for debugging/tracking + connection_id: WebSocket connection identifier mode: Processing mode (streaming/batch) - provider: Transcription provider ("deepgram", "mistral", etc.) + provider: Transcription provider from config.yml """ # Client-specific stream naming (one stream per client for isolation) stream_name = f"audio:stream:{client_id}" session_key = f"audio:session:{session_id}" await self.redis_client.hset(session_key, mapping={ + # User & Client tracking "user_id": user_id, + "user_email": user_email, "client_id": client_id, + "connection_id": connection_id, + + # Stream configuration "stream_name": stream_name, "provider": provider, "mode": mode, + + # Timestamps "started_at": str(time.time()), - "chunks_published": "0", "last_chunk_at": str(time.time()), + + # Counters + "chunks_published": "0", + + # Job tracking (populated by queue_controller when jobs start) + "speech_detection_job_id": "", + "audio_persistence_job_id": "", + + # Connection state + "websocket_connected": "true", + + # Session status "status": "active" }) @@ -134,6 +159,63 @@ async def send_session_end_signal(self, session_id: str): ) logger.info(f"📡 Sent end-of-session signal for {session_id} to {stream_name}") + async def get_session(self, session_id: str) -> dict: + """ + Get session metadata from Redis. + + Args: + session_id: Session identifier + + Returns: + Dictionary with session metadata, empty dict if not found + """ + session_key = f"audio:session:{session_id}" + session_data = await self.redis_client.hgetall(session_key) + + # Convert bytes to strings for easier handling + return {k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v + for k, v in session_data.items()} if session_data else {} + + async def update_session_job_ids( + self, + session_id: str, + speech_detection_job_id: str = None, + audio_persistence_job_id: str = None + ): + """ + Update job IDs in session metadata. + + Args: + session_id: Session identifier + speech_detection_job_id: Speech detection job ID (optional) + audio_persistence_job_id: Audio persistence job ID (optional) + """ + session_key = f"audio:session:{session_id}" + updates = {} + + if speech_detection_job_id: + updates["speech_detection_job_id"] = speech_detection_job_id + if audio_persistence_job_id: + updates["audio_persistence_job_id"] = audio_persistence_job_id + + if updates: + await self.redis_client.hset(session_key, mapping=updates) + logger.debug(f"📊 Updated job IDs for session {session_id}: {updates}") + + async def mark_websocket_disconnected(self, session_id: str): + """ + Mark session's websocket as disconnected. + + Args: + session_id: Session identifier + """ + session_key = f"audio:session:{session_id}" + await self.redis_client.hset(session_key, mapping={ + "websocket_connected": "false", + "disconnected_at": str(time.time()) + }) + logger.info(f"🔌 Marked websocket disconnected for session {session_id}") + async def finalize_session(self, session_id: str): """ Mark session as finalizing and clean up buffer. diff --git a/backends/advanced/src/advanced_omi_backend/services/plugin_service.py b/backends/advanced/src/advanced_omi_backend/services/plugin_service.py new file mode 100644 index 00000000..0dc693d6 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/services/plugin_service.py @@ -0,0 +1,175 @@ +"""Plugin service for accessing the global plugin router. + +This module provides singleton access to the plugin router, allowing +worker jobs to trigger plugins without accessing FastAPI app state directly. +""" + +import logging +import os +import re +from typing import Optional, Any +from pathlib import Path +import yaml + +from advanced_omi_backend.plugins import PluginRouter + +logger = logging.getLogger(__name__) + +# Global plugin router instance +_plugin_router: Optional[PluginRouter] = None + + +def expand_env_vars(value: Any) -> Any: + """ + Recursively expand environment variables in configuration values. + + Supports ${ENV_VAR} syntax. If the environment variable is not set, + the original placeholder is kept. + + Args: + value: Configuration value (can be str, dict, list, or other) + + Returns: + Value with environment variables expanded + + Examples: + >>> os.environ['MY_TOKEN'] = 'secret123' + >>> expand_env_vars('token: ${MY_TOKEN}') + 'token: secret123' + >>> expand_env_vars({'token': '${MY_TOKEN}'}) + {'token': 'secret123'} + """ + if isinstance(value, str): + # Pattern: ${ENV_VAR} or ${ENV_VAR:-default} + def replacer(match): + var_expr = match.group(1) + # Support default values: ${VAR:-default} + if ':-' in var_expr: + var_name, default = var_expr.split(':-', 1) + return os.environ.get(var_name.strip(), default.strip()) + else: + var_name = var_expr.strip() + env_value = os.environ.get(var_name) + if env_value is None: + logger.warning( + f"Environment variable '{var_name}' not found, " + f"keeping placeholder: ${{{var_name}}}" + ) + return match.group(0) # Keep original placeholder + return env_value + + return re.sub(r'\$\{([^}]+)\}', replacer, value) + + elif isinstance(value, dict): + return {k: expand_env_vars(v) for k, v in value.items()} + + elif isinstance(value, list): + return [expand_env_vars(item) for item in value] + + else: + return value + + +def get_plugin_router() -> Optional[PluginRouter]: + """Get the global plugin router instance. + + Returns: + Plugin router instance if initialized, None otherwise + """ + global _plugin_router + return _plugin_router + + +def set_plugin_router(router: PluginRouter) -> None: + """Set the global plugin router instance. + + This should be called during app initialization in app_factory.py. + + Args: + router: Initialized plugin router instance + """ + global _plugin_router + _plugin_router = router + logger.info("Plugin router registered with plugin service") + + +def init_plugin_router() -> Optional[PluginRouter]: + """Initialize the plugin router from configuration. + + This is called during app startup to create and register the plugin router. + + Returns: + Initialized plugin router, or None if no plugins configured + """ + global _plugin_router + + if _plugin_router is not None: + logger.warning("Plugin router already initialized") + return _plugin_router + + try: + _plugin_router = PluginRouter() + + # Load plugin configuration + plugins_yml = Path("/app/plugins.yml") + logger.info(f"🔍 Looking for plugins config at: {plugins_yml}") + logger.info(f"🔍 File exists: {plugins_yml.exists()}") + + if plugins_yml.exists(): + with open(plugins_yml, 'r') as f: + plugins_config = yaml.safe_load(f) + # Expand environment variables in configuration + plugins_config = expand_env_vars(plugins_config) + plugins_data = plugins_config.get('plugins', {}) + + logger.info(f"🔍 Loaded plugins config with {len(plugins_data)} plugin(s): {list(plugins_data.keys())}") + + # Initialize each enabled plugin + for plugin_id, plugin_config in plugins_data.items(): + logger.info(f"🔍 Processing plugin '{plugin_id}', enabled={plugin_config.get('enabled', False)}") + if not plugin_config.get('enabled', False): + continue + + try: + if plugin_id == 'homeassistant': + from advanced_omi_backend.plugins.homeassistant import HomeAssistantPlugin + plugin = HomeAssistantPlugin(plugin_config) + # Note: async initialization happens in app_factory lifespan + _plugin_router.register_plugin(plugin_id, plugin) + logger.info(f"✅ Plugin '{plugin_id}' registered") + elif plugin_id == 'test_event': + from advanced_omi_backend.plugins.test_event import TestEventPlugin + plugin = TestEventPlugin(plugin_config) + # Note: async initialization happens in app_factory lifespan + _plugin_router.register_plugin(plugin_id, plugin) + logger.info(f"✅ Plugin '{plugin_id}' registered") + else: + logger.warning(f"Unknown plugin: {plugin_id}") + + except Exception as e: + logger.error(f"Failed to register plugin '{plugin_id}': {e}", exc_info=True) + + logger.info(f"Plugins registered: {len(_plugin_router.plugins)} total") + else: + logger.info("No plugins.yml found, plugins disabled") + + return _plugin_router + + except Exception as e: + logger.error(f"Failed to initialize plugin router: {e}", exc_info=True) + _plugin_router = None + return None + + +async def cleanup_plugin_router() -> None: + """Clean up the plugin router and all registered plugins.""" + global _plugin_router + + if _plugin_router: + try: + await _plugin_router.cleanup_all() + logger.info("Plugin router cleanup complete") + except Exception as e: + logger.error(f"Error during plugin router cleanup: {e}") + finally: + _plugin_router = None diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py index 2e20171b..f481ac3f 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py @@ -10,6 +10,7 @@ import json import logging from typing import Optional +from urllib.parse import urlencode import httpx import websockets @@ -167,26 +168,65 @@ def __init__(self): def name(self) -> str: return self._name + async def transcribe(self, audio_data: bytes, sample_rate: int, **kwargs) -> dict: + """Not used for streaming providers - use start_stream/process_audio_chunk/end_stream instead.""" + raise NotImplementedError("Streaming providers do not support batch transcription") + async def start_stream(self, client_id: str, sample_rate: int = 16000, diarize: bool = False): - url = self.model.model_url + base_url = self.model.model_url ops = self.model.operations or {} + + # Build WebSocket URL with query parameters (for Deepgram streaming) + query_params = ops.get("query", {}) + query_dict = dict(query_params) if query_params else {} + + # Override sample_rate if provided + if sample_rate and "sample_rate" in query_dict: + query_dict["sample_rate"] = sample_rate + if diarize and "diarize" in query_dict: + query_dict["diarize"] = "true" + + # Normalize boolean values to lowercase strings (Deepgram expects "true"/"false", not "True"/"False") + normalized_query = {} + for k, v in query_dict.items(): + if isinstance(v, bool): + normalized_query[k] = "true" if v else "false" + else: + normalized_query[k] = v + + # Build query string with proper URL encoding (NO token in query) + query_str = urlencode(normalized_query) + url = f"{base_url}?{query_str}" if query_str else base_url + + # Debug: Log the URL + logger.info(f"🔗 Connecting to Deepgram WebSocket: {url}") + + # Connect to WebSocket with Authorization header (Deepgram requires this for server-side connections) + headers = {} + if self.model.api_key: + headers["Authorization"] = f"Token {self.model.api_key}" + + ws = await websockets.connect(url, additional_headers=headers) + + # Send start message if required by provider start_msg = (ops.get("start", {}) or {}).get("message", {}) - # Inject session_id if placeholder present - start_msg = json.loads(json.dumps(start_msg)) # deep copy - start_msg.setdefault("session_id", client_id) - # Apply sample rate and diarization if present - if "config" in start_msg and isinstance(start_msg["config"], dict): - start_msg["config"].setdefault("sample_rate", sample_rate) - if diarize: - start_msg["config"]["diarize"] = True - - ws = await websockets.connect(url, open_timeout=10) - await ws.send(json.dumps(start_msg)) - # Wait for confirmation; non-fatal if not provided - try: - await asyncio.wait_for(ws.recv(), timeout=2.0) - except Exception: - pass + if start_msg: + # Inject session_id if placeholder present + start_msg = json.loads(json.dumps(start_msg)) # deep copy + start_msg.setdefault("session_id", client_id) + # Apply sample rate and diarization if present + if "config" in start_msg and isinstance(start_msg["config"], dict): + start_msg["config"].setdefault("sample_rate", sample_rate) + if diarize: + start_msg["config"]["diarize"] = True + await ws.send(json.dumps(start_msg)) + + # Wait for confirmation; non-fatal if not provided + try: + await asyncio.wait_for(ws.recv(), timeout=2.0) + except Exception: + pass + self._streams[client_id] = {"ws": ws, "sample_rate": sample_rate, "final": None, "interim": []} async def process_audio_chunk(self, client_id: str, audio_chunk: bytes) -> dict | None: @@ -194,26 +234,67 @@ async def process_audio_chunk(self, client_id: str, audio_chunk: bytes) -> dict return None ws = self._streams[client_id]["ws"] ops = self.model.operations or {} + + # Send chunk header if required (for providers like Parakeet) chunk_hdr = (ops.get("chunk_header", {}) or {}).get("message", {}) - hdr = json.loads(json.dumps(chunk_hdr)) - hdr.setdefault("type", "audio_chunk") - hdr.setdefault("session_id", client_id) - hdr.setdefault("rate", self._streams[client_id]["sample_rate"]) - await ws.send(json.dumps(hdr)) + if chunk_hdr: + hdr = json.loads(json.dumps(chunk_hdr)) + hdr.setdefault("type", "audio_chunk") + hdr.setdefault("session_id", client_id) + hdr.setdefault("rate", self._streams[client_id]["sample_rate"]) + await ws.send(json.dumps(hdr)) + + # Send audio chunk (raw bytes for Deepgram, or after header for others) await ws.send(audio_chunk) - # Non-blocking read for interim results + # Non-blocking read for results expect = (ops.get("expect", {}) or {}) + extract = expect.get("extract", {}) interim_type = expect.get("interim_type") + final_type = expect.get("final_type") + try: - while True: - msg = await asyncio.wait_for(ws.recv(), timeout=0.01) - data = json.loads(msg) - if interim_type and data.get("type") == interim_type: - self._streams[client_id]["interim"].append(data) + # Try to read a message (non-blocking) + msg = await asyncio.wait_for(ws.recv(), timeout=0.05) + data = json.loads(msg) + + # Determine if this is interim or final result + is_final = False + if final_type and data.get("type") == final_type: + # Check if Deepgram marks it as final + is_final = data.get("is_final", False) + elif interim_type and data.get("type") == interim_type: + is_final = data.get("is_final", False) + + # Extract result data + text = _dotted_get(data, extract.get("text")) if extract.get("text") else data.get("text", "") + words = _dotted_get(data, extract.get("words")) if extract.get("words") else data.get("words", []) + segments = _dotted_get(data, extract.get("segments")) if extract.get("segments") else data.get("segments", []) + + # Calculate confidence if available + confidence = data.get("confidence", 0.0) + if not confidence and words and isinstance(words, list): + # Calculate average word confidence + confidences = [w.get("confidence", 0.0) for w in words if isinstance(w, dict) and "confidence" in w] + if confidences: + confidence = sum(confidences) / len(confidences) + + # Return result with is_final flag + # Consumer decides what to do with interim vs final + return { + "text": text, + "words": words, + "segments": segments, + "is_final": is_final, + "confidence": confidence + } + except asyncio.TimeoutError: - pass - return None + # No message available yet + return None + except Exception as e: + logger.error(f"Error processing audio chunk result for {client_id}: {e}") + return None async def end_stream(self, client_id: str) -> dict: if client_id not in self._streams: diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/base.py b/backends/advanced/src/advanced_omi_backend/services/transcription/base.py index 13893a68..7d0f2306 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/base.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/base.py @@ -36,7 +36,6 @@ class TranscriptionProvider(Enum): """Available transcription providers for audio stream routing.""" DEEPGRAM = "deepgram" PARAKEET = "parakeet" - MISTRAL = "mistral" class BaseTranscriptionProvider(abc.ABC): diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/deepgram.py b/backends/advanced/src/advanced_omi_backend/services/transcription/deepgram.py deleted file mode 100644 index ef54a3d9..00000000 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/deepgram.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Deepgram transcription consumer for Redis Streams architecture. - -Uses the registry-driven transcription provider for Deepgram batch transcription. -""" - -import logging - -from advanced_omi_backend.services.audio_stream.consumer import BaseAudioStreamConsumer -from advanced_omi_backend.services.transcription import get_transcription_provider - -logger = logging.getLogger(__name__) - - -class DeepgramStreamConsumer: - """ - Deepgram consumer for Redis Streams architecture. - - Reads from: specified stream (client-specific or provider-specific) - Writes to: transcription:results:{session_id} - - Uses RegistryBatchTranscriptionProvider configured via config.yml for - Deepgram transcription. This ensures consistent behavior with batch - transcription jobs. - """ - - def __init__(self, redis_client, buffer_chunks: int = 30): - """ - Initialize Deepgram consumer. - - Dynamically discovers all audio:stream:* streams and claims them using Redis locks. - Uses config.yml stt-deepgram configuration for transcription. - - Args: - redis_client: Connected Redis client - buffer_chunks: Number of chunks to buffer before transcribing (default: 30 = ~7.5s) - """ - - # Get registry-driven transcription provider - self.provider = get_transcription_provider(mode="batch") - if not self.provider: - raise RuntimeError( - "Failed to load transcription provider. Ensure config.yml has a default 'stt' model configured." - ) - - # Create a concrete subclass that implements transcribe_audio - class _ConcreteConsumer(BaseAudioStreamConsumer): - def __init__(inner_self, provider_name: str, redis_client, buffer_chunks: int): - super().__init__(provider_name, redis_client, buffer_chunks) - inner_self._transcription_provider = self.provider - - async def transcribe_audio(inner_self, audio_data: bytes, sample_rate: int) -> dict: - """Transcribe using registry-driven transcription provider.""" - try: - result = await inner_self._transcription_provider.transcribe( - audio_data=audio_data, - sample_rate=sample_rate, - diarize=True - ) - - # Calculate confidence - confidence = 0.0 - if result.get("words"): - confidences = [ - w.get("confidence", 0) - for w in result["words"] - if "confidence" in w - ] - if confidences: - confidence = sum(confidences) / len(confidences) - - return { - "text": result.get("text", ""), - "words": result.get("words", []), - "segments": result.get("segments", []), - "confidence": confidence - } - - except Exception as e: - logger.error(f"Deepgram transcription failed: {e}", exc_info=True) - raise - - # Instantiate the concrete consumer - self._consumer = _ConcreteConsumer("deepgram", redis_client, buffer_chunks) - - async def start_consuming(self): - """Delegate to base consumer.""" - return await self._consumer.start_consuming() - - async def stop(self): - """Delegate to base consumer.""" - return await self._consumer.stop() diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/parakeet_stream_consumer.py b/backends/advanced/src/advanced_omi_backend/services/transcription/parakeet_stream_consumer.py deleted file mode 100644 index f629cefd..00000000 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/parakeet_stream_consumer.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -Parakeet stream consumer for Redis Streams architecture. - -Reads from: audio:stream:* streams -Writes to: transcription:results:{session_id} -""" - -import logging - -from advanced_omi_backend.services.audio_stream.consumer import BaseAudioStreamConsumer -from advanced_omi_backend.services.transcription import get_transcription_provider - -logger = logging.getLogger(__name__) - - -class ParakeetStreamConsumer: - """ - Parakeet consumer for Redis Streams architecture. - - Reads from: specified stream (client-specific or provider-specific) - Writes to: transcription:results:{session_id} - - This inherits from BaseAudioStreamConsumer and implements transcribe_audio(). - """ - - def __init__(self, redis_client, buffer_chunks: int = 30): - """ - Initialize Parakeet consumer. - - Dynamically discovers all audio:stream:* streams and claims them using Redis locks. - Uses config.yml stt-parakeet-batch configuration for transcription. - - Args: - redis_client: Connected Redis client - buffer_chunks: Number of chunks to buffer before transcribing (default: 30 = ~7.5s) - """ - # Get registry-driven transcription provider - self.provider = get_transcription_provider(mode="batch") - if not self.provider: - raise RuntimeError( - "Failed to load transcription provider. Ensure config.yml has a default 'stt' model configured." - ) - - # Create a concrete subclass that implements transcribe_audio - class _ConcreteConsumer(BaseAudioStreamConsumer): - def __init__(inner_self, provider_name: str, redis_client, buffer_chunks: int): - super().__init__(provider_name, redis_client, buffer_chunks) - inner_self._parakeet_provider = self.provider - - async def transcribe_audio(inner_self, audio_data: bytes, sample_rate: int) -> dict: - """Transcribe using ParakeetProvider.""" - try: - result = await inner_self._parakeet_provider.transcribe( - audio_data=audio_data, - sample_rate=sample_rate - ) - - # Calculate confidence (Parakeet may not provide confidence, default to 0.9) - confidence = 0.9 - if result.get("words"): - confidences = [ - w.get("confidence", 0.9) - for w in result["words"] - if "confidence" in w - ] - if confidences: - confidence = sum(confidences) / len(confidences) - - return { - "text": result.get("text", ""), - "words": result.get("words", []), - "segments": result.get("segments", []), - "confidence": confidence - } - - except Exception as e: - logger.error(f"Parakeet transcription failed: {e}", exc_info=True) - raise - - # Instantiate the concrete consumer - self._consumer = _ConcreteConsumer("parakeet", redis_client, buffer_chunks) - - async def start_consuming(self): - """Delegate to base consumer.""" - return await self._consumer.start_consuming() - - async def stop(self): - """Delegate to base consumer.""" - return await self._consumer.stop() - diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/streaming_consumer.py b/backends/advanced/src/advanced_omi_backend/services/transcription/streaming_consumer.py new file mode 100644 index 00000000..579bc195 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/streaming_consumer.py @@ -0,0 +1,506 @@ +""" +Generic streaming transcription consumer for real-time audio processing. + +Uses registry-driven transcription provider from config.yml (supports any streaming provider). + +Reads from: audio:stream:* streams +Publishes interim to: Redis Pub/Sub channel transcription:interim:{session_id} +Writes final to: transcription:results:{session_id} Redis Stream +Triggers plugins: streaming_transcript level (final results only) +""" + +import asyncio +import json +import logging +import os +import time +from typing import Dict, Optional + +import redis.asyncio as redis +from redis import exceptions as redis_exceptions + +from advanced_omi_backend.plugins.router import PluginRouter +from advanced_omi_backend.services.transcription import get_transcription_provider +from advanced_omi_backend.client_manager import get_client_owner_async + +logger = logging.getLogger(__name__) + + +class StreamingTranscriptionConsumer: + """ + Generic streaming transcription consumer using registry-driven providers. + + - Discovers audio:stream:* streams dynamically + - Uses Redis consumer groups for fan-out (allows batch workers to process same stream) + - Starts WebSocket connections using configured provider (from config.yml) + - Sends audio immediately (no buffering) + - Publishes interim results to Redis Pub/Sub for client display + - Publishes final results to Redis Streams for storage + - Triggers plugins only on final results + + Supported providers (via config.yml): Any streaming STT service with WebSocket API + """ + + def __init__(self, redis_client: redis.Redis, plugin_router: Optional[PluginRouter] = None): + """ + Initialize streaming transcription consumer. + + Args: + redis_client: Connected Redis client + plugin_router: Plugin router for triggering plugins on final results + """ + self.redis_client = redis_client + self.plugin_router = plugin_router + + # Get streaming transcription provider from registry + self.provider = get_transcription_provider(mode="streaming") + if not self.provider: + raise RuntimeError( + "Failed to load streaming transcription provider. " + "Ensure config.yml has a default 'stt_stream' model configured." + ) + + # Stream configuration + self.stream_pattern = "audio:stream:*" + self.group_name = "streaming-transcription" + self.consumer_name = f"streaming-worker-{os.getpid()}" + + self.running = False + + # Active stream tracking - consumer groups handle fan-out + self.active_streams: Dict[str, Dict] = {} # {stream_name: {"session_id": ...}} + + # Session tracking for WebSocket connections + self.active_sessions: Dict[str, Dict] = {} # {session_id: {"last_activity": timestamp}} + + async def discover_streams(self) -> list[str]: + """ + Discover all audio streams matching the pattern. + + Returns: + List of stream names + """ + streams = [] + cursor = b"0" + + while cursor: + cursor, keys = await self.redis_client.scan( + cursor, match=self.stream_pattern, count=100 + ) + if keys: + streams.extend([k.decode() if isinstance(k, bytes) else k for k in keys]) + + return streams + + async def setup_consumer_group(self, stream_name: str): + """Create consumer group if it doesn't exist.""" + try: + await self.redis_client.xgroup_create( + stream_name, + self.group_name, + "0", + mkstream=True + ) + logger.debug(f"➡️ Created consumer group {self.group_name} for {stream_name}") + except redis_exceptions.ResponseError as e: + if "BUSYGROUP" not in str(e): + raise + logger.debug(f"➡️ Consumer group {self.group_name} already exists for {stream_name}") + + async def start_session_stream(self, session_id: str, sample_rate: int = 16000): + """ + Start WebSocket connection to Deepgram for a session. + + Args: + session_id: Session ID (client_id from audio stream) + sample_rate: Audio sample rate in Hz + """ + try: + await self.provider.start_stream( + client_id=session_id, + sample_rate=sample_rate, + diarize=False # Deepgram streaming doesn't support diarization + ) + + self.active_sessions[session_id] = { + "last_activity": time.time(), + "sample_rate": sample_rate + } + + logger.info(f"🎙️ Started Deepgram WebSocket stream for session: {session_id}") + + except Exception as e: + logger.error(f"Failed to start Deepgram stream for {session_id}: {e}", exc_info=True) + + # Set error flag in Redis so speech detection can detect failure early + session_key = f"audio:session:{session_id}" + try: + await self.redis_client.hset(session_key, "transcription_error", str(e)) + logger.info(f"Set transcription error flag for {session_id}") + except Exception as redis_error: + logger.warning(f"Failed to set error flag in Redis: {redis_error}") + + raise + + async def end_session_stream(self, session_id: str): + """ + End WebSocket connection to Deepgram for a session. + + Args: + session_id: Session ID + """ + try: + # Get final result from Deepgram + final_result = await self.provider.end_stream(client_id=session_id) + + # If there's a final result, publish it + if final_result and final_result.get("text"): + await self.publish_to_client(session_id, final_result, is_final=True) + await self.store_final_result(session_id, final_result) + + # Trigger plugins on final result + if self.plugin_router: + await self.trigger_plugins(session_id, final_result) + + self.active_sessions.pop(session_id, None) + logger.info(f"🛑 Ended Deepgram WebSocket stream for session: {session_id}") + + except Exception as e: + logger.error(f"Error ending stream for {session_id}: {e}", exc_info=True) + + async def process_audio_chunk(self, session_id: str, audio_chunk: bytes, chunk_id: str): + """ + Process a single audio chunk through Deepgram WebSocket. + + Args: + session_id: Session ID + audio_chunk: Raw audio bytes + chunk_id: Chunk identifier from Redis stream + """ + try: + # Send audio chunk to Deepgram WebSocket and get result + result = await self.provider.process_audio_chunk( + client_id=session_id, + audio_chunk=audio_chunk + ) + + # Update last activity + if session_id in self.active_sessions: + self.active_sessions[session_id]["last_activity"] = time.time() + + # Deepgram returns None if no response yet, or a dict with results + if result: + is_final = result.get("is_final", False) + text = result.get("text", "") + word_count = len(result.get("words", [])) + + # Track transcript at each step + logger.info( + f"🔤 TRANSCRIPT [DEEPGRAM] session={session_id}, is_final={is_final}, " + f"words={word_count}, text=\"{text}\"" + ) + + # Always publish to clients (interim + final) for real-time display + await self.publish_to_client(session_id, result, is_final=is_final) + + # If final result, also store and trigger plugins + if is_final: + logger.info( + f"🔤 TRANSCRIPT [STORE] session={session_id}, words={word_count}, text=\"{text}\"" + ) + await self.store_final_result(session_id, result, chunk_id=chunk_id) + + # Trigger plugins on final results only + if self.plugin_router: + await self.trigger_plugins(session_id, result) + + except Exception as e: + logger.error(f"Error processing audio chunk for {session_id}: {e}", exc_info=True) + + async def publish_to_client(self, session_id: str, result: Dict, is_final: bool): + """ + Publish interim or final results to Redis Pub/Sub for client consumption. + + Args: + session_id: Session ID + result: Transcription result from Deepgram + is_final: Whether this is a final result + """ + try: + channel = f"transcription:interim:{session_id}" + + # Prepare message for clients + message = { + "text": result.get("text", ""), + "is_final": is_final, + "words": result.get("words", []), + "confidence": result.get("confidence", 0.0), + "timestamp": time.time() + } + + # Publish to Redis Pub/Sub + await self.redis_client.publish(channel, json.dumps(message)) + + result_type = "FINAL" if is_final else "interim" + logger.debug(f"📢 Published {result_type} result to {channel}: {message['text'][:50]}...") + + except Exception as e: + logger.error(f"Error publishing to client for {session_id}: {e}", exc_info=True) + + async def store_final_result(self, session_id: str, result: Dict, chunk_id: str = None): + """ + Store final transcription result to Redis Stream. + + Args: + session_id: Session ID + result: Final transcription result + chunk_id: Optional chunk identifier + """ + try: + stream_name = f"transcription:results:{session_id}" + + # Prepare result entry - MUST match aggregator's expected schema + # All keys and values must be bytes to match consumer.py format + entry = { + b"text": result.get("text", "").encode(), + b"chunk_id": (chunk_id or f"final_{int(time.time() * 1000)}").encode(), + b"provider": b"deepgram-stream", + b"confidence": str(result.get("confidence", 0.0)).encode(), + b"processing_time": b"0.0", # Streaming has minimal processing time + b"timestamp": str(time.time()).encode(), + } + + # Add optional JSON fields + words = result.get("words", []) + if words: + entry[b"words"] = json.dumps(words).encode() + + segments = result.get("segments", []) + if segments: + entry[b"segments"] = json.dumps(segments).encode() + + # Write to Redis Stream + await self.redis_client.xadd(stream_name, entry) + + logger.info(f"💾 Stored final result to {stream_name}: {result.get('text', '')[:50]}...") + + except Exception as e: + logger.error(f"Error storing final result for {session_id}: {e}", exc_info=True) + + async def _get_user_id_from_client_id(self, client_id: str) -> Optional[str]: + """ + Look up user_id from client_id using ClientManager (async Redis lookup). + + Args: + client_id: Client ID to search for + + Returns: + user_id if found, None otherwise + """ + user_id = await get_client_owner_async(client_id) + + if user_id: + logger.debug(f"Found user_id {user_id} for client_id {client_id} via Redis") + else: + logger.warning(f"No user_id found for client_id {client_id} in Redis") + + return user_id + + async def trigger_plugins(self, session_id: str, result: Dict): + """ + Trigger plugins at streaming_transcript access level (final results only). + + Args: + session_id: Session ID (client_id from stream name) + result: Final transcription result + """ + try: + # Find user_id by looking up session with matching client_id + # session_id here is actually the client_id extracted from stream name + user_id = await self._get_user_id_from_client_id(session_id) + + if not user_id: + logger.warning( + f"Could not find user_id for client_id {session_id}. " + "Plugins will not be triggered." + ) + return + + plugin_data = { + 'transcript': result.get("text", ""), + 'session_id': session_id, + 'words': result.get("words", []), + 'segments': result.get("segments", []), + 'confidence': result.get("confidence", 0.0), + 'is_final': True + } + + # Dispatch transcript.streaming event + logger.info(f"🎯 Dispatching transcript.streaming event for user {user_id}, transcript: {plugin_data['transcript'][:50]}...") + + plugin_results = await self.plugin_router.dispatch_event( + event='transcript.streaming', + user_id=user_id, + data=plugin_data, + metadata={'client_id': session_id} + ) + + if plugin_results: + logger.info(f"✅ Plugins triggered successfully: {len(plugin_results)} results") + else: + logger.info(f"ℹ️ No plugins triggered (no matching conditions)") + + except Exception as e: + logger.error(f"Error triggering plugins for {session_id}: {e}", exc_info=True) + + async def process_stream(self, stream_name: str): + """ + Process a single audio stream. + + Args: + stream_name: Redis stream name (e.g., "audio:stream:user01-phone") + """ + # Extract session_id from stream name (format: audio:stream:{session_id}) + session_id = stream_name.replace("audio:stream:", "") + + # Track this stream + self.active_streams[stream_name] = { + "session_id": session_id, + "started_at": time.time() + } + + # Start WebSocket connection to Deepgram + await self.start_session_stream(session_id) + + last_id = "0" # Start from beginning + stream_ended = False + + try: + while self.running and not stream_ended: + # Read messages from Redis stream using consumer group + try: + messages = await self.redis_client.xreadgroup( + self.group_name, # "streaming-transcription" + self.consumer_name, # "streaming-worker-{pid}" + {stream_name: ">"}, # Read only new messages + count=10, + block=1000 # Block for 1 second + ) + + if not messages: + # No new messages - check if stream is still alive + # Check for stream end marker or timeout + if session_id not in self.active_sessions: + logger.info(f"Session {session_id} no longer active, ending stream processing") + stream_ended = True + continue + + for stream, stream_messages in messages: + logger.debug(f"📥 Read {len(stream_messages)} messages from {stream_name}") + for message_id, fields in stream_messages: + msg_id = message_id.decode() if isinstance(message_id, bytes) else message_id + + # Check for end marker + if fields.get(b'end_marker') or fields.get('end_marker'): + logger.info(f"End marker received for {session_id}") + stream_ended = True + # ACK the end marker + await self.redis_client.xack(stream_name, self.group_name, msg_id) + break + + # Extract audio data (producer sends as 'audio_data', not 'audio_chunk') + audio_chunk = fields.get(b'audio_data') or fields.get('audio_data') + if audio_chunk: + logger.debug(f"🎵 Processing audio chunk {msg_id} ({len(audio_chunk)} bytes)") + # Process audio chunk through Deepgram WebSocket + await self.process_audio_chunk( + session_id=session_id, + audio_chunk=audio_chunk, + chunk_id=msg_id + ) + else: + logger.warning(f"⚠️ Message {msg_id} has no audio_data field") + + # ACK the message after processing + await self.redis_client.xack(stream_name, self.group_name, msg_id) + + if stream_ended: + break + + except redis_exceptions.ResponseError as e: + if "NOGROUP" in str(e): + # Stream has expired or been deleted - exit gracefully + logger.info(f"Stream {stream_name} expired or deleted, ending processing") + stream_ended = True + break + else: + logger.error(f"Redis error reading from stream {stream_name}: {e}", exc_info=True) + await asyncio.sleep(1) + except Exception as e: + logger.error(f"Error reading from stream {stream_name}: {e}", exc_info=True) + await asyncio.sleep(1) + + finally: + # End WebSocket connection + await self.end_session_stream(session_id) + + # Remove from active streams tracking + self.active_streams.pop(stream_name, None) + logger.debug(f"Removed {stream_name} from active streams tracking") + + async def start_consuming(self): + """ + Start consuming audio streams and processing through Deepgram WebSocket. + Uses Redis consumer groups for fan-out (allows batch workers to process same stream). + """ + self.running = True + logger.info(f"🚀 Deepgram streaming consumer started (group: {self.group_name})") + + try: + while self.running: + # Discover available streams + streams = await self.discover_streams() + + if streams: + logger.debug(f"🔍 Discovered {len(streams)} audio streams") + else: + logger.debug("🔍 No audio streams found") + + # Setup consumer groups and spawn processing tasks + for stream_name in streams: + if stream_name in self.active_streams: + continue # Already processing + + # Setup consumer group (no manual lock needed) + await self.setup_consumer_group(stream_name) + + # Track stream and spawn task to process it + session_id = stream_name.replace("audio:stream:", "") + self.active_streams[stream_name] = {"session_id": session_id} + + # Spawn task to process this stream + asyncio.create_task(self.process_stream(stream_name)) + logger.info(f"✅ Now consuming from {stream_name} (group: {self.group_name})") + + # Sleep before next discovery cycle + await asyncio.sleep(5) + + except Exception as e: + logger.error(f"Fatal error in consumer main loop: {e}", exc_info=True) + finally: + await self.stop() + + async def stop(self): + """Stop consuming and clean up resources.""" + logger.info("🛑 Stopping Deepgram streaming consumer...") + self.running = False + + # End all active sessions + session_ids = list(self.active_sessions.keys()) + for session_id in session_ids: + try: + await self.end_session_stream(session_id) + except Exception as e: + logger.error(f"Error ending session {session_id}: {e}") + + logger.info("✅ Deepgram streaming consumer stopped") diff --git a/backends/advanced/src/advanced_omi_backend/utils/audio_utils.py b/backends/advanced/src/advanced_omi_backend/utils/audio_utils.py index 3a3b554d..4d3fa0ae 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/audio_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/audio_utils.py @@ -275,73 +275,6 @@ async def process_audio_chunk( client_state.update_audio_received(chunk) -async def _process_audio_cropping_with_relative_timestamps( - original_path: str, - speech_segments: list[tuple[float, float]], - output_path: str, - audio_uuid: str, - _deprecated_chunk_repo=None, # Deprecated - kept for backward compatibility -) -> tuple[bool, list[dict]]: - """ - Process audio cropping with speech segments already in relative format. - - The segments are expected to be in relative format (seconds from audio start), - as provided by Deepgram transcription. No timestamp conversion is needed. - - Note: Database updates are now handled by the caller (audio_jobs.py). - - Returns: - Tuple of (success: bool, segment_mapping: list[dict]) - """ - try: - # Validate input segments - validated_segments = [] - for start_rel, end_rel in speech_segments: - # Validate input timestamps - if start_rel >= end_rel: - logger.warning( - f"⚠️ Invalid speech segment: start={start_rel} >= end={end_rel}, skipping" - ) - continue - - # Ensure timestamps are positive (sanity check) - if start_rel < 0: - logger.warning( - f"⚠️ Negative start timestamp: {start_rel}, clamping to 0.0" - ) - start_rel = 0.0 - if end_rel < 0: - logger.warning( - f"⚠️ Negative end timestamp: {end_rel}, skipping segment" - ) - continue - - validated_segments.append((start_rel, end_rel)) - - logger.info(f"🕐 Processing cropping for {audio_uuid}") - logger.info(f"🕐 Input segments (relative timestamps): {speech_segments}") - logger.info(f"🕐 Validated segments: {validated_segments}") - - # Validate that we have valid segments - if not validated_segments: - logger.warning( - f"No valid segments for cropping {audio_uuid}" - ) - return False, [] - - success, segment_mapping = await _crop_audio_with_ffmpeg(original_path, validated_segments, output_path) - if success: - cropped_filename = output_path.split("/")[-1] - logger.info(f"Successfully processed cropped audio: {cropped_filename}") - return True, segment_mapping - else: - logger.error(f"Failed to crop audio for {audio_uuid}") - return False, segment_mapping - except Exception as e: - logger.error(f"Error in audio cropping task for {audio_uuid}: {e}", exc_info=True) - return False, [] - - def write_pcm_to_wav( pcm_data: bytes, output_path: str, @@ -383,142 +316,3 @@ def write_pcm_to_wav( except Exception as e: logger.error(f"❌ Failed to write PCM to WAV: {e}") raise - - -async def _crop_audio_with_ffmpeg( - original_path: str, speech_segments: list[tuple[float, float]], output_path: str -) -> tuple[bool, list[dict]]: - """ - Use ffmpeg to crop audio - runs as async subprocess, no GIL issues. - - Returns: - Tuple of (success: bool, segment_mapping: list[dict]) - - segment_mapping contains one entry per input segment with: - - original_index: Index in input speech_segments - - original_start/end: Original timestamps in source audio - - cropped_start/end: Where the speech starts/ends in cropped file (None if filtered) - - kept: Whether segment was kept (True) or filtered out (False) - """ - logger.info(f"Cropping audio {original_path} with {len(speech_segments)} speech segments") - - if not speech_segments: - logger.warning(f"No speech segments to crop for {original_path}") - return False, [] - - # Check if the original file exists - if not os.path.exists(original_path): - logger.error(f"Original audio file does not exist: {original_path}") - return False, [] - - # Filter out segments that are too short and build mapping - filtered_segments = [] - segment_mapping = [] - current_cropped_offset = 0.0 - - for idx, (start, end) in enumerate(speech_segments): - duration = end - start - if duration >= MIN_SPEECH_SEGMENT_DURATION: - # Add padding around speech segments - padded_start = max(0, start - CROPPING_CONTEXT_PADDING) - padded_end = end + CROPPING_CONTEXT_PADDING - padded_duration = padded_end - padded_start - - filtered_segments.append((padded_start, padded_end)) - - # Calculate where the speech (not padding) appears in cropped file - # The cropped file will have: [padding_before][speech][padding_after] - padding_before = start - padded_start - speech_start_in_cropped = current_cropped_offset + padding_before - speech_end_in_cropped = speech_start_in_cropped + duration - - segment_mapping.append({ - "original_index": idx, - "original_start": start, - "original_end": end, - "cropped_start": speech_start_in_cropped, - "cropped_end": speech_end_in_cropped, - "kept": True - }) - - # Move offset by the full padded duration - current_cropped_offset += padded_duration - else: - # Segment filtered out - segment_mapping.append({ - "original_index": idx, - "original_start": start, - "original_end": end, - "cropped_start": None, - "cropped_end": None, - "kept": False - }) - logger.debug( - f"Skipping short segment: {start}-{end} ({duration:.2f}s < {MIN_SPEECH_SEGMENT_DURATION}s)" - ) - - if not filtered_segments: - logger.warning( - f"No segments meet minimum duration ({MIN_SPEECH_SEGMENT_DURATION}s) for {original_path}" - ) - return False, segment_mapping - - logger.info( - f"Cropping audio {original_path} with {len(filtered_segments)} speech segments (filtered from {len(speech_segments)})" - ) - - try: - # Build ffmpeg filter for concatenating speech segments - filter_parts = [] - for i, (start, end) in enumerate(filtered_segments): - duration = end - start - filter_parts.append( - f"[0:a]atrim=start={start}:duration={duration},asetpts=PTS-STARTPTS[seg{i}]" - ) - - # Concatenate all segments - inputs = "".join(f"[seg{i}]" for i in range(len(filtered_segments))) - concat_filter = f"{inputs}concat=n={len(filtered_segments)}:v=0:a=1[out]" - - full_filter = ";".join(filter_parts + [concat_filter]) - - # Run ffmpeg as async subprocess - cmd = [ - "ffmpeg", - "-y", # -y = overwrite output - "-i", - original_path, - "-filter_complex", - full_filter, - "-map", - "[out]", - "-c:a", - "pcm_s16le", # Keep same format as original - output_path, - ] - - logger.info(f"Running ffmpeg command: {' '.join(cmd)}") - - process = await asyncio.create_subprocess_exec( - *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE - ) - - stdout, stderr = await process.communicate() - if stdout: - logger.debug(f"FFMPEG stdout: {stdout.decode()}") - - if process.returncode == 0: - # Calculate cropped duration - cropped_duration = sum(end - start for start, end in filtered_segments) - logger.info( - f"Successfully cropped {original_path} -> {output_path} ({cropped_duration:.1f}s from {len(filtered_segments)} segments)" - ) - return True, segment_mapping - else: - error_msg = stderr.decode() if stderr else "Unknown ffmpeg error" - logger.error(f"ffmpeg failed for {original_path}: {error_msg}") - return False, segment_mapping - - except Exception as e: - logger.error(f"Error running ffmpeg on {original_path}: {e}", exc_info=True) - return False, segment_mapping diff --git a/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py b/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py index b2cddf4c..3acba204 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py @@ -87,37 +87,41 @@ def analyze_speech(transcript_data: dict) -> dict: valid_words = [w for w in words if w.get("confidence", 0) >= settings["min_confidence"]] if len(valid_words) < settings["min_words"]: - return { - "has_speech": False, - "reason": f"Not enough valid words ({len(valid_words)} < {settings['min_words']})", - "word_count": len(valid_words), - "duration": 0.0, - } - - # Calculate speech duration from word timing - if valid_words: - speech_start = valid_words[0].get("start", 0) - speech_end = valid_words[-1].get("end", 0) - speech_duration = speech_end - speech_start - - # Check minimum duration threshold - min_duration = settings.get("min_duration", 10.0) - if speech_duration < min_duration: - return { - "has_speech": False, - "reason": f"Speech too short ({speech_duration:.1f}s < {min_duration}s)", - "word_count": len(valid_words), - "duration": speech_duration, - } - - return { - "has_speech": True, - "word_count": len(valid_words), - "speech_start": speech_start, - "speech_end": speech_end, - "duration": speech_duration, - "reason": f"Valid speech detected ({len(valid_words)} words, {speech_duration:.1f}s)", - } + # Not enough valid words in word-level data - fall through to text-only analysis + # This handles cases where word-level data is incomplete or low confidence + logger.debug(f"Only {len(valid_words)} valid words, falling back to text-only analysis") + # Continue to Method 2 (don't return early) + else: + # Calculate speech duration from word timing + if valid_words: + speech_start = valid_words[0].get("start", 0) + speech_end = valid_words[-1].get("end", 0) + speech_duration = speech_end - speech_start + + # If no timing data (duration = 0), fall back to text-only analysis + # This happens with some streaming transcription services + if speech_duration == 0: + logger.debug("Word timing data missing, falling back to text-only analysis") + # Continue to Method 2 (text-only fallback) + else: + # Check minimum duration threshold when we have timing data + min_duration = settings.get("min_duration", 10.0) + if speech_duration < min_duration: + return { + "has_speech": False, + "reason": f"Speech too short ({speech_duration:.1f}s < {min_duration}s)", + "word_count": len(valid_words), + "duration": speech_duration, + } + + return { + "has_speech": True, + "word_count": len(valid_words), + "speech_start": speech_start, + "speech_end": speech_end, + "duration": speech_duration, + "reason": f"Valid speech detected ({len(valid_words)} words, {speech_duration:.1f}s)", + } # Method 2: Text-only fallback (when no word-level data available) text = transcript_data.get("text", "").strip() diff --git a/backends/advanced/src/advanced_omi_backend/utils/job_utils.py b/backends/advanced/src/advanced_omi_backend/utils/job_utils.py index 6200af82..ba9fcc74 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/job_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/job_utils.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -async def check_job_alive(redis_client, current_job) -> bool: +async def check_job_alive(redis_client, current_job, session_id: Optional[str] = None) -> bool: """ Check if current RQ job still exists in Redis. @@ -20,6 +20,7 @@ async def check_job_alive(redis_client, current_job) -> bool: Args: redis_client: Async Redis client current_job: RQ job instance from get_current_job() + session_id: Optional session ID to check if session has ended naturally Returns: False if job is zombie (caller should exit), True otherwise @@ -32,13 +33,23 @@ async def check_job_alive(redis_client, current_job) -> bool: while True: # Check for zombie state each iteration - if not await check_job_alive(redis_client, current_job): + if not await check_job_alive(redis_client, current_job, session_id): break # ... do work ... """ if current_job: job_exists = await redis_client.exists(f"rq:job:{current_job.id}") if not job_exists: - logger.error(f"🧟 Zombie job detected - job {current_job.id} deleted from Redis, exiting") + # Check if this is a natural exit (session ended) vs true zombie + if session_id: + session_key = f"audio:session:{session_id}" + session_status = await redis_client.hget(session_key, "status") + if session_status and session_status.decode() in ["finalizing", "complete", "closed"]: + # Session ended naturally - not a zombie, just natural cleanup + logger.debug(f"📋 Job {current_job.id} ending naturally (session closed)") + return False + + # True zombie - job deleted while session still active + logger.error(f"🧟 Zombie job detected - job {current_job.id} deleted from Redis while session still active, exiting") return False return True diff --git a/backends/advanced/src/advanced_omi_backend/workers/__init__.py b/backends/advanced/src/advanced_omi_backend/workers/__init__.py index fb32797d..ea82056b 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/__init__.py +++ b/backends/advanced/src/advanced_omi_backend/workers/__init__.py @@ -6,7 +6,7 @@ - speaker_jobs: Speaker recognition and identification - conversation_jobs: Conversation management and updates - memory_jobs: Memory extraction and processing -- audio_jobs: Audio file processing and cropping +- audio_jobs: Audio file processing Queue configuration and utilities are in controllers/queue_controller.py """ @@ -36,9 +36,7 @@ # Import from audio_jobs from .audio_jobs import ( - process_cropping_job, audio_streaming_persistence_job, - enqueue_cropping, ) # Import from queue_controller @@ -78,10 +76,6 @@ "process_memory_job", "enqueue_memory_processing", - # Audio jobs - "process_cropping_job", - "enqueue_cropping", - # Queue utils "get_queue", "get_job_stats", diff --git a/backends/advanced/src/advanced_omi_backend/workers/audio_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/audio_jobs.py index 56df7149..99f6dd53 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/audio_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/audio_jobs.py @@ -21,170 +21,6 @@ logger = logging.getLogger(__name__) -@async_job(redis=True, beanie=True) -async def process_cropping_job( - conversation_id: str, - audio_path: str, - *, - redis_client=None -) -> Dict[str, Any]: - """ - RQ job function for audio cropping - removes silent segments from audio. - - This job: - 1. Reads transcript segments from conversation - 2. Extracts speech timestamps - 3. Creates cropped audio file with only speech segments - 4. Updates conversation with cropped file path - - Args: - conversation_id: Conversation ID - audio_path: Path to original audio file - redis_client: Redis client (injected by decorator) - - Returns: - Dict with processing results - """ - from pathlib import Path - from advanced_omi_backend.utils.audio_utils import _process_audio_cropping_with_relative_timestamps - from advanced_omi_backend.models.conversation import Conversation - from advanced_omi_backend.config import CHUNK_DIR - - try: - logger.info(f"🔄 RQ: Starting audio cropping for conversation {conversation_id}") - - # Get conversation to access segments - conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id) - if not conversation: - raise ValueError(f"Conversation {conversation_id} not found") - - # Extract speech segments from transcript (property returns data from active version) - segments = conversation.segments - if not segments or len(segments) == 0: - logger.warning(f"⚠️ No segments found for conversation {conversation_id}, skipping cropping") - return { - "success": False, - "conversation_id": conversation_id, - "reason": "no_segments" - } - - # Convert segments to (start, end) tuples - speech_segments = [(seg.start, seg.end) for seg in segments] - logger.info(f"Found {len(speech_segments)} speech segments for cropping") - - # Generate output path for cropped audio - audio_uuid = conversation.audio_uuid - - # Build full path from conversation.audio_path (which may include folder prefix) - # conversation.audio_path is like "fixtures/filename.wav" or just "filename.wav" - full_audio_path = CHUNK_DIR / conversation.audio_path - original_path = Path(full_audio_path) - cropped_filename = f"cropped_{original_path.name}" - - # If the conversation's audio_path contains a folder prefix, use the same folder for cropped audio - if conversation.audio_path and "/" in conversation.audio_path: - folder = conversation.audio_path.split("/")[0] - output_dir = CHUNK_DIR / folder - output_dir.mkdir(parents=True, exist_ok=True) - output_path = output_dir / cropped_filename - cropped_path_for_db = f"{folder}/{cropped_filename}" - else: - output_path = CHUNK_DIR / cropped_filename - cropped_path_for_db = cropped_filename - - # Process cropping (no repository needed - we update conversation directly) - success, segment_mapping = await _process_audio_cropping_with_relative_timestamps( - str(original_path), - speech_segments, - str(output_path), - audio_uuid, - None # No repository - we update conversation model directly - ) - - if not success: - logger.error(f"❌ RQ: Audio cropping failed for conversation {conversation_id}") - return { - "success": False, - "conversation_id": conversation_id, - "reason": "cropping_failed" - } - - # Calculate actual cropped duration from kept segments - kept_segments = [m for m in segment_mapping if m["kept"]] - if kept_segments: - # Duration is end of last kept segment - cropped_duration_seconds = kept_segments[-1]["cropped_end"] - else: - cropped_duration_seconds = 0.0 - - # Update segment timestamps using the mapping - # Only keep segments that weren't filtered out - updated_segments = [] - for i, seg in enumerate(segments): - if i >= len(segment_mapping): - logger.warning(f"⚠️ Segment {i} not in mapping, skipping") - continue - - mapping = segment_mapping[i] - if mapping["kept"]: - # Segment was kept - use the cropped timestamps - updated_seg = seg.model_copy() - updated_seg.start = mapping["cropped_start"] - updated_seg.end = mapping["cropped_end"] - updated_segments.append(updated_seg) - logger.debug( - f"Segment {i}: {seg.start:.2f}-{seg.end:.2f}s → " - f"{updated_seg.start:.2f}-{updated_seg.end:.2f}s (in cropped audio)" - ) - else: - # Segment was filtered out (too short) - logger.debug( - f"Segment {i} filtered out (duration {seg.end - seg.start:.2f}s < MIN_SPEECH_SEGMENT_DURATION)" - ) - - # Update conversation with cropped audio path and adjusted segments - conversation.cropped_audio_path = cropped_path_for_db - - # Update the active transcript version segments - # Find and update the version directly in the list to ensure Beanie detects the change - if conversation.active_transcript_version: - for i, version in enumerate(conversation.transcript_versions): - if version.version_id == conversation.active_transcript_version: - conversation.transcript_versions[i].segments = updated_segments - logger.info(f"📝 Updated segments in transcript version {version.version_id[:12]}") - break - - await conversation.save() - logger.info(f"💾 Updated conversation {conversation_id[:12]} with cropped_audio_path and adjusted {len(updated_segments)} segment timestamps") - - logger.info(f"✅ RQ: Completed audio cropping for conversation {conversation_id} ({cropped_duration_seconds:.1f}s)") - - # Update job metadata with cropped duration - from rq import get_current_job - current_job = get_current_job() - if current_job: - if not current_job.meta: - current_job.meta = {} - current_job.meta['cropped_duration_seconds'] = round(cropped_duration_seconds, 1) - current_job.meta['segments_cropped'] = len(speech_segments) - current_job.save_meta() - - return { - "success": True, - "conversation_id": conversation_id, - "audio_uuid": audio_uuid, - "original_path": str(original_path), - "cropped_path": str(output_path), - "cropped_filename": cropped_filename, - "segments_count": len(speech_segments), - "cropped_duration_seconds": cropped_duration_seconds - } - - except Exception as e: - logger.error(f"❌ RQ: Audio cropping failed for conversation {conversation_id}: {e}") - raise - - @async_job(redis=True, beanie=True) async def audio_streaming_persistence_job( session_id: str, @@ -267,7 +103,7 @@ async def audio_streaming_persistence_job( while True: # Check if job still exists in Redis (detect zombie state) - if not await check_job_alive(redis_client, current_job): + if not await check_job_alive(redis_client, current_job, session_id): if file_sink: await file_sink.close() break @@ -380,7 +216,7 @@ async def audio_streaming_persistence_job( # If no file open yet, wait for conversation to be created if not file_sink: - await asyncio.sleep(0.5) + await asyncio.sleep(0.0001) # Minimal sleep to yield to event loop continue # Read audio chunks from stream (non-blocking) @@ -390,7 +226,7 @@ async def audio_streaming_persistence_job( audio_consumer_name, {audio_stream_name: ">"}, count=20, # Read up to 20 chunks at a time for efficiency - block=500 # 500ms timeout + block=100 # 100ms timeout - more responsive ) if audio_messages: @@ -443,7 +279,7 @@ async def audio_streaming_persistence_job( # Stream might not exist yet or other transient errors logger.debug(f"Audio stream read error (non-fatal): {audio_error}") - await asyncio.sleep(0.1) # Check every 100ms for responsiveness + await asyncio.sleep(0.0001) # Minimal sleep to yield to event loop # Job complete - calculate final stats runtime_seconds = time.time() - start_time @@ -480,40 +316,3 @@ async def audio_streaming_persistence_job( # Enqueue wrapper functions - -def enqueue_cropping( - conversation_id: str, - audio_path: str, - priority: JobPriority = JobPriority.NORMAL -): - """ - Enqueue an audio cropping job. - - Args: - conversation_id: Conversation ID - audio_path: Path to audio file - priority: Job priority level - - Returns: - RQ Job object for tracking. - """ - timeout_mapping = { - JobPriority.URGENT: 300, # 5 minutes - JobPriority.HIGH: 240, # 4 minutes - JobPriority.NORMAL: 180, # 3 minutes - JobPriority.LOW: 120 # 2 minutes - } - - job = default_queue.enqueue( - process_cropping_job, - conversation_id, - audio_path, - job_timeout=timeout_mapping.get(priority, 180), - result_ttl=JOB_RESULT_TTL, - job_id=f"crop_{conversation_id[:12]}", - description=f"Crop audio for conversation {conversation_id[:12]}", - meta={'conversation_id': conversation_id} - ) - - logger.info(f"📥 RQ: Enqueued cropping job {job.id} for conversation {conversation_id}") - return job diff --git a/backends/advanced/src/advanced_omi_backend/workers/audio_stream_deepgram_worker.py b/backends/advanced/src/advanced_omi_backend/workers/audio_stream_deepgram_worker.py deleted file mode 100644 index a58682c1..00000000 --- a/backends/advanced/src/advanced_omi_backend/workers/audio_stream_deepgram_worker.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 -""" -Deepgram audio stream worker. - -Starts a consumer that reads from audio:stream:deepgram and transcribes audio. -""" - -import asyncio -import logging -import os -import signal -import sys - -import redis.asyncio as redis - -from advanced_omi_backend.services.transcription.deepgram import DeepgramStreamConsumer - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" -) - -logger = logging.getLogger(__name__) - - -async def main(): - """Main worker entry point.""" - logger.info("🚀 Starting Deepgram audio stream worker") - - # Check that config.yml has Deepgram configured - # The registry provider will load configuration from config.yml - api_key = os.getenv("DEEPGRAM_API_KEY") - if not api_key: - logger.warning("DEEPGRAM_API_KEY environment variable not set") - logger.warning("Ensure config.yml has a default 'stt' model configured for Deepgram") - logger.warning("Audio transcription will use alternative providers if configured in config.yml") - - redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") - - # Create Redis client - redis_client = await redis.from_url( - redis_url, - encoding="utf-8", - decode_responses=False - ) - logger.info("Connected to Redis") - - # Create consumer with balanced buffer size - # 20 chunks = ~5 seconds of audio - # Balance between transcription accuracy and latency - # Consumer uses registry-driven provider from config.yml - consumer = DeepgramStreamConsumer( - redis_client=redis_client, - buffer_chunks=20 # 5 seconds - good context without excessive delay - ) - - # Setup signal handlers for graceful shutdown - def signal_handler(signum, frame): - logger.info(f"Received signal {signum}, shutting down...") - asyncio.create_task(consumer.stop()) - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - try: - logger.info("✅ Deepgram worker ready") - - # This blocks until consumer is stopped - await consumer.start_consuming() - - except Exception as e: - logger.error(f"Worker error: {e}", exc_info=True) - sys.exit(1) - finally: - await redis_client.aclose() - logger.info("👋 Deepgram worker stopped") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/backends/advanced/src/advanced_omi_backend/workers/audio_stream_parakeet_worker.py b/backends/advanced/src/advanced_omi_backend/workers/audio_stream_parakeet_worker.py deleted file mode 100644 index 56f2f26b..00000000 --- a/backends/advanced/src/advanced_omi_backend/workers/audio_stream_parakeet_worker.py +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/env python3 -""" -Parakeet audio stream worker. - -Starts a consumer that reads from audio:stream:* and transcribes audio using Parakeet. -""" - -import asyncio -import logging -import os -import signal -import sys - -import redis.asyncio as redis - -from advanced_omi_backend.services.transcription.parakeet_stream_consumer import ParakeetStreamConsumer - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" -) - -logger = logging.getLogger(__name__) - - -async def main(): - """Main worker entry point.""" - logger.info("🚀 Starting Parakeet audio stream worker") - - # Check that config.yml has Parakeet configured - # The registry provider will load configuration from config.yml - service_url = os.getenv("PARAKEET_ASR_URL") - if not service_url: - logger.warning("PARAKEET_ASR_URL environment variable not set") - logger.warning("Ensure config.yml has a default 'stt' model configured for Parakeet") - logger.warning("Audio transcription will use alternative providers if configured in config.yml") - - redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") - - # Create Redis client - redis_client = await redis.from_url( - redis_url, - encoding="utf-8", - decode_responses=False - ) - logger.info("Connected to Redis") - - # Create consumer with balanced buffer size - # 20 chunks = ~5 seconds of audio - # Balance between transcription accuracy and latency - # Consumer uses registry-driven provider from config.yml - consumer = ParakeetStreamConsumer( - redis_client=redis_client, - buffer_chunks=20 # 5 seconds - good context without excessive delay - ) - - # Setup signal handlers for graceful shutdown - shutdown_event = asyncio.Event() - - def signal_handler(signum, _frame): - logger.info(f"Received signal {signum}, shutting down...") - shutdown_event.set() - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - try: - logger.info("✅ Parakeet worker ready") - - # This blocks until consumer is stopped or shutdown signaled - consume_task = asyncio.create_task(consumer.start_consuming()) - shutdown_task = asyncio.create_task(shutdown_event.wait()) - - done, pending = await asyncio.wait( - [consume_task, shutdown_task], - return_when=asyncio.FIRST_COMPLETED - ) - - # Cancel pending tasks - for task in pending: - task.cancel() - - await consumer.stop() - - except Exception as e: - logger.error(f"Worker error: {e}", exc_info=True) - sys.exit(1) - finally: - await redis_client.aclose() - logger.info("👋 Parakeet worker stopped") - - -if __name__ == "__main__": - asyncio.run(main()) - diff --git a/backends/advanced/src/advanced_omi_backend/workers/audio_stream_worker.py b/backends/advanced/src/advanced_omi_backend/workers/audio_stream_worker.py new file mode 100644 index 00000000..df133de4 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/workers/audio_stream_worker.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +""" +Generic streaming transcription worker using registry-driven providers. + +Starts a consumer that reads from audio:stream:* streams and transcribes via configured provider. +Provider configuration is loaded from config.yml (supports any streaming STT service). +Publishes interim results to Redis Pub/Sub for real-time client display. +Publishes final results to Redis Streams for storage. +Triggers plugins on final results only. +""" + +import asyncio +import logging +import os +import signal +import sys + +import redis.asyncio as redis + +from advanced_omi_backend.services.plugin_service import init_plugin_router +from advanced_omi_backend.services.transcription.streaming_consumer import StreamingTranscriptionConsumer +from advanced_omi_backend.client_manager import initialize_redis_for_client_manager + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" +) + +logger = logging.getLogger(__name__) + + +async def main(): + """Main worker entry point.""" + logger.info("🚀 Starting streaming transcription worker") + logger.info("📋 Provider configuration loaded from config.yml (defaults.stt_stream)") + + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") + + # Create Redis client + try: + redis_client = await redis.from_url( + redis_url, + encoding="utf-8", + decode_responses=False + ) + logger.info(f"✅ Connected to Redis: {redis_url}") + + # Initialize ClientManager Redis for cross-container client→user mapping + initialize_redis_for_client_manager(redis_url) + + except Exception as e: + logger.error(f"Failed to connect to Redis: {e}", exc_info=True) + sys.exit(1) + + # Initialize plugin router + try: + plugin_router = init_plugin_router() + if plugin_router: + logger.info(f"✅ Plugin router initialized with {len(plugin_router.plugins)} plugins") + + # Initialize async plugins + for plugin_id, plugin in plugin_router.plugins.items(): + try: + await plugin.initialize() + logger.info(f"✅ Plugin '{plugin_id}' initialized in streaming worker") + except Exception as e: + logger.exception(f"Failed to initialize plugin '{plugin_id}' in streaming worker: {e}") + else: + logger.warning("No plugin router available - plugins will not be triggered") + except Exception as e: + logger.error(f"Failed to initialize plugin router: {e}", exc_info=True) + plugin_router = None + + # Create streaming transcription consumer (uses registry-driven provider from config.yml) + try: + consumer = StreamingTranscriptionConsumer( + redis_client=redis_client, + plugin_router=plugin_router + ) + logger.info("✅ Streaming transcription consumer created") + except Exception as e: + logger.error(f"Failed to create streaming transcription consumer: {e}", exc_info=True) + logger.error("Ensure config.yml has defaults.stt_stream configured with valid provider") + await redis_client.aclose() + sys.exit(1) + + # Setup signal handlers for graceful shutdown + def signal_handler(signum, frame): + logger.info(f"Received signal {signum}, shutting down...") + asyncio.create_task(consumer.stop()) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + logger.info("✅ Streaming transcription worker ready") + logger.info("📡 Listening for audio streams on audio:stream:* pattern") + logger.info("📢 Publishing interim results to transcription:interim:{session_id}") + logger.info("💾 Publishing final results to transcription:results:{session_id}") + + # This blocks until consumer is stopped + await consumer.start_consuming() + + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down...") + except Exception as e: + logger.error(f"Worker error: {e}", exc_info=True) + sys.exit(1) + finally: + await redis_client.aclose() + logger.info("👋 Streaming transcription worker stopped") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py index d2b8c4fd..1d3f81f3 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py @@ -10,8 +10,12 @@ from datetime import datetime from typing import Dict, Any from rq.job import Job +from rq.exceptions import NoSuchJobError + from advanced_omi_backend.models.job import async_job from advanced_omi_backend.controllers.queue_controller import redis_conn +from advanced_omi_backend.controllers.session_controller import mark_session_complete +from advanced_omi_backend.services.plugin_service import get_plugin_router, init_plugin_router from advanced_omi_backend.utils.conversation_utils import ( analyze_speech, @@ -235,14 +239,35 @@ async def open_conversation_job( # Link job metadata to conversation (cascading updates) current_job.meta["conversation_id"] = conversation_id current_job.save_meta() - speech_job = Job.fetch(speech_job_id, connection=redis_conn) - speech_job.meta["conversation_id"] = conversation_id - speech_job.save_meta() - speaker_check_job_id = speech_job.meta.get("speaker_check_job_id") - if speaker_check_job_id: - speaker_check_job = Job.fetch(speaker_check_job_id, connection=redis_conn) - speaker_check_job.meta["conversation_id"] = conversation_id - speaker_check_job.save_meta() + + try: + speech_job = Job.fetch(speech_job_id, connection=redis_conn) + speech_job.meta["conversation_id"] = conversation_id + speech_job.save_meta() + speaker_check_job_id = speech_job.meta.get("speaker_check_job_id") + if speaker_check_job_id: + try: + speaker_check_job = Job.fetch(speaker_check_job_id, connection=redis_conn) + speaker_check_job.meta["conversation_id"] = conversation_id + speaker_check_job.save_meta() + except Exception as e: + if isinstance(e, NoSuchJobError): + logger.error( + f"❌ Missing job hash for speaker_check job {speaker_check_job_id}: " + f"Job was linked to speech_job {speech_job_id} but hash key disappeared. " + f"This may indicate TTL expiry or job collision." + ) + else: + raise + except Exception as e: + if isinstance(e, NoSuchJobError): + logger.error( + f"❌ Missing job hash for speech_job {speech_job_id}: " + f"Job was created for session {session_id} but hash key disappeared before metadata link. " + f"This may indicate TTL expiry or job collision." + ) + else: + raise # Signal audio persistence job to rotate to this conversation's file rotation_signal_key = f"conversation:current:{session_id}" @@ -283,7 +308,7 @@ async def open_conversation_job( while True: # Check if job still exists in Redis (detect zombie state) from advanced_omi_backend.utils.job_utils import check_job_alive - if not await check_job_alive(redis_client, current_job): + if not await check_job_alive(redis_client, current_job, session_id): break # Check if session is finalizing (set by producer when recording stops) @@ -294,9 +319,9 @@ async def open_conversation_job( if status_str in ["finalizing", "complete"]: finalize_received = True - # Check if this was a WebSocket disconnect + # Get completion reason (guaranteed to exist with unified API) completion_reason = await redis_client.hget(session_key, "completion_reason") - completion_reason_str = completion_reason.decode() if completion_reason else None + completion_reason_str = completion_reason.decode() if completion_reason else "unknown" if completion_reason_str == "websocket_disconnect": logger.warning( @@ -306,7 +331,7 @@ async def open_conversation_job( timeout_triggered = False # This is a disconnect, not a timeout else: logger.info( - f"🛑 Session finalizing (reason: {completion_reason_str or 'user_stopped'}), " + f"🛑 Session finalizing (reason: {completion_reason_str}), " f"waiting for audio persistence job to complete..." ) break # Exit immediately when finalize signal received @@ -398,6 +423,42 @@ async def open_conversation_job( ) last_result_count = current_count + # Trigger transcript-level plugins on new transcript segments + try: + plugin_router = get_plugin_router() + if plugin_router: + # Get the latest transcript text for plugin processing + transcript_text = combined.get('text', '') + + if transcript_text: + plugin_data = { + 'transcript': transcript_text, + 'segment_id': f"{session_id}_{current_count}", + 'conversation_id': conversation_id, + 'segments': combined.get('segments', []), + 'word_count': speech_analysis.get('word_count', 0), + } + + plugin_results = await plugin_router.trigger_plugins( + access_level='streaming_transcript', + user_id=user_id, + data=plugin_data, + metadata={'client_id': client_id} + ) + + if plugin_results: + logger.info(f"📌 Triggered {len(plugin_results)} streaming transcript plugins") + for result in plugin_results: + if result.message: + logger.info(f" Plugin: {result.message}") + + # If plugin stopped processing, log it + if not result.should_continue: + logger.info(f" Plugin stopped normal processing") + + except Exception as e: + logger.warning(f"⚠️ Error triggering transcript-level plugins: {e}") + await asyncio.sleep(1) # Check every second for responsiveness logger.info( @@ -496,6 +557,43 @@ async def open_conversation_job( # Wait a moment to ensure jobs are registered in RQ await asyncio.sleep(0.5) + # Trigger conversation-level plugins + try: + plugin_router = get_plugin_router() + if plugin_router: + # Get conversation data for plugin context + conversation_model = await Conversation.find_one( + Conversation.conversation_id == conversation_id + ) + + plugin_data = { + 'conversation': { + 'conversation_id': conversation_id, + 'audio_uuid': session_id, + 'client_id': client_id, + 'user_id': user_id, + }, + 'transcript': conversation_model.transcript if conversation_model else "", + 'duration': time.time() - start_time, + 'conversation_id': conversation_id, + } + + plugin_results = await plugin_router.dispatch_event( + event='conversation.complete', + user_id=user_id, + data=plugin_data, + metadata={'end_reason': end_reason} + ) + + if plugin_results: + logger.info(f"📌 Triggered {len(plugin_results)} conversation-level plugins") + for result in plugin_results: + if result.message: + logger.info(f" Plugin result: {result.message}") + + except Exception as e: + logger.warning(f"⚠️ Error triggering conversation-level plugins: {e}") + # Call shared cleanup/restart logic return await handle_end_of_conversation( session_id=session_id, @@ -635,3 +733,112 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None) "detailed_summary": conversation.detailed_summary, "processing_time_seconds": processing_time, } + + +@async_job(redis=True, beanie=True) +async def dispatch_conversation_complete_event_job( + conversation_id: str, + audio_uuid: str, + client_id: str, + user_id: str, + *, + redis_client=None +) -> Dict[str, Any]: + """ + Dispatch conversation.complete plugin event for file upload processing. + + This job runs at the end of the post-conversation job chain to ensure + plugins receive the conversation.complete event for uploaded audio files. + WebSocket streaming dispatches this event in open_conversation_job instead. + + Args: + conversation_id: Conversation ID + audio_uuid: Audio UUID + client_id: Client ID + user_id: User ID + redis_client: Redis client (injected by decorator) + + Returns: + Dict with success status and plugin results + """ + from advanced_omi_backend.models.conversation import Conversation + + logger.info(f"📌 Dispatching conversation.complete event for conversation {conversation_id}") + + start_time = time.time() + + # Get the conversation to include in event data + conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id) + if not conversation: + logger.error(f"Conversation {conversation_id} not found") + return {"success": False, "error": "Conversation not found"} + + # Get user email for event data + from advanced_omi_backend.models.user import User + user = await User.get(user_id) + user_email = user.email if user else "" + + # Prepare plugin event data (same format as open_conversation_job) + try: + # Get or initialize plugin router (same pattern as transcription_jobs.py) + plugin_router = get_plugin_router() + if not plugin_router: + logger.info("🔧 Initializing plugin router in worker process...") + plugin_router = init_plugin_router() + + # Initialize all plugins asynchronously (same as app_factory.py) + if plugin_router: + for plugin_id, plugin in plugin_router.plugins.items(): + try: + await plugin.initialize() + logger.info(f"✅ Plugin '{plugin_id}' initialized") + except Exception as e: + logger.error(f"Failed to initialize plugin '{plugin_id}': {e}") + + if not plugin_router: + logger.warning("⚠️ Plugin router could not be initialized, skipping event dispatch") + return {"success": True, "skipped": True, "reason": "No plugin router"} + + plugin_data = { + 'conversation': { + 'audio_uuid': audio_uuid, + 'client_id': client_id, + 'user_id': user_id, + }, + 'transcript': conversation.transcript if conversation else "", + 'duration': 0, # Duration not tracked for file uploads + 'conversation_id': conversation_id, + } + + plugin_results = await plugin_router.dispatch_event( + event='conversation.complete', + user_id=user_id, + data=plugin_data, + metadata={'end_reason': 'file_upload'} + ) + + if plugin_results: + logger.info(f"📌 Triggered {len(plugin_results)} conversation-level plugins") + for result in plugin_results: + if result.message: + logger.info(f" Plugin result: {result.message}") + + processing_time = time.time() - start_time + logger.info( + f"✅ Conversation complete event dispatched for {conversation_id} in {processing_time:.2f}s" + ) + + return { + "success": True, + "conversation_id": conversation_id, + "plugin_count": len(plugin_results) if plugin_results else 0, + "processing_time_seconds": processing_time, + } + + except Exception as e: + logger.warning(f"⚠️ Error dispatching conversation complete event: {e}") + return { + "success": False, + "error": str(e), + "conversation_id": conversation_id, + } diff --git a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py index 8b64d690..ee02b065 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py @@ -16,6 +16,7 @@ ) from advanced_omi_backend.models.job import BaseRQJob, JobPriority, async_job from advanced_omi_backend.services.memory.base import MemoryEntry +from advanced_omi_backend.services.plugin_service import get_plugin_router, init_plugin_router logger = logging.getLogger(__name__) @@ -240,6 +241,55 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict # This allows users to resume talking immediately after conversation closes, # without waiting for memory processing to complete. + # Trigger memory-level plugins + try: + # Get or initialize plugin router (same pattern as conversation_jobs.py) + plugin_router = get_plugin_router() + if not plugin_router: + logger.info("🔧 Initializing plugin router in worker process...") + plugin_router = init_plugin_router() + + # Initialize all plugins asynchronously (same as app_factory.py) + if plugin_router: + for plugin_id, plugin in plugin_router.plugins.items(): + try: + await plugin.initialize() + logger.info(f"✅ Plugin '{plugin_id}' initialized") + except Exception as e: + logger.error(f"Failed to initialize plugin '{plugin_id}': {e}") + + if plugin_router: + plugin_data = { + 'memories': created_memory_ids, + 'conversation': { + 'conversation_id': conversation_id, + 'client_id': client_id, + 'user_id': user_id, + 'user_email': user_email, + }, + 'memory_count': len(created_memory_ids), + 'conversation_id': conversation_id, + } + + plugin_results = await plugin_router.dispatch_event( + event='memory.processed', + user_id=user_id, + data=plugin_data, + metadata={ + 'processing_time': processing_time, + 'memory_provider': str(memory_provider), + } + ) + + if plugin_results: + logger.info(f"📌 Triggered {len(plugin_results)} memory-level plugins") + for result in plugin_results: + if result.message: + logger.info(f" Plugin result: {result.message}") + + except Exception as e: + logger.warning(f"⚠️ Error triggering memory-level plugins: {e}") + return { "success": True, "memories_created": len(created_memory_ids), diff --git a/backends/advanced/src/advanced_omi_backend/workers/orchestrator/__init__.py b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/__init__.py new file mode 100644 index 00000000..1c7b0d7a --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/__init__.py @@ -0,0 +1,28 @@ +""" +Worker Orchestrator Package + +This package provides a Python-based orchestration system for managing +Chronicle's worker processes, replacing the bash-based start-workers.sh script. + +Components: +- config: Worker definitions and orchestrator configuration +- worker_registry: Build worker list with conditional logic +- process_manager: Process lifecycle management +- health_monitor: Health checks and self-healing +""" + +from .config import WorkerDefinition, OrchestratorConfig, WorkerType +from .worker_registry import build_worker_definitions +from .process_manager import ManagedWorker, ProcessManager, WorkerState +from .health_monitor import HealthMonitor + +__all__ = [ + "WorkerDefinition", + "OrchestratorConfig", + "WorkerType", + "build_worker_definitions", + "ManagedWorker", + "ProcessManager", + "WorkerState", + "HealthMonitor", +] diff --git a/backends/advanced/src/advanced_omi_backend/workers/orchestrator/config.py b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/config.py new file mode 100644 index 00000000..633d366a --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/config.py @@ -0,0 +1,91 @@ +""" +Worker Orchestrator Configuration + +Defines data structures for worker definitions and orchestrator configuration. +""" + +import os +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional, Callable, List + + +class WorkerType(Enum): + """Type of worker process""" + + RQ_WORKER = "rq_worker" # RQ queue worker + STREAM_CONSUMER = "stream_consumer" # Redis Streams consumer + + +@dataclass +class WorkerDefinition: + """ + Definition of a single worker process. + + Attributes: + name: Unique identifier for the worker + command: Full command to execute (as list for subprocess) + worker_type: Type of worker (RQ vs stream consumer) + queues: Queue names for RQ workers (empty for stream consumers) + enabled_check: Optional predicate function to determine if worker should start + restart_on_failure: Whether to automatically restart on failure + health_check: Optional custom health check function + """ + + name: str + command: List[str] + worker_type: WorkerType = WorkerType.RQ_WORKER + queues: List[str] = field(default_factory=list) + enabled_check: Optional[Callable[[], bool]] = None + restart_on_failure: bool = True + health_check: Optional[Callable[[], bool]] = None + + def is_enabled(self) -> bool: + """Check if this worker should be started""" + if self.enabled_check is None: + return True + return self.enabled_check() + + +@dataclass +class OrchestratorConfig: + """ + Global configuration for the worker orchestrator. + + All settings can be overridden via environment variables. + """ + + # Redis connection + redis_url: str = field( + default_factory=lambda: os.getenv("REDIS_URL", "redis://localhost:6379/0") + ) + + # Health monitoring settings + check_interval: int = field( + default_factory=lambda: int(os.getenv("WORKER_CHECK_INTERVAL", "10")) + ) + min_rq_workers: int = field( + default_factory=lambda: int(os.getenv("MIN_RQ_WORKERS", "6")) + ) + startup_grace_period: int = field( + default_factory=lambda: int(os.getenv("WORKER_STARTUP_GRACE_PERIOD", "30")) + ) + + # Shutdown settings + shutdown_timeout: int = field( + default_factory=lambda: int(os.getenv("WORKER_SHUTDOWN_TIMEOUT", "30")) + ) + + # Logging + log_level: str = field(default_factory=lambda: os.getenv("LOG_LEVEL", "INFO")) + + def __post_init__(self): + """Validate configuration after initialization""" + if self.check_interval <= 0: + raise ValueError("check_interval must be positive") + if self.min_rq_workers < 0: + raise ValueError("min_rq_workers must be non-negative") + if self.startup_grace_period < 0: + raise ValueError("startup_grace_period must be non-negative") + if self.shutdown_timeout <= 0: + raise ValueError("shutdown_timeout must be positive") diff --git a/backends/advanced/src/advanced_omi_backend/workers/orchestrator/health_monitor.py b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/health_monitor.py new file mode 100644 index 00000000..9b1149e2 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/health_monitor.py @@ -0,0 +1,317 @@ +""" +Health Monitor + +Self-healing monitor that detects and recovers from worker failures. +Periodically checks worker health and restarts failed workers. +""" + +import asyncio +import logging +import time +from typing import Optional + +from redis import Redis +from rq import Worker + +from .config import OrchestratorConfig, WorkerType +from .process_manager import ProcessManager, WorkerState + +logger = logging.getLogger(__name__) + + +class HealthMonitor: + """ + Self-healing monitor for worker processes. + + Periodically checks: + 1. Individual worker health (process liveness) + 2. RQ worker registration count in Redis + + Automatically restarts failed workers if configured. + """ + + def __init__( + self, + process_manager: ProcessManager, + config: OrchestratorConfig, + redis_client: Redis, + ): + self.process_manager = process_manager + self.config = config + self.redis = redis_client + self.running = False + self.monitor_task: Optional[asyncio.Task] = None + self.start_time = time.time() + self.last_registration_recovery: Optional[float] = None + self.registration_recovery_cooldown = 60 # seconds + + async def start(self): + """Start the health monitoring loop""" + if self.running: + logger.warning("Health monitor already running") + return + + self.running = True + self.start_time = time.time() + logger.info( + f"Starting health monitor (check interval: {self.config.check_interval}s, " + f"grace period: {self.config.startup_grace_period}s)" + ) + + self.monitor_task = asyncio.create_task(self._monitor_loop()) + + async def stop(self): + """Stop the health monitoring loop""" + if not self.running: + return + + logger.info("Stopping health monitor...") + self.running = False + + if self.monitor_task: + self.monitor_task.cancel() + try: + await self.monitor_task + except asyncio.CancelledError: + pass + + logger.info("Health monitor stopped") + + async def _monitor_loop(self): + """Main monitoring loop""" + try: + while self.running: + # Wait for startup grace period before starting checks + elapsed = time.time() - self.start_time + if elapsed < self.config.startup_grace_period: + remaining = self.config.startup_grace_period - elapsed + logger.debug( + f"In startup grace period - waiting {remaining:.0f}s before health checks" + ) + await asyncio.sleep(self.config.check_interval) + continue + + # Perform health checks + await self._check_health() + + # Wait for next check + await asyncio.sleep(self.config.check_interval) + + except asyncio.CancelledError: + logger.info("Health monitor loop cancelled") + raise + except Exception as e: + logger.error(f"Health monitor loop error: {e}", exc_info=True) + self.running = False # Mark monitor as stopped so callers know it's not active + raise # Re-raise to ensure the monitor task fails properly + + async def _check_health(self): + """Perform all health checks and restart failed workers""" + try: + # Check individual worker health + worker_health = self._check_worker_health() + + # Check RQ worker registration count + rq_health = self._check_rq_worker_registration() + + # If RQ workers lost registration, trigger bulk restart (matches old bash script behavior) + if not rq_health: + self._handle_registration_loss() + + # Restart failed workers + self._restart_failed_workers() + + # Log summary + if not worker_health or not rq_health: + logger.warning( + f"Health check: worker_health={worker_health}, rq_health={rq_health}" + ) + + except Exception as e: + logger.error(f"Error during health check: {e}", exc_info=True) + + def _check_worker_health(self) -> bool: + """ + Check individual worker health. + + Returns: + True if all workers are healthy + """ + all_healthy = True + + for worker in self.process_manager.get_all_workers(): + try: + is_healthy = worker.check_health() + if not is_healthy: + all_healthy = False + logger.warning( + f"{worker.name}: Health check failed (state={worker.state.value})" + ) + except Exception as e: + all_healthy = False + logger.error(f"{worker.name}: Health check raised exception: {e}") + + return all_healthy + + def _check_rq_worker_registration(self) -> bool: + """ + Check RQ worker registration count in Redis. + + This replicates the bash script's logic: + - Query Redis for all registered RQ workers + - Check if count >= min_rq_workers + + Returns: + True if RQ worker count is sufficient + """ + try: + workers = Worker.all(connection=self.redis) + worker_count = len(workers) + + if worker_count < self.config.min_rq_workers: + logger.warning( + f"RQ worker registration: {worker_count} workers " + f"(expected >= {self.config.min_rq_workers})" + ) + return False + + logger.debug(f"RQ worker registration: {worker_count} workers registered") + return True + + except Exception as e: + logger.error(f"Failed to check RQ worker registration: {e}") + return False + + def _restart_failed_workers(self): + """Restart workers that have failed and should be restarted""" + for worker in self.process_manager.get_all_workers(): + # Only restart if: + # 1. Worker state is FAILED + # 2. Worker definition has restart_on_failure=True + if ( + worker.state == WorkerState.FAILED + and worker.definition.restart_on_failure + ): + logger.warning( + f"{worker.name}: Worker failed, initiating restart " + f"(restart count: {worker.restart_count})" + ) + + success = self.process_manager.restart_worker(worker.name) + + if success: + logger.info( + f"{worker.name}: Restart successful " + f"(total restarts: {worker.restart_count})" + ) + else: + logger.error(f"{worker.name}: Restart failed") + + def _handle_registration_loss(self): + """ + Handle RQ worker registration loss. + + This replicates the old bash script's self-healing behavior: + - Check if cooldown period has passed + - Restart all RQ workers (bulk restart) + - Update recovery timestamp + + Cooldown prevents too-frequent restarts during Redis/network issues. + """ + current_time = time.time() + + # Check if cooldown period has passed + if self.last_registration_recovery is not None: + elapsed = current_time - self.last_registration_recovery + if elapsed < self.registration_recovery_cooldown: + remaining = self.registration_recovery_cooldown - elapsed + logger.debug( + f"Registration recovery cooldown active - " + f"waiting {remaining:.0f}s before next recovery attempt" + ) + return + + logger.warning( + "⚠️ RQ worker registration loss detected - initiating bulk restart " + "(replicating old start-workers.sh behavior)" + ) + + # Restart all RQ workers + success = self._restart_all_rq_workers() + + if success: + logger.info("✅ Bulk restart completed - workers should re-register soon") + else: + logger.error("❌ Bulk restart encountered errors - check individual worker logs") + + # Update recovery timestamp to start cooldown + self.last_registration_recovery = current_time + + def _restart_all_rq_workers(self) -> bool: + """ + Restart all RQ workers (bulk restart). + + This matches the old bash script's recovery mechanism: + - Kill all RQ workers + - Restart them + - Workers will automatically re-register with Redis on startup + + Returns: + True if all RQ workers restarted successfully, False otherwise + """ + rq_workers = [ + worker + for worker in self.process_manager.get_all_workers() + if worker.definition.worker_type == WorkerType.RQ_WORKER + ] + + if not rq_workers: + logger.warning("No RQ workers found to restart") + return False + + logger.info(f"Restarting {len(rq_workers)} RQ workers...") + + all_success = True + for worker in rq_workers: + logger.info(f" ↻ Restarting {worker.name}...") + success = self.process_manager.restart_worker(worker.name) + + if success: + logger.info(f" ✓ {worker.name} restarted successfully") + else: + logger.error(f" ✗ {worker.name} restart failed") + all_success = False + + return all_success + + def get_health_status(self) -> dict: + """ + Get current health status summary. + + Returns: + Dictionary with health status information + """ + worker_status = self.process_manager.get_status() + + # Count workers by state + state_counts = {} + for status in worker_status.values(): + state = status["state"] + state_counts[state] = state_counts.get(state, 0) + 1 + + # Check RQ worker registration + try: + rq_workers = Worker.all(connection=self.redis) + rq_worker_count = len(rq_workers) + except Exception: + rq_worker_count = -1 # Error indicator + + return { + "running": self.running, + "uptime": time.time() - self.start_time if self.running else 0, + "total_workers": len(worker_status), + "state_counts": state_counts, + "rq_worker_count": rq_worker_count, + "min_rq_workers": self.config.min_rq_workers, + "rq_healthy": rq_worker_count >= self.config.min_rq_workers, + } diff --git a/backends/advanced/src/advanced_omi_backend/workers/orchestrator/process_manager.py b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/process_manager.py new file mode 100644 index 00000000..21b7f23e --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/process_manager.py @@ -0,0 +1,305 @@ +""" +Process Manager + +Manages lifecycle of all worker processes with state tracking. +Handles process creation, monitoring, and graceful shutdown. +""" + +import logging +import subprocess +import time +from enum import Enum +from typing import Dict, List, Optional + +from .config import WorkerDefinition + +logger = logging.getLogger(__name__) + + +class WorkerState(Enum): + """Worker process lifecycle states""" + + PENDING = "pending" # Not yet started + STARTING = "starting" # Process started, waiting for health check + RUNNING = "running" # Healthy and running + UNHEALTHY = "unhealthy" # Running but health check failed + STOPPING = "stopping" # Shutdown initiated + STOPPED = "stopped" # Cleanly stopped + FAILED = "failed" # Crashed or failed to start + + +class ManagedWorker: + """ + Wraps a single worker process with state tracking. + + Attributes: + definition: Worker definition + process: Subprocess.Popen object (None if not started) + state: Current worker state + start_time: Timestamp when worker was started + restart_count: Number of times worker has been restarted + last_health_check: Timestamp of last health check + """ + + def __init__(self, definition: WorkerDefinition): + self.definition = definition + self.process: Optional[subprocess.Popen] = None + self.state = WorkerState.PENDING + self.start_time: Optional[float] = None + self.restart_count = 0 + self.last_health_check: Optional[float] = None + + @property + def name(self) -> str: + """Worker name""" + return self.definition.name + + @property + def pid(self) -> Optional[int]: + """Process ID (None if not started)""" + return self.process.pid if self.process else None + + @property + def is_alive(self) -> bool: + """Check if process is alive""" + if not self.process: + return False + return self.process.poll() is None + + def start(self) -> bool: + """ + Start the worker process. + + Returns: + True if started successfully, False otherwise + """ + if self.process and self.is_alive: + logger.warning(f"{self.name}: Already running (PID {self.pid})") + return False + + try: + logger.info(f"{self.name}: Starting worker...") + logger.debug(f"{self.name}: Command: {' '.join(self.definition.command)}") + + # Don't capture stdout/stderr - let it flow to container logs (Docker captures it) + # This prevents buffer overflow and blocking when worker output exceeds 64KB + # Worker logs will be visible via 'docker logs' command + self.process = subprocess.Popen( + self.definition.command, + stdout=None, # Inherit from parent (goes to container stdout) + stderr=None, # Inherit from parent (goes to container stderr) + ) + + self.state = WorkerState.STARTING + self.start_time = time.time() + + logger.info(f"{self.name}: Started with PID {self.pid}") + return True + + except Exception as e: + logger.error(f"{self.name}: Failed to start: {e}") + self.state = WorkerState.FAILED + return False + + def stop(self, timeout: int = 30) -> bool: + """ + Gracefully stop the worker process. + + Args: + timeout: Maximum wait time in seconds + + Returns: + True if stopped successfully, False otherwise + """ + if not self.process or not self.is_alive: + logger.debug(f"{self.name}: Already stopped") + self.state = WorkerState.STOPPED + return True + + try: + logger.info(f"{self.name}: Stopping worker (PID {self.pid})...") + self.state = WorkerState.STOPPING + + # Send SIGTERM for graceful shutdown + self.process.terminate() + + # Wait for process to exit + try: + self.process.wait(timeout=timeout) + logger.info(f"{self.name}: Stopped gracefully") + self.state = WorkerState.STOPPED + return True + + except subprocess.TimeoutExpired: + # Force kill if timeout exceeded + logger.warning( + f"{self.name}: Timeout expired, force killing (SIGKILL)..." + ) + self.process.kill() + self.process.wait(timeout=5) + logger.warning(f"{self.name}: Force killed") + self.state = WorkerState.STOPPED + return True + + except Exception as e: + logger.error(f"{self.name}: Error during shutdown: {e}") + self.state = WorkerState.FAILED + return False + + def check_health(self) -> bool: + """ + Check worker health. + + Returns: + True if healthy, False otherwise + """ + self.last_health_check = time.time() + + # Basic liveness check + if not self.is_alive: + logger.warning(f"{self.name}: Process is not alive") + self.state = WorkerState.FAILED + return False + + # Custom health check if defined + if self.definition.health_check: + try: + if not self.definition.health_check(): + logger.warning(f"{self.name}: Custom health check failed") + self.state = WorkerState.UNHEALTHY + return False + except Exception as e: + logger.error(f"{self.name}: Health check raised exception: {e}") + self.state = WorkerState.UNHEALTHY + return False + + # Update state if currently starting + if self.state == WorkerState.STARTING: + self.state = WorkerState.RUNNING + + return True + + +class ProcessManager: + """ + Manages all worker processes. + + Provides high-level API for starting, stopping, and monitoring workers. + """ + + def __init__(self, worker_definitions: List[WorkerDefinition]): + self.workers: Dict[str, ManagedWorker] = { + defn.name: ManagedWorker(defn) for defn in worker_definitions + } + logger.info(f"ProcessManager initialized with {len(self.workers)} workers") + + def start_all(self) -> bool: + """ + Start all workers. + + Returns: + True if all workers started successfully + """ + logger.info("Starting all workers...") + success = True + + for worker in self.workers.values(): + if not worker.start(): + success = False + + if success: + logger.info("All workers started successfully") + else: + logger.warning("Some workers failed to start") + + return success + + def stop_all(self, timeout: int = 30) -> bool: + """ + Stop all workers gracefully. + + Args: + timeout: Maximum wait time per worker in seconds + + Returns: + True if all workers stopped successfully + """ + logger.info("Stopping all workers...") + success = True + + for worker in self.workers.values(): + if not worker.stop(timeout=timeout): + success = False + + if success: + logger.info("All workers stopped successfully") + else: + logger.warning("Some workers failed to stop cleanly") + + return success + + def restart_worker(self, name: str, timeout: int = 30) -> bool: + """ + Restart a specific worker. + + Args: + name: Worker name + timeout: Maximum wait time for shutdown in seconds + + Returns: + True if restarted successfully + """ + worker = self.workers.get(name) + if not worker: + logger.error(f"Worker '{name}' not found") + return False + + logger.info(f"Restarting worker: {name}") + + # Ensure worker is fully stopped before attempting restart + stop_success = worker.stop(timeout=timeout) + if not stop_success: + logger.error(f"{name}: Failed to stop cleanly, restart aborted") + worker.state = WorkerState.FAILED + return False + + # Attempt to start the worker + success = worker.start() + + if success: + worker.restart_count += 1 + logger.info(f"{name}: Restart #{worker.restart_count} successful") + else: + logger.error(f"{name}: Restart failed") + + return success + + def get_status(self) -> Dict[str, Dict]: + """ + Get detailed status of all workers. + + Returns: + Dictionary mapping worker name to status info + """ + status = {} + + for name, worker in self.workers.items(): + status[name] = { + "pid": worker.pid, + "state": worker.state.value, + "is_alive": worker.is_alive, + "restart_count": worker.restart_count, + "start_time": worker.start_time, + "last_health_check": worker.last_health_check, + "queues": worker.definition.queues, + } + + return status + + def get_worker(self, name: str) -> Optional[ManagedWorker]: + """Get worker by name""" + return self.workers.get(name) + + def get_all_workers(self) -> List[ManagedWorker]: + """Get all workers""" + return list(self.workers.values()) diff --git a/backends/advanced/src/advanced_omi_backend/workers/orchestrator/worker_registry.py b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/worker_registry.py new file mode 100644 index 00000000..a5cf4b74 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/worker_registry.py @@ -0,0 +1,137 @@ +""" +Worker Registry + +Builds the complete list of worker definitions with conditional logic. +Reuses model_registry.py for config.yml parsing. +""" + +import os +import logging +from typing import List + +from .config import WorkerDefinition, WorkerType + +logger = logging.getLogger(__name__) + + +def has_streaming_stt_configured() -> bool: + """ + Check if streaming STT provider is configured in config.yml. + + Returns: + True if defaults.stt_stream is configured, False otherwise + + Note: Batch STT is handled by RQ workers in transcription_jobs.py, + no separate worker needed. + """ + try: + from advanced_omi_backend.model_registry import get_models_registry + + registry = get_models_registry() + if registry and registry.defaults: + stt_stream_model = registry.get_default("stt_stream") + return stt_stream_model is not None + except Exception as e: + logger.warning(f"Failed to read streaming STT config from config.yml: {e}") + + return False + + +def build_worker_definitions() -> List[WorkerDefinition]: + """ + Build the complete list of worker definitions. + + Returns: + List of WorkerDefinition objects, including conditional workers + """ + workers = [] + + # 6x RQ Workers - Multi-queue workers (transcription, memory, default) + for i in range(1, 7): + workers.append( + WorkerDefinition( + name=f"rq-worker-{i}", + command=[ + "uv", + "run", + "python", + "-m", + "advanced_omi_backend.workers.rq_worker_entry", + "transcription", + "memory", + "default", + ], + worker_type=WorkerType.RQ_WORKER, + queues=["transcription", "memory", "default"], + restart_on_failure=True, + ) + ) + + # Audio Persistence Workers - Single-queue workers (audio queue) + # Multiple workers allow concurrent audio persistence for multiple sessions + for i in range(1, 4): # 3 audio workers + workers.append( + WorkerDefinition( + name=f"audio-persistence-{i}", + command=[ + "uv", + "run", + "python", + "-m", + "advanced_omi_backend.workers.rq_worker_entry", + "audio", + ], + worker_type=WorkerType.RQ_WORKER, + queues=["audio"], + restart_on_failure=True, + ) + ) + + # Streaming STT Worker - Conditional (if streaming STT is configured in config.yml) + # This worker uses the registry-driven streaming provider (RegistryStreamingTranscriptionProvider) + # Batch transcription happens via RQ jobs in transcription_jobs.py (already uses registry provider) + workers.append( + WorkerDefinition( + name="streaming-stt", + command=[ + "uv", + "run", + "python", + "-m", + "advanced_omi_backend.workers.audio_stream_worker", + ], + worker_type=WorkerType.STREAM_CONSUMER, + enabled_check=has_streaming_stt_configured, + restart_on_failure=True, + ) + ) + + # Log worker configuration + try: + from advanced_omi_backend.model_registry import get_models_registry + registry = get_models_registry() + if registry: + stt_stream = registry.get_default("stt_stream") + stt_batch = registry.get_default("stt") + if stt_stream: + logger.info(f"Streaming STT configured: {stt_stream.name} ({stt_stream.model_provider})") + if stt_batch: + logger.info(f"Batch STT configured: {stt_batch.name} ({stt_batch.model_provider}) - handled by RQ workers") + except Exception as e: + logger.warning(f"Failed to log STT configuration: {e}") + + enabled_workers = [w for w in workers if w.is_enabled()] + disabled_workers = [w for w in workers if not w.is_enabled()] + + logger.info(f"Total workers configured: {len(workers)}") + logger.info(f"Enabled workers: {len(enabled_workers)}") + logger.info( + f"Enabled worker names: {', '.join([w.name for w in enabled_workers])}" + ) + + if disabled_workers: + logger.info( + f"Disabled workers: {', '.join([w.name for w in disabled_workers])}" + ) + + return enabled_workers diff --git a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py index c9216d4f..b37f6454 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py @@ -9,6 +9,7 @@ import logging import time from typing import Dict, Any +from rq.exceptions import NoSuchJobError from advanced_omi_backend.models.job import JobPriority, BaseRQJob, async_job @@ -19,6 +20,7 @@ REDIS_URL, ) from advanced_omi_backend.utils.conversation_utils import analyze_speech, mark_conversation_deleted +from advanced_omi_backend.services.plugin_service import get_plugin_router logger = logging.getLogger(__name__) @@ -167,6 +169,10 @@ async def transcribe_full_audio_job( if not conversation: raise ValueError(f"Conversation {conversation_id} not found") + # Extract user_id and client_id for plugin context + user_id = str(conversation.user_id) if conversation.user_id else None + client_id = conversation.client_id if hasattr(conversation, 'client_id') else None + # Use the provided audio path actual_audio_path = audio_path logger.info(f"📁 Using audio for transcription: {audio_path}") @@ -202,6 +208,62 @@ async def transcribe_full_audio_job( f"📊 Transcription complete: {len(transcript_text)} chars, {len(segments)} segments, {len(words)} words" ) + # Trigger transcript-level plugins BEFORE speech validation + # This ensures wake-word commands execute even if conversation gets deleted + logger.info(f"🔍 DEBUG: About to trigger plugins - transcript_text exists: {bool(transcript_text)}") + if transcript_text: + try: + from advanced_omi_backend.services.plugin_service import init_plugin_router + + # Initialize plugin router if not already initialized (worker context) + plugin_router = get_plugin_router() + logger.info(f"🔍 DEBUG: Plugin router from service: {plugin_router is not None}") + + if not plugin_router: + logger.info("🔧 Initializing plugin router in worker process...") + plugin_router = init_plugin_router() + logger.info(f"🔧 After init, plugin_router: {plugin_router is not None}, plugins count: {len(plugin_router.plugins) if plugin_router else 0}") + + # Initialize async plugins + if plugin_router: + for plugin_id, plugin in plugin_router.plugins.items(): + try: + await plugin.initialize() + logger.info(f"✅ Plugin '{plugin_id}' initialized in worker") + except Exception as e: + logger.exception(f"Failed to initialize plugin '{plugin_id}' in worker: {e}") + + logger.info(f"🔍 DEBUG: Plugin router final check: {plugin_router is not None}, has {len(plugin_router.plugins) if plugin_router else 0} plugins") + + if plugin_router: + logger.info(f"🔍 DEBUG: Preparing to trigger transcript plugins for conversation {conversation_id}") + plugin_data = { + 'transcript': transcript_text, + 'segment_id': f"{conversation_id}_batch", + 'conversation_id': conversation_id, + 'segments': segments, + 'word_count': len(words), + } + + logger.info(f"🔍 DEBUG: Dispatching transcript.batch event with user_id={user_id}, client_id={client_id}") + plugin_results = await plugin_router.dispatch_event( + event='transcript.batch', + user_id=user_id, + data=plugin_data, + metadata={'client_id': client_id} + ) + logger.info(f"🔍 DEBUG: Event dispatch returned {len(plugin_results) if plugin_results else 0} results") + + if plugin_results: + logger.info(f"✅ Triggered {len(plugin_results)} transcript plugins in batch mode") + for result in plugin_results: + if result.message: + logger.info(f" Plugin: {result.message}") + except Exception as e: + logger.exception(f"⚠️ Error triggering transcript plugins in batch mode: {e}") + + logger.info(f"🔍 DEBUG: Plugin processing complete, moving to speech validation") + # Validate meaningful speech BEFORE any further processing transcript_data = {"text": transcript_text, "words": words} speech_analysis = analyze_speech(transcript_data) @@ -250,7 +312,10 @@ async def transcribe_full_audio_job( cancelled_jobs.append(job_id) logger.info(f"✅ Cancelled dependent job: {job_id}") except Exception as e: - logger.debug(f"Job {job_id} not found or already completed: {e}") + if isinstance(e, NoSuchJobError): + logger.debug(f"Job {job_id} hash not found (likely already completed or expired)") + else: + logger.debug(f"Job {job_id} not found or already completed: {e}") if cancelled_jobs: logger.info( @@ -286,7 +351,7 @@ async def transcribe_full_audio_job( for seg in segments: # Use identified_as if available (from speaker recognition), otherwise use speaker label speaker_id = seg.get("identified_as") or seg.get("speaker", "Unknown") - # Convert speaker ID to string if it's an integer (Deepgram returns int speaker IDs) + # Convert speaker ID to string if it's an integer (some providers return int speaker IDs) speaker_name = f"Speaker {speaker_id}" if isinstance(speaker_id, int) else speaker_id speaker_segments.append( @@ -299,8 +364,8 @@ async def transcribe_full_audio_job( ) ) elif transcript_text: - # NOTE: Parakeet falls here. - # If no segments but we have text, create a single segment from the full transcript + # Fallback: If no segments but we have text, create a single segment from the full transcript + # This handles providers that don't support segmentation # Calculate duration from words if available, otherwise estimate from audio start_time_seg = 0.0 end_time_seg = 0.0 @@ -526,18 +591,30 @@ async def stream_speech_detection_job( ) current_job.save_meta() + # Track when session closes for graceful shutdown + session_closed_at = None + final_check_grace_period = 15 # Wait up to 15 seconds for final transcription after session closes + # Main loop: Listen for speech while True: # Check if job still exists in Redis (detect zombie state) from advanced_omi_backend.utils.job_utils import check_job_alive - if not await check_job_alive(redis_client, current_job): + if not await check_job_alive(redis_client, current_job, session_id): break - # Exit conditions + # Check if session has closed session_status = await redis_client.hget(session_key, "status") - if session_status and session_status.decode() in ["complete", "closed"]: - logger.info(f"🛑 Session ended, exiting") + session_closed = session_status and session_status.decode() in ["complete", "closed"] + + if session_closed and session_closed_at is None: + # Session just closed - start grace period for final transcription + session_closed_at = time.time() + logger.info(f"🛑 Session closed, waiting up to {final_check_grace_period}s for final transcription results...") + + # Exit if grace period expired without speech + if session_closed_at and (time.time() - session_closed_at) > final_check_grace_period: + logger.info(f"✅ Session ended without speech (grace period expired)") break if time.time() - start_time > max_runtime: @@ -547,11 +624,35 @@ async def stream_speech_detection_job( # Get transcription results combined = await aggregator.get_combined_results(session_id) if not combined["text"]: + # Health check: detect transcription errors early during grace period + if session_closed_at: + # Check for streaming consumer errors in session metadata + error_status = await redis_client.hget(session_key, "transcription_error") + if error_status: + error_msg = error_status.decode() + logger.warning(f"❌ Transcription error detected: {error_msg}") + logger.info(f"✅ Session ended without speech (transcription error)") + break + + # Check if we've been waiting too long with no results at all + grace_elapsed = time.time() - session_closed_at + if grace_elapsed > 5 and not combined.get("chunk_count", 0): + # 5+ seconds with no transcription activity at all - likely API key issue + logger.warning(f"⚠️ No transcription activity after {grace_elapsed:.1f}s - possible API key or connectivity issue") + logger.info(f"✅ Session ended without speech (no transcription activity)") + break + await asyncio.sleep(2) continue # Step 1: Check for meaningful speech transcript_data = {"text": combined["text"], "words": combined.get("words", [])} + + logger.info( + f"🔤 TRANSCRIPT [SPEECH_DETECT] session={session_id}, " + f"words={len(combined.get('words', []))}, text=\"{combined['text']}\"" + ) + speech_analysis = analyze_speech(transcript_data) logger.info( @@ -610,8 +711,6 @@ async def stream_speech_detection_job( try: speaker_check_job.refresh() except Exception as e: - from rq.exceptions import NoSuchJobError - if isinstance(e, NoSuchJobError): logger.warning( f"⚠️ Speaker check job disappeared from Redis (likely completed quickly), assuming not enrolled" diff --git a/backends/advanced/start-k8s.sh b/backends/advanced/start-k8s.sh index a2f3d817..847e3a6e 100755 --- a/backends/advanced/start-k8s.sh +++ b/backends/advanced/start-k8s.sh @@ -79,19 +79,20 @@ sleep 1 # Function to start all workers start_workers() { - # NEW WORKERS - Redis Streams multi-provider architecture - # Single worker ensures sequential processing of audio chunks (matching start-workers.sh) - echo "🎵 Starting audio stream Deepgram worker (1 worker for sequential processing)..." - if python3 -m advanced_omi_backend.workers.audio_stream_deepgram_worker & + # NEW WORKERS - Registry-driven streaming transcription architecture + # Single worker ensures sequential processing of audio chunks (matching worker_orchestrator.py) + # Uses config.yml for provider selection (Deepgram, Parakeet, etc.) + echo "🎵 Starting streaming transcription worker (registry-driven provider from config.yml)..." + if python3 -m advanced_omi_backend.workers.audio_stream_worker & then AUDIO_WORKER_1_PID=$! - echo " ✅ Deepgram stream worker started with PID: $AUDIO_WORKER_1_PID" + echo " ✅ Streaming transcription worker started with PID: $AUDIO_WORKER_1_PID" else - echo " ❌ Failed to start Deepgram stream worker" + echo " ❌ Failed to start streaming transcription worker" exit 1 fi - # Start 3 RQ workers listening to ALL queues (matching start-workers.sh) + # Start 3 RQ workers listening to ALL queues (matching worker_orchestrator.py) echo "🔧 Starting RQ workers (3 workers, all queues: transcription, memory, default)..." if python3 -m advanced_omi_backend.workers.rq_worker_entry transcription memory default & then @@ -123,7 +124,7 @@ start_workers() { exit 1 fi - # Start 1 dedicated audio persistence worker (matching start-workers.sh) + # Start 1 dedicated audio persistence worker (matching worker_orchestrator.py) echo "💾 Starting audio persistence worker (1 worker for audio queue)..." if python3 -m advanced_omi_backend.workers.rq_worker_entry audio & then diff --git a/backends/advanced/start-workers.sh b/backends/advanced/start-workers.sh deleted file mode 100755 index 3fea5a39..00000000 --- a/backends/advanced/start-workers.sh +++ /dev/null @@ -1,204 +0,0 @@ -#!/bin/bash -# Unified worker startup script -# Starts all workers in a single container for efficiency - -set -e - -echo "🚀 Starting Chronicle Workers..." - -# Clean up any stale worker registrations from previous runs -echo "🧹 Cleaning up stale worker registrations from Redis..." -# Use RQ's cleanup command to remove dead workers -uv run python -c " -from rq import Worker -from redis import Redis -import os -import socket - -redis_url = os.getenv('REDIS_URL', 'redis://localhost:6379/0') -redis_conn = Redis.from_url(redis_url) -hostname = socket.gethostname() - -# Only clean up workers from THIS hostname (pod) -workers = Worker.all(connection=redis_conn) -cleaned = 0 -for worker in workers: - if worker.hostname == hostname: - worker.register_death() - cleaned += 1 -print(f'Cleaned up {cleaned} stale workers from {hostname}') -" 2>/dev/null || echo "No stale workers to clean" - -sleep 1 - -# Function to start all workers -start_workers() { - echo "🔧 Starting RQ workers (6 workers, all queues: transcription, memory, default)..." - uv run python -m advanced_omi_backend.workers.rq_worker_entry transcription memory default & - RQ_WORKER_1_PID=$! - uv run python -m advanced_omi_backend.workers.rq_worker_entry transcription memory default & - RQ_WORKER_2_PID=$! - uv run python -m advanced_omi_backend.workers.rq_worker_entry transcription memory default & - RQ_WORKER_3_PID=$! - uv run python -m advanced_omi_backend.workers.rq_worker_entry transcription memory default & - RQ_WORKER_4_PID=$! - uv run python -m advanced_omi_backend.workers.rq_worker_entry transcription memory default & - RQ_WORKER_5_PID=$! - uv run python -m advanced_omi_backend.workers.rq_worker_entry transcription memory default & - RQ_WORKER_6_PID=$! - - echo "💾 Starting audio persistence worker (1 worker for audio queue)..." - uv run python -m advanced_omi_backend.workers.rq_worker_entry audio & - AUDIO_PERSISTENCE_WORKER_PID=$! - - # Determine which STT provider to use from config.yml - echo "📋 Checking config.yml for default STT provider..." - DEFAULT_STT=$(uv run python -c " -from advanced_omi_backend.model_registry import get_models_registry -registry = get_models_registry() -if registry and registry.defaults: - stt_model = registry.get_default('stt') - if stt_model: - print(stt_model.model_provider or '') -" 2>/dev/null || echo "") - - echo "📋 Configured STT provider: ${DEFAULT_STT:-none}" - - # Only start Deepgram worker if configured as default STT - if [[ "$DEFAULT_STT" == "deepgram" ]] && [ -n "$DEEPGRAM_API_KEY" ]; then - echo "🎵 Starting audio stream Deepgram worker (1 worker for sequential processing)..." - uv run python -m advanced_omi_backend.workers.audio_stream_deepgram_worker & - AUDIO_STREAM_DEEPGRAM_WORKER_PID=$! - else - echo "⏭️ Skipping Deepgram stream worker (not configured as default STT or API key missing)" - AUDIO_STREAM_DEEPGRAM_WORKER_PID="" - fi - - # Only start Parakeet worker if configured as default STT - if [[ "$DEFAULT_STT" == "parakeet" ]]; then - echo "🎵 Starting audio stream Parakeet worker (1 worker for sequential processing)..." - uv run python -m advanced_omi_backend.workers.audio_stream_parakeet_worker & - AUDIO_STREAM_PARAKEET_WORKER_PID=$! - else - echo "⏭️ Skipping Parakeet stream worker (not configured as default STT)" - AUDIO_STREAM_PARAKEET_WORKER_PID="" - fi - - echo "✅ All workers started:" - echo " - RQ worker 1: PID $RQ_WORKER_1_PID (transcription, memory, default)" - echo " - RQ worker 2: PID $RQ_WORKER_2_PID (transcription, memory, default)" - echo " - RQ worker 3: PID $RQ_WORKER_3_PID (transcription, memory, default)" - echo " - RQ worker 4: PID $RQ_WORKER_4_PID (transcription, memory, default)" - echo " - RQ worker 5: PID $RQ_WORKER_5_PID (transcription, memory, default)" - echo " - RQ worker 6: PID $RQ_WORKER_6_PID (transcription, memory, default)" - echo " - Audio persistence worker: PID $AUDIO_PERSISTENCE_WORKER_PID (audio queue - file rotation)" - [ -n "$AUDIO_STREAM_DEEPGRAM_WORKER_PID" ] && echo " - Audio stream Deepgram worker: PID $AUDIO_STREAM_DEEPGRAM_WORKER_PID (Redis Streams consumer)" || true - [ -n "$AUDIO_STREAM_PARAKEET_WORKER_PID" ] && echo " - Audio stream Parakeet worker: PID $AUDIO_STREAM_PARAKEET_WORKER_PID (Redis Streams consumer)" || true -} - -# Function to check worker registration health -check_worker_health() { - WORKER_COUNT=$(uv run python -c " -from rq import Worker -from redis import Redis -import os -import sys - -try: - redis_url = os.getenv('REDIS_URL', 'redis://localhost:6379/0') - r = Redis.from_url(redis_url) - workers = Worker.all(connection=r) - print(len(workers)) -except Exception as e: - print('0', file=sys.stderr) - sys.exit(1) -" 2>/dev/null || echo "0") - echo "$WORKER_COUNT" -} - -# Self-healing monitoring function -monitor_worker_health() { - local CHECK_INTERVAL=10 # Check every 10 seconds - local MIN_WORKERS=6 # Expect at least 6 RQ workers - - echo "🩺 Starting self-healing monitor (check interval: ${CHECK_INTERVAL}s, min workers: ${MIN_WORKERS})" - - while true; do - sleep $CHECK_INTERVAL - - WORKER_COUNT=$(check_worker_health) - - if [ "$WORKER_COUNT" -lt "$MIN_WORKERS" ]; then - echo "⚠️ Self-healing: Only $WORKER_COUNT workers registered (expected >= $MIN_WORKERS)" - echo "🔧 Self-healing: Restarting all workers to restore registration..." - - # Kill all workers - kill $RQ_WORKER_1_PID $RQ_WORKER_2_PID $RQ_WORKER_3_PID $RQ_WORKER_4_PID $RQ_WORKER_5_PID $RQ_WORKER_6_PID $AUDIO_PERSISTENCE_WORKER_PID 2>/dev/null || true - [ -n "$AUDIO_STREAM_DEEPGRAM_WORKER_PID" ] && kill $AUDIO_STREAM_DEEPGRAM_WORKER_PID 2>/dev/null || true - [ -n "$AUDIO_STREAM_PARAKEET_WORKER_PID" ] && kill $AUDIO_STREAM_PARAKEET_WORKER_PID 2>/dev/null || true - wait 2>/dev/null || true - - # Restart workers - start_workers - - # Verify recovery - sleep 3 - NEW_WORKER_COUNT=$(check_worker_health) - echo "✅ Self-healing: Workers restarted - new count: $NEW_WORKER_COUNT" - fi - done -} - -# Function to handle shutdown -shutdown() { - echo "🛑 Shutting down workers..." - kill $MONITOR_PID 2>/dev/null || true - kill $RQ_WORKER_1_PID 2>/dev/null || true - kill $RQ_WORKER_2_PID 2>/dev/null || true - kill $RQ_WORKER_3_PID 2>/dev/null || true - kill $RQ_WORKER_4_PID 2>/dev/null || true - kill $RQ_WORKER_5_PID 2>/dev/null || true - kill $RQ_WORKER_6_PID 2>/dev/null || true - kill $AUDIO_PERSISTENCE_WORKER_PID 2>/dev/null || true - [ -n "$AUDIO_STREAM_DEEPGRAM_WORKER_PID" ] && kill $AUDIO_STREAM_DEEPGRAM_WORKER_PID 2>/dev/null || true - [ -n "$AUDIO_STREAM_PARAKEET_WORKER_PID" ] && kill $AUDIO_STREAM_PARAKEET_WORKER_PID 2>/dev/null || true - wait - echo "✅ All workers stopped" - exit 0 -} - -# Set up signal handlers -trap shutdown SIGTERM SIGINT - -# Configure Python logging for RQ workers -export PYTHONUNBUFFERED=1 - -# Start all workers -start_workers - -# Start self-healing monitor in background -monitor_worker_health & -MONITOR_PID=$! -echo "🩺 Self-healing monitor started: PID $MONITOR_PID" - -# Keep the script running and let the self-healing monitor handle worker failures -# Don't use wait -n (fail-fast on first worker exit) - this kills all workers when one fails -# Instead, wait for the monitor process or explicit shutdown signal -echo "⏳ Workers running - self-healing monitor will restart failed workers automatically" -wait $MONITOR_PID - -# If monitor exits (should only happen on SIGTERM/SIGINT), shut down gracefully -echo "🛑 Monitor exited, shutting down all workers..." -kill $RQ_WORKER_1_PID 2>/dev/null || true -kill $RQ_WORKER_2_PID 2>/dev/null || true -kill $RQ_WORKER_3_PID 2>/dev/null || true -kill $RQ_WORKER_4_PID 2>/dev/null || true -kill $RQ_WORKER_5_PID 2>/dev/null || true -kill $RQ_WORKER_6_PID 2>/dev/null || true -kill $AUDIO_PERSISTENCE_WORKER_PID 2>/dev/null || true -[ -n "$AUDIO_STREAM_DEEPGRAM_WORKER_PID" ] && kill $AUDIO_STREAM_DEEPGRAM_WORKER_PID 2>/dev/null || true -[ -n "$AUDIO_STREAM_PARAKEET_WORKER_PID" ] && kill $AUDIO_STREAM_PARAKEET_WORKER_PID 2>/dev/null || true -wait - -echo "✅ All workers stopped gracefully" -exit 0 diff --git a/backends/advanced/start.sh b/backends/advanced/start.sh index 5cc79635..feb8d57a 100755 --- a/backends/advanced/start.sh +++ b/backends/advanced/start.sh @@ -2,9 +2,17 @@ # Chronicle Backend Startup Script # Starts both the FastAPI backend and RQ workers +# Usage: ./start.sh [--test] set -e +# Check for test mode flag +TEST_MODE=false +if [[ "$1" == "--test" ]]; then + TEST_MODE=true + echo "🧪 Running in TEST mode (with test dependencies)" +fi + echo "🚀 Starting Chronicle Backend..." # Function to handle shutdown @@ -53,7 +61,12 @@ sleep 2 # Start the main FastAPI application echo "🌐 Starting FastAPI backend..." -uv run --extra deepgram python3 src/advanced_omi_backend/main.py & +# Use --group test in test mode +if [ "$TEST_MODE" = true ]; then + uv run --extra deepgram --group test python3 src/advanced_omi_backend/main.py & +else + uv run --extra deepgram python3 src/advanced_omi_backend/main.py & +fi BACKEND_PID=$! # Wait for any process to exit diff --git a/backends/advanced/tests/test_conversation_models.py b/backends/advanced/tests/test_conversation_models.py index e4387c89..c2c27dd0 100644 --- a/backends/advanced/tests/test_conversation_models.py +++ b/backends/advanced/tests/test_conversation_models.py @@ -134,7 +134,7 @@ def test_add_transcript_version(self): version_id="v2", transcript="Updated transcript", segments=segments, - provider=TranscriptProvider.MISTRAL, + provider=TranscriptProvider.PARAKEET, set_as_active=False ) @@ -170,7 +170,7 @@ def test_set_active_versions(self): segments2 = [SpeakerSegment(start=0.0, end=5.0, text="Version 2", speaker="Speaker A")] conversation.add_transcript_version("v1", "Transcript 1", segments1, TranscriptProvider.DEEPGRAM) - conversation.add_transcript_version("v2", "Transcript 2", segments2, TranscriptProvider.MISTRAL, set_as_active=False) + conversation.add_transcript_version("v2", "Transcript 2", segments2, TranscriptProvider.PARAKEET, set_as_active=False) # Should be v1 active assert conversation.active_transcript_version == "v1" @@ -213,7 +213,6 @@ def test_provider_enums(self): """Test that provider enums work correctly.""" # Test TranscriptProvider enum assert TranscriptProvider.DEEPGRAM == "deepgram" - assert TranscriptProvider.MISTRAL == "mistral" assert TranscriptProvider.PARAKEET == "parakeet" # Test MemoryProvider enum diff --git a/backends/advanced/uv.lock b/backends/advanced/uv.lock index c73386c8..afd88ad2 100644 --- a/backends/advanced/uv.lock +++ b/backends/advanced/uv.lock @@ -56,6 +56,7 @@ dev = [ { name = "pre-commit-uv" }, ] test = [ + { name = "aiosqlite" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, @@ -108,6 +109,7 @@ dev = [ { name = "pre-commit-uv", specifier = ">=4.1.4" }, ] test = [ + { name = "aiosqlite", specifier = ">=0.20.0" }, { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-asyncio", specifier = ">=1.0.0" }, { name = "pytest-cov", specifier = ">=6.0.0" }, @@ -226,6 +228,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "aiosqlite" +version = "0.22.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/8a/64761f4005f17809769d23e518d915db74e6310474e733e3593cfc854ef1/aiosqlite-0.22.1.tar.gz", hash = "sha256:043e0bd78d32888c0a9ca90fc788b38796843360c855a7262a532813133a0650", size = 14821, upload-time = "2025-12-23T19:25:43.997Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" diff --git a/backends/advanced/webui/src/App.tsx b/backends/advanced/webui/src/App.tsx index fca59623..42370975 100644 --- a/backends/advanced/webui/src/App.tsx +++ b/backends/advanced/webui/src/App.tsx @@ -13,6 +13,7 @@ import System from './pages/System' import Upload from './pages/Upload' import Queue from './pages/Queue' import LiveRecord from './pages/LiveRecord' +import Plugins from './pages/Plugins' import ProtectedRoute from './components/auth/ProtectedRoute' import { ErrorBoundary, PageErrorBoundary } from './components/ErrorBoundary' @@ -89,6 +90,11 @@ function App() { } /> + + + + } /> diff --git a/backends/advanced/webui/src/components/PluginSettings.tsx b/backends/advanced/webui/src/components/PluginSettings.tsx new file mode 100644 index 00000000..05576120 --- /dev/null +++ b/backends/advanced/webui/src/components/PluginSettings.tsx @@ -0,0 +1,195 @@ +import { useState, useEffect } from 'react' +import { Puzzle, RefreshCw, CheckCircle, Save, RotateCcw, AlertCircle } from 'lucide-react' +import { systemApi } from '../services/api' +import { useAuth } from '../contexts/AuthContext' + +interface PluginSettingsProps { + className?: string +} + +export default function PluginSettings({ className }: PluginSettingsProps) { + const [configYaml, setConfigYaml] = useState('') + const [loading, setLoading] = useState(false) + const [validating, setValidating] = useState(false) + const [saving, setSaving] = useState(false) + const [message, setMessage] = useState('') + const [error, setError] = useState('') + const { isAdmin } = useAuth() + + useEffect(() => { + loadPluginsConfig() + }, []) + + const loadPluginsConfig = async () => { + setLoading(true) + setError('') + setMessage('') + + try { + const response = await systemApi.getPluginsConfigRaw() + setConfigYaml(response.data.config_yaml || response.data) + setMessage('Configuration loaded successfully') + setTimeout(() => setMessage(''), 3000) + } catch (err: any) { + const status = err.response?.status + if (status === 401) { + setError('Unauthorized: admin privileges required') + } else { + setError(err.response?.data?.error || 'Failed to load configuration') + } + } finally { + setLoading(false) + } + } + + const validateConfig = async () => { + if (!configYaml.trim()) { + setError('Configuration cannot be empty') + return + } + + setValidating(true) + setError('') + setMessage('') + + try { + const response = await systemApi.validatePluginsConfig(configYaml) + if (response.data.valid) { + setMessage('✅ Configuration is valid') + } else { + setError(response.data.error || 'Validation failed') + } + setTimeout(() => setMessage(''), 3000) + } catch (err: any) { + setError(err.response?.data?.error || 'Validation failed') + } finally { + setValidating(false) + } + } + + const saveConfig = async () => { + if (!configYaml.trim()) { + setError('Configuration cannot be empty') + return + } + + setSaving(true) + setError('') + setMessage('') + + try { + await systemApi.updatePluginsConfigRaw(configYaml) + setMessage('✅ Configuration saved successfully. Restart backend for changes to take effect.') + setTimeout(() => setMessage(''), 5000) + } catch (err: any) { + setError(err.response?.data?.error || 'Failed to save configuration') + } finally { + setSaving(false) + } + } + + const resetConfig = () => { + loadPluginsConfig() + setMessage('Configuration reset to file version') + setTimeout(() => setMessage(''), 3000) + } + + if (!isAdmin) { + return null + } + + return ( +
+
+ {/* Header */} +
+
+ +

+ Plugin Configuration +

+
+
+ + +
+
+ + {/* Messages */} + {message && ( +
+ +

{message}

+
+ )} + + {error && ( +
+ +

{error}

+
+ )} + + {/* Editor */} +
+