Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions src/memos/api/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import json
import os
import re
import time
import traceback
Expand All @@ -23,6 +24,7 @@
APIADDRequest,
APIChatCompleteRequest,
APISearchRequest,
ChatBusinessRequest,
ChatPlaygroundRequest,
ChatRequest,
)
Expand Down Expand Up @@ -759,6 +761,195 @@ def generate_chat_response() -> Generator[str, None, None]:
)
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err

def handle_chat_stream_for_business_user(
self, chat_req: ChatBusinessRequest
) -> StreamingResponse:
"""Chat API for business user."""
self.logger.info(f"[ChatBusinessHandler] Chat Req is: {chat_req}")

# Validate business_key permission
business_chat_keys = os.environ.get("BUSINESS_CHAT_KEYS", "[]")
allowed_keys = json.loads(business_chat_keys)

if not allowed_keys or chat_req.business_key not in allowed_keys:
self.logger.warning(
f"[ChatBusinessHandler] Unauthorized access attempt with business_key: {chat_req.business_key}"
)
raise HTTPException(
status_code=403,
detail="Access denied: Invalid business_key. You do not have permission to use this service.",
)

try:

def generate_chat_response() -> Generator[str, None, None]:
"""Generate chat stream response as SSE stream."""
try:
if chat_req.need_search:
# Resolve readable cube IDs (for search)
readable_cube_ids = chat_req.readable_cube_ids or (
[chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id]
)

search_req = APISearchRequest(
query=chat_req.query,
user_id=chat_req.user_id,
readable_cube_ids=readable_cube_ids,
mode=chat_req.mode,
internet_search=chat_req.internet_search,
top_k=chat_req.top_k,
chat_history=chat_req.history,
session_id=chat_req.session_id,
include_preference=chat_req.include_preference,
pref_top_k=chat_req.pref_top_k,
filter=chat_req.filter,
)

search_response = self.search_handler.handle_search_memories(search_req)

# Extract memories from search results
memories_list = []
if search_response.data and search_response.data.get("text_mem"):
text_mem_results = search_response.data["text_mem"]
if text_mem_results and text_mem_results[0].get("memories"):
memories_list = text_mem_results[0]["memories"]

# Drop internet memories forced
memories_list = [
mem
for mem in memories_list
if mem.get("metadata", {}).get("memory_type") != "OuterMemory"
]

# Filter memories by threshold
filtered_memories = self._filter_memories_by_threshold(memories_list)

# Step 2: Build system prompt with memories
system_prompt = self._build_system_prompt(
query=chat_req.query,
memories=filtered_memories,
pref_string=search_response.data.get("pref_string", ""),
base_prompt=chat_req.system_prompt,
)

self.logger.info(
f"[ChatBusinessHandler] chat stream user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, "
f"current_system_prompt: {system_prompt}"
)
else:
system_prompt = self._build_system_prompt(
query=chat_req.query,
memories=None,
pref_string=None,
base_prompt=chat_req.system_prompt,
)

# Prepare messages
history_info = chat_req.history[-20:] if chat_req.history else []
current_messages = [
{"role": "system", "content": system_prompt},
*history_info,
{"role": "user", "content": chat_req.query},
]

# Step 3: Generate streaming response from LLM
if (
chat_req.model_name_or_path
and chat_req.model_name_or_path not in self.chat_llms
):
raise HTTPException(
status_code=400,
detail=f"Model {chat_req.model_name_or_path} not suport, choose from {list(self.chat_llms.keys())}",
)

model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys()))
self.logger.info(f"[ChatBusinessHandler] Chat Stream Model: {model}")

start = time.time()
response_stream = self.chat_llms[model].generate_stream(
current_messages, model_name_or_path=model
)

# Stream the response
buffer = ""
full_response = ""
in_think = False

for chunk in response_stream:
if chunk == "<think>":
in_think = True
continue
if chunk == "</think>":
in_think = False
continue

if in_think:
chunk_data = f"data: {json.dumps({'type': 'reasoning', 'data': chunk}, ensure_ascii=False)}\n\n"
yield chunk_data
continue

