Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions scripts/generate_openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
# it is needed to read proper configuration in order to start the app to generate schema
from configuration import configuration

from client import LlamaStackClientHolder
from client import AsyncLlamaStackClientHolder

cfg_file = "lightspeed-stack.yaml"
configuration.load_configuration(cfg_file)

# Llama Stack client needs to be loaded before REST API is fully initialized
LlamaStackClientHolder().load(configuration.configuration.llama_stack)
import asyncio # noqa: E402

asyncio.run(AsyncLlamaStackClientHolder().load(configuration.configuration.llama_stack))

from app.main import app # noqa: E402 pylint: disable=C0413

Expand Down
9 changes: 5 additions & 4 deletions src/app/endpoints/authorized.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Handler for REST API call to authorized endpoint."""

import asyncio
import logging
from typing import Any

from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends

from auth import get_auth_dependency
from models.responses import AuthorizedResponse, UnauthorizedResponse, ForbiddenResponse
Expand All @@ -31,8 +30,10 @@


@router.post("/authorized", responses=authorized_responses)
def authorized_endpoint_handler(_request: Request) -> AuthorizedResponse:
async def authorized_endpoint_handler(
auth: Any = Depends(auth_dependency),
) -> AuthorizedResponse:
"""Handle request to the /authorized endpoint."""
# Ignore the user token, we should not return it in the response
user_id, user_name, _ = asyncio.run(auth_dependency(_request))
user_id, user_name, _ = auth
return AuthorizedResponse(user_id=user_id, username=user_name)
16 changes: 9 additions & 7 deletions src/app/endpoints/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fastapi import APIRouter, HTTPException, status, Depends

from client import LlamaStackClientHolder
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from models.responses import ConversationResponse, ConversationDeleteResponse
from auth import get_auth_dependency
Expand Down Expand Up @@ -110,7 +110,7 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:


@router.get("/conversations/{conversation_id}", responses=conversation_responses)
def get_conversation_endpoint_handler(
async def get_conversation_endpoint_handler(
conversation_id: str,
_auth: Any = Depends(auth_dependency),
) -> ConversationResponse:
Expand All @@ -132,9 +132,9 @@ def get_conversation_endpoint_handler(
logger.info("Retrieving conversation %s", conversation_id)

try:
client = LlamaStackClientHolder().get_client()
client = AsyncLlamaStackClientHolder().get_client()

session_data = client.agents.session.list(agent_id=agent_id).data[0]
session_data = (await client.agents.session.list(agent_id=agent_id)).data[0]

logger.info("Successfully retrieved conversation %s", conversation_id)

Expand Down Expand Up @@ -179,7 +179,7 @@ def get_conversation_endpoint_handler(
@router.delete(
"/conversations/{conversation_id}", responses=conversation_delete_responses
)
def delete_conversation_endpoint_handler(
async def delete_conversation_endpoint_handler(
conversation_id: str,
_auth: Any = Depends(auth_dependency),
) -> ConversationDeleteResponse:
Expand All @@ -201,10 +201,12 @@ def delete_conversation_endpoint_handler(

try:
# Get Llama Stack client
client = LlamaStackClientHolder().get_client()
client = AsyncLlamaStackClientHolder().get_client()
# Delete session using the conversation_id as session_id
# In this implementation, conversation_id and session_id are the same
client.agents.session.delete(agent_id=agent_id, session_id=conversation_id)
await client.agents.session.delete(
agent_id=agent_id, session_id=conversation_id
)

logger.info("Successfully deleted conversation %s", conversation_id)

Expand Down
8 changes: 4 additions & 4 deletions src/app/endpoints/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from llama_stack_client import APIConnectionError
from fastapi import APIRouter, HTTPException, Request, status

from client import LlamaStackClientHolder
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from models.responses import ModelsResponse
from utils.endpoints import check_configuration_loaded
Expand Down Expand Up @@ -43,7 +43,7 @@


@router.get("/models", responses=models_responses)
def models_endpoint_handler(_request: Request) -> ModelsResponse:
async def models_endpoint_handler(_request: Request) -> ModelsResponse:
"""Handle requests to the /models endpoint."""
check_configuration_loaded(configuration)

Expand All @@ -52,9 +52,9 @@ def models_endpoint_handler(_request: Request) -> ModelsResponse:

try:
# try to get Llama Stack client
client = LlamaStackClientHolder().get_client()
client = AsyncLlamaStackClientHolder().get_client()
# retrieve models
models = client.models.list()
models = await client.models.list()
m = [dict(m) for m in models]
return ModelsResponse(models=m)

Expand Down
85 changes: 21 additions & 64 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
"""Handler for REST API call to provide answer to query."""

from contextlib import suppress
from datetime import datetime, UTC
import json
import logging
import os
from pathlib import Path
from typing import Annotated, Any

from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client import APIConnectionError
from llama_stack_client import LlamaStackClient # type: ignore
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage, Shield # type: ignore
from llama_stack_client.types.agents.turn_create_params import (
ToolgroupAgentToolGroupWithArgs,
Expand All @@ -20,18 +18,17 @@

from fastapi import APIRouter, HTTPException, status, Depends

from client import LlamaStackClientHolder
from auth import get_auth_dependency
from auth.interface import AuthTuple
from client import AsyncLlamaStackClientHolder
from configuration import configuration
import metrics
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
from models.requests import QueryRequest, Attachment
import constants
from auth import get_auth_dependency
from auth.interface import AuthTuple
from utils.endpoints import check_configuration_loaded, get_system_prompt
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.suid import get_suid
from utils.types import GraniteToolParser

logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["query"])
Expand Down Expand Up @@ -68,53 +65,8 @@ def is_transcripts_enabled() -> bool:
return configuration.user_data_collection_configuration.transcripts_enabled


def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
client: LlamaStackClient,
model_id: str,
system_prompt: str,
available_input_shields: list[str],
available_output_shields: list[str],
conversation_id: str | None,
no_tools: bool = False,
) -> tuple[Agent, str, str]:
"""Get existing agent or create a new one with session persistence."""
existing_agent_id = None
if conversation_id:
with suppress(ValueError):
existing_agent_id = client.agents.retrieve(
agent_id=conversation_id
).agent_id

logger.debug("Creating new agent")
# TODO(lucasagomes): move to ReActAgent
agent = Agent(
client,
model=model_id,
instructions=system_prompt,
input_shields=available_input_shields if available_input_shields else [],
output_shields=available_output_shields if available_output_shields else [],
tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id),
enable_session_persistence=True,
)

agent.initialize()

if existing_agent_id and conversation_id:
orphan_agent_id = agent.agent_id
agent.agent_id = conversation_id
client.agents.delete(agent_id=orphan_agent_id)
sessions_response = client.agents.session.list(agent_id=conversation_id)
logger.info("session response: %s", sessions_response)
session_id = str(sessions_response.data[0]["session_id"])
else:
conversation_id = agent.agent_id
session_id = agent.create_session(get_suid())

return agent, conversation_id, session_id


@router.post("/query", responses=query_response)
def query_endpoint_handler(
async def query_endpoint_handler(
query_request: QueryRequest,
auth: Annotated[AuthTuple, Depends(auth_dependency)],
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
Expand All @@ -129,11 +81,11 @@ def query_endpoint_handler(

try:
# try to get Llama Stack client
client = LlamaStackClientHolder().get_client()
client = AsyncLlamaStackClientHolder().get_client()
model_id, provider_id = select_model_and_provider_id(
client.models.list(), query_request
await client.models.list(), query_request
)
response, conversation_id = retrieve_response(
response, conversation_id = await retrieve_response(
client,
model_id,
query_request,
Expand Down Expand Up @@ -253,19 +205,21 @@ def is_input_shield(shield: Shield) -> bool:
return _is_inout_shield(shield) or not is_output_shield(shield)


def retrieve_response( # pylint: disable=too-many-locals
client: LlamaStackClient,
async def retrieve_response( # pylint: disable=too-many-locals
client: AsyncLlamaStackClient,
model_id: str,
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[str, str]:
"""Retrieve response from LLMs and agents."""
available_input_shields = [
shield.identifier for shield in filter(is_input_shield, client.shields.list())
shield.identifier
for shield in filter(is_input_shield, await client.shields.list())
]
available_output_shields = [
shield.identifier for shield in filter(is_output_shield, client.shields.list())
shield.identifier
for shield in filter(is_output_shield, await client.shields.list())
]
if not available_input_shields and not available_output_shields:
logger.info("No available shields. Disabling safety")
Expand All @@ -284,7 +238,7 @@ def retrieve_response( # pylint: disable=too-many-locals
if query_request.attachments:
validate_attachments_metadata(query_request.attachments)

agent, conversation_id, session_id = get_agent(
agent, conversation_id, session_id = await get_agent(
client,
model_id,
system_prompt,
Expand All @@ -294,6 +248,7 @@ def retrieve_response( # pylint: disable=too-many-locals
query_request.no_tools or False,
)

logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id)
# bypass tools and MCP servers if no_tools is True
if query_request.no_tools:
mcp_headers = {}
Expand All @@ -318,15 +273,17 @@ def retrieve_response( # pylint: disable=too-many-locals
),
}

vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
vector_db_ids = [
vector_db.identifier for vector_db in await client.vector_dbs.list()
]
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]
# Convert empty list to None for consistency with existing behavior
if not toolgroups:
toolgroups = None

response = agent.create_turn(
response = await agent.create_turn(
messages=[UserMessage(role="user", content=query_request.query)],
session_id=session_id,
documents=query_request.get_documents(),
Expand Down
53 changes: 2 additions & 51 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""Handler for REST API call to provide answer to streaming query."""

