diff --git a/pyproject.toml b/pyproject.toml index a567ea2b4..bc1617e62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "aiohttp>=3.12.14", "authlib>=1.6.0", "openai==1.99.1", + "sqlalchemy>=2.0.42", ] [tool.pyright] diff --git a/src/app/database.py b/src/app/database.py new file mode 100644 index 000000000..2f1f84085 --- /dev/null +++ b/src/app/database.py @@ -0,0 +1,124 @@ +"""Database engine management.""" + +from pathlib import Path +from typing import Any + +from sqlalchemy import create_engine, text +from sqlalchemy.engine.base import Engine +from sqlalchemy.orm import sessionmaker, Session +from log import get_logger, logging +from configuration import configuration +from models.database.base import Base +from models.config import SQLiteDatabaseConfiguration, PostgreSQLDatabaseConfiguration + +logger = get_logger(__name__) + +engine: Engine | None = None +SessionLocal: sessionmaker | None = None + + +def get_engine() -> Engine: + """Get the database engine. Raises an error if not initialized.""" + if engine is None: + raise RuntimeError( + "Database engine not initialized. Call initialize_database() first." + ) + return engine + + +def create_tables() -> None: + """Create tables.""" + Base.metadata.create_all(get_engine()) + + +def get_session() -> Session: + """Get a database session. Raises an error if not initialized.""" + if SessionLocal is None: + raise RuntimeError( + "Database session not initialized. Call initialize_database() first." + ) + return SessionLocal() + + +def _create_sqlite_engine(config: SQLiteDatabaseConfiguration, **kwargs: Any) -> Engine: + """Create SQLite database engine.""" + if not Path(config.db_path).parent.exists(): + raise FileNotFoundError( + f"SQLite database directory does not exist: {config.db_path}" + ) + + try: + return create_engine(f"sqlite:///{config.db_path}", **kwargs) + except Exception as e: + logger.exception("Failed to create SQLite engine") + raise RuntimeError(f"SQLite engine creation failed: {e}") from e + + +def _create_postgres_engine( + config: PostgreSQLDatabaseConfiguration, **kwargs: Any +) -> Engine: + """Create PostgreSQL database engine.""" + postgres_url = ( + f"postgresql://{config.user}:{config.password}@" + f"{config.host}:{config.port}/{config.db}" + f"?sslmode={config.ssl_mode}&gssencmode={config.gss_encmode}" + ) + + is_custom_schema = config.namespace is not None and config.namespace != "public" + + connect_args = {} + if is_custom_schema: + connect_args["options"] = f"-csearch_path={config.namespace}" + + if config.ca_cert_path is not None: + connect_args["sslrootcert"] = str(config.ca_cert_path) + + try: + postgres_engine = create_engine( + postgres_url, connect_args=connect_args, **kwargs + ) + except Exception as e: + logger.exception("Failed to create PostgreSQL engine") + raise RuntimeError(f"PostgreSQL engine creation failed: {e}") from e + + if is_custom_schema: + try: + with postgres_engine.connect() as connection: + connection.execute( + text(f'CREATE SCHEMA IF NOT EXISTS "{config.namespace}"') + ) + connection.commit() + logger.info("Schema '%s' created or already exists", config.namespace) + except Exception as e: + logger.exception("Failed to create schema '%s'", config.namespace) + raise RuntimeError( + f"Schema creation failed for '{config.namespace}': {e}" + ) from e + + return postgres_engine + + +def initialize_database() -> None: + """Initialize the database engine.""" + db_config = configuration.database_configuration + + global engine, SessionLocal # pylint: disable=global-statement + + # Debug print all SQL statements if our logger is at-least DEBUG level + echo = bool(logger.isEnabledFor(logging.DEBUG)) + + create_engine_kwargs = { + "echo": echo, + } + + match db_config.db_type: + case "sqlite": + sqlite_config = db_config.config + assert isinstance(sqlite_config, SQLiteDatabaseConfiguration) + engine = _create_sqlite_engine(sqlite_config, **create_engine_kwargs) + case "postgres": + postgres_config = db_config.config + assert isinstance(postgres_config, PostgreSQLDatabaseConfiguration) + engine = _create_postgres_engine(postgres_config, **create_engine_kwargs) + + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index f7133c959..87a0ce718 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -9,9 +9,16 @@ from client import AsyncLlamaStackClientHolder from configuration import configuration -from models.responses import ConversationResponse, ConversationDeleteResponse +from models.responses import ( + ConversationResponse, + ConversationDeleteResponse, + ConversationsListResponse, + ConversationDetails, +) +from models.database.conversations import UserConversation from auth import get_auth_dependency -from utils.endpoints import check_configuration_loaded +from app.database import get_session +from utils.endpoints import check_configuration_loaded, validate_conversation_ownership from utils.suid import check_suid logger = logging.getLogger("app.endpoints.handlers") @@ -66,6 +73,35 @@ }, } +conversations_list_responses: dict[int | str, dict[str, Any]] = { + 200: { + "conversations": [ + { + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "created_at": "2024-01-01T00:00:00Z", + "last_message_at": "2024-01-01T00:05:00Z", + "last_used_model": "gemini/gemini-1.5-flash", + "last_used_provider": "gemini", + "message_count": 5, + }, + { + "conversation_id": "456e7890-e12b-34d5-a678-901234567890", + "created_at": "2024-01-01T01:00:00Z", + "last_message_at": "2024-01-01T01:02:00Z", + "last_used_model": "gemini/gemini-2.0-flash", + "last_used_provider": "gemini", + "message_count": 2, + }, + ] + }, + 503: { + "detail": { + "response": "Unable to connect to Llama Stack", + "cause": "Connection error.", + } + }, +} + def simplify_session_data(session_data: dict) -> list[dict[str, Any]]: """Simplify session data to include only essential conversation information. @@ -109,10 +145,64 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]: return chat_history +@router.get("/conversations", responses=conversations_list_responses) +def get_conversations_list_endpoint_handler( + auth: Any = Depends(auth_dependency), +) -> ConversationsListResponse: + """Handle request to retrieve all conversations for the authenticated user.""" + check_configuration_loaded(configuration) + + user_id, _, _ = auth + + logger.info("Retrieving conversations for user %s", user_id) + + with get_session() as session: + try: + # Get all conversations for this user + user_conversations = ( + session.query(UserConversation).filter_by(user_id=user_id).all() + ) + + # Return conversation summaries with metadata + conversations = [ + ConversationDetails( + conversation_id=conv.id, + created_at=conv.created_at.isoformat() if conv.created_at else None, + last_message_at=( + conv.last_message_at.isoformat() + if conv.last_message_at + else None + ), + message_count=conv.message_count, + last_used_model=conv.last_used_model, + last_used_provider=conv.last_used_provider, + ) + for conv in user_conversations + ] + + logger.info( + "Found %d conversations for user %s", len(conversations), user_id + ) + + return ConversationsListResponse(conversations=conversations) + + except Exception as e: + logger.exception( + "Error retrieving conversations for user %s: %s", user_id, e + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unknown error", + "cause": f"Unknown error while getting conversations for user {user_id}", + }, + ) from e + + @router.get("/conversations/{conversation_id}", responses=conversation_responses) async def get_conversation_endpoint_handler( conversation_id: str, - _auth: Any = Depends(auth_dependency), + auth: Any = Depends(auth_dependency), ) -> ConversationResponse: """Handle request to retrieve a conversation by ID.""" check_configuration_loaded(configuration) @@ -128,6 +218,13 @@ async def get_conversation_endpoint_handler( }, ) + user_id, _, _ = auth + + validate_conversation_ownership( + user_id=user_id, + conversation_id=conversation_id, + ) + agent_id = conversation_id logger.info("Retrieving conversation %s", conversation_id) @@ -187,7 +284,7 @@ async def get_conversation_endpoint_handler( ) async def delete_conversation_endpoint_handler( conversation_id: str, - _auth: Any = Depends(auth_dependency), + auth: Any = Depends(auth_dependency), ) -> ConversationDeleteResponse: """Handle request to delete a conversation by ID.""" check_configuration_loaded(configuration) @@ -202,6 +299,14 @@ async def delete_conversation_endpoint_handler( "cause": f"Conversation ID {conversation_id} is not a valid UUID", }, ) + + user_id, _, _ = auth + + validate_conversation_ownership( + user_id=user_id, + conversation_id=conversation_id, + ) + agent_id = conversation_id logger.info("Deleting conversation %s", conversation_id) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 86b1ef443..8a8ee5664 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -22,11 +22,18 @@ from auth.interface import AuthTuple from client import AsyncLlamaStackClientHolder from configuration import configuration +from app.database import get_session import metrics +from models.database.conversations import UserConversation from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse from models.requests import QueryRequest, Attachment import constants -from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt +from utils.endpoints import ( + check_configuration_loaded, + get_agent, + get_system_prompt, + validate_conversation_ownership, +) from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.suid import get_suid @@ -65,6 +72,80 @@ def is_transcripts_enabled() -> bool: return configuration.user_data_collection_configuration.transcripts_enabled +def persist_user_conversation_details( + user_id: str, conversation_id: str, model: str, provider_id: str +) -> None: + """Associate conversation to user in the database.""" + with get_session() as session: + existing_conversation = ( + session.query(UserConversation) + .filter_by(id=conversation_id, user_id=user_id) + .first() + ) + + if not existing_conversation: + conversation = UserConversation( + id=conversation_id, + user_id=user_id, + last_used_model=model, + last_used_provider=provider_id, + message_count=1, + ) + session.add(conversation) + logger.debug( + "Associated conversation %s to user %s", conversation_id, user_id + ) + else: + existing_conversation.last_used_model = model + existing_conversation.last_used_provider = provider_id + existing_conversation.last_message_at = datetime.now(UTC) + existing_conversation.message_count += 1 + + session.commit() + + +def evaluate_model_hints( + user_conversation: UserConversation | None, + query_request: QueryRequest, +) -> tuple[str | None, str | None]: + """Evaluate model hints from user conversation.""" + model_id: str | None = query_request.model + provider_id: str | None = query_request.provider + + if user_conversation is not None: + if query_request.model is not None: + if query_request.model != user_conversation.last_used_model: + logger.debug( + "Model specified in request: %s, preferring it over user conversation model %s", + query_request.model, + user_conversation.last_used_model, + ) + else: + logger.debug( + "No model specified in request, using latest model from user conversation: %s", + user_conversation.last_used_model, + ) + model_id = user_conversation.last_used_model + + if query_request.provider is not None: + if query_request.provider != user_conversation.last_used_provider: + logger.debug( + "Provider specified in request: %s, " + "preferring it over user conversation provider %s", + query_request.provider, + user_conversation.last_used_provider, + ) + else: + logger.debug( + "No provider specified in request, " + "using latest provider from user conversation: %s", + user_conversation.last_used_provider, + ) + provider_id = user_conversation.last_used_provider + + return model_id, provider_id + + @router.post("/query", responses=query_response) async def query_endpoint_handler( query_request: QueryRequest, @@ -79,11 +160,34 @@ async def query_endpoint_handler( user_id, _, token = auth + user_conversation: UserConversation | None = None + if query_request.conversation_id: + user_conversation = validate_conversation_ownership( + user_id=user_id, conversation_id=query_request.conversation_id + ) + + if user_conversation is None: + logger.warning( + "User %s attempted to query conversation %s they don't own", + user_id, + query_request.conversation_id, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "response": "Access denied", + "cause": "You do not have permission to access this conversation", + }, + ) + try: # try to get Llama Stack client client = AsyncLlamaStackClientHolder().get_client() model_id, provider_id = select_model_and_provider_id( - await client.models.list(), query_request + await client.models.list(), + *evaluate_model_hints( + user_conversation=user_conversation, query_request=query_request + ), ) response, conversation_id = await retrieve_response( client, @@ -110,6 +214,13 @@ async def query_endpoint_handler( attachments=query_request.attachments or [], ) + persist_user_conversation_details( + user_id=user_id, + conversation_id=conversation_id, + model=model_id, + provider_id=provider_id, + ) + return QueryResponse(conversation_id=conversation_id, response=response) # connection to Llama Stack server @@ -127,12 +238,10 @@ async def query_endpoint_handler( def select_model_and_provider_id( - models: ModelListResponse, query_request: QueryRequest -) -> tuple[str, str | None]: + models: ModelListResponse, model_id: str | None, provider_id: str | None +) -> tuple[str, str]: """Select the model ID and provider ID based on the request or available models.""" # If model_id and provider_id are provided in the request, use them - model_id = query_request.model - provider_id = query_request.provider # If model_id is not provided in the request, check the configuration if not model_id or not provider_id: diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 009c2f017..48e29e1fa 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -23,6 +23,7 @@ from configuration import configuration import metrics from models.requests import QueryRequest +from models.database.conversations import UserConversation from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups @@ -34,6 +35,9 @@ store_transcript, select_model_and_provider_id, validate_attachments_metadata, + validate_conversation_ownership, + persist_user_conversation_details, + evaluate_model_hints, ) logger = logging.getLogger("app.endpoints.handlers") @@ -380,7 +384,7 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]: @router.post("/streaming_query") -async def streaming_query_endpoint_handler( +async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals _request: Request, query_request: QueryRequest, auth: Annotated[AuthTuple, Depends(auth_dependency)], @@ -394,11 +398,34 @@ async def streaming_query_endpoint_handler( user_id, _user_name, token = auth + user_conversation: UserConversation | None = None + if query_request.conversation_id is not None: + user_conversation = validate_conversation_ownership( + user_id=user_id, conversation_id=query_request.conversation_id + ) + + if user_conversation is None: + logger.warning( + "User %s attempted to query conversation %s they don't own", + user_id, + query_request.conversation_id, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "response": "Access denied", + "cause": "You do not have permission to access this conversation", + }, + ) + try: # try to get Llama Stack client client = AsyncLlamaStackClientHolder().get_client() model_id, provider_id = select_model_and_provider_id( - await client.models.list(), query_request + await client.models.list(), + *evaluate_model_hints( + user_conversation=user_conversation, query_request=query_request + ), ) response, conversation_id = await retrieve_response( client, @@ -447,6 +474,13 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: attachments=query_request.attachments or [], ) + persist_user_conversation_details( + user_id=user_id, + conversation_id=conversation_id, + model=model_id, + provider_id=provider_id, + ) + # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() diff --git a/src/app/main.py b/src/app/main.py index fcbb7f5c2..0517ce6a3 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -7,6 +7,7 @@ from starlette.routing import Mount, Route, WebSocketRoute from app import routers +from app.database import initialize_database, create_tables from configuration import configuration from log import get_logger import metrics @@ -81,3 +82,6 @@ async def startup_event() -> None: await register_mcp_servers_async(logger, configuration.configuration) get_logger("app.endpoints.handlers") logger.info("App startup complete") + + initialize_database() + create_tables() diff --git a/src/configuration.py b/src/configuration.py index a5d7384cc..1b2e87c17 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -13,6 +13,7 @@ ModelContextProtocolServer, AuthenticationConfiguration, InferenceConfiguration, + DatabaseConfiguration, ) logger = logging.getLogger(__name__) @@ -113,5 +114,13 @@ def inference(self) -> InferenceConfiguration: ), "logic error: configuration is not loaded" return self._configuration.inference + @property + def database_configuration(self) -> DatabaseConfiguration: + """Return database configuration.""" + assert ( + self._configuration is not None + ), "logic error: configuration is not loaded" + return self._configuration.database + configuration: AppConfig = AppConfig() diff --git a/src/constants.py b/src/constants.py index dba3bc018..b8380f33c 100644 --- a/src/constants.py +++ b/src/constants.py @@ -51,3 +51,9 @@ DATA_COLLECTOR_COLLECTION_INTERVAL = 7200 # 2 hours in seconds DATA_COLLECTOR_CONNECTION_TIMEOUT = 30 DATA_COLLECTOR_RETRY_INTERVAL = 300 # 5 minutes in seconds + +# PostgreSQL connection constants +# See: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLMODE +POSTGRES_DEFAULT_SSL_MODE = "prefer" +# See: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-GSSENCMODE +POSTGRES_DEFAULT_GSS_ENCMODE = "prefer" diff --git a/src/models/config.py b/src/models/config.py index 961c37976..d0b374123 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -4,7 +4,7 @@ from typing import Optional from pydantic import BaseModel, model_validator, FilePath, AnyHttpUrl, PositiveInt -from typing_extensions import Self +from typing_extensions import Self, Literal import constants @@ -24,6 +24,78 @@ def check_tls_configuration(self) -> Self: return self +class SQLiteDatabaseConfiguration(BaseModel): + """SQLite database configuration.""" + + db_path: str + + +class PostgreSQLDatabaseConfiguration(BaseModel): + """PostgreSQL database configuration.""" + + host: str = "localhost" + port: int = 5432 + db: str + user: str + password: str + namespace: Optional[str] = "lightspeed-stack" + ssl_mode: str = constants.POSTGRES_DEFAULT_SSL_MODE + gss_encmode: str = constants.POSTGRES_DEFAULT_GSS_ENCMODE + ca_cert_path: Optional[FilePath] = None + + @model_validator(mode="after") + def check_postgres_configuration(self) -> Self: + """Check PostgreSQL configuration.""" + if self.port <= 0: + raise ValueError("Port value should not be negative") + if self.port > 65535: + raise ValueError("Port value should be less than 65536") + if self.ca_cert_path is not None and not self.ca_cert_path.exists(): + raise ValueError(f"CA certificate file does not exist: {self.ca_cert_path}") + return self + + +class DatabaseConfiguration(BaseModel): + """Database configuration.""" + + sqlite: Optional[SQLiteDatabaseConfiguration] = None + postgres: Optional[PostgreSQLDatabaseConfiguration] = None + + @model_validator(mode="after") + def check_database_configuration(self) -> Self: + """Check that exactly one database type is configured.""" + total_configured_dbs = sum([self.sqlite is not None, self.postgres is not None]) + + if total_configured_dbs == 0: + # Default to SQLite in a (hopefully) tmpfs if no database configuration is provided. + # This is good for backwards compatibility for deployments that do not mind having + # no persistent database. + sqlite_file_name = "/tmp/lightspeed-stack.db" + self.sqlite = SQLiteDatabaseConfiguration(db_path=sqlite_file_name) + elif total_configured_dbs > 1: + raise ValueError("Only one database configuration can be provided") + + return self + + @property + def db_type(self) -> Literal["sqlite", "postgres"]: + """Return the configured database type.""" + if self.sqlite is not None: + return "sqlite" + if self.postgres is not None: + return "postgres" + raise ValueError("No database configuration found") + + @property + def config(self) -> SQLiteDatabaseConfiguration | PostgreSQLDatabaseConfiguration: + """Return the active database configuration.""" + if self.sqlite is not None: + return self.sqlite + if self.postgres is not None: + return self.postgres + raise ValueError("No database configuration found") + + class ServiceConfiguration(BaseModel): """Service configuration.""" @@ -244,6 +316,7 @@ class Configuration(BaseModel): service: ServiceConfiguration llama_stack: LlamaStackConfiguration user_data_collection: UserDataCollection + database: DatabaseConfiguration = DatabaseConfiguration() mcp_servers: list[ModelContextProtocolServer] = [] authentication: Optional[AuthenticationConfiguration] = ( AuthenticationConfiguration() diff --git a/src/models/database/__init__.py b/src/models/database/__init__.py new file mode 100644 index 000000000..a01a86e5c --- /dev/null +++ b/src/models/database/__init__.py @@ -0,0 +1 @@ +"""Database models package.""" diff --git a/src/models/database/base.py b/src/models/database/base.py new file mode 100644 index 000000000..e6f8e48cc --- /dev/null +++ b/src/models/database/base.py @@ -0,0 +1,7 @@ +"""Base model for SQLAlchemy ORM classes.""" + +from sqlalchemy.orm import DeclarativeBase + + +class Base(DeclarativeBase): # pylint: disable=too-few-public-methods + """Base class for all SQLAlchemy ORM models.""" diff --git a/src/models/database/conversations.py b/src/models/database/conversations.py new file mode 100644 index 000000000..1cce8a64d --- /dev/null +++ b/src/models/database/conversations.py @@ -0,0 +1,36 @@ +"""User conversation models.""" + +from datetime import datetime + +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy import DateTime, func + +from models.database.base import Base + + +class UserConversation(Base): # pylint: disable=too-few-public-methods + """Model for storing user conversation metadata.""" + + __tablename__ = "user_conversation" + + # The conversation ID + id: Mapped[str] = mapped_column(primary_key=True) + + # The user ID associated with the conversation + user_id: Mapped[str] = mapped_column(index=True) + + # The last provider/model used in the conversation + last_used_model: Mapped[str] = mapped_column() + last_used_provider: Mapped[str] = mapped_column() + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), # pylint: disable=not-callable + ) + last_message_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), # pylint: disable=not-callable + ) + + # The number of user messages in the conversation + message_count: Mapped[int] = mapped_column(default=0) diff --git a/src/models/requests.py b/src/models/requests.py index f5c77b668..56c801efd 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -184,6 +184,14 @@ class QueryRequest(BaseModel): }, } + @field_validator("conversation_id") + @classmethod + def check_uuid(cls, value: str | None) -> str | None: + """Check if conversation ID has the proper format.""" + if value and not suid.check_suid(value): + raise ValueError(f"Improper conversation ID '{value}'") + return value + def get_documents(self) -> list[Document]: """Return the list of documents from the attachments.""" if not self.attachments: diff --git a/src/models/responses.py b/src/models/responses.py index 76270739d..c00d61cdb 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -384,3 +384,89 @@ class ConversationDeleteResponse(BaseModel): ] } } + + +class ConversationDetails(BaseModel): + """Model representing the details of a user conversation. + + Attributes: + conversation_id: The conversation ID (UUID). + created_at: When the conversation was created. + last_message_at: When the last message was sent. + message_count: Number of user messages in the conversation. + model: The model used for the conversation. + + Example: + ```python + conversation = ConversationSummary( + conversation_id="123e4567-e89b-12d3-a456-426614174000" + created_at="2024-01-01T00:00:00Z", + last_message_at="2024-01-01T00:05:00Z", + message_count=5, + model="gemini/gemini-2.0-flash" + ) + ``` + """ + + conversation_id: str + created_at: Optional[str] = None + last_message_at: Optional[str] = None + message_count: Optional[int] = None + last_used_model: Optional[str] = None + last_used_provider: Optional[str] = None + + +class ConversationsListResponse(BaseModel): + """Model representing a response for listing conversations of a user. + + Attributes: + conversations: List of conversation details associated with the user. + + Example: + ```python + conversations_list = ConversationsListResponse( + conversations=[ + ConversationDetails( + conversation_id="123e4567-e89b-12d3-a456-426614174000", + created_at="2024-01-01T00:00:00Z", + last_message_at="2024-01-01T00:05:00Z", + message_count=5, + model="gemini/gemini-2.0-flash" + ), + ConversationDetails( + conversation_id="456e7890-e12b-34d5-a678-901234567890" + created_at="2024-01-01T01:00:00Z", + message_count=2, + model="gemini/gemini-2.5-flash" + ) + ] + ) + ``` + """ + + conversations: list[ConversationDetails] + + # provides examples for /docs endpoint + model_config = { + "json_schema_extra": { + "examples": [ + { + "conversations": [ + { + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "created_at": "2024-01-01T00:00:00Z", + "last_message_at": "2024-01-01T00:05:00Z", + "message_count": 5, + "model": "gemini/gemini-2.0-flash", + }, + { + "conversation_id": "456e7890-e12b-34d5-a678-901234567890", + "created_at": "2024-01-01T01:00:00Z", + "message_count": 2, + "model": "gemini/gemini-2.5-flash", + }, + ] + } + ] + } + } diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 19f816de0..7e1c98bd5 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -8,6 +8,8 @@ import constants from models.requests import QueryRequest +from models.database.conversations import UserConversation +from app.database import get_session from configuration import AppConfig from utils.suid import get_suid from utils.types import GraniteToolParser @@ -16,6 +18,19 @@ logger = logging.getLogger("utils.endpoints") +def validate_conversation_ownership( + user_id: str, conversation_id: str +) -> UserConversation | None: + """Validate that the conversation belongs to the user.""" + with get_session() as session: + conversation = ( + session.query(UserConversation) + .filter_by(id=conversation_id, user_id=user_id) + .first() + ) + return conversation + + def check_configuration_loaded(config: AppConfig) -> None: """Check that configuration is loaded and raise exception when it is not.""" if config is None: diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index 1958d131a..d07416b78 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -7,9 +7,14 @@ from app.endpoints.conversations import ( get_conversation_endpoint_handler, delete_conversation_endpoint_handler, + get_conversations_list_endpoint_handler, simplify_session_data, ) -from models.responses import ConversationResponse, ConversationDeleteResponse +from models.responses import ( + ConversationResponse, + ConversationDeleteResponse, + ConversationsListResponse, +) from configuration import AppConfig MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") @@ -17,6 +22,46 @@ INVALID_CONVERSATION_ID = "invalid-id" +def create_mock_conversation( + mocker, + conversation_id, + created_at, + last_message_at, + message_count, + last_used_model, + last_used_provider, +): # pylint: disable=too-many-arguments,too-many-positional-arguments + """Helper function to create a mock conversation object with all required attributes.""" + mock_conversation = mocker.Mock() + mock_conversation.id = conversation_id + mock_conversation.created_at = mocker.Mock() + mock_conversation.created_at.isoformat.return_value = created_at + mock_conversation.last_message_at = mocker.Mock() + mock_conversation.last_message_at.isoformat.return_value = last_message_at + mock_conversation.message_count = message_count + mock_conversation.last_used_model = last_used_model + mock_conversation.last_used_provider = last_used_provider + return mock_conversation + + +def mock_database_session(mocker, query_result=None): + """Helper function to mock get_session with proper context manager support.""" + mock_session = mocker.Mock() + if query_result is not None: + mock_session.query.return_value.filter_by.return_value.all.return_value = ( + query_result + ) + + # Mock get_session to return a context manager + mock_session_context = mocker.MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + mocker.patch( + "app.endpoints.conversations.get_session", return_value=mock_session_context + ) + return mock_session + + @pytest.fixture(name="setup_configuration") def setup_configuration_fixture(): """Set up configuration for tests.""" @@ -191,7 +236,7 @@ async def test_configuration_not_loaded(self, mocker): with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, _auth=MOCK_AUTH + VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -205,7 +250,7 @@ async def test_invalid_conversation_id_format(self, mocker, setup_configuration) with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - INVALID_CONVERSATION_ID, _auth=MOCK_AUTH + INVALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST @@ -217,6 +262,7 @@ async def test_llama_stack_connection_error(self, mocker, setup_configuration): """Test the endpoint when LlamaStack connection fails.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations.validate_conversation_ownership") # Mock AsyncLlamaStackClientHolder to raise APIConnectionError mock_client = mocker.AsyncMock() @@ -229,7 +275,7 @@ async def test_llama_stack_connection_error(self, mocker, setup_configuration): # simulate situation when it is not possible to connect to Llama Stack with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, _auth=MOCK_AUTH + VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE @@ -240,6 +286,7 @@ async def test_llama_stack_not_found_error(self, mocker, setup_configuration): """Test the endpoint when LlamaStack returns NotFoundError.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations.validate_conversation_ownership") # Mock AsyncLlamaStackClientHolder to raise NotFoundError mock_client = mocker.AsyncMock() @@ -253,7 +300,7 @@ async def test_llama_stack_not_found_error(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, _auth=MOCK_AUTH + VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND @@ -266,12 +313,11 @@ async def test_session_retrieve_exception(self, mocker, setup_configuration): """Test the endpoint when session retrieval raises an exception.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations.validate_conversation_ownership") # Mock AsyncLlamaStackClientHolder to raise a general exception - mock_client = mocker.Mock() - mock_client.agents.session.retrieve.side_effect = HTTPException( - status_code=500, detail="Failed to get session" - ) + mock_client = mocker.AsyncMock() + mock_client.agents.session.list.side_effect = Exception("Failed to get session") mock_client_holder = mocker.patch( "app.endpoints.conversations.AsyncLlamaStackClientHolder" ) @@ -279,7 +325,7 @@ async def test_session_retrieve_exception(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, _auth=MOCK_AUTH + VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -295,10 +341,7 @@ async def test_successful_conversation_retrieval( """Test successful conversation retrieval with simplified response structure.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - - # Mock session data with model_dump method - mock_session_obj = mocker.Mock() - mock_session_obj.model_dump.return_value = mock_session_data + mocker.patch("app.endpoints.conversations.validate_conversation_ownership") # Mock AsyncLlamaStackClientHolder mock_client = mocker.AsyncMock() @@ -317,7 +360,7 @@ async def test_successful_conversation_retrieval( mock_client_holder.return_value.get_client.return_value = mock_client response = await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, _auth=MOCK_AUTH + VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert isinstance(response, ConversationResponse) @@ -338,7 +381,7 @@ async def test_configuration_not_loaded(self, mocker): with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, _auth=MOCK_AUTH + VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -352,7 +395,7 @@ async def test_invalid_conversation_id_format(self, mocker, setup_configuration) with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - INVALID_CONVERSATION_ID, _auth=MOCK_AUTH + INVALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST @@ -364,6 +407,7 @@ async def test_llama_stack_connection_error(self, mocker, setup_configuration): """Test the endpoint when LlamaStack connection fails.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations.validate_conversation_ownership") # Mock AsyncLlamaStackClientHolder to raise APIConnectionError mock_client = mocker.AsyncMock() @@ -375,7 +419,7 @@ async def test_llama_stack_connection_error(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, _auth=MOCK_AUTH + VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE @@ -386,6 +430,7 @@ async def test_llama_stack_not_found_error(self, mocker, setup_configuration): """Test the endpoint when LlamaStack returns NotFoundError.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations.validate_conversation_ownership") # Mock AsyncLlamaStackClientHolder to raise NotFoundError mock_client = mocker.AsyncMock() @@ -399,7 +444,7 @@ async def test_llama_stack_not_found_error(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, _auth=MOCK_AUTH + VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND @@ -412,6 +457,7 @@ async def test_session_deletion_exception(self, mocker, setup_configuration): """Test the endpoint when session deletion raises an exception.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations.validate_conversation_ownership") # Mock AsyncLlamaStackClientHolder to raise a general exception mock_client = mocker.AsyncMock() @@ -425,7 +471,7 @@ async def test_session_deletion_exception(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, _auth=MOCK_AUTH + VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -440,6 +486,7 @@ async def test_successful_conversation_deletion(self, mocker, setup_configuratio """Test successful conversation deletion.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations.validate_conversation_ownership") # Mock AsyncLlamaStackClientHolder mock_client = mocker.AsyncMock() @@ -450,7 +497,7 @@ async def test_successful_conversation_deletion(self, mocker, setup_configuratio mock_client_holder.return_value.get_client.return_value = mock_client response = await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, _auth=MOCK_AUTH + VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert isinstance(response, ConversationDeleteResponse) @@ -460,3 +507,89 @@ async def test_successful_conversation_deletion(self, mocker, setup_configuratio mock_client.agents.session.delete.assert_called_once_with( agent_id=VALID_CONVERSATION_ID, session_id=VALID_CONVERSATION_ID ) + + +# Generated entirely by AI, no human review, so read with that in mind. +class TestGetConversationsListEndpoint: + """Test cases for the GET /conversations endpoint.""" + + def test_configuration_not_loaded(self, mocker): + """Test the endpoint when configuration is not loaded.""" + mocker.patch("app.endpoints.conversations.configuration", None) + + with pytest.raises(HTTPException) as exc_info: + get_conversations_list_endpoint_handler(auth=MOCK_AUTH) + + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Configuration is not loaded" in exc_info.value.detail["response"] + + def test_successful_conversations_list_retrieval(self, mocker, setup_configuration): + """Test successful retrieval of conversations list.""" + mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + + # Mock database session and query results + mock_conversations = [ + create_mock_conversation( + mocker, + "123e4567-e89b-12d3-a456-426614174000", + "2024-01-01T00:00:00Z", + "2024-01-01T00:05:00Z", + 5, + "gemini/gemini-2.0-flash", + "gemini", + ), + create_mock_conversation( + mocker, + "456e7890-e12b-34d5-a678-901234567890", + "2024-01-01T01:00:00Z", + "2024-01-01T01:02:00Z", + 2, + "gemini/gemini-2.5-flash", + "gemini", + ), + ] + mock_database_session(mocker, mock_conversations) + + response = get_conversations_list_endpoint_handler(auth=MOCK_AUTH) + + assert isinstance(response, ConversationsListResponse) + assert len(response.conversations) == 2 + assert ( + response.conversations[0].conversation_id + == "123e4567-e89b-12d3-a456-426614174000" + ) + assert ( + response.conversations[1].conversation_id + == "456e7890-e12b-34d5-a678-901234567890" + ) + + def test_empty_conversations_list(self, mocker, setup_configuration): + """Test when user has no conversations.""" + mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + + # Mock database session with no results + mock_database_session(mocker, []) + + response = get_conversations_list_endpoint_handler(auth=MOCK_AUTH) + + assert isinstance(response, ConversationsListResponse) + assert len(response.conversations) == 0 + assert response.conversations == [] + + def test_database_exception(self, mocker, setup_configuration): + """Test when database query raises an exception.""" + mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + + # Mock database session to raise exception + mock_session = mock_database_session(mocker) + mock_session.query.side_effect = Exception("Database error") + + with pytest.raises(HTTPException) as exc_info: + get_conversations_list_endpoint_handler(auth=MOCK_AUTH) + + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Unknown error" in exc_info.value.detail["response"] + assert ( + "Unknown error while getting conversations for user" + in exc_info.value.detail["cause"] + ) diff --git a/tests/unit/app/endpoints/test_feedback.py b/tests/unit/app/endpoints/test_feedback.py index 238b93344..dfd0cef20 100644 --- a/tests/unit/app/endpoints/test_feedback.py +++ b/tests/unit/app/endpoints/test_feedback.py @@ -76,8 +76,8 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data): # Call the endpoint handler result = feedback_endpoint_handler( feedback_request=feedback_request, - auth=["test-user", "", ""], _ensure_feedback_enabled=assert_feedback_enabled, + auth=("test_user_id", "test_username", "test_token"), ) # Assert that the expected response is returned @@ -100,8 +100,8 @@ def test_feedback_endpoint_handler_error(mocker): with pytest.raises(HTTPException) as exc_info: feedback_endpoint_handler( feedback_request=feedback_request, - auth=["test-user", "", ""], _ensure_feedback_enabled=assert_feedback_enabled, + auth=("test_user_id", "test_username", "test_token"), ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index db0ad233e..b9dc49daa 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -19,14 +19,24 @@ construct_transcripts_path, store_transcript, get_rag_toolgroups, + evaluate_model_hints, ) from models.requests import QueryRequest, Attachment from models.config import ModelContextProtocolServer +from models.database.conversations import UserConversation MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") +def mock_database_operations(mocker): + """Helper function to mock database operations for query endpoints.""" + mocker.patch( + "app.endpoints.query.validate_conversation_ownership", return_value=True + ) + mocker.patch("app.endpoints.query.persist_user_conversation_details") + + @pytest.fixture(name="setup_configuration") def setup_configuration_fixture(): """Set up configuration for tests.""" @@ -130,6 +140,9 @@ async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): ) mock_transcript = mocker.patch("app.endpoints.query.store_transcript") + # Mock database operations + mock_database_operations(mocker) + query_request = QueryRequest(query=query) response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) @@ -201,7 +214,9 @@ def test_select_model_and_provider_id_from_request(mocker): ) # Assert the model and provider from request take precedence from the configuration one - model_id, provider_id = select_model_and_provider_id(model_list, query_request) + model_id, provider_id = select_model_and_provider_id( + model_list, query_request.model, query_request.provider + ) assert model_id == "provider2/model2" assert provider_id == "provider2" @@ -234,7 +249,9 @@ def test_select_model_and_provider_id_from_configuration(mocker): query="What is OpenStack?", ) - model_id, provider_id = select_model_and_provider_id(model_list, query_request) + model_id, provider_id = select_model_and_provider_id( + model_list, query_request.model, query_request.provider + ) # Assert that the default model and provider from the configuration are returned assert model_id == "default_provider/default_model" @@ -257,7 +274,9 @@ def test_select_model_and_provider_id_first_from_list(mocker): query_request = QueryRequest(query="What is OpenStack?") - model_id, provider_id = select_model_and_provider_id(model_list, query_request) + model_id, provider_id = select_model_and_provider_id( + model_list, query_request.model, query_request.provider + ) # Assert return the first available LLM model when no model/provider is # specified in the request or in the configuration @@ -277,7 +296,9 @@ def test_select_model_and_provider_id_invalid_model(mocker): ) with pytest.raises(HTTPException) as exc_info: - select_model_and_provider_id(mock_client.models.list(), query_request) + select_model_and_provider_id( + mock_client.models.list(), query_request.model, query_request.provider + ) assert ( "Model invalid_model from provider provider1 not found in available models" @@ -294,7 +315,9 @@ def test_select_model_and_provider_id_no_available_models(mocker): query_request = QueryRequest(query="What is OpenStack?", model=None, provider=None) with pytest.raises(HTTPException) as exc_info: - select_model_and_provider_id(mock_client.models.list(), query_request) + select_model_and_provider_id( + mock_client.models.list(), query_request.model, query_request.provider + ) assert "No LLM model found in available models" in str(exc_info.value) @@ -1115,6 +1138,8 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): return_value=("test_model", "test_provider"), ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + # Mock database operations + mock_database_operations(mocker) _ = await query_endpoint_handler( QueryRequest(query="test query"), @@ -1152,6 +1177,8 @@ async def test_query_endpoint_handler_no_tools_true(mocker): return_value=("fake_model_id", "fake_provider_id"), ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + # Mock database operations + mock_database_operations(mocker) query_request = QueryRequest(query=query, no_tools=True) @@ -1189,6 +1216,8 @@ async def test_query_endpoint_handler_no_tools_false(mocker): return_value=("fake_model_id", "fake_provider_id"), ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + # Mock database operations + mock_database_operations(mocker) query_request = QueryRequest(query=query, no_tools=False) @@ -1320,3 +1349,88 @@ def test_no_tools_parameter_backward_compatibility(): # Test that QueryRequest can be created without no_tools parameter query_request_minimal = QueryRequest(query="Simple query") assert query_request_minimal.no_tools is False + + +@pytest.mark.parametrize( + "user_conversation,request_values,expected_values", + [ + # No user conversation, no request values + ( + None, + (None, None), + # Expect no values to be used + (None, None), + ), + # No user conversation, request values provided + ( + None, + ("foo", "bar"), + # Expect request values to be used + ("foo", "bar"), + ), + # User conversation exists, no request values + ( + UserConversation( + id="conv1", + user_id="user1", + last_used_provider="foo", + last_used_model="bar", + message_count=1, + ), + ( + None, + None, + ), + # Expect conversation values to be used + ( + "foo", + "bar", + ), + ), + # Request matches user conversation + ( + UserConversation( + id="conv1", + user_id="user1", + last_used_provider="foo", + last_used_model="bar", + message_count=1, + ), + ( + "foo", + "bar", + ), + # Expect request values to be used + ( + "foo", + "bar", + ), + ), + ], + ids=[ + "No user conversation, no request values", + "No user conversation, request values provided", + "User conversation exists, no request values", + "Request matches user conversation", + ], +) +def test_evaluate_model_hints( + user_conversation, + request_values, + expected_values, +): + """Test evaluate_model_hints function with various scenarios.""" + # Unpack fixtures + request_provider, request_model = request_values + expected_provider, expected_model = expected_values + + query_request = QueryRequest( + query="What is love?", + provider=request_provider, + model=request_model, + ) # pylint: disable=missing-kwoa + + model_id, provider_id = evaluate_model_hints(user_conversation, query_request) + + assert provider_id == expected_provider + assert model_id == expected_model diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index f1c2ddfec..e6dd73f6e 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -48,6 +48,15 @@ MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") +def mock_database_operations(mocker): + """Helper function to mock database operations for streaming query endpoints.""" + mocker.patch( + "app.endpoints.streaming_query.validate_conversation_ownership", + return_value=True, + ) + mocker.patch("app.endpoints.streaming_query.persist_user_conversation_details") + + SAMPLE_KNOWLEDGE_SEARCH_RESULTS = [ """knowledge_search tool found 2 chunks: BEGIN of knowledge_search tool results. @@ -264,6 +273,8 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) ) mock_transcript = mocker.patch("app.endpoints.streaming_query.store_transcript") + mock_database_operations(mocker) + query_request = QueryRequest(query=query) request = Request( @@ -1273,6 +1284,7 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): mocker.patch( "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False ) + mock_database_operations(mocker) request = Request( scope={ @@ -1318,6 +1330,8 @@ async def test_streaming_query_endpoint_handler_no_tools_true(mocker): mocker.patch( "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False ) + # Mock database operations + mock_database_operations(mocker) query_request = QueryRequest(query="What is OpenStack?", no_tools=True) @@ -1363,6 +1377,8 @@ async def test_streaming_query_endpoint_handler_no_tools_false(mocker): mocker.patch( "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False ) + # Mock database operations + mock_database_operations(mocker) query_request = QueryRequest(query="What is OpenStack?", no_tools=False) diff --git a/tests/unit/models/test_config.py b/tests/unit/models/test_config.py index 2801f4ebf..4d9a29da3 100644 --- a/tests/unit/models/test_config.py +++ b/tests/unit/models/test_config.py @@ -499,6 +499,7 @@ def test_dump_configuration(tmp_path) -> None: assert "authentication" in content assert "customization" in content assert "inference" in content + assert "database" in content # check the whole deserialized JSON file content assert content == { @@ -550,6 +551,10 @@ def test_dump_configuration(tmp_path) -> None: "default_provider": "default_provider", "default_model": "default_model", }, + "database": { + "sqlite": {"db_path": "/tmp/lightspeed-stack.db"}, + "postgres": None, + }, } diff --git a/uv.lock b/uv.lock index d752feb8d..6b61e5015 100644 --- a/uv.lock +++ b/uv.lock @@ -597,6 +597,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/86/f1/62a193f0227cf15a920390abe675f386dec35f7ae3ffe6da582d3ade42c7/googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8", size = 294530, upload-time = "2025-04-14T10:17:01.271Z" }, ] +[[package]] +name = "greenlet" +version = "3.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/92/bb85bd6e80148a4d2e0c59f7c0c2891029f8fd510183afc7d8d2feeed9b6/greenlet-3.2.3.tar.gz", hash = "sha256:8b0dd8ae4c0d6f5e54ee55ba935eeb3d735a9b58a8a1e5b5cbab64e01a39f365", size = 185752, upload-time = "2025-06-05T16:16:09.955Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/94/ad0d435f7c48debe960c53b8f60fb41c2026b1d0fa4a99a1cb17c3461e09/greenlet-3.2.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:25ad29caed5783d4bd7a85c9251c651696164622494c00802a139c00d639242d", size = 271992, upload-time = "2025-06-05T16:11:23.467Z" }, + { url = "https://files.pythonhosted.org/packages/93/5d/7c27cf4d003d6e77749d299c7c8f5fd50b4f251647b5c2e97e1f20da0ab5/greenlet-3.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:88cd97bf37fe24a6710ec6a3a7799f3f81d9cd33317dcf565ff9950c83f55e0b", size = 638820, upload-time = "2025-06-05T16:38:52.882Z" }, + { url = "https://files.pythonhosted.org/packages/c6/7e/807e1e9be07a125bb4c169144937910bf59b9d2f6d931578e57f0bce0ae2/greenlet-3.2.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:baeedccca94880d2f5666b4fa16fc20ef50ba1ee353ee2d7092b383a243b0b0d", size = 653046, upload-time = "2025-06-05T16:41:36.343Z" }, + { url = "https://files.pythonhosted.org/packages/9d/ab/158c1a4ea1068bdbc78dba5a3de57e4c7aeb4e7fa034320ea94c688bfb61/greenlet-3.2.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:be52af4b6292baecfa0f397f3edb3c6092ce071b499dd6fe292c9ac9f2c8f264", size = 647701, upload-time = "2025-06-05T16:48:19.604Z" }, + { url = "https://files.pythonhosted.org/packages/cc/0d/93729068259b550d6a0288da4ff72b86ed05626eaf1eb7c0d3466a2571de/greenlet-3.2.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0cc73378150b8b78b0c9fe2ce56e166695e67478550769536a6742dca3651688", size = 649747, upload-time = "2025-06-05T16:13:04.628Z" }, + { url = "https://files.pythonhosted.org/packages/f6/f6/c82ac1851c60851302d8581680573245c8fc300253fc1ff741ae74a6c24d/greenlet-3.2.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:706d016a03e78df129f68c4c9b4c4f963f7d73534e48a24f5f5a7101ed13dbbb", size = 605461, upload-time = "2025-06-05T16:12:50.792Z" }, + { url = "https://files.pythonhosted.org/packages/98/82/d022cf25ca39cf1200650fc58c52af32c90f80479c25d1cbf57980ec3065/greenlet-3.2.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:419e60f80709510c343c57b4bb5a339d8767bf9aef9b8ce43f4f143240f88b7c", size = 1121190, upload-time = "2025-06-05T16:36:48.59Z" }, + { url = "https://files.pythonhosted.org/packages/f5/e1/25297f70717abe8104c20ecf7af0a5b82d2f5a980eb1ac79f65654799f9f/greenlet-3.2.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:93d48533fade144203816783373f27a97e4193177ebaaf0fc396db19e5d61163", size = 1149055, upload-time = "2025-06-05T16:12:40.457Z" }, + { url = "https://files.pythonhosted.org/packages/1f/8f/8f9e56c5e82eb2c26e8cde787962e66494312dc8cb261c460e1f3a9c88bc/greenlet-3.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:7454d37c740bb27bdeddfc3f358f26956a07d5220818ceb467a483197d84f849", size = 297817, upload-time = "2025-06-05T16:29:49.244Z" }, + { url = "https://files.pythonhosted.org/packages/b1/cf/f5c0b23309070ae93de75c90d29300751a5aacefc0a3ed1b1d8edb28f08b/greenlet-3.2.3-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:500b8689aa9dd1ab26872a34084503aeddefcb438e2e7317b89b11eaea1901ad", size = 270732, upload-time = "2025-06-05T16:10:08.26Z" }, + { url = "https://files.pythonhosted.org/packages/48/ae/91a957ba60482d3fecf9be49bc3948f341d706b52ddb9d83a70d42abd498/greenlet-3.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a07d3472c2a93117af3b0136f246b2833fdc0b542d4a9799ae5f41c28323faef", size = 639033, upload-time = "2025-06-05T16:38:53.983Z" }, + { url = "https://files.pythonhosted.org/packages/6f/df/20ffa66dd5a7a7beffa6451bdb7400d66251374ab40b99981478c69a67a8/greenlet-3.2.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:8704b3768d2f51150626962f4b9a9e4a17d2e37c8a8d9867bbd9fa4eb938d3b3", size = 652999, upload-time = "2025-06-05T16:41:37.89Z" }, + { url = "https://files.pythonhosted.org/packages/51/b4/ebb2c8cb41e521f1d72bf0465f2f9a2fd803f674a88db228887e6847077e/greenlet-3.2.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:5035d77a27b7c62db6cf41cf786cfe2242644a7a337a0e155c80960598baab95", size = 647368, upload-time = "2025-06-05T16:48:21.467Z" }, + { url = "https://files.pythonhosted.org/packages/8e/6a/1e1b5aa10dced4ae876a322155705257748108b7fd2e4fae3f2a091fe81a/greenlet-3.2.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2d8aa5423cd4a396792f6d4580f88bdc6efcb9205891c9d40d20f6e670992efb", size = 650037, upload-time = "2025-06-05T16:13:06.402Z" }, + { url = "https://files.pythonhosted.org/packages/26/f2/ad51331a157c7015c675702e2d5230c243695c788f8f75feba1af32b3617/greenlet-3.2.3-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2c724620a101f8170065d7dded3f962a2aea7a7dae133a009cada42847e04a7b", size = 608402, upload-time = "2025-06-05T16:12:51.91Z" }, + { url = "https://files.pythonhosted.org/packages/26/bc/862bd2083e6b3aff23300900a956f4ea9a4059de337f5c8734346b9b34fc/greenlet-3.2.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:873abe55f134c48e1f2a6f53f7d1419192a3d1a4e873bace00499a4e45ea6af0", size = 1119577, upload-time = "2025-06-05T16:36:49.787Z" }, + { url = "https://files.pythonhosted.org/packages/86/94/1fc0cc068cfde885170e01de40a619b00eaa8f2916bf3541744730ffb4c3/greenlet-3.2.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:024571bbce5f2c1cfff08bf3fbaa43bbc7444f580ae13b0099e95d0e6e67ed36", size = 1147121, upload-time = "2025-06-05T16:12:42.527Z" }, + { url = "https://files.pythonhosted.org/packages/27/1a/199f9587e8cb08a0658f9c30f3799244307614148ffe8b1e3aa22f324dea/greenlet-3.2.3-cp313-cp313-win_amd64.whl", hash = "sha256:5195fb1e75e592dd04ce79881c8a22becdfa3e6f500e7feb059b1e6fdd54d3e3", size = 297603, upload-time = "2025-06-05T16:20:12.651Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -888,6 +914,7 @@ dependencies = [ { name = "openai" }, { name = "prometheus-client" }, { name = "rich" }, + { name = "sqlalchemy" }, { name = "starlette" }, { name = "uvicorn" }, ] @@ -929,6 +956,7 @@ requires-dist = [ { name = "openai", specifier = "==1.99.1" }, { name = "prometheus-client", specifier = ">=0.22.1" }, { name = "rich", specifier = ">=14.0.0" }, + { name = "sqlalchemy", specifier = ">=2.0.42" }, { name = "starlette", specifier = ">=0.47.1" }, { name = "uvicorn", specifier = ">=0.34.3" }, ] @@ -2196,6 +2224,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/78/3565d011c61f5a43488987ee32b6f3f656e7f107ac2782dd57bdd7d91d9a/snowballstemmer-3.0.1-py3-none-any.whl", hash = "sha256:6cd7b3897da8d6c9ffb968a6781fa6532dce9c3618a4b127d920dab764a19064", size = 103274, upload-time = "2025-05-09T16:34:50.371Z" }, ] +[[package]] +name = "sqlalchemy" +version = "2.0.42" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "greenlet", marker = "platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/03/a0af991e3a43174d6b83fca4fb399745abceddd1171bdabae48ce877ff47/sqlalchemy-2.0.42.tar.gz", hash = "sha256:160bedd8a5c28765bd5be4dec2d881e109e33b34922e50a3b881a7681773ac5f", size = 9749972, upload-time = "2025-07-29T12:48:09.323Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/66/ac31a9821fc70a7376321fb2c70fdd7eadbc06dadf66ee216a22a41d6058/sqlalchemy-2.0.42-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:09637a0872689d3eb71c41e249c6f422e3e18bbd05b4cd258193cfc7a9a50da2", size = 2132203, upload-time = "2025-07-29T13:29:19.291Z" }, + { url = "https://files.pythonhosted.org/packages/fc/ba/fd943172e017f955d7a8b3a94695265b7114efe4854feaa01f057e8f5293/sqlalchemy-2.0.42-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a3cb3ec67cc08bea54e06b569398ae21623534a7b1b23c258883a7c696ae10df", size = 2120373, upload-time = "2025-07-29T13:29:21.049Z" }, + { url = "https://files.pythonhosted.org/packages/ea/a2/b5f7d233d063ffadf7e9fff3898b42657ba154a5bec95a96f44cba7f818b/sqlalchemy-2.0.42-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e87e6a5ef6f9d8daeb2ce5918bf5fddecc11cae6a7d7a671fcc4616c47635e01", size = 3317685, upload-time = "2025-07-29T13:26:40.837Z" }, + { url = "https://files.pythonhosted.org/packages/86/00/fcd8daab13a9119d41f3e485a101c29f5d2085bda459154ba354c616bf4e/sqlalchemy-2.0.42-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b718011a9d66c0d2f78e1997755cd965f3414563b31867475e9bc6efdc2281d", size = 3326967, upload-time = "2025-07-29T13:22:31.009Z" }, + { url = "https://files.pythonhosted.org/packages/a3/85/e622a273d648d39d6771157961956991a6d760e323e273d15e9704c30ccc/sqlalchemy-2.0.42-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:16d9b544873fe6486dddbb859501a07d89f77c61d29060bb87d0faf7519b6a4d", size = 3255331, upload-time = "2025-07-29T13:26:42.579Z" }, + { url = "https://files.pythonhosted.org/packages/3a/a0/2c2338b592c7b0a61feffd005378c084b4c01fabaf1ed5f655ab7bd446f0/sqlalchemy-2.0.42-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:21bfdf57abf72fa89b97dd74d3187caa3172a78c125f2144764a73970810c4ee", size = 3291791, upload-time = "2025-07-29T13:22:32.454Z" }, + { url = "https://files.pythonhosted.org/packages/41/19/b8a2907972a78285fdce4c880ecaab3c5067eb726882ca6347f7a4bf64f6/sqlalchemy-2.0.42-cp312-cp312-win32.whl", hash = "sha256:78b46555b730a24901ceb4cb901c6b45c9407f8875209ed3c5d6bcd0390a6ed1", size = 2096180, upload-time = "2025-07-29T13:16:08.952Z" }, + { url = "https://files.pythonhosted.org/packages/48/1f/67a78f3dfd08a2ed1c7be820fe7775944f5126080b5027cc859084f8e223/sqlalchemy-2.0.42-cp312-cp312-win_amd64.whl", hash = "sha256:4c94447a016f36c4da80072e6c6964713b0af3c8019e9c4daadf21f61b81ab53", size = 2123533, upload-time = "2025-07-29T13:16:11.705Z" }, + { url = "https://files.pythonhosted.org/packages/e9/7e/25d8c28b86730c9fb0e09156f601d7a96d1c634043bf8ba36513eb78887b/sqlalchemy-2.0.42-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:941804f55c7d507334da38133268e3f6e5b0340d584ba0f277dd884197f4ae8c", size = 2127905, upload-time = "2025-07-29T13:29:22.249Z" }, + { url = "https://files.pythonhosted.org/packages/e5/a1/9d8c93434d1d983880d976400fcb7895a79576bd94dca61c3b7b90b1ed0d/sqlalchemy-2.0.42-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:95d3d06a968a760ce2aa6a5889fefcbdd53ca935735e0768e1db046ec08cbf01", size = 2115726, upload-time = "2025-07-29T13:29:23.496Z" }, + { url = "https://files.pythonhosted.org/packages/a2/cc/d33646fcc24c87cc4e30a03556b611a4e7bcfa69a4c935bffb923e3c89f4/sqlalchemy-2.0.42-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4cf10396a8a700a0f38ccd220d940be529c8f64435c5d5b29375acab9267a6c9", size = 3246007, upload-time = "2025-07-29T13:26:44.166Z" }, + { url = "https://files.pythonhosted.org/packages/67/08/4e6c533d4c7f5e7c4cbb6fe8a2c4e813202a40f05700d4009a44ec6e236d/sqlalchemy-2.0.42-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9cae6c2b05326d7c2c7c0519f323f90e0fb9e8afa783c6a05bb9ee92a90d0f04", size = 3250919, upload-time = "2025-07-29T13:22:33.74Z" }, + { url = "https://files.pythonhosted.org/packages/5c/82/f680e9a636d217aece1b9a8030d18ad2b59b5e216e0c94e03ad86b344af3/sqlalchemy-2.0.42-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f50f7b20677b23cfb35b6afcd8372b2feb348a38e3033f6447ee0704540be894", size = 3180546, upload-time = "2025-07-29T13:26:45.648Z" }, + { url = "https://files.pythonhosted.org/packages/7d/a2/8c8f6325f153894afa3775584c429cc936353fb1db26eddb60a549d0ff4b/sqlalchemy-2.0.42-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9d88a1c0d66d24e229e3938e1ef16ebdbd2bf4ced93af6eff55225f7465cf350", size = 3216683, upload-time = "2025-07-29T13:22:34.977Z" }, + { url = "https://files.pythonhosted.org/packages/39/44/3a451d7fa4482a8ffdf364e803ddc2cfcafc1c4635fb366f169ecc2c3b11/sqlalchemy-2.0.42-cp313-cp313-win32.whl", hash = "sha256:45c842c94c9ad546c72225a0c0d1ae8ef3f7c212484be3d429715a062970e87f", size = 2093990, upload-time = "2025-07-29T13:16:13.036Z" }, + { url = "https://files.pythonhosted.org/packages/4b/9e/9bce34f67aea0251c8ac104f7bdb2229d58fb2e86a4ad8807999c4bee34b/sqlalchemy-2.0.42-cp313-cp313-win_amd64.whl", hash = "sha256:eb9905f7f1e49fd57a7ed6269bc567fcbbdac9feadff20ad6bd7707266a91577", size = 2120473, upload-time = "2025-07-29T13:16:14.502Z" }, + { url = "https://files.pythonhosted.org/packages/ee/55/ba2546ab09a6adebc521bf3974440dc1d8c06ed342cceb30ed62a8858835/sqlalchemy-2.0.42-py3-none-any.whl", hash = "sha256:defcdff7e661f0043daa381832af65d616e060ddb54d3fe4476f51df7eaa1835", size = 1922072, upload-time = "2025-07-29T13:09:17.061Z" }, +] + [[package]] name = "starlette" version = "0.47.2"