buffer += chunk
full_response += chunk

chunk_data = f"data: {json.dumps({'type': 'text', 'data': chunk}, ensure_ascii=False)}\n\n"
yield chunk_data

end = time.time()
self.logger.info(
f"[ChatBusinessHandler] Chat Stream Time: {end - start} seconds"
)

self.logger.info(
f"[ChatBusinessHandler] Chat Stream LLM Input: {json.dumps(current_messages, ensure_ascii=False)} Chat Stream LLM Response: {full_response}"
)

current_messages.append({"role": "assistant", "content": full_response})
if chat_req.add_message_on_answer:
# Resolve writable cube IDs (for add)
writable_cube_ids = chat_req.writable_cube_ids or (
[chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id]
)
start = time.time()
self._start_add_to_memory(
user_id=chat_req.user_id,
writable_cube_ids=writable_cube_ids,
session_id=chat_req.session_id or "default_session",
query=chat_req.query,
full_response=full_response,
async_mode="async",
)
end = time.time()
self.logger.info(
f"[ChatBusinessHandler] Chat Stream Add Time: {end - start} seconds"
)
except Exception as e:
self.logger.error(
f"[ChatBusinessHandler] Error in chat stream: {e}", exc_info=True
)
error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n"
yield error_data

return StreamingResponse(
generate_chat_response(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Methods": "*",
},
)

except ValueError as err:
raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
except Exception as err:
self.logger.error(
f"[ChatBusinessHandler] Failed to start chat stream: {traceback.format_exc()}"
)
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err

