Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
consume_tokens,
get_available_quotas,
)
from utils.suid import normalize_conversation_id
from utils.token_counter import TokenCounter, extract_and_update_token_metrics
from utils.transcripts import store_transcript
from utils.types import TurnSummary, content_to_str
Expand Down Expand Up @@ -109,14 +110,23 @@ def persist_user_conversation_details(
topic_summary: Optional[str],
) -> None:
"""Associate conversation to user in the database."""
# Normalize the conversation ID (strip 'conv_' prefix if present)
normalized_id = normalize_conversation_id(conversation_id)
logger.debug(
"persist_user_conversation_details - original conv_id: %s, normalized: %s, user: %s",
conversation_id,
normalized_id,
user_id,
)

with get_session() as session:
existing_conversation = (
session.query(UserConversation).filter_by(id=conversation_id).first()
session.query(UserConversation).filter_by(id=normalized_id).first()
)

if not existing_conversation:
conversation = UserConversation(
id=conversation_id,
id=normalized_id,
user_id=user_id,
last_used_model=model,
last_used_provider=provider_id,
Expand All @@ -125,15 +135,24 @@ def persist_user_conversation_details(
)
session.add(conversation)
logger.debug(
"Associated conversation %s to user %s", conversation_id, user_id
"Associated conversation %s to user %s", normalized_id, user_id
)
else:
existing_conversation.last_used_model = model
existing_conversation.last_used_provider = provider_id
existing_conversation.last_message_at = datetime.now(UTC)
existing_conversation.message_count += 1
logger.debug(
"Updating existing conversation in DB - ID: %s, User: %s, Messages: %d",
normalized_id,
user_id,
existing_conversation.message_count,
)

session.commit()
logger.debug(
"Successfully committed conversation %s to database", normalized_id
)


def evaluate_model_hints(
Expand Down Expand Up @@ -257,9 +276,13 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
logger.debug(
"Conversation ID specified in query: %s", query_request.conversation_id
)
# Normalize the conversation ID for database lookup (strip conv_ prefix if present)
normalized_conv_id_for_lookup = normalize_conversation_id(
query_request.conversation_id
)
user_conversation = validate_conversation_ownership(
user_id=user_id,
conversation_id=query_request.conversation_id,
conversation_id=normalized_conv_id_for_lookup,
others_allowed=(
Action.QUERY_OTHERS_CONVERSATIONS in request.state.authorized_actions
),
Expand Down
12 changes: 5 additions & 7 deletions src/app/endpoints/query_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,13 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
conversation_id,
)

# Normalize conversation ID before returning (remove conv_ prefix for consistency)
normalized_conversation_id = (
normalize_conversation_id(conversation_id)
if conversation_id
else conversation_id
return (
summary,
normalize_conversation_id(conversation_id),
referenced_documents,
token_usage,
)

return (summary, normalized_conversation_id, referenced_documents, token_usage)


def parse_referenced_documents_from_responses_api(
response: OpenAIResponseObject, # pylint: disable=unused-argument
Expand Down
4 changes: 2 additions & 2 deletions src/app/endpoints/streaming_query_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,6 @@ async def retrieve_response( # pylint: disable=too-many-locals
response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response)
# async for chunk in response_stream:
# logger.error("Chunk: %s", chunk.model_dump_json())
# Return the normalized conversation_id (already normalized above)
# Return the normalized conversation_id
# The response_generator will emit it in the start event
return response_stream, conversation_id
return response_stream, normalize_conversation_id(conversation_id)
84 changes: 20 additions & 64 deletions src/utils/suid.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,75 +23,31 @@ def check_suid(suid: str) -> bool:
Returns True if the string is a valid UUID or a llama-stack conversation ID.

Parameters:
suid (str | bytes): UUID value to validate — accepts a UUID string,
its byte representation, or a llama-stack conversation ID (conv_xxx),
or a plain hex string (database format).
suid (str): UUID value to validate — accepts a UUID string,
or a llama-stack conversation ID (48-char hex, optionally with conv_ prefix).

Notes:
Validation is performed by:
1. For llama-stack conversation IDs starting with 'conv_':
- Strips the 'conv_' prefix
- Validates at least 32 hex characters follow (may have additional suffix)
- Extracts first 32 hex chars as the UUID part
- Converts to UUID format by inserting hyphens at standard positions
- Validates the resulting UUID structure
2. For plain hex strings (database format, 32+ chars without conv_ prefix):
- Validates it's a valid hex string
- Extracts first 32 chars as UUID part
- Converts to UUID format and validates
3. For standard UUIDs: attempts to construct uuid.UUID(suid)
Invalid formats or types result in False.
Validation accepts:
1. Standard UUID format (e.g., '550e8400-e29b-41d4-a716-446655440000')
2. 48-character hex string (llama-stack format)
3. 'conv_' prefix + 48-character hex string (53 chars total)
"""
try:
# Accept llama-stack conversation IDs (conv_<hex> format)
if isinstance(suid, str) and suid.startswith("conv_"):
# Extract the hex string after 'conv_'
hex_part = suid[5:] # Remove 'conv_' prefix

# Verify it's a valid hex string
# llama-stack may use 32 hex chars (UUID) or 36 hex chars (UUID + suffix)
if len(hex_part) < 32:
return False

# Verify all characters are valid hex
try:
int(hex_part, 16)
except ValueError:
return False

# Extract the first 32 hex characters (the UUID part)
uuid_hex = hex_part[:32]

# Convert to UUID format with hyphens: 8-4-4-4-12
uuid_str = (
f"{uuid_hex[:8]}-{uuid_hex[8:12]}-{uuid_hex[12:16]}-"
f"{uuid_hex[16:20]}-{uuid_hex[20:]}"
)

# Validate it's a proper UUID
uuid.UUID(uuid_str)
if not isinstance(suid, str):
return False

# Strip 'conv_' prefix if present
hex_part = suid[5:] if suid.startswith("conv_") else suid

# Check for 48-char hex string (llama-stack conversation ID format)
if len(hex_part) == 48:
try:
int(hex_part, 16)
return True
except ValueError:
return False

# Check if it's a plain hex string (database format without conv_ prefix)
if isinstance(suid, str) and len(suid) >= 32:
try:
int(suid, 16)
# Extract the first 32 hex characters (the UUID part)
uuid_hex = suid[:32]

# Convert to UUID format with hyphens: 8-4-4-4-12
uuid_str = (
f"{uuid_hex[:8]}-{uuid_hex[8:12]}-{uuid_hex[12:16]}-"
f"{uuid_hex[16:20]}-{uuid_hex[20:]}"
)

# Validate it's a proper UUID
uuid.UUID(uuid_str)
return True
except ValueError:
pass # Not a valid hex string, try standard UUID validation

# accepts strings and bytes only for UUID validation
# Check for standard UUID format
try:
uuid.UUID(suid)
return True
except (ValueError, TypeError):
Expand Down
66 changes: 58 additions & 8 deletions tests/unit/utils/test_suid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""Unit tests for functions defined in utils.suid module."""

from typing import Any

import pytest

from utils import suid


Expand All @@ -12,16 +16,62 @@ def test_get_suid(self) -> None:
assert suid.check_suid(suid_value), "Generated SUID is not valid"
assert isinstance(suid_value, str), "SUID should be a string"

def test_check_suid_valid(self) -> None:
def test_check_suid_valid_uuid(self) -> None:
"""Test that check_suid returns True for a valid UUID."""
valid_suid = "123e4567-e89b-12d3-a456-426614174000"
assert suid.check_suid(valid_suid), "check_suid should return True for UUID"

def test_check_suid_valid_48char_hex(self) -> None:
"""Test that check_suid returns True for a 48-char hex string."""
valid_hex = "e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c"
assert len(valid_hex) == 48
assert suid.check_suid(
valid_suid
), "check_suid should return True for a valid SUID"
valid_hex
), "check_suid should return True for 48-char hex"

def test_check_suid_valid_conv_prefix(self) -> None:
"""Test that check_suid returns True for conv_ + 48-char hex string."""
valid_conv = "conv_e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c"
assert len(valid_conv) == 53
assert suid.check_suid(
valid_conv
), "check_suid should return True for conv_ prefixed hex"

def test_check_suid_invalid_string(self) -> None:
"""Test that check_suid returns False for an invalid string."""
assert not suid.check_suid("invalid-uuid")

def test_check_suid_invalid(self) -> None:
"""Test that check_suid returns False for an invalid UUID."""
invalid_suid = "invalid-uuid"
def test_check_suid_valid_32char_hex_uuid(self) -> None:
"""Test that check_suid returns True for 32-char hex (valid UUID format)."""
# 32-char hex is a valid UUID format (without hyphens)
assert suid.check_suid("e6afd7aaa97b49ce8f4f96a801b07893")

def test_check_suid_invalid_hex_wrong_length(self) -> None:
"""Test that check_suid returns False for hex string with wrong length."""
# 47 chars (not 48, not valid UUID)
assert not suid.check_suid("e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3")
# 49 chars (not 48, not valid UUID)
assert not suid.check_suid("e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c1")

def test_check_suid_invalid_conv_prefix_wrong_length(self) -> None:
"""Test that check_suid returns False for conv_ with wrong hex length."""
# conv_ + 47 chars (not 48)
assert not suid.check_suid(
"conv_e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3"
)
# conv_ + 49 chars (not 48)
assert not suid.check_suid(
invalid_suid
), "check_suid should return False for an invalid SUID"
"conv_e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c1"
)

def test_check_suid_invalid_non_hex_chars(self) -> None:
"""Test that check_suid returns False for strings with non-hex characters."""
# 48 chars but contains 'g' and 'z'
invalid_hex = "g6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53ezz"
assert len(invalid_hex) == 48
assert not suid.check_suid(invalid_hex)

@pytest.mark.parametrize("invalid_type", [None, 123, [], {}])
def test_check_suid_invalid_type(self, invalid_type: Any) -> None:
"""Test that check_suid returns False for non-string types."""
assert not suid.check_suid(invalid_type)
Loading