diff --git a/chat_api/chats/chats_services.py b/chat_api/chats/chats_services.py index c0c3e1e..7b92c73 100644 --- a/chat_api/chats/chats_services.py +++ b/chat_api/chats/chats_services.py @@ -6,6 +6,7 @@ from chat_api.chats.chats_reponse_model import ChatRequest, ChatUserQuery, chatRequestPayload, ChatResponsePayload from chat_api.error_contant import ErrorConstant,ResponseError from fastapi import HTTPException +from chat_api.threads.thread_response_model import ThreadResponse from chat_api.chats.models import Chat from chat_api.db import SessionLocal @@ -13,6 +14,10 @@ from chat_api.threads.thread_service import create_thread from chat_api.threads.threads_request_model import ThreadCreateRequest + +from chat_api.threads.thread_service import get_thread_by_id + + from chat_api.auth_utils import get_user_email_from_token def merge_token_items(chat_list: list) -> list: @@ -39,8 +44,13 @@ def merge_token_items(chat_list: list) -> list: async def get_chat_stream(token: str, chat_request: ChatRequest): email = get_user_email_from_token(token) - user_query = ChatUserQuery(role="user", content=chat_request.query) - chat_request_payload = chatRequestPayload(messages=[user_query]).model_dump() + if chat_request.thread_id is not None: + thread: ThreadResponse = await get_thread_by_id(chat_request.thread_id) + user_query_payload = get_previous_chats(thread) + else: + user_query_payload = chatRequestPayload(messages=[ChatUserQuery(role="user", content=chat_request.query)]) + + chat_request_payload = user_query_payload.model_dump() url = get("OPENPECHA_AI_URL") chat_list = [] @@ -85,4 +95,9 @@ def sse_frame_from_line( payload = line[len("data:") :].strip() if line.startswith("data:") else line on_json(json.loads(payload)) return (f"data: {payload}\n\n").encode("utf-8") - \ No newline at end of file + +def get_previous_chats(thread: ThreadResponse) -> chatRequestPayload: + messages = thread.messages + messages = [ChatUserQuery(role=message.role, content=message.content) for message in messages] + chat_request_payload = chatRequestPayload(messages=messages) + return chat_request_payload \ No newline at end of file diff --git a/tests/chats/test_chats_services.py b/tests/chats/test_chats_services.py index 7a7c90e..d42cf92 100644 --- a/tests/chats/test_chats_services.py +++ b/tests/chats/test_chats_services.py @@ -127,11 +127,18 @@ async def _collect(): @patch("chat_api.chats.chats_services.save_chat") @patch("chat_api.chats.chats_services.SessionLocal") @patch("chat_api.chats.chats_services.create_thread") +@patch("chat_api.chats.chats_services.get_thread_by_id") @patch("chat_api.chats.chats_services.httpx.AsyncClient") def test_get_chat_stream_uses_existing_thread_id( - mock_async_client, mock_create_thread, mock_sessionlocal, mock_save_chat + mock_async_client, mock_get_thread_by_id, mock_create_thread, mock_sessionlocal, mock_save_chat ) -> None: existing_thread_id = str(uuid4()) + mock_get_thread_by_id.return_value = MagicMock( + messages=[ + MagicMock(role="user", content="previous question"), + MagicMock(role="assistant", content="previous answer"), + ] + ) stream_response = MagicMock() @@ -170,9 +177,19 @@ async def _collect(): # Should not create a new thread mock_create_thread.assert_not_called() + mock_get_thread_by_id.assert_awaited_once() mock_save_chat.assert_called_once() assert any(existing_thread_id.encode("utf-8") in c for c in chunks) + # Should use previous thread messages (not just current query) as the outbound payload + _, _, call_kwargs = client_instance.stream.mock_calls[0] + assert call_kwargs["json"] == { + "messages": [ + {"role": "user", "content": "previous question"}, + {"role": "assistant", "content": "previous answer"}, + ] + } + @patch("chat_api.chats.chats_services.save_chat") @patch("chat_api.chats.chats_services.SessionLocal") diff --git a/tests/threads/test_threads_services.py b/tests/threads/test_threads_services.py index 63968a7..ce5c4ba 100644 --- a/tests/threads/test_threads_services.py +++ b/tests/threads/test_threads_services.py @@ -7,8 +7,8 @@ from chat_api.threads.models import DeviceType from chat_api.threads.threads_request_model import ThreadCreateRequest -from chat_api.threads.thread_service import create_thread, ThreadService -from chat_api.threads.thread_repository import ThreadRepository +from chat_api.threads.thread_enums import MessageRole +from chat_api.threads.thread_service import create_thread, transform_chats_to_messages def _sessionlocal_cm(session): @@ -78,9 +78,6 @@ def test_create_thread_application_not_found( def test_transform_chats_to_messages_with_new_format() -> None: """Test transform_chats_to_messages with new list-based response format.""" - mock_repository = MagicMock(spec=ThreadRepository) - service = ThreadService(mock_repository) - chat_id = uuid4() mock_chat = MagicMock() mock_chat.id = chat_id @@ -112,18 +109,18 @@ def test_transform_chats_to_messages_with_new_format() -> None: } ] - messages = service.transform_chats_to_messages([mock_chat]) + messages = transform_chats_to_messages([mock_chat]) assert len(messages) == 2 # user message + assistant message # Check user message - assert messages[0].role == "user" + assert messages[0].role == MessageRole.USER assert messages[0].content == "what is emptiness" assert messages[0].id == chat_id assert messages[0].searchResults is None # Check assistant message - assert messages[1].role == "assistant" + assert messages[1].role == MessageRole.ASSISTANT assert messages[1].content == "Emptiness refers to the lack of inherent existence." assert messages[1].id == chat_id assert messages[1].searchResults is not None @@ -135,9 +132,6 @@ def test_transform_chats_to_messages_with_new_format() -> None: def test_transform_chats_to_messages_with_old_format() -> None: """Test transform_chats_to_messages with old dict-based response format for backward compatibility.""" - mock_repository = MagicMock(spec=ThreadRepository) - service = ThreadService(mock_repository) - chat_id = uuid4() mock_chat = MagicMock() mock_chat.id = chat_id @@ -154,7 +148,7 @@ def test_transform_chats_to_messages_with_old_format() -> None: ] } - messages = service.transform_chats_to_messages([mock_chat]) + messages = transform_chats_to_messages([mock_chat]) assert len(messages) == 2 assert messages[1].content == "test answer" @@ -165,9 +159,6 @@ def test_transform_chats_to_messages_with_old_format() -> None: def test_transform_chats_to_messages_no_search_results() -> None: """Test transform_chats_to_messages when there are no search results.""" - mock_repository = MagicMock(spec=ThreadRepository) - service = ThreadService(mock_repository) - chat_id = uuid4() mock_chat = MagicMock() mock_chat.id = chat_id @@ -184,7 +175,7 @@ def test_transform_chats_to_messages_no_search_results() -> None: } ] - messages = service.transform_chats_to_messages([mock_chat]) + messages = transform_chats_to_messages([mock_chat]) assert len(messages) == 2 assert messages[1].content == "Simple answer"