def _dedup_and_supplement_memories(
self, first_filtered_memories: list, second_filtered_memories: list
) -> list:
Expand Down
5 changes: 3 additions & 2 deletions src/memos/api/handlers/memory_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def handle_get_subgraph(
query: str,
top_k: int,
naive_mem_cube: Any,
search_type: Literal["embedding", "fulltext"],
) -> MemoryResponse:
"""
Main handler for getting memory subgraph based on query.
Expand All @@ -128,7 +129,7 @@ def handle_get_subgraph(
try:
# Get relevant subgraph from text memory
memories = naive_mem_cube.text_mem.get_relevant_subgraph(
query, top_k=top_k, user_name=mem_cube_id
query, top_k=top_k, user_name=mem_cube_id, search_type=search_type
)

# Format and convert to tree structure
Expand All @@ -139,7 +140,7 @@ def handle_get_subgraph(
"UserMemory": 0.40,
}
tree_result, node_type_count = convert_graph_to_tree_forworkmem(
memories_cleaned, target_node_count=150, type_ratios=custom_type_ratios
memories_cleaned, target_node_count=200, type_ratios=custom_type_ratios
)
# Ensure all node IDs are unique in the tree structure
tree_result = ensure_unique_tree_ids(tree_result)
Expand Down
31 changes: 31 additions & 0 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class GetMemoryPlaygroundRequest(BaseRequest):
)
mem_cube_ids: list[str] | None = Field(None, description="Cube IDs")
search_query: str | None = Field(None, description="Search query")
search_type: Literal["embedding", "fulltext"] = Field("fulltext", description="Search type")


# Start API Models
Expand Down Expand Up @@ -167,6 +168,13 @@ class ChatPlaygroundRequest(ChatRequest):
)


class ChatBusinessRequest(ChatRequest):
"""Request model for chat operations for business user."""

business_key: str = Field(..., description="Business User Key")
need_search: bool = Field(False, description="Whether to need search before chat")


class ChatCompleteRequest(BaseRequest):
"""Request model for chat operations. will (Deprecated), instead use APIChatCompleteRequest."""

Expand Down Expand Up @@ -1217,3 +1225,26 @@ class ExistMemCubeIdRequest(BaseRequest):

class ExistMemCubeIdResponse(BaseResponse[dict[str, bool]]):
"""Response model for checking if mem cube id exists."""


class DeleteMemoryByRecordIdRequest(BaseRequest):
"""Request model for deleting memory by record id."""

mem_cube_id: str = Field(..., description="Mem cube ID")
record_id: str = Field(..., description="Record ID")
hard_delete: bool = Field(False, description="Hard delete")


class DeleteMemoryByRecordIdResponse(BaseResponse[dict]):
"""Response model for deleting memory by record id."""


class RecoverMemoryByRecordIdRequest(BaseRequest):
"""Request model for recovering memory by record id."""

mem_cube_id: str = Field(..., description="Mem cube ID")
delete_record_id: str = Field(..., description="Delete record ID")


class RecoverMemoryByRecordIdResponse(BaseResponse[dict]):
"""Response model for recovering memory by record id."""
60 changes: 58 additions & 2 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@
APIChatCompleteRequest,
APIFeedbackRequest,
APISearchRequest,
ChatBusinessRequest,
ChatPlaygroundRequest,
ChatRequest,
DeleteMemoryByRecordIdRequest,
DeleteMemoryByRecordIdResponse,
DeleteMemoryRequest,
DeleteMemoryResponse,
ExistMemCubeIdRequest,
Expand All @@ -41,6 +44,8 @@
GetUserNamesByMemoryIdsRequest,
GetUserNamesByMemoryIdsResponse,
MemoryResponse,
RecoverMemoryByRecordIdRequest,
RecoverMemoryByRecordIdResponse,
SearchResponse,
StatusResponse,
SuggestionRequest,
Expand Down Expand Up @@ -290,8 +295,9 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest):
memory_req.mem_cube_ids[0] if memory_req.mem_cube_ids else memory_req.user_id
),
query=memory_req.search_query,
top_k=20,
top_k=200,
naive_mem_cube=naive_mem_cube,
search_type=memory_req.search_type,
)
else:
return handlers.memory_handler.handle_get_all_memories(
Expand Down Expand Up @@ -394,9 +400,59 @@ def get_user_names_by_memory_ids(request: GetUserNamesByMemoryIdsRequest):
response_model=ExistMemCubeIdResponse,
)
def exist_mem_cube_id(request: ExistMemCubeIdRequest):
"""Check if mem cube id exists."""
"""(inner) Check if mem cube id exists."""
return ExistMemCubeIdResponse(
code=200,
message="Successfully",
data=graph_db.exist_user_name(user_name=request.mem_cube_id),
)


@router.post("/chat/stream/business_user", summary="Chat with MemOS for business user")
def chat_stream_business_user(chat_req: ChatBusinessRequest):
"""(inner) Chat with MemOS for a specific business user. Returns SSE stream."""
if chat_handler is None:
raise HTTPException(
status_code=503, detail="Chat service is not available. Chat handler not initialized."
)

return chat_handler.handle_chat_stream_for_business_user(chat_req)


@router.post(
"/delete_memory_by_record_id",
summary="Delete memory by record id",
response_model=DeleteMemoryByRecordIdResponse,
)
def delete_memory_by_record_id(memory_req: DeleteMemoryByRecordIdRequest):
"""(inner) Delete memory nodes by mem_cube_id (user_name) and delete_record_id. Record id is inner field, just for delete and recover memory, not for user to set."""
graph_db.delete_node_by_mem_cube_id(
mem_cube_id=memory_req.mem_cube_id,
delete_record_id=memory_req.record_id,
hard_delete=memory_req.hard_delete,
)

return DeleteMemoryByRecordIdResponse(
code=200,
message="Called Successfully",
data={"status": "success"},
)


@router.post(
"/recover_memory_by_record_id",
summary="Recover memory by record id",
response_model=RecoverMemoryByRecordIdResponse,
)
def recover_memory_by_record_id(memory_req: RecoverMemoryByRecordIdRequest):
"""(inner) Recover memory nodes by mem_cube_id (user_name) and delete_record_id. Record id is inner field, just for delete and recover memory, not for user to set."""
graph_db.recover_memory_by_mem_cube_id(
mem_cube_id=memory_req.mem_cube_id,
delete_record_id=memory_req.delete_record_id,
)

return RecoverMemoryByRecordIdResponse(
code=200,
message="Called Successfully",
data={"status": "success"},
)
Loading