diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index ba8440855..776a19f7f 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -20,7 +20,11 @@ ConversationsListResponse, ConversationDetails, ) -from utils.endpoints import check_configuration_loaded, validate_conversation_ownership +from utils.endpoints import ( + check_configuration_loaded, + delete_conversation, + validate_conversation_ownership, +) from utils.suid import check_suid logger = logging.getLogger("app.endpoints.handlers") @@ -247,7 +251,7 @@ async def get_conversation_endpoint_handler( user_id, _, _ = auth - validate_conversation_ownership( + user_conversation = validate_conversation_ownership( user_id=user_id, conversation_id=conversation_id, others_allowed=( @@ -255,6 +259,20 @@ async def get_conversation_endpoint_handler( ), ) + if user_conversation is None: + logger.warning( + "User %s attempted to read conversation %s they don't own", + user_id, + conversation_id, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "response": "Access denied", + "cause": "You do not have permission to read this conversation", + }, + ) + agent_id = conversation_id logger.info("Retrieving conversation %s", conversation_id) @@ -355,7 +373,7 @@ async def delete_conversation_endpoint_handler( user_id, _, _ = auth - validate_conversation_ownership( + user_conversation = validate_conversation_ownership( user_id=user_id, conversation_id=conversation_id, others_allowed=( @@ -363,6 +381,20 @@ async def delete_conversation_endpoint_handler( ), ) + if user_conversation is None: + logger.warning( + "User %s attempted to delete conversation %s they don't own", + user_id, + conversation_id, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "response": "Access denied", + "cause": "You do not have permission to delete this conversation", + }, + ) + agent_id = conversation_id logger.info("Deleting conversation %s", conversation_id) @@ -387,6 +419,8 @@ async def delete_conversation_endpoint_handler( logger.info("Successfully deleted conversation %s", conversation_id) + delete_conversation(conversation_id=conversation_id) + return ConversationDeleteResponse( conversation_id=conversation_id, success=True, diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 34c0147de..3da5f228b 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -18,6 +18,23 @@ logger = logging.getLogger("utils.endpoints") +def delete_conversation(conversation_id: str) -> None: + """Delete a conversation according to its ID.""" + with get_session() as session: + db_conversation = ( + session.query(UserConversation).filter_by(id=conversation_id).first() + ) + if db_conversation: + session.delete(db_conversation) + session.commit() + logger.info("Deleted conversation %s from local database", conversation_id) + else: + logger.info( + "Conversation %s not found in local database, it may have already been deleted", + conversation_id, + ) + + def validate_conversation_ownership( user_id: str, conversation_id: str, others_allowed: bool = False ) -> UserConversation | None: diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index f4fcd4c80..cde19bf9a 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -563,6 +563,9 @@ async def test_successful_conversation_deletion( mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") + # Mock the delete_conversation function + mocker.patch("app.endpoints.conversations.delete_conversation") + # Mock AsyncLlamaStackClientHolder mock_client = mocker.AsyncMock() # Ensure the endpoint sees an existing session so it proceeds to delete