From 73670c7cf5f1558f00543dd0d15793cb0bad3593 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 25 Dec 2025 19:57:36 +0800 Subject: [PATCH] add dedup to playground tree display --- src/memos/api/handlers/memory_handler.py | 13 +++++++++++++ src/memos/api/routers/server_router.py | 2 ++ 2 files changed, 15 insertions(+) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index a33ee9254..5cfa98160 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -23,6 +23,10 @@ remove_embedding_recursive, sort_children_by_memory_type, ) +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + cosine_similarity_matrix, + find_best_unrelated_subgroup, +) if TYPE_CHECKING: @@ -37,6 +41,7 @@ def handle_get_all_memories( mem_cube_id: str, memory_type: Literal["text_mem", "act_mem", "param_mem", "para_mem"], naive_mem_cube: Any, + embedder: Any, ) -> MemoryResponse: """ Main handler for getting all memories. @@ -59,6 +64,14 @@ def handle_get_all_memories( # Get all text memories from the graph database memories = naive_mem_cube.text_mem.get_all(user_name=mem_cube_id) + mems = [mem.get("memory", "") for mem in memories.get("nodes", [])] + embeddings = embedder.embed(mems) + similarity_matrix = cosine_similarity_matrix(embeddings) + selected_indices, _ = find_best_unrelated_subgroup( + embeddings, similarity_matrix, bar=0.9 + ) + memories["nodes"] = [memories["nodes"][i] for i in selected_indices] + # Format and convert to tree structure memories_cleaned = remove_embedding_recursive(memories) custom_type_ratios = { diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 37ca361ea..e87e006dd 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -83,6 +83,7 @@ naive_mem_cube = components["naive_mem_cube"] redis_client = components["redis_client"] status_tracker = TaskStatusTracker(redis_client=redis_client) +embedder = components["embedder"] # ============================================================================= @@ -294,6 +295,7 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest): ), memory_type=memory_req.memory_type or "text_mem", naive_mem_cube=naive_mem_cube, + embedder=embedder, )