diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index 88875cacc..94988295b 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -29,7 +29,7 @@ def to_iter(running: Any) -> list[Any]: return list(running) if running else [] -def format_memory_item(memory_data: Any) -> dict[str, Any]: +def format_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]: """ Format a single memory item for API response. @@ -47,7 +47,8 @@ def format_memory_item(memory_data: Any) -> dict[str, Any]: ref_id = f"[{memory_id.split('-')[0]}]" memory["ref_id"] = ref_id - memory["metadata"]["embedding"] = [] + if not include_embedding: + memory["metadata"]["embedding"] = [] memory["metadata"]["sources"] = [] memory["metadata"]["usage"] = [] memory["metadata"]["ref_id"] = ref_id diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index f7d6ee2c8..3774410dc 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -5,9 +5,14 @@ using dependency injection for better modularity and testability. """ +from typing import Any + from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies from memos.api.product_models import APISearchRequest, SearchResponse from memos.log import get_logger +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + cosine_similarity_matrix, +) from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView @@ -50,9 +55,19 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse """ self.logger.info(f"[SearchHandler] Search Req is: {search_req}") + # Increase recall pool if deduplication is enabled to ensure diversity + original_top_k = search_req.top_k + if search_req.dedup == "sim": + search_req.top_k = original_top_k * 5 + cube_view = self._build_cube_view(search_req) results = cube_view.search_memories(search_req) + if search_req.dedup == "sim": + results = self._dedup_text_memories(results, original_top_k) + self._strip_embeddings(results) + # Restore original top_k for downstream logic or response metadata + search_req.top_k = original_top_k self.logger.info( f"[SearchHandler] Final search results: count={len(results)} results={results}" @@ -63,6 +78,93 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse data=results, ) + def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> dict[str, Any]: + buckets = results.get("text_mem", []) + if not buckets: + return results + + flat: list[tuple[int, dict[str, Any], float]] = [] + for bucket_idx, bucket in enumerate(buckets): + for mem in bucket.get("memories", []): + score = mem.get("metadata", {}).get("relativity", 0.0) + flat.append((bucket_idx, mem, score)) + + if len(flat) <= 1: + return results + + embeddings = self._extract_embeddings([mem for _, mem, _ in flat]) + if embeddings is None: + documents = [mem.get("memory", "") for _, mem, _ in flat] + embeddings = self.searcher.embedder.embed(documents) + + similarity_matrix = cosine_similarity_matrix(embeddings) + + indices_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))} + for flat_index, (bucket_idx, _, _) in enumerate(flat): + indices_by_bucket[bucket_idx].append(flat_index) + + selected_global: list[int] = [] + selected_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))} + + ordered_indices = sorted(range(len(flat)), key=lambda idx: flat[idx][2], reverse=True) + for idx in ordered_indices: + bucket_idx = flat[idx][0] + if len(selected_by_bucket[bucket_idx]) >= target_top_k: + continue + # Use 0.92 threshold strictly + if self._is_unrelated(idx, selected_global, similarity_matrix, 0.92): + selected_by_bucket[bucket_idx].append(idx) + selected_global.append(idx) + + # Removed the 'filling' logic that was pulling back similar items. + # Now it will only return items that truly pass the 0.92 threshold, + # up to target_top_k. + + for bucket_idx, bucket in enumerate(buckets): + selected_indices = selected_by_bucket.get(bucket_idx, []) + bucket["memories"] = [flat[i][1] for i in selected_indices] + return results + + @staticmethod + def _is_unrelated( + index: int, + selected_indices: list[int], + similarity_matrix: list[list[float]], + similarity_threshold: float, + ) -> bool: + return all(similarity_matrix[index][j] <= similarity_threshold for j in selected_indices) + + @staticmethod + def _max_similarity( + index: int, selected_indices: list[int], similarity_matrix: list[list[float]] + ) -> float: + if not selected_indices: + return 0.0 + return max(similarity_matrix[index][j] for j in selected_indices) + + @staticmethod + def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None: + embeddings: list[list[float]] = [] + for mem in memories: + embedding = mem.get("metadata", {}).get("embedding") + if not embedding: + return None + embeddings.append(embedding) + return embeddings + + @staticmethod + def _strip_embeddings(results: dict[str, Any]) -> None: + for bucket in results.get("text_mem", []): + for mem in bucket.get("memories", []): + metadata = mem.get("metadata", {}) + if "embedding" in metadata: + metadata["embedding"] = [] + for bucket in results.get("tool_mem", []): + for mem in bucket.get("memories", []): + metadata = mem.get("metadata", {}) + if "embedding" in metadata: + metadata["embedding"] = [] + def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]: """ Normalize target cube ids from search_req. diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 3c7070ec9..120da8b55 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -319,6 +319,15 @@ class APISearchRequest(BaseRequest): description="Number of textual memories to retrieve (top-K). Default: 10.", ) + dedup: Literal["no", "sim"] | None = Field( + None, + description=( + "Optional dedup option for textual memories. " + "Use 'no' for no dedup, 'sim' for similarity dedup. " + "If None, default exact-text dedup is applied." + ), + ) + pref_top_k: int = Field( 6, ge=0, diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 48be9b72c..2ed1af53e 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -10,6 +10,7 @@ from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang from memos.mem_reader.read_multi_modal.base import _derive_key from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader +from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH from memos.types import MessagesType @@ -377,7 +378,7 @@ def _get_llm_response( messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) - response_json = self.parse_json_result(response_text) + response_json = parse_json_result(response_text) except Exception as e: logger.error(f"[LLM] Exception during chat generation: {e}") response_json = { diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index c3f5891ae..7007f8418 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -186,10 +186,14 @@ def mix_search_memories( info=info, search_tool_memory=search_req.search_tool_memory, tool_mem_top_k=search_req.tool_mem_top_k, + dedup=search_req.dedup, ) memories = merged_memories[: search_req.top_k] - formatted_memories = [format_textual_memory_item(item) for item in memories] + formatted_memories = [ + format_textual_memory_item(item, include_embedding=search_req.dedup == "sim") + for item in memories + ] self.submit_memory_history_async_task( search_req=search_req, user_context=user_context, @@ -233,7 +237,10 @@ def update_search_memories_to_redis( mem_cube=self.mem_cube, mode=SearchMode.FAST, ) - formatted_memories = [format_textual_memory_item(data) for data in memories] + formatted_memories = [ + format_textual_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in memories + ] else: memories = [ TextualMemoryItem.from_dict(one) for one in memories_to_store["memories"] diff --git a/src/memos/mem_scheduler/utils/api_utils.py b/src/memos/mem_scheduler/utils/api_utils.py index c8d096517..3833b5926 100644 --- a/src/memos/mem_scheduler/utils/api_utils.py +++ b/src/memos/mem_scheduler/utils/api_utils.py @@ -6,14 +6,15 @@ from memos.memories.textual.tree import TextualMemoryItem -def format_textual_memory_item(memory_data: Any) -> dict[str, Any]: +def format_textual_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]: """Format a single memory item for API response.""" memory = memory_data.model_dump() memory_id = memory["id"] ref_id = f"[{memory_id.split('-')[0]}]" memory["ref_id"] = ref_id - memory["metadata"]["embedding"] = [] + if not include_embedding: + memory["metadata"]["embedding"] = [] memory["metadata"]["sources"] = [] memory["metadata"]["ref_id"] = ref_id memory["metadata"]["id"] = memory_id diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 22545496a..fb33a2d03 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -161,6 +161,7 @@ def search( user_name: str | None = None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + dedup: str | None = None, **kwargs, ) -> list[TextualMemoryItem]: """Search for memories based on a query. @@ -207,6 +208,7 @@ def search( user_name=user_name, search_tool_memory=search_tool_memory, tool_mem_top_k=tool_mem_top_k, + dedup=dedup, **kwargs, ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index d9398a22c..5a82883c8 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any - import numpy as np from memos.dependency import require_python_package diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index dc47dd4d7..f3d6ba037 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -119,9 +119,13 @@ def post_retrieve( info=None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + dedup: str | None = None, plugin=False, ): - deduped = self._deduplicate_results(retrieved_results) + if dedup == "no": + deduped = retrieved_results + else: + deduped = self._deduplicate_results(retrieved_results) final_results = self._sort_and_trim( deduped, top_k, plugin, search_tool_memory, tool_mem_top_k ) @@ -141,6 +145,7 @@ def search( user_name: str | None = None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + dedup: str | None = None, **kwargs, ) -> list[TextualMemoryItem]: """ @@ -202,6 +207,7 @@ def search( plugin=kwargs.get("plugin", False), search_tool_memory=search_tool_memory, tool_mem_top_k=tool_mem_top_k, + dedup=dedup, ) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") @@ -284,49 +290,6 @@ def _parse_task( return parsed_goal, query_embedding, context, query - @timed - def _retrieve_simple( - self, - query: str, - top_k: int, - search_filter: dict | None = None, - user_name: str | None = None, - **kwargs, - ): - """Retrieve from by keywords and embedding""" - query_words = [] - if self.tokenizer: - query_words = self.tokenizer.tokenize_mixed(query) - else: - query_words = query.strip().split() - query_words = [query, *query_words] - logger.info(f"[SIMPLESEARCH] Query words: {query_words}") - query_embeddings = self.embedder.embed(query_words) - - items = self.graph_retriever.retrieve_from_mixed( - top_k=top_k * 2, - memory_scope=None, - query_embedding=query_embeddings, - search_filter=search_filter, - user_name=user_name, - use_fast_graph=self.use_fast_graph, - ) - logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") - documents = [getattr(item, "memory", "") for item in items] - documents_embeddings = self.embedder.embed(documents) - similarity_matrix = cosine_similarity_matrix(documents_embeddings) - selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) - selected_items = [items[i] for i in selected_indices] - logger.info( - f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" - ) - return self.reranker.rerank( - query=query, - query_embedding=query_embeddings[0], - graph_results=selected_items, - top_k=top_k, - ) - @timed def _retrieve_paths( self, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index e1ce859bf..f4d6c4847 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -5,7 +5,10 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + FastTokenizer, + parse_json_result, +) from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT @@ -111,8 +114,10 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal: for attempt_times in range(attempts): try: context = kwargs.get("context", "") - response = response.replace("```", "").replace("json", "").strip() - response_json = eval(response) + response_json = parse_json_result(response) + if not response_json: + raise ValueError("Parsed JSON is empty") + return ParsedTaskGoal( memories=response_json.get("memories", []), keys=response_json.get("keys", []), @@ -123,6 +128,8 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal: context=context, ) except Exception as e: - raise ValueError( - f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts + 1}" - ) from e + if attempt_times == attempts - 1: + raise ValueError( + f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts}" + ) from e + continue diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index a920f7b0e..6c3cc0cc7 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -264,7 +264,10 @@ def _deep_search( search_filter=search_filter, info=info, ) - formatted_memories = [format_memory_item(data) for data in enhanced_memories] + formatted_memories = [ + format_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in enhanced_memories + ] return formatted_memories def _agentic_search( @@ -273,7 +276,10 @@ def _agentic_search( deepsearch_results = self.deepsearch_agent.run( search_req.query, user_id=user_context.mem_cube_id ) - formatted_memories = [format_memory_item(data) for data in deepsearch_results] + formatted_memories = [ + format_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in deepsearch_results + ] return formatted_memories def _fine_search( @@ -328,6 +334,7 @@ def _fine_search( top_k=search_req.top_k, user_name=user_context.mem_cube_id, info=info, + dedup=search_req.dedup, ) # Enhance with query @@ -378,8 +385,13 @@ def _dedup_by_content(memories: list) -> list: unique_memories.append(mem) return unique_memories - deduped_memories = _dedup_by_content(enhanced_memories) - formatted_memories = [format_memory_item(data) for data in deduped_memories] + deduped_memories = ( + enhanced_memories if search_req.dedup == "no" else _dedup_by_content(enhanced_memories) + ) + formatted_memories = [ + format_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in deduped_memories + ] logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}") @@ -463,9 +475,13 @@ def _fast_search( plugin=plugin, search_tool_memory=search_req.search_tool_memory, tool_mem_top_k=search_req.tool_mem_top_k, + dedup=search_req.dedup, ) - formatted_memories = [format_memory_item(data) for data in search_results] + formatted_memories = [ + format_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in search_results + ] return formatted_memories