import ast
from contextlib import suppress
import json
import re
import logging
from typing import Annotated, Any, AsyncIterator, Iterator

from llama_stack_client import APIConnectionError
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage # type: ignore

Expand All @@ -25,10 +23,8 @@
from configuration import configuration
import metrics
from models.requests import QueryRequest
from utils.endpoints import check_configuration_loaded, get_system_prompt
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.suid import get_suid
from utils.types import GraniteToolParser

from app.endpoints.query import (
get_rag_toolgroups,
Expand All @@ -45,50 +41,6 @@
auth_dependency = get_auth_dependency()


# # pylint: disable=R0913,R0917
async def get_agent(
client: AsyncLlamaStackClient,
model_id: str,
system_prompt: str,
available_input_shields: list[str],
available_output_shields: list[str],
conversation_id: str | None,
no_tools: bool = False,
) -> tuple[AsyncAgent, str, str]:
"""Get existing agent or create a new one with session persistence."""
existing_agent_id = None
if conversation_id:
with suppress(ValueError):
agent_response = await client.agents.retrieve(agent_id=conversation_id)
existing_agent_id = agent_response.agent_id

logger.debug("Creating new agent")
agent = AsyncAgent(
client, # type: ignore[arg-type]
model=model_id,
instructions=system_prompt,
input_shields=available_input_shields if available_input_shields else [],
output_shields=available_output_shields if available_output_shields else [],
tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id),
enable_session_persistence=True,
)

await agent.initialize()

if existing_agent_id and conversation_id:
orphan_agent_id = agent.agent_id
agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access
await client.agents.delete(agent_id=orphan_agent_id)
sessions_response = await client.agents.session.list(agent_id=conversation_id)
logger.info("session response: %s", sessions_response)
session_id = str(sessions_response.data[0]["session_id"])
else:
conversation_id = agent.agent_id
session_id = await agent.create_session(get_suid())

return agent, conversation_id, session_id


METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")


Expand Down Expand Up @@ -556,8 +508,7 @@ async def retrieve_response(
query_request.no_tools or False,
)

logger.debug("Session ID: %s", conversation_id)

logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id)
# bypass tools and MCP servers if no_tools is True
if query_request.no_tools:
mcp_headers = {}
Expand Down
Loading