diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 4430a1501..a8aaf61cf 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -65,6 +65,7 @@ consume_tokens, get_available_quotas, ) +from utils.suid import normalize_conversation_id from utils.token_counter import TokenCounter, extract_and_update_token_metrics from utils.transcripts import store_transcript from utils.types import TurnSummary, content_to_str @@ -109,14 +110,23 @@ def persist_user_conversation_details( topic_summary: Optional[str], ) -> None: """Associate conversation to user in the database.""" + # Normalize the conversation ID (strip 'conv_' prefix if present) + normalized_id = normalize_conversation_id(conversation_id) + logger.debug( + "persist_user_conversation_details - original conv_id: %s, normalized: %s, user: %s", + conversation_id, + normalized_id, + user_id, + ) + with get_session() as session: existing_conversation = ( - session.query(UserConversation).filter_by(id=conversation_id).first() + session.query(UserConversation).filter_by(id=normalized_id).first() ) if not existing_conversation: conversation = UserConversation( - id=conversation_id, + id=normalized_id, user_id=user_id, last_used_model=model, last_used_provider=provider_id, @@ -125,15 +135,24 @@ def persist_user_conversation_details( ) session.add(conversation) logger.debug( - "Associated conversation %s to user %s", conversation_id, user_id + "Associated conversation %s to user %s", normalized_id, user_id ) else: existing_conversation.last_used_model = model existing_conversation.last_used_provider = provider_id existing_conversation.last_message_at = datetime.now(UTC) existing_conversation.message_count += 1 + logger.debug( + "Updating existing conversation in DB - ID: %s, User: %s, Messages: %d", + normalized_id, + user_id, + existing_conversation.message_count, + ) session.commit() + logger.debug( + "Successfully committed conversation %s to database", normalized_id + ) def evaluate_model_hints( @@ -257,9 +276,13 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 logger.debug( "Conversation ID specified in query: %s", query_request.conversation_id ) + # Normalize the conversation ID for database lookup (strip conv_ prefix if present) + normalized_conv_id_for_lookup = normalize_conversation_id( + query_request.conversation_id + ) user_conversation = validate_conversation_ownership( user_id=user_id, - conversation_id=query_request.conversation_id, + conversation_id=normalized_conv_id_for_lookup, others_allowed=( Action.QUERY_OTHERS_CONVERSATIONS in request.state.authorized_actions ), diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index 5e0a8c87c..cbaa3b987 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -439,15 +439,13 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche conversation_id, ) - # Normalize conversation ID before returning (remove conv_ prefix for consistency) - normalized_conversation_id = ( - normalize_conversation_id(conversation_id) - if conversation_id - else conversation_id + return ( + summary, + normalize_conversation_id(conversation_id), + referenced_documents, + token_usage, ) - return (summary, normalized_conversation_id, referenced_documents, token_usage) - def parse_referenced_documents_from_responses_api( response: OpenAIResponseObject, # pylint: disable=unused-argument diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py index eb4e73c5a..aac676f23 100644 --- a/src/app/endpoints/streaming_query_v2.py +++ b/src/app/endpoints/streaming_query_v2.py @@ -456,6 +456,6 @@ async def retrieve_response( # pylint: disable=too-many-locals response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response) # async for chunk in response_stream: # logger.error("Chunk: %s", chunk.model_dump_json()) - # Return the normalized conversation_id (already normalized above) + # Return the normalized conversation_id # The response_generator will emit it in the start event - return response_stream, conversation_id + return response_stream, normalize_conversation_id(conversation_id) diff --git a/src/utils/suid.py b/src/utils/suid.py index 0c5742e5c..23d1b46af 100644 --- a/src/utils/suid.py +++ b/src/utils/suid.py @@ -23,75 +23,31 @@ def check_suid(suid: str) -> bool: Returns True if the string is a valid UUID or a llama-stack conversation ID. Parameters: - suid (str | bytes): UUID value to validate — accepts a UUID string, - its byte representation, or a llama-stack conversation ID (conv_xxx), - or a plain hex string (database format). + suid (str): UUID value to validate — accepts a UUID string, + or a llama-stack conversation ID (48-char hex, optionally with conv_ prefix). Notes: - Validation is performed by: - 1. For llama-stack conversation IDs starting with 'conv_': - - Strips the 'conv_' prefix - - Validates at least 32 hex characters follow (may have additional suffix) - - Extracts first 32 hex chars as the UUID part - - Converts to UUID format by inserting hyphens at standard positions - - Validates the resulting UUID structure - 2. For plain hex strings (database format, 32+ chars without conv_ prefix): - - Validates it's a valid hex string - - Extracts first 32 chars as UUID part - - Converts to UUID format and validates - 3. For standard UUIDs: attempts to construct uuid.UUID(suid) - Invalid formats or types result in False. + Validation accepts: + 1. Standard UUID format (e.g., '550e8400-e29b-41d4-a716-446655440000') + 2. 48-character hex string (llama-stack format) + 3. 'conv_' prefix + 48-character hex string (53 chars total) """ - try: - # Accept llama-stack conversation IDs (conv_ format) - if isinstance(suid, str) and suid.startswith("conv_"): - # Extract the hex string after 'conv_' - hex_part = suid[5:] # Remove 'conv_' prefix - - # Verify it's a valid hex string - # llama-stack may use 32 hex chars (UUID) or 36 hex chars (UUID + suffix) - if len(hex_part) < 32: - return False - - # Verify all characters are valid hex - try: - int(hex_part, 16) - except ValueError: - return False - - # Extract the first 32 hex characters (the UUID part) - uuid_hex = hex_part[:32] - - # Convert to UUID format with hyphens: 8-4-4-4-12 - uuid_str = ( - f"{uuid_hex[:8]}-{uuid_hex[8:12]}-{uuid_hex[12:16]}-" - f"{uuid_hex[16:20]}-{uuid_hex[20:]}" - ) - - # Validate it's a proper UUID - uuid.UUID(uuid_str) + if not isinstance(suid, str): + return False + + # Strip 'conv_' prefix if present + hex_part = suid[5:] if suid.startswith("conv_") else suid + + # Check for 48-char hex string (llama-stack conversation ID format) + if len(hex_part) == 48: + try: + int(hex_part, 16) return True + except ValueError: + return False - # Check if it's a plain hex string (database format without conv_ prefix) - if isinstance(suid, str) and len(suid) >= 32: - try: - int(suid, 16) - # Extract the first 32 hex characters (the UUID part) - uuid_hex = suid[:32] - - # Convert to UUID format with hyphens: 8-4-4-4-12 - uuid_str = ( - f"{uuid_hex[:8]}-{uuid_hex[8:12]}-{uuid_hex[12:16]}-" - f"{uuid_hex[16:20]}-{uuid_hex[20:]}" - ) - - # Validate it's a proper UUID - uuid.UUID(uuid_str) - return True - except ValueError: - pass # Not a valid hex string, try standard UUID validation - - # accepts strings and bytes only for UUID validation + # Check for standard UUID format + try: uuid.UUID(suid) return True except (ValueError, TypeError): diff --git a/tests/unit/utils/test_suid.py b/tests/unit/utils/test_suid.py index 65722d299..42fdb0724 100644 --- a/tests/unit/utils/test_suid.py +++ b/tests/unit/utils/test_suid.py @@ -1,5 +1,9 @@ """Unit tests for functions defined in utils.suid module.""" +from typing import Any + +import pytest + from utils import suid @@ -12,16 +16,62 @@ def test_get_suid(self) -> None: assert suid.check_suid(suid_value), "Generated SUID is not valid" assert isinstance(suid_value, str), "SUID should be a string" - def test_check_suid_valid(self) -> None: + def test_check_suid_valid_uuid(self) -> None: """Test that check_suid returns True for a valid UUID.""" valid_suid = "123e4567-e89b-12d3-a456-426614174000" + assert suid.check_suid(valid_suid), "check_suid should return True for UUID" + + def test_check_suid_valid_48char_hex(self) -> None: + """Test that check_suid returns True for a 48-char hex string.""" + valid_hex = "e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c" + assert len(valid_hex) == 48 assert suid.check_suid( - valid_suid - ), "check_suid should return True for a valid SUID" + valid_hex + ), "check_suid should return True for 48-char hex" + + def test_check_suid_valid_conv_prefix(self) -> None: + """Test that check_suid returns True for conv_ + 48-char hex string.""" + valid_conv = "conv_e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c" + assert len(valid_conv) == 53 + assert suid.check_suid( + valid_conv + ), "check_suid should return True for conv_ prefixed hex" + + def test_check_suid_invalid_string(self) -> None: + """Test that check_suid returns False for an invalid string.""" + assert not suid.check_suid("invalid-uuid") - def test_check_suid_invalid(self) -> None: - """Test that check_suid returns False for an invalid UUID.""" - invalid_suid = "invalid-uuid" + def test_check_suid_valid_32char_hex_uuid(self) -> None: + """Test that check_suid returns True for 32-char hex (valid UUID format).""" + # 32-char hex is a valid UUID format (without hyphens) + assert suid.check_suid("e6afd7aaa97b49ce8f4f96a801b07893") + + def test_check_suid_invalid_hex_wrong_length(self) -> None: + """Test that check_suid returns False for hex string with wrong length.""" + # 47 chars (not 48, not valid UUID) + assert not suid.check_suid("e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3") + # 49 chars (not 48, not valid UUID) + assert not suid.check_suid("e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c1") + + def test_check_suid_invalid_conv_prefix_wrong_length(self) -> None: + """Test that check_suid returns False for conv_ with wrong hex length.""" + # conv_ + 47 chars (not 48) + assert not suid.check_suid( + "conv_e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3" + ) + # conv_ + 49 chars (not 48) assert not suid.check_suid( - invalid_suid - ), "check_suid should return False for an invalid SUID" + "conv_e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c1" + ) + + def test_check_suid_invalid_non_hex_chars(self) -> None: + """Test that check_suid returns False for strings with non-hex characters.""" + # 48 chars but contains 'g' and 'z' + invalid_hex = "g6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53ezz" + assert len(invalid_hex) == 48 + assert not suid.check_suid(invalid_hex) + + @pytest.mark.parametrize("invalid_type", [None, 123, [], {}]) + def test_check_suid_invalid_type(self, invalid_type: Any) -> None: + """Test that check_suid returns False for non-string types.""" + assert not suid.check_suid(invalid_type)