-
Notifications
You must be signed in to change notification settings - Fork 63
[RHDHPAI-978] Topic summary of initial query #564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4ac5e51
60f550d
52f53d2
f3da43e
5c5d85f
fdf269d
2d22af0
9c9a1b4
ed95e01
c9769e8
c91471a
a44bf7b
2978081
17c6518
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,6 +46,8 @@ | |
| from utils.endpoints import ( | ||
| check_configuration_loaded, | ||
| get_agent, | ||
| get_topic_summary_system_prompt, | ||
| get_temp_agent, | ||
| get_system_prompt, | ||
| store_conversation_into_cache, | ||
| validate_conversation_ownership, | ||
|
|
@@ -98,7 +100,11 @@ def is_transcripts_enabled() -> bool: | |
|
|
||
|
|
||
| def persist_user_conversation_details( | ||
| user_id: str, conversation_id: str, model: str, provider_id: str | ||
| user_id: str, | ||
| conversation_id: str, | ||
| model: str, | ||
| provider_id: str, | ||
| topic_summary: Optional[str], | ||
| ) -> None: | ||
|
Comment on lines
+103
to
108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard against None for non-null DB column topic_summary. UserConversation.topic_summary is defined as non-null with default="". Passing None here can violate NOT NULL. Normalize to empty string at creation and modernize the type hint. Apply: -def persist_user_conversation_details(
- user_id: str,
- conversation_id: str,
- model: str,
- provider_id: str,
- topic_summary: Optional[str],
-) -> None:
+def persist_user_conversation_details(
+ user_id: str,
+ conversation_id: str,
+ model: str,
+ provider_id: str,
+ topic_summary: str | None,
+) -> None:
@@
- conversation = UserConversation(
+ conversation = UserConversation(
id=conversation_id,
user_id=user_id,
last_used_model=model,
last_used_provider=provider_id,
- topic_summary=topic_summary,
+ topic_summary=topic_summary or "",
message_count=1,
)Also applies to: 116-123 🤖 Prompt for AI Agents |
||
| """Associate conversation to user in the database.""" | ||
| with get_session() as session: | ||
|
|
@@ -112,6 +118,7 @@ def persist_user_conversation_details( | |
| user_id=user_id, | ||
| last_used_model=model, | ||
| last_used_provider=provider_id, | ||
| topic_summary=topic_summary, | ||
| message_count=1, | ||
| ) | ||
| session.add(conversation) | ||
|
|
@@ -169,9 +176,42 @@ def evaluate_model_hints( | |
| return model_id, provider_id | ||
|
|
||
|
|
||
| async def get_topic_summary( | ||
| question: str, client: AsyncLlamaStackClient, model_id: str | ||
| ) -> str: | ||
| """Get a topic summary for a question. | ||
|
|
||
| Args: | ||
| question: The question to be validated. | ||
| client: The AsyncLlamaStackClient to use for the request. | ||
| model_id: The ID of the model to use. | ||
| Returns: | ||
| str: The topic summary for the question. | ||
| """ | ||
| topic_summary_system_prompt = get_topic_summary_system_prompt(configuration) | ||
| agent, session_id, _ = await get_temp_agent( | ||
| client, model_id, topic_summary_system_prompt | ||
| ) | ||
| response = await agent.create_turn( | ||
| messages=[UserMessage(role="user", content=question)], | ||
| session_id=session_id, | ||
| stream=False, | ||
| toolgroups=None, | ||
| ) | ||
| response = cast(Turn, response) | ||
| return ( | ||
| interleaved_content_as_str(response.output_message.content) | ||
| if ( | ||
| getattr(response, "output_message", None) is not None | ||
| and getattr(response.output_message, "content", None) is not None | ||
| ) | ||
| else "" | ||
| ) | ||
|
|
||
|
|
||
| @router.post("/query", responses=query_response) | ||
| @authorize(Action.QUERY) | ||
| async def query_endpoint_handler( | ||
| async def query_endpoint_handler( # pylint: disable=R0914 | ||
| request: Request, | ||
| query_request: QueryRequest, | ||
| auth: Annotated[AuthTuple, Depends(auth_dependency)], | ||
|
|
@@ -200,7 +240,7 @@ async def query_endpoint_handler( | |
| # log Llama Stack configuration | ||
| logger.info("Llama stack config: %s", configuration.llama_stack_configuration) | ||
|
|
||
| user_id, _, _, token = auth | ||
| user_id, _, _skip_userid_check, token = auth | ||
|
|
||
| user_conversation: UserConversation | None = None | ||
| if query_request.conversation_id: | ||
|
|
@@ -251,6 +291,16 @@ async def query_endpoint_handler( | |
| # Update metrics for the LLM call | ||
| metrics.llm_calls_total.labels(provider_id, model_id).inc() | ||
|
|
||
| # Get the initial topic summary for the conversation | ||
| topic_summary = None | ||
| with get_session() as session: | ||
| existing_conversation = ( | ||
| session.query(UserConversation).filter_by(id=conversation_id).first() | ||
| ) | ||
| if not existing_conversation: | ||
| topic_summary = await get_topic_summary( | ||
| query_request.query, client, model_id | ||
| ) | ||
|
Comment on lines
+294
to
+303
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use llama_stack_model_id instead of model_id for topic summary generation. The code passes Apply this diff: # Get the initial topic summary for the conversation
topic_summary = None
with get_session() as session:
existing_conversation = (
session.query(UserConversation).filter_by(id=conversation_id).first()
)
if not existing_conversation:
topic_summary = await get_topic_summary(
- query_request.query, client, model_id
+ query_request.query, client, llama_stack_model_id
)
🤖 Prompt for AI Agents |
||
| # Convert RAG chunks to dictionary format once for reuse | ||
| logger.info("Processing RAG chunks...") | ||
| rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks] | ||
|
|
@@ -278,6 +328,7 @@ async def query_endpoint_handler( | |
| conversation_id=conversation_id, | ||
| model=model_id, | ||
| provider_id=provider_id, | ||
| topic_summary=topic_summary, | ||
| ) | ||
|
|
||
| store_conversation_into_cache( | ||
|
|
@@ -288,6 +339,8 @@ async def query_endpoint_handler( | |
| model_id, | ||
| query_request.query, | ||
| summary.llm_response, | ||
| _skip_userid_check, | ||
| topic_summary, | ||
| ) | ||
|
|
||
| # Convert tool calls to response format | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix type inconsistency in example data.
The
last_message_timestampfield is defined as typenumberin theConversationDataschema (line 1442), but the example shows a string value"2024-01-01T00:00:00Z". This inconsistency could mislead API consumers.Apply this diff to use a numeric timestamp:
"conversations": [ { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", "topic_summary": "This is a topic summary", - "last_message_timestamp": "2024-01-01T00:00:00Z" + "last_message_timestamp": 1704067200 } ]Alternatively, if timestamps should be ISO 8601 strings throughout the API, update the schema definition at line 1442 to use
"type": "string"with"format": "date-time".📝 Committable suggestion
🤖 Prompt for AI Agents