Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions chat_api/chats/chats_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
from chat_api.chats.chats_reponse_model import ChatRequest, ChatUserQuery, chatRequestPayload, ChatResponsePayload
from chat_api.error_contant import ErrorConstant,ResponseError
from fastapi import HTTPException
from chat_api.threads.thread_response_model import ThreadResponse

from chat_api.chats.models import Chat
from chat_api.db import SessionLocal
from chat_api.chats.chats_repository import save_chat

from chat_api.threads.thread_service import create_thread
from chat_api.threads.threads_request_model import ThreadCreateRequest

from chat_api.threads.thread_service import get_thread_by_id


from chat_api.auth_utils import get_user_email_from_token

def merge_token_items(chat_list: list) -> list:
Expand All @@ -39,8 +44,13 @@ def merge_token_items(chat_list: list) -> list:
async def get_chat_stream(token: str, chat_request: ChatRequest):
email = get_user_email_from_token(token)

user_query = ChatUserQuery(role="user", content=chat_request.query)
chat_request_payload = chatRequestPayload(messages=[user_query]).model_dump()
if chat_request.thread_id is not None:
thread: ThreadResponse = await get_thread_by_id(chat_request.thread_id)
user_query_payload = get_previous_chats(thread)
else:
user_query_payload = chatRequestPayload(messages=[ChatUserQuery(role="user", content=chat_request.query)])

chat_request_payload = user_query_payload.model_dump()
url = get("OPENPECHA_AI_URL")
chat_list = []

Expand Down Expand Up @@ -85,4 +95,9 @@ def sse_frame_from_line(
payload = line[len("data:") :].strip() if line.startswith("data:") else line
on_json(json.loads(payload))
return (f"data: {payload}\n\n").encode("utf-8")


def get_previous_chats(thread: ThreadResponse) -> chatRequestPayload:
messages = thread.messages
messages = [ChatUserQuery(role=message.role, content=message.content) for message in messages]
chat_request_payload = chatRequestPayload(messages=messages)
return chat_request_payload
19 changes: 18 additions & 1 deletion tests/chats/test_chats_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,18 @@ async def _collect():
@patch("chat_api.chats.chats_services.save_chat")
@patch("chat_api.chats.chats_services.SessionLocal")
@patch("chat_api.chats.chats_services.create_thread")
@patch("chat_api.chats.chats_services.get_thread_by_id")
@patch("chat_api.chats.chats_services.httpx.AsyncClient")
def test_get_chat_stream_uses_existing_thread_id(
mock_async_client, mock_create_thread, mock_sessionlocal, mock_save_chat
mock_async_client, mock_get_thread_by_id, mock_create_thread, mock_sessionlocal, mock_save_chat
) -> None:
existing_thread_id = str(uuid4())
mock_get_thread_by_id.return_value = MagicMock(
messages=[
MagicMock(role="user", content="previous question"),
MagicMock(role="assistant", content="previous answer"),
]
)

stream_response = MagicMock()

Expand Down Expand Up @@ -170,9 +177,19 @@ async def _collect():

# Should not create a new thread
mock_create_thread.assert_not_called()
mock_get_thread_by_id.assert_awaited_once()
mock_save_chat.assert_called_once()
assert any(existing_thread_id.encode("utf-8") in c for c in chunks)

# Should use previous thread messages (not just current query) as the outbound payload
_, _, call_kwargs = client_instance.stream.mock_calls[0]
assert call_kwargs["json"] == {
"messages": [
{"role": "user", "content": "previous question"},
{"role": "assistant", "content": "previous answer"},
]
}


@patch("chat_api.chats.chats_services.save_chat")
@patch("chat_api.chats.chats_services.SessionLocal")
Expand Down
23 changes: 7 additions & 16 deletions tests/threads/test_threads_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from chat_api.threads.models import DeviceType
from chat_api.threads.threads_request_model import ThreadCreateRequest
from chat_api.threads.thread_service import create_thread, ThreadService
from chat_api.threads.thread_repository import ThreadRepository
from chat_api.threads.thread_enums import MessageRole
from chat_api.threads.thread_service import create_thread, transform_chats_to_messages


def _sessionlocal_cm(session):
Expand Down Expand Up @@ -78,9 +78,6 @@ def test_create_thread_application_not_found(

def test_transform_chats_to_messages_with_new_format() -> None:
"""Test transform_chats_to_messages with new list-based response format."""
mock_repository = MagicMock(spec=ThreadRepository)
service = ThreadService(mock_repository)

chat_id = uuid4()
mock_chat = MagicMock()
mock_chat.id = chat_id
Expand Down Expand Up @@ -112,18 +109,18 @@ def test_transform_chats_to_messages_with_new_format() -> None:
}
]

messages = service.transform_chats_to_messages([mock_chat])
messages = transform_chats_to_messages([mock_chat])

assert len(messages) == 2 # user message + assistant message

# Check user message
assert messages[0].role == "user"
assert messages[0].role == MessageRole.USER
assert messages[0].content == "what is emptiness"
assert messages[0].id == chat_id
assert messages[0].searchResults is None

# Check assistant message
assert messages[1].role == "assistant"
assert messages[1].role == MessageRole.ASSISTANT
assert messages[1].content == "Emptiness refers to the lack of inherent existence."
assert messages[1].id == chat_id
assert messages[1].searchResults is not None
Expand All @@ -135,9 +132,6 @@ def test_transform_chats_to_messages_with_new_format() -> None:

def test_transform_chats_to_messages_with_old_format() -> None:
"""Test transform_chats_to_messages with old dict-based response format for backward compatibility."""
mock_repository = MagicMock(spec=ThreadRepository)
service = ThreadService(mock_repository)

chat_id = uuid4()
mock_chat = MagicMock()
mock_chat.id = chat_id
Expand All @@ -154,7 +148,7 @@ def test_transform_chats_to_messages_with_old_format() -> None:
]
}

messages = service.transform_chats_to_messages([mock_chat])
messages = transform_chats_to_messages([mock_chat])

assert len(messages) == 2
assert messages[1].content == "test answer"
Expand All @@ -165,9 +159,6 @@ def test_transform_chats_to_messages_with_old_format() -> None:

def test_transform_chats_to_messages_no_search_results() -> None:
"""Test transform_chats_to_messages when there are no search results."""
mock_repository = MagicMock(spec=ThreadRepository)
service = ThreadService(mock_repository)

chat_id = uuid4()
mock_chat = MagicMock()
mock_chat.id = chat_id
Expand All @@ -184,7 +175,7 @@ def test_transform_chats_to_messages_no_search_results() -> None:
}
]

messages = service.transform_chats_to_messages([mock_chat])
messages = transform_chats_to_messages([mock_chat])

assert len(messages) == 2
assert messages[1].content == "Simple answer"
Expand Down
Loading