From ce297bae605ced0a7ce58ce1c5cf70f897f87c99 Mon Sep 17 00:00:00 2001 From: "glin1993@outlook.com" <> Date: Thu, 25 Dec 2025 17:24:12 +0800 Subject: [PATCH 1/3] Add dedup option to search pipeline --- src/memos/api/product_models.py | 9 +++ src/memos/api/start_api.py | 11 +++- src/memos/mem_os/core.py | 2 + .../mem_scheduler/optimized_scheduler.py | 1 + src/memos/memories/textual/tree.py | 2 + .../retrieve/advanced_searcher.py | 2 + .../tree_text_memory/retrieve/searcher.py | 56 +++++++++++++++---- src/memos/multi_mem_cube/single_cube.py | 25 ++++++++- 8 files changed, 96 insertions(+), 12 deletions(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index adcb68a96..1e9e7326c 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -319,6 +319,15 @@ class APISearchRequest(BaseRequest): description="Number of textual memories to retrieve (top-K). Default: 10.", ) + dedup: Literal["no", "sim"] | None = Field( + None, + description=( + "Optional dedup option for textual memories. " + "Use 'no' for no dedup, 'sim' for similarity dedup. " + "If None, default exact-text dedup is applied." + ), + ) + pref_top_k: int = Field( 6, ge=0, diff --git a/src/memos/api/start_api.py b/src/memos/api/start_api.py index cbcdf6ce2..15145664d 100644 --- a/src/memos/api/start_api.py +++ b/src/memos/api/start_api.py @@ -1,7 +1,7 @@ import logging import os -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Literal, TypeVar from dotenv import load_dotenv from fastapi import FastAPI @@ -145,6 +145,14 @@ class SearchRequest(BaseRequest): description="List of cube IDs to search in", json_schema_extra={"example": ["cube123", "cube456"]}, ) + dedup: Literal["no", "sim"] | None = Field( + None, + description=( + "Optional dedup option for textual memories. " + "Use 'no' for no dedup, 'sim' for similarity dedup. " + "If None, default exact-text dedup is applied." + ), + ) class MemCubeRegister(BaseRequest): @@ -349,6 +357,7 @@ async def search_memories(search_req: SearchRequest): query=search_req.query, user_id=search_req.user_id, install_cube_ids=search_req.install_cube_ids, + dedup=search_req.dedup, ) return SearchResponse(message="Search completed successfully", data=result) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 1a88fa831..30efa487a 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -551,6 +551,7 @@ def search( internet_search: bool = False, moscube: bool = False, session_id: str | None = None, + dedup: str | None = None, **kwargs, ) -> MOSSearchResult: """ @@ -625,6 +626,7 @@ def search_textual_memory(cube_id, cube): }, moscube=moscube, search_filter=search_filter, + dedup=dedup, ) search_time_end = time.time() logger.info( diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index c3f5891ae..7a4110b90 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -186,6 +186,7 @@ def mix_search_memories( info=info, search_tool_memory=search_req.search_tool_memory, tool_mem_top_k=search_req.tool_mem_top_k, + dedup=search_req.dedup, ) memories = merged_memories[: search_req.top_k] diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 22545496a..fb33a2d03 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -161,6 +161,7 @@ def search( user_name: str | None = None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + dedup: str | None = None, **kwargs, ) -> list[TextualMemoryItem]: """Search for memories based on a query. @@ -207,6 +208,7 @@ def search( user_name=user_name, search_tool_memory=search_tool_memory, tool_mem_top_k=tool_mem_top_k, + dedup=dedup, **kwargs, ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py index e58ebcdd1..64b216fcb 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py @@ -239,6 +239,7 @@ def deep_search( user_name: str | None = None, **kwargs, ): + dedup = kwargs.get("dedup") previous_retrieval_phrases = [query] retrieved_memories = self.retrieve( query=query, @@ -254,6 +255,7 @@ def deep_search( top_k=top_k, user_name=user_name, info=info, + dedup=dedup, ) if len(memories) == 0: logger.warning("Requirements not met; returning memories as-is.") diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index dc47dd4d7..b3f0c6e83 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -119,9 +119,15 @@ def post_retrieve( info=None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + dedup: str | None = None, plugin=False, ): - deduped = self._deduplicate_results(retrieved_results) + if dedup == "no": + deduped = retrieved_results + elif dedup == "sim": + deduped = self._deduplicate_similar_results(retrieved_results) + else: + deduped = self._deduplicate_results(retrieved_results) final_results = self._sort_and_trim( deduped, top_k, plugin, search_tool_memory, tool_mem_top_k ) @@ -141,6 +147,7 @@ def search( user_name: str | None = None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + dedup: str | None = None, **kwargs, ) -> list[TextualMemoryItem]: """ @@ -173,7 +180,11 @@ def search( if kwargs.get("plugin", False): logger.info(f"[SEARCH] Retrieve from plugin: {query}") retrieved_results = self._retrieve_simple( - query=query, top_k=top_k, search_filter=search_filter, user_name=user_name + query=query, + top_k=top_k, + search_filter=search_filter, + user_name=user_name, + dedup=dedup, ) else: retrieved_results = self.retrieve( @@ -202,6 +213,7 @@ def search( plugin=kwargs.get("plugin", False), search_tool_memory=search_tool_memory, tool_mem_top_k=tool_mem_top_k, + dedup=None if kwargs.get("plugin", False) and dedup == "sim" else dedup, ) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") @@ -291,6 +303,7 @@ def _retrieve_simple( top_k: int, search_filter: dict | None = None, user_name: str | None = None, + dedup: str | None = None, **kwargs, ): """Retrieve from by keywords and embedding""" @@ -710,14 +723,17 @@ def _retrieve_simple( user_name=user_name, ) logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") - documents = [getattr(item, "memory", "") for item in items] - documents_embeddings = self.embedder.embed(documents) - similarity_matrix = cosine_similarity_matrix(documents_embeddings) - selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) - selected_items = [items[i] for i in selected_indices] - logger.info( - f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" - ) + if dedup == "no": + selected_items = items + else: + documents = [getattr(item, "memory", "") for item in items] + documents_embeddings = self.embedder.embed(documents) + similarity_matrix = cosine_similarity_matrix(documents_embeddings) + selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) + selected_items = [items[i] for i in selected_indices] + logger.info( + f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" + ) return self.reranker.rerank( query=query, query_embedding=query_embeddings[0], @@ -734,6 +750,26 @@ def _deduplicate_results(self, results): deduped[item.memory] = (item, score) return list(deduped.values()) + @timed + def _deduplicate_similar_results( + self, results: list[tuple[TextualMemoryItem, float]], similarity_threshold: float = 0.85 + ): + """Deduplicate results by semantic similarity while keeping higher scores.""" + if len(results) <= 1: + return results + + sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True) + documents = [getattr(item, "memory", "") for item, _ in sorted_results] + embeddings = self.embedder.embed(documents) + similarity_matrix = cosine_similarity_matrix(embeddings) + + selected_indices: list[int] = [] + for i in range(len(sorted_results)): + if all(similarity_matrix[i][j] <= similarity_threshold for j in selected_indices): + selected_indices.append(i) + + return [sorted_results[i] for i in selected_indices] + @timed def _sort_and_trim( self, results, top_k, plugin=False, search_tool_memory=False, tool_mem_top_k=6 diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index a920f7b0e..906416461 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -23,6 +23,9 @@ MEM_READ_TASK_LABEL, PREF_ADD_TASK_LABEL, ) +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + cosine_similarity_matrix, +) from memos.multi_mem_cube.views import MemCubeView from memos.templates.mem_reader_prompts import PROMPT_MAPPING from memos.types.general_types import ( @@ -263,6 +266,7 @@ def _deep_search( moscube=search_req.moscube, search_filter=search_filter, info=info, + dedup=search_req.dedup, ) formatted_memories = [format_memory_item(data) for data in enhanced_memories] return formatted_memories @@ -328,6 +332,7 @@ def _fine_search( top_k=search_req.top_k, user_name=user_context.mem_cube_id, info=info, + dedup=search_req.dedup, ) # Enhance with query @@ -378,7 +383,24 @@ def _dedup_by_content(memories: list) -> list: unique_memories.append(mem) return unique_memories - deduped_memories = _dedup_by_content(enhanced_memories) + def _dedup_by_similarity(memories: list) -> list: + if len(memories) <= 1: + return memories + documents = [getattr(mem, "memory", "") for mem in memories] + embeddings = self.searcher.embedder.embed(documents) + similarity_matrix = cosine_similarity_matrix(embeddings) + selected_indices = [] + for i in range(len(memories)): + if all(similarity_matrix[i][j] <= 0.85 for j in selected_indices): + selected_indices.append(i) + return [memories[i] for i in selected_indices] + + if search_req.dedup == "no": + deduped_memories = enhanced_memories + elif search_req.dedup == "sim": + deduped_memories = _dedup_by_similarity(enhanced_memories) + else: + deduped_memories = _dedup_by_content(enhanced_memories) formatted_memories = [format_memory_item(data) for data in deduped_memories] logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}") @@ -463,6 +485,7 @@ def _fast_search( plugin=plugin, search_tool_memory=search_req.search_tool_memory, tool_mem_top_k=search_req.tool_mem_top_k, + dedup=search_req.dedup, ) formatted_memories = [format_memory_item(data) for data in search_results] From 66db2e00bfcfe14bb634fe8bd17400555225cc4e Mon Sep 17 00:00:00 2001 From: "glin1993@outlook.com" <> Date: Thu, 25 Dec 2025 17:36:55 +0800 Subject: [PATCH 2/3] Fix dedup handling in simple search --- src/memos/api/handlers/formatters_handler.py | 5 +- src/memos/api/handlers/search_handler.py | 112 ++++++++++++++++++ src/memos/api/start_api.py | 11 +- src/memos/mem_os/core.py | 2 - .../mem_scheduler/optimized_scheduler.py | 10 +- src/memos/mem_scheduler/utils/api_utils.py | 5 +- .../retrieve/advanced_searcher.py | 2 - .../tree_text_memory/retrieve/searcher.py | 93 ++------------- src/memos/multi_mem_cube/single_cube.py | 45 +++---- 9 files changed, 156 insertions(+), 129 deletions(-) diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index 88875cacc..94988295b 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -29,7 +29,7 @@ def to_iter(running: Any) -> list[Any]: return list(running) if running else [] -def format_memory_item(memory_data: Any) -> dict[str, Any]: +def format_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]: """ Format a single memory item for API response. @@ -47,7 +47,8 @@ def format_memory_item(memory_data: Any) -> dict[str, Any]: ref_id = f"[{memory_id.split('-')[0]}]" memory["ref_id"] = ref_id - memory["metadata"]["embedding"] = [] + if not include_embedding: + memory["metadata"]["embedding"] = [] memory["metadata"]["sources"] = [] memory["metadata"]["usage"] = [] memory["metadata"]["ref_id"] = ref_id diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index f7d6ee2c8..4aa993a06 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -5,9 +5,14 @@ using dependency injection for better modularity and testability. """ +from typing import Any + from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies from memos.api.product_models import APISearchRequest, SearchResponse from memos.log import get_logger +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + cosine_similarity_matrix, +) from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView @@ -53,6 +58,9 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse cube_view = self._build_cube_view(search_req) results = cube_view.search_memories(search_req) + if search_req.dedup == "sim": + results = self._dedup_text_memories(results, search_req.top_k) + self._strip_embeddings(results) self.logger.info( f"[SearchHandler] Final search results: count={len(results)} results={results}" @@ -63,6 +71,110 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse data=results, ) + def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> dict[str, Any]: + buckets = results.get("text_mem", []) + if not buckets: + return results + + flat: list[tuple[int, dict[str, Any], float]] = [] + for bucket_idx, bucket in enumerate(buckets): + for mem in bucket.get("memories", []): + score = mem.get("metadata", {}).get("relativity", 0.0) + flat.append((bucket_idx, mem, score)) + + if len(flat) <= 1: + return results + + embeddings = self._extract_embeddings([mem for _, mem, _ in flat]) + if embeddings is None: + documents = [mem.get("memory", "") for _, mem, _ in flat] + embeddings = self.searcher.embedder.embed(documents) + + similarity_matrix = cosine_similarity_matrix(embeddings) + + indices_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))} + for flat_index, (bucket_idx, _, _) in enumerate(flat): + indices_by_bucket[bucket_idx].append(flat_index) + + selected_global: list[int] = [] + selected_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))} + + ordered_indices = sorted(range(len(flat)), key=lambda idx: flat[idx][2], reverse=True) + for idx in ordered_indices: + bucket_idx = flat[idx][0] + if len(selected_by_bucket[bucket_idx]) >= target_top_k: + continue + if self._is_unrelated(idx, selected_global, similarity_matrix, 0.85): + selected_by_bucket[bucket_idx].append(idx) + selected_global.append(idx) + + for bucket_idx in range(len(buckets)): + if len(selected_by_bucket[bucket_idx]) >= min( + target_top_k, len(indices_by_bucket[bucket_idx]) + ): + continue + remaining_indices = [ + idx + for idx in indices_by_bucket.get(bucket_idx, []) + if idx not in selected_by_bucket[bucket_idx] + ] + if not remaining_indices: + continue + # Fill to target_top_k with the least-similar candidates to preserve diversity. + remaining_indices.sort( + key=lambda idx: self._max_similarity(idx, selected_global, similarity_matrix) + ) + for idx in remaining_indices: + if len(selected_by_bucket[bucket_idx]) >= target_top_k: + break + selected_by_bucket[bucket_idx].append(idx) + selected_global.append(idx) + + for bucket_idx, bucket in enumerate(buckets): + selected_indices = selected_by_bucket.get(bucket_idx, []) + bucket["memories"] = [flat[i][1] for i in selected_indices[:target_top_k]] + return results + + @staticmethod + def _is_unrelated( + index: int, + selected_indices: list[int], + similarity_matrix: list[list[float]], + similarity_threshold: float, + ) -> bool: + return all(similarity_matrix[index][j] <= similarity_threshold for j in selected_indices) + + @staticmethod + def _max_similarity( + index: int, selected_indices: list[int], similarity_matrix: list[list[float]] + ) -> float: + if not selected_indices: + return 0.0 + return max(similarity_matrix[index][j] for j in selected_indices) + + @staticmethod + def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None: + embeddings: list[list[float]] = [] + for mem in memories: + embedding = mem.get("metadata", {}).get("embedding") + if not embedding: + return None + embeddings.append(embedding) + return embeddings + + @staticmethod + def _strip_embeddings(results: dict[str, Any]) -> None: + for bucket in results.get("text_mem", []): + for mem in bucket.get("memories", []): + metadata = mem.get("metadata", {}) + if "embedding" in metadata: + metadata["embedding"] = [] + for bucket in results.get("tool_mem", []): + for mem in bucket.get("memories", []): + metadata = mem.get("metadata", {}) + if "embedding" in metadata: + metadata["embedding"] = [] + def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]: """ Normalize target cube ids from search_req. diff --git a/src/memos/api/start_api.py b/src/memos/api/start_api.py index 15145664d..cbcdf6ce2 100644 --- a/src/memos/api/start_api.py +++ b/src/memos/api/start_api.py @@ -1,7 +1,7 @@ import logging import os -from typing import Any, Generic, Literal, TypeVar +from typing import Any, Generic, TypeVar from dotenv import load_dotenv from fastapi import FastAPI @@ -145,14 +145,6 @@ class SearchRequest(BaseRequest): description="List of cube IDs to search in", json_schema_extra={"example": ["cube123", "cube456"]}, ) - dedup: Literal["no", "sim"] | None = Field( - None, - description=( - "Optional dedup option for textual memories. " - "Use 'no' for no dedup, 'sim' for similarity dedup. " - "If None, default exact-text dedup is applied." - ), - ) class MemCubeRegister(BaseRequest): @@ -357,7 +349,6 @@ async def search_memories(search_req: SearchRequest): query=search_req.query, user_id=search_req.user_id, install_cube_ids=search_req.install_cube_ids, - dedup=search_req.dedup, ) return SearchResponse(message="Search completed successfully", data=result) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 30efa487a..1a88fa831 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -551,7 +551,6 @@ def search( internet_search: bool = False, moscube: bool = False, session_id: str | None = None, - dedup: str | None = None, **kwargs, ) -> MOSSearchResult: """ @@ -626,7 +625,6 @@ def search_textual_memory(cube_id, cube): }, moscube=moscube, search_filter=search_filter, - dedup=dedup, ) search_time_end = time.time() logger.info( diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index 7a4110b90..7007f8418 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -190,7 +190,10 @@ def mix_search_memories( ) memories = merged_memories[: search_req.top_k] - formatted_memories = [format_textual_memory_item(item) for item in memories] + formatted_memories = [ + format_textual_memory_item(item, include_embedding=search_req.dedup == "sim") + for item in memories + ] self.submit_memory_history_async_task( search_req=search_req, user_context=user_context, @@ -234,7 +237,10 @@ def update_search_memories_to_redis( mem_cube=self.mem_cube, mode=SearchMode.FAST, ) - formatted_memories = [format_textual_memory_item(data) for data in memories] + formatted_memories = [ + format_textual_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in memories + ] else: memories = [ TextualMemoryItem.from_dict(one) for one in memories_to_store["memories"] diff --git a/src/memos/mem_scheduler/utils/api_utils.py b/src/memos/mem_scheduler/utils/api_utils.py index c8d096517..3833b5926 100644 --- a/src/memos/mem_scheduler/utils/api_utils.py +++ b/src/memos/mem_scheduler/utils/api_utils.py @@ -6,14 +6,15 @@ from memos.memories.textual.tree import TextualMemoryItem -def format_textual_memory_item(memory_data: Any) -> dict[str, Any]: +def format_textual_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]: """Format a single memory item for API response.""" memory = memory_data.model_dump() memory_id = memory["id"] ref_id = f"[{memory_id.split('-')[0]}]" memory["ref_id"] = ref_id - memory["metadata"]["embedding"] = [] + if not include_embedding: + memory["metadata"]["embedding"] = [] memory["metadata"]["sources"] = [] memory["metadata"]["ref_id"] = ref_id memory["metadata"]["id"] = memory_id diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py index 64b216fcb..e58ebcdd1 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py @@ -239,7 +239,6 @@ def deep_search( user_name: str | None = None, **kwargs, ): - dedup = kwargs.get("dedup") previous_retrieval_phrases = [query] retrieved_memories = self.retrieve( query=query, @@ -255,7 +254,6 @@ def deep_search( top_k=top_k, user_name=user_name, info=info, - dedup=dedup, ) if len(memories) == 0: logger.warning("Requirements not met; returning memories as-is.") diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index b3f0c6e83..f3d6ba037 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -124,8 +124,6 @@ def post_retrieve( ): if dedup == "no": deduped = retrieved_results - elif dedup == "sim": - deduped = self._deduplicate_similar_results(retrieved_results) else: deduped = self._deduplicate_results(retrieved_results) final_results = self._sort_and_trim( @@ -180,11 +178,7 @@ def search( if kwargs.get("plugin", False): logger.info(f"[SEARCH] Retrieve from plugin: {query}") retrieved_results = self._retrieve_simple( - query=query, - top_k=top_k, - search_filter=search_filter, - user_name=user_name, - dedup=dedup, + query=query, top_k=top_k, search_filter=search_filter, user_name=user_name ) else: retrieved_results = self.retrieve( @@ -213,7 +207,7 @@ def search( plugin=kwargs.get("plugin", False), search_tool_memory=search_tool_memory, tool_mem_top_k=tool_mem_top_k, - dedup=None if kwargs.get("plugin", False) and dedup == "sim" else dedup, + dedup=dedup, ) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") @@ -296,50 +290,6 @@ def _parse_task( return parsed_goal, query_embedding, context, query - @timed - def _retrieve_simple( - self, - query: str, - top_k: int, - search_filter: dict | None = None, - user_name: str | None = None, - dedup: str | None = None, - **kwargs, - ): - """Retrieve from by keywords and embedding""" - query_words = [] - if self.tokenizer: - query_words = self.tokenizer.tokenize_mixed(query) - else: - query_words = query.strip().split() - query_words = [query, *query_words] - logger.info(f"[SIMPLESEARCH] Query words: {query_words}") - query_embeddings = self.embedder.embed(query_words) - - items = self.graph_retriever.retrieve_from_mixed( - top_k=top_k * 2, - memory_scope=None, - query_embedding=query_embeddings, - search_filter=search_filter, - user_name=user_name, - use_fast_graph=self.use_fast_graph, - ) - logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") - documents = [getattr(item, "memory", "") for item in items] - documents_embeddings = self.embedder.embed(documents) - similarity_matrix = cosine_similarity_matrix(documents_embeddings) - selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) - selected_items = [items[i] for i in selected_indices] - logger.info( - f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" - ) - return self.reranker.rerank( - query=query, - query_embedding=query_embeddings[0], - graph_results=selected_items, - top_k=top_k, - ) - @timed def _retrieve_paths( self, @@ -723,17 +673,14 @@ def _retrieve_simple( user_name=user_name, ) logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") - if dedup == "no": - selected_items = items - else: - documents = [getattr(item, "memory", "") for item in items] - documents_embeddings = self.embedder.embed(documents) - similarity_matrix = cosine_similarity_matrix(documents_embeddings) - selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) - selected_items = [items[i] for i in selected_indices] - logger.info( - f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" - ) + documents = [getattr(item, "memory", "") for item in items] + documents_embeddings = self.embedder.embed(documents) + similarity_matrix = cosine_similarity_matrix(documents_embeddings) + selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) + selected_items = [items[i] for i in selected_indices] + logger.info( + f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" + ) return self.reranker.rerank( query=query, query_embedding=query_embeddings[0], @@ -750,26 +697,6 @@ def _deduplicate_results(self, results): deduped[item.memory] = (item, score) return list(deduped.values()) - @timed - def _deduplicate_similar_results( - self, results: list[tuple[TextualMemoryItem, float]], similarity_threshold: float = 0.85 - ): - """Deduplicate results by semantic similarity while keeping higher scores.""" - if len(results) <= 1: - return results - - sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True) - documents = [getattr(item, "memory", "") for item, _ in sorted_results] - embeddings = self.embedder.embed(documents) - similarity_matrix = cosine_similarity_matrix(embeddings) - - selected_indices: list[int] = [] - for i in range(len(sorted_results)): - if all(similarity_matrix[i][j] <= similarity_threshold for j in selected_indices): - selected_indices.append(i) - - return [sorted_results[i] for i in selected_indices] - @timed def _sort_and_trim( self, results, top_k, plugin=False, search_tool_memory=False, tool_mem_top_k=6 diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 906416461..6c3cc0cc7 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -23,9 +23,6 @@ MEM_READ_TASK_LABEL, PREF_ADD_TASK_LABEL, ) -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( - cosine_similarity_matrix, -) from memos.multi_mem_cube.views import MemCubeView from memos.templates.mem_reader_prompts import PROMPT_MAPPING from memos.types.general_types import ( @@ -266,9 +263,11 @@ def _deep_search( moscube=search_req.moscube, search_filter=search_filter, info=info, - dedup=search_req.dedup, ) - formatted_memories = [format_memory_item(data) for data in enhanced_memories] + formatted_memories = [ + format_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in enhanced_memories + ] return formatted_memories def _agentic_search( @@ -277,7 +276,10 @@ def _agentic_search( deepsearch_results = self.deepsearch_agent.run( search_req.query, user_id=user_context.mem_cube_id ) - formatted_memories = [format_memory_item(data) for data in deepsearch_results] + formatted_memories = [ + format_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in deepsearch_results + ] return formatted_memories def _fine_search( @@ -383,25 +385,13 @@ def _dedup_by_content(memories: list) -> list: unique_memories.append(mem) return unique_memories - def _dedup_by_similarity(memories: list) -> list: - if len(memories) <= 1: - return memories - documents = [getattr(mem, "memory", "") for mem in memories] - embeddings = self.searcher.embedder.embed(documents) - similarity_matrix = cosine_similarity_matrix(embeddings) - selected_indices = [] - for i in range(len(memories)): - if all(similarity_matrix[i][j] <= 0.85 for j in selected_indices): - selected_indices.append(i) - return [memories[i] for i in selected_indices] - - if search_req.dedup == "no": - deduped_memories = enhanced_memories - elif search_req.dedup == "sim": - deduped_memories = _dedup_by_similarity(enhanced_memories) - else: - deduped_memories = _dedup_by_content(enhanced_memories) - formatted_memories = [format_memory_item(data) for data in deduped_memories] + deduped_memories = ( + enhanced_memories if search_req.dedup == "no" else _dedup_by_content(enhanced_memories) + ) + formatted_memories = [ + format_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in deduped_memories + ] logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}") @@ -488,7 +478,10 @@ def _fast_search( dedup=search_req.dedup, ) - formatted_memories = [format_memory_item(data) for data in search_results] + formatted_memories = [ + format_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in search_results + ] return formatted_memories From 6d7f410811d300f7561da7da5d35d7dd913ac46b Mon Sep 17 00:00:00 2001 From: "glin1993@outlook.com" <> Date: Fri, 26 Dec 2025 10:47:45 +0800 Subject: [PATCH 3/3] feat: optimize memory search deduplication and fix parsing bugs - Tune similarity threshold to 0.92 for 'dedup=sim' to preserve subtle semantic nuances. - Implement recall expansion (5x Top-K) when deduplicating to ensure output diversity. - Remove aggressive filling logic to strictly enforce the similarity threshold. - Fix attribute error in MultiModalStructMemReader by correctly importing parse_json_result. - Replace fragile eval() with robust parse_json_result in TaskGoalParser to handle JSON booleans. --- src/memos/api/handlers/search_handler.py | 38 +++++++------------ src/memos/mem_reader/multi_modal_struct.py | 3 +- .../retrieve/retrieve_utils.py | 1 - .../retrieve/task_goal_parser.py | 19 +++++++--- 4 files changed, 29 insertions(+), 32 deletions(-) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 4aa993a06..3774410dc 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -55,12 +55,19 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse """ self.logger.info(f"[SearchHandler] Search Req is: {search_req}") + # Increase recall pool if deduplication is enabled to ensure diversity + original_top_k = search_req.top_k + if search_req.dedup == "sim": + search_req.top_k = original_top_k * 5 + cube_view = self._build_cube_view(search_req) results = cube_view.search_memories(search_req) if search_req.dedup == "sim": - results = self._dedup_text_memories(results, search_req.top_k) + results = self._dedup_text_memories(results, original_top_k) self._strip_embeddings(results) + # Restore original top_k for downstream logic or response metadata + search_req.top_k = original_top_k self.logger.info( f"[SearchHandler] Final search results: count={len(results)} results={results}" @@ -104,35 +111,18 @@ def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> di bucket_idx = flat[idx][0] if len(selected_by_bucket[bucket_idx]) >= target_top_k: continue - if self._is_unrelated(idx, selected_global, similarity_matrix, 0.85): + # Use 0.92 threshold strictly + if self._is_unrelated(idx, selected_global, similarity_matrix, 0.92): selected_by_bucket[bucket_idx].append(idx) selected_global.append(idx) - for bucket_idx in range(len(buckets)): - if len(selected_by_bucket[bucket_idx]) >= min( - target_top_k, len(indices_by_bucket[bucket_idx]) - ): - continue - remaining_indices = [ - idx - for idx in indices_by_bucket.get(bucket_idx, []) - if idx not in selected_by_bucket[bucket_idx] - ] - if not remaining_indices: - continue - # Fill to target_top_k with the least-similar candidates to preserve diversity. - remaining_indices.sort( - key=lambda idx: self._max_similarity(idx, selected_global, similarity_matrix) - ) - for idx in remaining_indices: - if len(selected_by_bucket[bucket_idx]) >= target_top_k: - break - selected_by_bucket[bucket_idx].append(idx) - selected_global.append(idx) + # Removed the 'filling' logic that was pulling back similar items. + # Now it will only return items that truly pass the 0.92 threshold, + # up to target_top_k. for bucket_idx, bucket in enumerate(buckets): selected_indices = selected_by_bucket.get(bucket_idx, []) - bucket["memories"] = [flat[i][1] for i in selected_indices[:target_top_k]] + bucket["memories"] = [flat[i][1] for i in selected_indices] return results @staticmethod diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 48be9b72c..2ed1af53e 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -10,6 +10,7 @@ from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang from memos.mem_reader.read_multi_modal.base import _derive_key from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader +from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH from memos.types import MessagesType @@ -377,7 +378,7 @@ def _get_llm_response( messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) - response_json = self.parse_json_result(response_text) + response_json = parse_json_result(response_text) except Exception as e: logger.error(f"[LLM] Exception during chat generation: {e}") response_json = { diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index d9398a22c..5a82883c8 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any - import numpy as np from memos.dependency import require_python_package diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index e1ce859bf..f4d6c4847 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -5,7 +5,10 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + FastTokenizer, + parse_json_result, +) from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT @@ -111,8 +114,10 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal: for attempt_times in range(attempts): try: context = kwargs.get("context", "") - response = response.replace("```", "").replace("json", "").strip() - response_json = eval(response) + response_json = parse_json_result(response) + if not response_json: + raise ValueError("Parsed JSON is empty") + return ParsedTaskGoal( memories=response_json.get("memories", []), keys=response_json.get("keys", []), @@ -123,6 +128,8 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal: context=context, ) except Exception as e: - raise ValueError( - f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts + 1}" - ) from e + if attempt_times == attempts - 1: + raise ValueError( + f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts}" + ) from e + continue