diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 58a271f34..fa57dd9eb 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -210,6 +210,8 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An query=chat_req.query, full_response=response, async_mode="async", + manager_user_id=chat_req.manager_user_id, + project_id=chat_req.project_id, ) end = time.time() self.logger.info(f"[Cloud Service] Chat Add Time: {end - start} seconds") @@ -382,6 +384,8 @@ def generate_chat_response() -> Generator[str, None, None]: query=chat_req.query, full_response=full_response, async_mode="async", + manager_user_id=chat_req.manager_user_id, + project_id=chat_req.project_id, ) end = time.time() self.logger.info( @@ -563,6 +567,8 @@ def generate_chat_response() -> Generator[str, None, None]: query=chat_req.query, full_response=None, async_mode="sync", + manager_user_id=chat_req.manager_user_id, + project_id=chat_req.project_id, ) # Extract memories from search results (second search) @@ -731,6 +737,8 @@ def generate_chat_response() -> Generator[str, None, None]: query=chat_req.query, full_response=full_response, async_mode="sync", + manager_user_id=chat_req.manager_user_id, + project_id=chat_req.project_id, ) except Exception as e: @@ -917,6 +925,8 @@ def generate_chat_response() -> Generator[str, None, None]: query=chat_req.query, full_response=full_response, async_mode="async", + manager_user_id=chat_req.manager_user_id, + project_id=chat_req.project_id, ) end = time.time() self.logger.info( @@ -1309,6 +1319,8 @@ async def _add_conversation_to_memory( writable_cube_ids: list[str], session_id: str, query: str, + manager_user_id: str | None = None, + project_id: str | None = None, clean_response: str | None = None, async_mode: Literal["async", "sync"] = "sync", ) -> None: @@ -1333,6 +1345,8 @@ async def _add_conversation_to_memory( session_id=session_id, messages=messages, async_mode=async_mode, + manager_user_id=manager_user_id, + project_id=project_id, ) self.add_handler.handle_add_memories(add_req) @@ -1540,6 +1554,8 @@ def _start_add_to_memory( query: str, full_response: str | None = None, async_mode: Literal["async", "sync"] = "sync", + manager_user_id: str | None = None, + project_id: str | None = None, ) -> None: def run_async_in_thread(): try: @@ -1557,6 +1573,8 @@ def run_async_in_thread(): query=query, clean_response=clean_response, async_mode=async_mode, + manager_user_id=manager_user_id, + project_id=project_id, ) ) finally: @@ -1580,6 +1598,8 @@ def run_async_in_thread(): query=query, clean_response=clean_response, async_mode=async_mode, + manager_user_id=manager_user_id, + project_id=project_id, ) ) task.add_done_callback( diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 54f795696..a3430d475 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -421,13 +421,27 @@ def handle_get_memories_dashboard( filter=get_mem_req.filter, memory_type=text_memory_type, ) - text_memories, total_text_nodes = text_memories_info["nodes"], text_memories_info["total_nodes"] + text_memories, _ = text_memories_info["nodes"], text_memories_info["total_nodes"] + + # Group text memories by cube_id from metadata.user_name + text_mem_by_cube: dict[str, list] = {} + for memory in text_memories: + cube_id = memory.get("metadata", {}).get("user_name", get_mem_req.mem_cube_id) + if cube_id not in text_mem_by_cube: + text_mem_by_cube[cube_id] = [] + text_mem_by_cube[cube_id].append(memory) + + # If no memories found, create a default entry with the requested cube_id + if not text_mem_by_cube and get_mem_req.mem_cube_id: + text_mem_by_cube[get_mem_req.mem_cube_id] = [] + results["text_mem"] = [ { - "cube_id": get_mem_req.mem_cube_id, - "memories": text_memories, - "total_nodes": total_text_nodes, + "cube_id": cube_id, + "memories": memories, + "total_nodes": len(memories), } + for cube_id, memories in text_mem_by_cube.items() ] if get_mem_req.include_tool_memory: @@ -439,18 +453,32 @@ def handle_get_memories_dashboard( filter=get_mem_req.filter, memory_type=["ToolSchemaMemory", "ToolTrajectoryMemory"], ) - tool_memories, total_tool_nodes = ( + tool_memories, _ = ( tool_memories_info["nodes"], tool_memories_info["total_nodes"], ) + # Group tool memories by cube_id from metadata.user_name + tool_mem_by_cube: dict[str, list] = {} + for memory in tool_memories: + cube_id = memory.get("metadata", {}).get("user_name", get_mem_req.mem_cube_id) + if cube_id not in tool_mem_by_cube: + tool_mem_by_cube[cube_id] = [] + tool_mem_by_cube[cube_id].append(memory) + + # If no memories found, create a default entry with the requested cube_id + if not tool_mem_by_cube and get_mem_req.mem_cube_id: + tool_mem_by_cube[get_mem_req.mem_cube_id] = [] + results["tool_mem"] = [ { - "cube_id": get_mem_req.mem_cube_id, - "memories": tool_memories, - "total_nodes": total_tool_nodes, + "cube_id": cube_id, + "memories": memories, + "total_nodes": len(memories), } + for cube_id, memories in tool_mem_by_cube.items() ] + if get_mem_req.include_skill_memory: skill_memories_info = naive_mem_cube.text_mem.get_all( user_name=get_mem_req.mem_cube_id, @@ -460,18 +488,32 @@ def handle_get_memories_dashboard( filter=get_mem_req.filter, memory_type=["SkillMemory"], ) - skill_memories, total_skill_nodes = ( + skill_memories, _ = ( skill_memories_info["nodes"], skill_memories_info["total_nodes"], ) + # Group skill memories by cube_id from metadata.user_name + skill_mem_by_cube: dict[str, list] = {} + for memory in skill_memories: + cube_id = memory.get("metadata", {}).get("user_name", get_mem_req.mem_cube_id) + if cube_id not in skill_mem_by_cube: + skill_mem_by_cube[cube_id] = [] + skill_mem_by_cube[cube_id].append(memory) + + # If no memories found, create a default entry with the requested cube_id + if not skill_mem_by_cube and get_mem_req.mem_cube_id: + skill_mem_by_cube[get_mem_req.mem_cube_id] = [] + results["skill_mem"] = [ { - "cube_id": get_mem_req.mem_cube_id, - "memories": skill_memories, - "total_nodes": total_skill_nodes, + "cube_id": cube_id, + "memories": memories, + "total_nodes": len(memories), } + for cube_id, memories in skill_mem_by_cube.items() ] + preferences: list[TextualMemoryItem] = [] total_preference_nodes = 0 @@ -507,13 +549,28 @@ def handle_get_memories_dashboard( ) format_preferences = [format_memory_item(item, save_sources=False) for item in preferences] - results = post_process_pref_mem( - results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference - ) - if total_preference_nodes > 0 and results.get("pref_mem", []): - results["pref_mem"][0]["total_nodes"] = total_preference_nodes + # Group preferences by cube_id from metadata.mem_cube_id + pref_mem_by_cube: dict[str, list] = {} + for pref in format_preferences: + cube_id = pref.get("metadata", {}).get("mem_cube_id", get_mem_req.mem_cube_id) + if cube_id not in pref_mem_by_cube: + pref_mem_by_cube[cube_id] = [] + pref_mem_by_cube[cube_id].append(pref) - # Filter to only keep text_mem, pref_mem, tool_mem + # If no preferences found, create a default entry with the requested cube_id + if not pref_mem_by_cube and get_mem_req.mem_cube_id: + pref_mem_by_cube[get_mem_req.mem_cube_id] = [] + + results["pref_mem"] = [ + { + "cube_id": cube_id, + "memories": memories, + "total_nodes": len(memories), + } + for cube_id, memories in pref_mem_by_cube.items() + ] + + # Filter to only keep text_mem, pref_mem, tool_mem, skill_mem filtered_results = { "text_mem": results.get("text_mem", []), "pref_mem": results.get("pref_mem", []), diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index feaf55680..d056bca6a 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -96,6 +96,8 @@ class ChatRequest(BaseRequest): temperature: float | None = Field(None, description="Temperature for sampling") top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") + manager_user_id: str | None = Field(None, description="Manager User ID") + project_id: str | None = Field(None, description="Project ID") # ==== Filter conditions ==== filter: dict[str, Any] | None = Field( @@ -771,6 +773,8 @@ class APIChatCompleteRequest(BaseRequest): temperature: float | None = Field(None, description="Temperature for sampling") top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") + manager_user_id: str | None = Field(None, description="Manager User ID") + project_id: str | None = Field(None, description="Project ID") # ==== Filter conditions ==== filter: dict[str, Any] | None = Field( diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 29dad54f6..f0a23e39b 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -5650,7 +5650,7 @@ def recover_memory_by_mem_cube_id( SET properties = ( properties::jsonb || %s::jsonb )::text::agtype, - deletetime = NULL + deletetime = NULL WHERE {where_clause} """