diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index a35caf4d9..54f795696 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -9,7 +9,6 @@ from memos.api.handlers.formatters_handler import ( format_memory_item, post_process_pref_mem, - post_process_textual_mem, ) from memos.api.product_models import ( DeleteMemoryRequest, @@ -251,22 +250,68 @@ def handle_get_memories( get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube ) -> GetMemoryResponse: results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": [], "skill_mem": []} - memories = naive_mem_cube.text_mem.get_all( + text_memory_type = ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"] + text_memories_info = naive_mem_cube.text_mem.get_all( user_name=get_mem_req.mem_cube_id, user_id=get_mem_req.user_id, page=get_mem_req.page, page_size=get_mem_req.page_size, filter=get_mem_req.filter, - )["nodes"] + memory_type=text_memory_type, + ) + text_memories, total_text_nodes = text_memories_info["nodes"], text_memories_info["total_nodes"] + results["text_mem"] = [ + { + "cube_id": get_mem_req.mem_cube_id, + "memories": text_memories, + "total_nodes": total_text_nodes, + } + ] - results = post_process_textual_mem(results, memories, get_mem_req.mem_cube_id) + if get_mem_req.include_tool_memory: + tool_memories_info = naive_mem_cube.text_mem.get_all( + user_name=get_mem_req.mem_cube_id, + user_id=get_mem_req.user_id, + page=get_mem_req.page, + page_size=get_mem_req.page_size, + filter=get_mem_req.filter, + memory_type=["ToolSchemaMemory", "ToolTrajectoryMemory"], + ) + tool_memories, total_tool_nodes = ( + tool_memories_info["nodes"], + tool_memories_info["total_nodes"], + ) - if not get_mem_req.include_tool_memory: - results["tool_mem"] = [] - if not get_mem_req.include_skill_memory: - results["skill_mem"] = [] + results["tool_mem"] = [ + { + "cube_id": get_mem_req.mem_cube_id, + "memories": tool_memories, + "total_nodes": total_tool_nodes, + } + ] + if get_mem_req.include_skill_memory: + skill_memories_info = naive_mem_cube.text_mem.get_all( + user_name=get_mem_req.mem_cube_id, + user_id=get_mem_req.user_id, + page=get_mem_req.page, + page_size=get_mem_req.page_size, + filter=get_mem_req.filter, + memory_type=["SkillMemory"], + ) + skill_memories, total_skill_nodes = ( + skill_memories_info["nodes"], + skill_memories_info["total_nodes"], + ) + results["skill_mem"] = [ + { + "cube_id": get_mem_req.mem_cube_id, + "memories": skill_memories, + "total_nodes": total_skill_nodes, + } + ] preferences: list[TextualMemoryItem] = [] + total_preference_nodes = 0 format_preferences = [] if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: @@ -295,7 +340,7 @@ def handle_get_memories( filter_params.update(filter_copy) - preferences, _ = naive_mem_cube.pref_mem.get_memory_by_filter( + preferences, total_preference_nodes = naive_mem_cube.pref_mem.get_memory_by_filter( filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size ) format_preferences = [format_memory_item(item, save_sources=False) for item in preferences] @@ -303,6 +348,8 @@ def handle_get_memories( results = post_process_pref_mem( results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference ) + if total_preference_nodes > 0 and results.get("pref_mem", []): + results["pref_mem"][0]["total_nodes"] = total_preference_nodes # Filter to only keep text_mem, pref_mem, tool_mem filtered_results = { @@ -365,22 +412,68 @@ def handle_get_memories_dashboard( get_mem_req: GetMemoryDashboardRequest, naive_mem_cube: NaiveMemCube ) -> GetMemoryResponse: results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": [], "skill_mem": []} - memories = naive_mem_cube.text_mem.get_all( + text_memory_type = ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"] + text_memories_info = naive_mem_cube.text_mem.get_all( user_name=get_mem_req.mem_cube_id, user_id=get_mem_req.user_id, page=get_mem_req.page, page_size=get_mem_req.page_size, filter=get_mem_req.filter, - )["nodes"] + memory_type=text_memory_type, + ) + text_memories, total_text_nodes = text_memories_info["nodes"], text_memories_info["total_nodes"] + results["text_mem"] = [ + { + "cube_id": get_mem_req.mem_cube_id, + "memories": text_memories, + "total_nodes": total_text_nodes, + } + ] - results = post_process_textual_mem(results, memories, get_mem_req.mem_cube_id) + if get_mem_req.include_tool_memory: + tool_memories_info = naive_mem_cube.text_mem.get_all( + user_name=get_mem_req.mem_cube_id, + user_id=get_mem_req.user_id, + page=get_mem_req.page, + page_size=get_mem_req.page_size, + filter=get_mem_req.filter, + memory_type=["ToolSchemaMemory", "ToolTrajectoryMemory"], + ) + tool_memories, total_tool_nodes = ( + tool_memories_info["nodes"], + tool_memories_info["total_nodes"], + ) - if not get_mem_req.include_tool_memory: - results["tool_mem"] = [] - if not get_mem_req.include_skill_memory: - results["skill_mem"] = [] + results["tool_mem"] = [ + { + "cube_id": get_mem_req.mem_cube_id, + "memories": tool_memories, + "total_nodes": total_tool_nodes, + } + ] + if get_mem_req.include_skill_memory: + skill_memories_info = naive_mem_cube.text_mem.get_all( + user_name=get_mem_req.mem_cube_id, + user_id=get_mem_req.user_id, + page=get_mem_req.page, + page_size=get_mem_req.page_size, + filter=get_mem_req.filter, + memory_type=["SkillMemory"], + ) + skill_memories, total_skill_nodes = ( + skill_memories_info["nodes"], + skill_memories_info["total_nodes"], + ) + results["skill_mem"] = [ + { + "cube_id": get_mem_req.mem_cube_id, + "memories": skill_memories, + "total_nodes": total_skill_nodes, + } + ] preferences: list[TextualMemoryItem] = [] + total_preference_nodes = 0 format_preferences = [] if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: @@ -409,7 +502,7 @@ def handle_get_memories_dashboard( filter_params.update(filter_copy) - preferences, _ = naive_mem_cube.pref_mem.get_memory_by_filter( + preferences, total_preference_nodes = naive_mem_cube.pref_mem.get_memory_by_filter( filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size ) format_preferences = [format_memory_item(item, save_sources=False) for item in preferences] @@ -417,6 +510,8 @@ def handle_get_memories_dashboard( results = post_process_pref_mem( results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference ) + if total_preference_nodes > 0 and results.get("pref_mem", []): + results["pref_mem"][0]["total_nodes"] = total_preference_nodes # Filter to only keep text_mem, pref_mem, tool_mem filtered_results = { diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 72e43d15f..90326a044 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -368,13 +368,19 @@ def get_all( page: int | None = None, page_size: int | None = None, filter: dict | None = None, + memory_type: list[str] | None = None, ) -> dict: """Get all memories. Returns: list[TextualMemoryItem]: List of all memories. """ graph_output = self.graph_store.export_graph( - user_name=user_name, user_id=user_id, page=page, page_size=page_size, filter=filter + user_name=user_name, + user_id=user_id, + page=page, + page_size=page_size, + filter=filter, + memory_type=memory_type, ) return graph_output