-
Notifications
You must be signed in to change notification settings - Fork 22
feat: Add annotation feature and cron jobs for transcript corrections #256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
88b826d
e6eaa4c
3413eec
42e0036
d2c50db
a2cc30b
a6eb773
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| import asyncio | ||
| import logging | ||
| import os | ||
| from datetime import datetime | ||
| import signal | ||
| import sys | ||
|
|
||
| from advanced_omi_backend.workers.annotation_jobs import surface_error_suggestions, finetune_hallucination_model | ||
| from advanced_omi_backend.database import init_db | ||
|
|
||
| # Configure logging | ||
| logging.basicConfig( | ||
| level=logging.INFO, | ||
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | ||
| stream=sys.stdout | ||
| ) | ||
| logger = logging.getLogger("cron_scheduler") | ||
|
|
||
| # Frequency configuration (in seconds) | ||
| SUGGESTION_INTERVAL = 24 * 60 * 60 # Daily | ||
| TRAINING_INTERVAL = 7 * 24 * 60 * 60 # Weekly | ||
|
|
||
| # For testing purposes, we can check more frequently if ENV var is set | ||
| if os.getenv("DEV_MODE", "false").lower() == "true": | ||
| SUGGESTION_INTERVAL = 60 # 1 minute | ||
| TRAINING_INTERVAL = 300 # 5 minutes | ||
|
|
||
| async def run_scheduler(): | ||
| logger.info("Starting Cron Scheduler...") | ||
|
|
||
| # Initialize DB connection | ||
| await init_db() | ||
|
|
||
| last_suggestion_run = datetime.min | ||
| last_training_run = datetime.min | ||
|
|
||
| while True: | ||
| now = datetime.utcnow() | ||
|
|
||
| # Check Suggestions Job | ||
| if (now - last_suggestion_run).total_seconds() >= SUGGESTION_INTERVAL: | ||
| logger.info("Running scheduled job: surface_error_suggestions") | ||
| try: | ||
| await surface_error_suggestions() | ||
| last_suggestion_run = now | ||
| except Exception as e: | ||
| logger.error(f"Error in surface_error_suggestions: {e}", exc_info=True) | ||
|
|
||
| # Check Training Job | ||
| if (now - last_training_run).total_seconds() >= TRAINING_INTERVAL: | ||
| logger.info("Running scheduled job: finetune_hallucination_model") | ||
| try: | ||
| await finetune_hallucination_model() | ||
| last_training_run = now | ||
| except Exception as e: | ||
| logger.error(f"Error in finetune_hallucination_model: {e}", exc_info=True) | ||
|
|
||
|
Comment on lines
+40
to
+57
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Jobs lack user_id scoping required by coding guidelines. The Consider one of these approaches:
Would you like me to generate a solution that iterates over users and invokes these jobs per-user? 🤖 Prompt for AI Agents |
||
| # Sleep for a bit before next check (e.g. 1 minute) | ||
| await asyncio.sleep(60) | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def handle_shutdown(signum, frame): | ||
| logger.info("Shutting down Cron Scheduler...") | ||
| sys.exit(0) | ||
|
|
||
| if __name__ == "__main__": | ||
| signal.signal(signal.SIGTERM, handle_shutdown) | ||
| signal.signal(signal.SIGINT, handle_shutdown) | ||
|
|
||
| try: | ||
| asyncio.run(run_scheduler()) | ||
| except KeyboardInterrupt: | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| from datetime import datetime, timezone | ||
| from typing import Optional, List | ||
| from pydantic import Field | ||
| from beanie import Document, Indexed | ||
| from enum import Enum | ||
| import uuid | ||
|
|
||
| class TranscriptAnnotation(Document): | ||
| """Model for transcript annotations/corrections.""" | ||
|
|
||
| class AnnotationStatus(str, Enum): | ||
| PENDING = "pending" | ||
| ACCEPTED = "accepted" | ||
| REJECTED = "rejected" | ||
|
|
||
| class AnnotationSource(str, Enum): | ||
| USER = "user" | ||
| MODEL_SUGGESTION = "model_suggestion" | ||
|
|
||
| id: str = Field(default_factory=lambda: str(uuid.uuid4())) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inconsistent ID type: use MongoDB ObjectId instead of UUID string. The
🔍 Proposed fixBased on the Beanie documentation, the default type of the id field is PydanticObjectId. While you can override it with UUID as shown in the documentation, consider aligning with the existing User model pattern for consistency across the codebase. +from beanie import Document, Indexed, PydanticObjectId
class TranscriptAnnotation(Document):
"""Model for transcript annotations/corrections."""
class AnnotationStatus(str, Enum):
PENDING = "pending"
ACCEPTED = "accepted"
REJECTED = "rejected"
class AnnotationSource(str, Enum):
USER = "user"
MODEL_SUGGESTION = "model_suggestion"
- id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+ # id field defaults to PydanticObjectId - MongoDB's ObjectId type
conversation_id: Indexed(str)🤖 Prompt for AI Agents |
||
| conversation_id: Indexed(str) | ||
| segment_index: int | ||
| original_text: str | ||
| corrected_text: str | ||
| user_id: Indexed(str) | ||
|
|
||
| status: AnnotationStatus = Field(default=AnnotationStatus.ACCEPTED) # User edits are accepted by default | ||
| source: AnnotationSource = Field(default=AnnotationSource.USER) | ||
|
|
||
| created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) | ||
| updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) | ||
|
|
||
| class Settings: | ||
| name = "transcript_annotations" | ||
| indexes = [ | ||
| "conversation_id", | ||
| "user_id", | ||
| "status" | ||
| ] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| from fastapi import APIRouter, HTTPException, Depends | ||
| from typing import List, Optional | ||
| from pydantic import BaseModel | ||
| from datetime import datetime | ||
|
|
||
| from advanced_omi_backend.models.annotation import TranscriptAnnotation | ||
| from advanced_omi_backend.models.conversation import Conversation | ||
| from advanced_omi_backend.auth import current_active_user | ||
| from advanced_omi_backend.models.user import User | ||
| from advanced_omi_backend.workers.memory_jobs import enqueue_memory_processing | ||
| from advanced_omi_backend.models.job import JobPriority | ||
|
|
||
| router = APIRouter() | ||
|
|
||
| class AnnotationCreate(BaseModel): | ||
| conversation_id: str | ||
| segment_index: int | ||
| original_text: str | ||
| corrected_text: str | ||
| status: Optional[TranscriptAnnotation.AnnotationStatus] = TranscriptAnnotation.AnnotationStatus.ACCEPTED | ||
|
|
||
| class AnnotationResponse(BaseModel): | ||
| id: str | ||
| conversation_id: str | ||
| segment_index: int | ||
| original_text: str | ||
| corrected_text: str | ||
| status: str | ||
| created_at: datetime | ||
|
|
||
| @router.post("/", response_model=AnnotationResponse) | ||
| async def create_annotation( | ||
| annotation: AnnotationCreate, | ||
| current_user: User = Depends(current_active_user) | ||
| ): | ||
| # Verify conversation exists and belongs to user | ||
| conversation = await Conversation.find_one({ | ||
| "conversation_id": annotation.conversation_id, | ||
| "user_id": str(current_user.id) | ||
| }) | ||
|
|
||
| if not conversation: | ||
| raise HTTPException(status_code=404, detail="Conversation not found") | ||
|
|
||
| # Create annotation | ||
| new_annotation = TranscriptAnnotation( | ||
| conversation_id=annotation.conversation_id, | ||
| segment_index=annotation.segment_index, | ||
| original_text=annotation.original_text, | ||
| corrected_text=annotation.corrected_text, | ||
| user_id=str(current_user.id), | ||
| status=annotation.status, | ||
| source=TranscriptAnnotation.AnnotationSource.USER | ||
| ) | ||
|
|
||
| await new_annotation.insert() | ||
|
|
||
| # Update the actual transcript in the conversation | ||
| # We need to find the active transcript version and update the segment | ||
| if conversation.active_transcript: | ||
| version = conversation.active_transcript | ||
| if 0 <= annotation.segment_index < len(version.segments): | ||
| version.segments[annotation.segment_index].text = annotation.corrected_text | ||
|
|
||
| # Save the conversation with the updated segment | ||
| # We need to update the specific version in the list | ||
| for i, v in enumerate(conversation.transcript_versions): | ||
| if v.version_id == version.version_id: | ||
| conversation.transcript_versions[i] = version | ||
| break | ||
|
|
||
| await conversation.save() | ||
|
|
||
| # Trigger memory reprocessing | ||
| enqueue_memory_processing( | ||
| client_id=conversation.client_id, | ||
| user_id=str(current_user.id), | ||
| user_email=current_user.email, | ||
| conversation_id=conversation.conversation_id, | ||
| priority=JobPriority.NORMAL | ||
| ) | ||
| else: | ||
| raise HTTPException(status_code=400, detail="Segment index out of range") | ||
| else: | ||
| raise HTTPException(status_code=400, detail="No active transcript found") | ||
|
|
||
| return AnnotationResponse( | ||
| id=str(new_annotation.id), | ||
| conversation_id=new_annotation.conversation_id, | ||
| segment_index=new_annotation.segment_index, | ||
| original_text=new_annotation.original_text, | ||
| corrected_text=new_annotation.corrected_text, | ||
| status=new_annotation.status, | ||
| created_at=new_annotation.created_at | ||
| ) | ||
|
|
||
| @router.get("/{conversation_id}", response_model=List[AnnotationResponse]) | ||
| async def get_annotations( | ||
| conversation_id: str, | ||
| current_user: User = Depends(current_active_user) | ||
| ): | ||
| annotations = await TranscriptAnnotation.find({ | ||
| "conversation_id": conversation_id, | ||
| "user_id": str(current_user.id) | ||
| }).to_list() | ||
|
|
||
| return [ | ||
| AnnotationResponse( | ||
| id=str(a.id), | ||
| conversation_id=a.conversation_id, | ||
| segment_index=a.segment_index, | ||
| original_text=a.original_text, | ||
| corrected_text=a.corrected_text, | ||
| status=a.status, | ||
| created_at=a.created_at | ||
| ) | ||
| for a in annotations | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace deprecated
datetime.utcnow()withdatetime.now(timezone.utc).datetime.utcnow()is deprecated in Python 3.12+ in favor of timezone-aware alternatives.🔧 Use timezone-aware datetime
📝 Committable suggestion
🤖 Prompt for AI Agents