diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index e5af52f87..93eff185b 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -5,12 +5,12 @@ using dependency injection for better modularity and testability. """ -import time +import copy +import math from typing import Any from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies -from memos.api.handlers.formatters_handler import rerank_knowledge_mem from memos.api.product_models import APISearchRequest, SearchResponse from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( @@ -58,32 +58,41 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse """ self.logger.info(f"[SearchHandler] Search Req is: {search_req}") - # Increase recall pool if deduplication is enabled to ensure diversity - original_top_k = search_req.top_k - if search_req.dedup == "sim": - search_req.top_k = original_top_k * 5 + # Use deepcopy to avoid modifying the original request object + search_req_local = copy.deepcopy(search_req) + original_top_k = search_req_local.top_k + + # Expand top_k for deduplication (5x to ensure enough candidates) + if search_req_local.dedup in ("sim", "mmr"): + search_req_local.top_k = original_top_k * 5 + + # Create new searcher with include_embedding for MMR deduplication + searcher_to_use = self.searcher + if search_req_local.dedup == "mmr": + text_mem = getattr(self.naive_mem_cube, "text_mem", None) + if text_mem is not None: + # Create new searcher instance with include_embedding=True + searcher_to_use = text_mem.get_searcher( + manual_close_internet=not getattr(self.searcher, "internet_retriever", None), + moscube=False, + process_llm=getattr(self.mem_reader, "llm", None), + ) + # Override include_embedding for this searcher + if hasattr(searcher_to_use, "graph_retriever"): + searcher_to_use.graph_retriever.include_embedding = True - cube_view = self._build_cube_view(search_req) + # Search and deduplicate + cube_view = self._build_cube_view(search_req_local, searcher_to_use) + results = cube_view.search_memories(search_req_local) - results = cube_view.search_memories(search_req) - if search_req.dedup == "sim": + if search_req_local.dedup == "sim": results = self._dedup_text_memories(results, original_top_k) self._strip_embeddings(results) - # Restore original top_k for downstream logic or response metadata - search_req.top_k = original_top_k - - start_time = time.time() - text_mem = results["text_mem"] - results["text_mem"] = rerank_knowledge_mem( - self.reranker, - query=search_req.query, - text_mem=text_mem, - top_k=original_top_k, - file_mem_proportion=0.5, - ) - rerank_time = time.time() - start_time + elif search_req_local.dedup == "mmr": + pref_top_k = getattr(search_req_local, "pref_top_k", 6) + results = self._mmr_dedup_text_memories(results, original_top_k, pref_top_k) + self._strip_embeddings(results) - self.logger.info(f"[Knowledge_replace_memory_time] Rerank time: {rerank_time} seconds") self.logger.info( f"[SearchHandler] Final search results: count={len(results)} results={results}" ) @@ -140,6 +149,205 @@ def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> di bucket["memories"] = [flat[i][1] for i in selected_indices] return results + def _mmr_dedup_text_memories( + self, results: dict[str, Any], text_top_k: int, pref_top_k: int = 6 + ) -> dict[str, Any]: + """ + MMR-based deduplication with progressive penalty for high similarity. + + Performs deduplication on both text_mem and preference memories together. + Other memory types (tool_mem, etc.) are not modified. + + Args: + results: Search results containing text_mem and preference buckets + text_top_k: Target number of text memories to return per bucket + pref_top_k: Target number of preference memories to return per bucket + + Algorithm: + 1. Prefill top 5 by relevance + 2. MMR selection: balance relevance vs diversity + 3. Re-sort by original relevance for better generation quality + """ + text_buckets = results.get("text_mem", []) + pref_buckets = results.get("preference", []) + + # Early return if no memories to deduplicate + if not text_buckets and not pref_buckets: + return results + + # Flatten all memories with their type and scores + # flat structure: (memory_type, bucket_idx, mem, score) + flat: list[tuple[str, int, dict[str, Any], float]] = [] + + # Flatten text memories + for bucket_idx, bucket in enumerate(text_buckets): + for mem in bucket.get("memories", []): + score = mem.get("metadata", {}).get("relativity", 0.0) + flat.append(("text", bucket_idx, mem, float(score) if score is not None else 0.0)) + + # Flatten preference memories + for bucket_idx, bucket in enumerate(pref_buckets): + for mem in bucket.get("memories", []): + score = mem.get("metadata", {}).get("relativity", 0.0) + flat.append( + ("preference", bucket_idx, mem, float(score) if score is not None else 0.0) + ) + + if len(flat) <= 1: + return results + + # Get or compute embeddings + embeddings = self._extract_embeddings([mem for _, _, mem, _ in flat]) + if embeddings is None: + self.logger.warning("[SearchHandler] Embedding is missing; recomputing embeddings") + documents = [mem.get("memory", "") for _, _, mem, _ in flat] + embeddings = self.searcher.embedder.embed(documents) + + # Compute similarity matrix using NumPy-optimized method + # Returns numpy array but compatible with list[i][j] indexing + similarity_matrix = cosine_similarity_matrix(embeddings) + + # Initialize selection tracking for both text and preference + text_indices_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(text_buckets))} + pref_indices_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(pref_buckets))} + + for flat_index, (mem_type, bucket_idx, _, _) in enumerate(flat): + if mem_type == "text": + text_indices_by_bucket[bucket_idx].append(flat_index) + elif mem_type == "preference": + pref_indices_by_bucket[bucket_idx].append(flat_index) + + selected_global: list[int] = [] + text_selected_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(text_buckets))} + pref_selected_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(pref_buckets))} + selected_texts: set[str] = set() # Track exact text content to avoid duplicates + + # Phase 1: Prefill top N by relevance + # Use the smaller of text_top_k and pref_top_k for prefill count + prefill_top_n = min(2, text_top_k, pref_top_k) if pref_buckets else min(2, text_top_k) + ordered_by_relevance = sorted(range(len(flat)), key=lambda idx: flat[idx][3], reverse=True) + for idx in ordered_by_relevance[: len(flat)]: + if len(selected_global) >= prefill_top_n: + break + mem_type, bucket_idx, mem, _ = flat[idx] + + # Skip if exact text already exists in selected set + mem_text = mem.get("memory", "").strip() + if mem_text in selected_texts: + continue + + # Skip if highly similar (Dice + TF-IDF + 2-gram combined, with embedding filter) + if SearchHandler._is_text_highly_similar_optimized( + idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.9 + ): + continue + + # Check bucket capacity with correct top_k for each type + if mem_type == "text" and len(text_selected_by_bucket[bucket_idx]) < text_top_k: + selected_global.append(idx) + text_selected_by_bucket[bucket_idx].append(idx) + selected_texts.add(mem_text) + elif mem_type == "preference" and len(pref_selected_by_bucket[bucket_idx]) < pref_top_k: + selected_global.append(idx) + pref_selected_by_bucket[bucket_idx].append(idx) + selected_texts.add(mem_text) + + # Phase 2: MMR selection for remaining slots + lambda_relevance = 0.8 + similarity_threshold = 0.9 # Start exponential penalty from 0.9 (lowered from 0.9) + alpha_exponential = 10.0 # Exponential penalty coefficient + remaining = set(range(len(flat))) - set(selected_global) + + while remaining: + best_idx: int | None = None + best_mmr: float | None = None + + for idx in remaining: + mem_type, bucket_idx, mem, _ = flat[idx] + + # Check bucket capacity with correct top_k for each type + if ( + mem_type == "text" and len(text_selected_by_bucket[bucket_idx]) >= text_top_k + ) or ( + mem_type == "preference" + and len(pref_selected_by_bucket[bucket_idx]) >= pref_top_k + ): + continue + + # Check if exact text already exists - if so, skip this candidate entirely + mem_text = mem.get("memory", "").strip() + if mem_text in selected_texts: + continue # Skip duplicate text, don't participate in MMR competition + + # Skip if highly similar (Dice + TF-IDF + 2-gram combined, with embedding filter) + if SearchHandler._is_text_highly_similar_optimized( + idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.9 + ): + continue # Skip highly similar text, don't participate in MMR competition + + relevance = flat[idx][3] + max_sim = ( + 0.0 + if not selected_global + else max(similarity_matrix[idx][j] for j in selected_global) + ) + + # Exponential penalty for similarity > 0.80 + if max_sim > similarity_threshold: + penalty_multiplier = math.exp( + alpha_exponential * (max_sim - similarity_threshold) + ) + diversity = max_sim * penalty_multiplier + else: + diversity = max_sim + + mmr_score = lambda_relevance * relevance - (1.0 - lambda_relevance) * diversity + + if best_mmr is None or mmr_score > best_mmr: + best_mmr = mmr_score + best_idx = idx + + if best_idx is None: + break + + mem_type, bucket_idx, mem, _ = flat[best_idx] + + # Add to selected set and track text + mem_text = mem.get("memory", "").strip() + selected_global.append(best_idx) + selected_texts.add(mem_text) + + if mem_type == "text": + text_selected_by_bucket[bucket_idx].append(best_idx) + elif mem_type == "preference": + pref_selected_by_bucket[bucket_idx].append(best_idx) + remaining.remove(best_idx) + + # Early termination: all buckets are full + text_all_full = all( + len(text_selected_by_bucket[b_idx]) >= min(text_top_k, len(bucket_indices)) + for b_idx, bucket_indices in text_indices_by_bucket.items() + ) + pref_all_full = all( + len(pref_selected_by_bucket[b_idx]) >= min(pref_top_k, len(bucket_indices)) + for b_idx, bucket_indices in pref_indices_by_bucket.items() + ) + if text_all_full and pref_all_full: + break + + # Phase 3: Re-sort by original relevance and fill back to buckets + for bucket_idx, bucket in enumerate(text_buckets): + selected_indices = text_selected_by_bucket.get(bucket_idx, []) + selected_indices = sorted(selected_indices, key=lambda i: flat[i][3], reverse=True) + bucket["memories"] = [flat[i][2] for i in selected_indices] + + for bucket_idx, bucket in enumerate(pref_buckets): + selected_indices = pref_selected_by_bucket.get(bucket_idx, []) + selected_indices = sorted(selected_indices, key=lambda i: flat[i][3], reverse=True) + bucket["memories"] = [flat[i][2] for i in selected_indices] + + return results + @staticmethod def _is_unrelated( index: int, @@ -180,6 +388,168 @@ def _strip_embeddings(results: dict[str, Any]) -> None: if "embedding" in metadata: metadata["embedding"] = [] + @staticmethod + def _dice_similarity(text1: str, text2: str) -> float: + """ + Calculate Dice coefficient (character-level, fastest). + + Dice = 2 * |A ∩ B| / (|A| + |B|) + Speed: O(n + m), ~0.05-0.1ms per comparison + + Args: + text1: First text string + text2: Second text string + + Returns: + Dice similarity score between 0.0 and 1.0 + """ + if not text1 or not text2: + return 0.0 + + chars1 = set(text1) + chars2 = set(text2) + + intersection = len(chars1 & chars2) + return 2 * intersection / (len(chars1) + len(chars2)) + + @staticmethod + def _bigram_similarity(text1: str, text2: str) -> float: + """ + Calculate character-level 2-gram Jaccard similarity. + + Speed: O(n + m), ~0.1-0.2ms per comparison + Considers local order (more strict than Dice). + + Args: + text1: First text string + text2: Second text string + + Returns: + Jaccard similarity score between 0.0 and 1.0 + """ + if not text1 or not text2: + return 0.0 + + # Generate 2-grams + bigrams1 = {text1[i : i + 2] for i in range(len(text1) - 1)} if len(text1) >= 2 else {text1} + bigrams2 = {text2[i : i + 2] for i in range(len(text2) - 1)} if len(text2) >= 2 else {text2} + + intersection = len(bigrams1 & bigrams2) + union = len(bigrams1 | bigrams2) + + return intersection / union if union > 0 else 0.0 + + @staticmethod + def _tfidf_similarity(text1: str, text2: str) -> float: + """ + Calculate TF-IDF cosine similarity (character-level, no sklearn). + + Speed: O(n + m), ~0.3-0.5ms per comparison + Considers character frequency weighting. + + Args: + text1: First text string + text2: Second text string + + Returns: + Cosine similarity score between 0.0 and 1.0 + """ + if not text1 or not text2: + return 0.0 + + from collections import Counter + + # Character frequency (TF) + tf1 = Counter(text1) + tf2 = Counter(text2) + + # All unique characters (vocabulary) + vocab = set(tf1.keys()) | set(tf2.keys()) + + # Simple IDF: log(2 / df) where df is document frequency + # For two documents, IDF is log(2/1)=0.693 if char appears in one doc, + # or log(2/2)=0 if appears in both (we use log(2/1) for simplicity) + idf = {char: (1.0 if char in tf1 and char in tf2 else 1.5) for char in vocab} + + # TF-IDF vectors + vec1 = {char: tf1.get(char, 0) * idf[char] for char in vocab} + vec2 = {char: tf2.get(char, 0) * idf[char] for char in vocab} + + # Cosine similarity + dot_product = sum(vec1[char] * vec2[char] for char in vocab) + norm1 = math.sqrt(sum(v * v for v in vec1.values())) + norm2 = math.sqrt(sum(v * v for v in vec2.values())) + + if norm1 == 0 or norm2 == 0: + return 0.0 + + return dot_product / (norm1 * norm2) + + @staticmethod + def _is_text_highly_similar_optimized( + candidate_idx: int, + candidate_text: str, + selected_global: list[int], + similarity_matrix, + flat: list, + threshold: float = 0.9, + ) -> bool: + """ + Multi-algorithm text similarity check with embedding pre-filtering. + + Strategy: + 1. Only compare with the single highest embedding similarity item (not all 25) + 2. Only perform text comparison if embedding similarity > 0.60 + 3. Use weighted combination of three algorithms: + - Dice (40%): Fastest, character-level set similarity + - TF-IDF (35%): Considers character frequency weighting + - 2-gram (25%): Considers local character order + + Combined formula: + combined_score = 0.40 * dice + 0.35 * tfidf + 0.25 * bigram + + This reduces comparisons from O(N) to O(1) per candidate, with embedding pre-filtering. + Expected speedup: 100-200x compared to LCS approach. + + Args: + candidate_idx: Index of candidate memory in flat list + candidate_text: Text content of candidate memory + selected_global: List of already selected memory indices + similarity_matrix: Precomputed embedding similarity matrix + flat: Flat list of all memories + threshold: Combined similarity threshold (default 0.75) + + Returns: + True if candidate is highly similar to any selected memory + """ + if not selected_global: + return False + + # Find the already-selected memory with highest embedding similarity + max_sim_idx = max(selected_global, key=lambda j: similarity_matrix[candidate_idx][j]) + max_sim = similarity_matrix[candidate_idx][max_sim_idx] + + # If highest embedding similarity < 0.60, skip text comparison entirely + if max_sim <= 0.9: + return False + + # Get text of most similar memory + most_similar_mem = flat[max_sim_idx][2] + most_similar_text = most_similar_mem.get("memory", "").strip() + + # Calculate three similarity scores + dice_sim = SearchHandler._dice_similarity(candidate_text, most_similar_text) + tfidf_sim = SearchHandler._tfidf_similarity(candidate_text, most_similar_text) + bigram_sim = SearchHandler._bigram_similarity(candidate_text, most_similar_text) + + # Weighted combination: Dice (40%) + TF-IDF (35%) + 2-gram (25%) + # Dice has highest weight (fastest and most reliable) + # TF-IDF considers frequency (handles repeated characters well) + # 2-gram considers order (catches local pattern similarity) + combined_score = 0.40 * dice_sim + 0.35 * tfidf_sim + 0.25 * bigram_sim + + return combined_score >= threshold + def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]: """ Normalize target cube ids from search_req. @@ -192,8 +562,9 @@ def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]: return [search_req.user_id] - def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: + def _build_cube_view(self, search_req: APISearchRequest, searcher=None) -> MemCubeView: cube_ids = self._resolve_cube_ids(search_req) + searcher_to_use = searcher if searcher is not None else self.searcher if len(cube_ids) == 1: cube_id = cube_ids[0] @@ -203,7 +574,7 @@ def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: mem_reader=self.mem_reader, mem_scheduler=self.mem_scheduler, logger=self.logger, - searcher=self.searcher, + searcher=searcher_to_use, deepsearch_agent=self.deepsearch_agent, ) else: @@ -214,7 +585,7 @@ def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: mem_reader=self.mem_reader, mem_scheduler=self.mem_scheduler, logger=self.logger, - searcher=self.searcher, + searcher=searcher_to_use, deepsearch_agent=self.deepsearch_agent, ) for cube_id in cube_ids diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index e6c4ae23d..d8fa784a3 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -319,11 +319,11 @@ class APISearchRequest(BaseRequest): description="Number of textual memories to retrieve (top-K). Default: 10.", ) - dedup: Literal["no", "sim"] | None = Field( - None, + dedup: Literal["no", "sim", "mmr"] | None = Field( + "mmr", description=( "Optional dedup option for textual memories. " - "Use 'no' for no dedup, 'sim' for similarity dedup. " + "Use 'no' for no dedup, 'sim' for similarity dedup, 'mmr' for MMR-based dedup. " "If None, default exact-text dedup is applied." ), ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index 5a82883c8..1c887355c 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -466,7 +466,12 @@ def find_best_unrelated_subgroup(sentences: list, similarity_matrix: list, bar: def cosine_similarity_matrix(embeddings: list[list[float]]) -> list[list[float]]: - norms = np.linalg.norm(embeddings, axis=1, keepdims=True) - x_normalized = embeddings / norms + embeddings_array = np.asarray(embeddings) + norms = np.linalg.norm(embeddings_array, axis=1, keepdims=True) + # Handle zero vectors to avoid division by zero + norms[norms == 0] = 1.0 + x_normalized = embeddings_array / norms similarity_matrix = np.dot(x_normalized, x_normalized.T) + # Handle any NaN or Inf values + similarity_matrix = np.nan_to_num(similarity_matrix, nan=0.0, posinf=0.0, neginf=0.0) return similarity_matrix diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index c75fc23c6..4a21d3218 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -266,7 +266,7 @@ def _deep_search( info=info, ) formatted_memories = [ - format_memory_item(data, include_embedding=search_req.dedup == "sim") + format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr")) for data in enhanced_memories ] return formatted_memories @@ -278,7 +278,7 @@ def _agentic_search( search_req.query, user_id=user_context.mem_cube_id ) formatted_memories = [ - format_memory_item(data, include_embedding=search_req.dedup == "sim") + format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr")) for data in deepsearch_results ] return formatted_memories @@ -390,7 +390,7 @@ def _dedup_by_content(memories: list) -> list: enhanced_memories if search_req.dedup == "no" else _dedup_by_content(enhanced_memories) ) formatted_memories = [ - format_memory_item(data, include_embedding=search_req.dedup == "sim") + format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr")) for data in deduped_memories ] @@ -482,7 +482,7 @@ def _fast_search( ) formatted_memories = [ - format_memory_item(data, include_embedding=search_req.dedup == "sim") + format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr")) for data in search_results ]