diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 6eda1e2aa..e1a71737a 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -60,37 +60,25 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse # Use deepcopy to avoid modifying the original request object search_req_local = copy.deepcopy(search_req) - original_top_k = search_req_local.top_k # Expand top_k for deduplication (5x to ensure enough candidates) if search_req_local.dedup in ("sim", "mmr"): - search_req_local.top_k = original_top_k * 5 - - # Create new searcher with include_embedding for MMR deduplication - searcher_to_use = self.searcher - if search_req_local.dedup == "mmr": - text_mem = getattr(self.naive_mem_cube, "text_mem", None) - if text_mem is not None: - # Create new searcher instance with include_embedding=True - searcher_to_use = text_mem.get_searcher( - manual_close_internet=not getattr(self.searcher, "internet_retriever", None), - moscube=False, - process_llm=getattr(self.mem_reader, "llm", None), - ) - # Override include_embedding for this searcher - if hasattr(searcher_to_use, "graph_retriever"): - searcher_to_use.graph_retriever.include_embedding = True + search_req_local.top_k = search_req_local.top_k * 5 # Search and deduplicate - cube_view = self._build_cube_view(search_req_local, searcher_to_use) + cube_view = self._build_cube_view(search_req_local) results = cube_view.search_memories(search_req_local) + if not search_req_local.relativity: + search_req_local.relativity = 0 + self.logger.info(f"[SearchHandler] Relativity filter: {search_req_local.relativity}") + results = self._apply_relativity_threshold(results, search_req_local.relativity) if search_req_local.dedup == "sim": - results = self._dedup_text_memories(results, original_top_k) + results = self._dedup_text_memories(results, search_req.top_k) self._strip_embeddings(results) elif search_req_local.dedup == "mmr": pref_top_k = getattr(search_req_local, "pref_top_k", 6) - results = self._mmr_dedup_text_memories(results, original_top_k, pref_top_k) + results = self._mmr_dedup_text_memories(results, search_req.top_k, pref_top_k) self._strip_embeddings(results) self.logger.info( @@ -102,6 +90,40 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse data=results, ) + @staticmethod + def _apply_relativity_threshold(results: dict[str, Any], relativity: float) -> dict[str, Any]: + if relativity <= 0: + return results + + for key in ("text_mem", "pref_mem"): + buckets = results.get(key) + if not isinstance(buckets, list): + continue + + for bucket in buckets: + memories = bucket.get("memories") + if not isinstance(memories, list): + continue + + filtered: list[dict[str, Any]] = [] + for mem in memories: + if not isinstance(mem, dict): + continue + meta = mem.get("metadata", {}) + score = meta.get("relativity", 0.0) if isinstance(meta, dict) else 0.0 + try: + score_val = float(score) if score is not None else 0.0 + except (TypeError, ValueError): + score_val = 0.0 + if score_val >= relativity: + filtered.append(mem) + + bucket["memories"] = filtered + if "total_nodes" in bucket: + bucket["total_nodes"] = len(filtered) + + return results + def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> dict[str, Any]: buckets = results.get("text_mem", []) if not buckets: @@ -169,7 +191,7 @@ def _mmr_dedup_text_memories( 3. Re-sort by original relevance for better generation quality """ text_buckets = results.get("text_mem", []) - pref_buckets = results.get("preference", []) + pref_buckets = results.get("pref_mem", []) # Early return if no memories to deduplicate if not text_buckets and not pref_buckets: @@ -238,7 +260,7 @@ def _mmr_dedup_text_memories( # Skip if highly similar (Dice + TF-IDF + 2-gram combined, with embedding filter) if SearchHandler._is_text_highly_similar_optimized( - idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.9 + idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.92 ): continue @@ -281,7 +303,7 @@ def _mmr_dedup_text_memories( # Skip if highly similar (Dice + TF-IDF + 2-gram combined, with embedding filter) if SearchHandler._is_text_highly_similar_optimized( - idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.9 + idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.92 ): continue # Skip highly similar text, don't participate in MMR competition @@ -547,6 +569,168 @@ def _is_text_highly_similar_optimized( return combined_score >= threshold + @staticmethod + def _dice_similarity(text1: str, text2: str) -> float: + """ + Calculate Dice coefficient (character-level, fastest). + + Dice = 2 * |A ∩ B| / (|A| + |B|) + Speed: O(n + m), ~0.05-0.1ms per comparison + + Args: + text1: First text string + text2: Second text string + + Returns: + Dice similarity score between 0.0 and 1.0 + """ + if not text1 or not text2: + return 0.0 + + chars1 = set(text1) + chars2 = set(text2) + + intersection = len(chars1 & chars2) + return 2 * intersection / (len(chars1) + len(chars2)) + + @staticmethod + def _bigram_similarity(text1: str, text2: str) -> float: + """ + Calculate character-level 2-gram Jaccard similarity. + + Speed: O(n + m), ~0.1-0.2ms per comparison + Considers local order (more strict than Dice). + + Args: + text1: First text string + text2: Second text string + + Returns: + Jaccard similarity score between 0.0 and 1.0 + """ + if not text1 or not text2: + return 0.0 + + # Generate 2-grams + bigrams1 = {text1[i : i + 2] for i in range(len(text1) - 1)} if len(text1) >= 2 else {text1} + bigrams2 = {text2[i : i + 2] for i in range(len(text2) - 1)} if len(text2) >= 2 else {text2} + + intersection = len(bigrams1 & bigrams2) + union = len(bigrams1 | bigrams2) + + return intersection / union if union > 0 else 0.0 + + @staticmethod + def _tfidf_similarity(text1: str, text2: str) -> float: + """ + Calculate TF-IDF cosine similarity (character-level, no sklearn). + + Speed: O(n + m), ~0.3-0.5ms per comparison + Considers character frequency weighting. + + Args: + text1: First text string + text2: Second text string + + Returns: + Cosine similarity score between 0.0 and 1.0 + """ + if not text1 or not text2: + return 0.0 + + from collections import Counter + + # Character frequency (TF) + tf1 = Counter(text1) + tf2 = Counter(text2) + + # All unique characters (vocabulary) + vocab = set(tf1.keys()) | set(tf2.keys()) + + # Simple IDF: log(2 / df) where df is document frequency + # For two documents, IDF is log(2/1)=0.693 if char appears in one doc, + # or log(2/2)=0 if appears in both (we use log(2/1) for simplicity) + idf = {char: (1.0 if char in tf1 and char in tf2 else 1.5) for char in vocab} + + # TF-IDF vectors + vec1 = {char: tf1.get(char, 0) * idf[char] for char in vocab} + vec2 = {char: tf2.get(char, 0) * idf[char] for char in vocab} + + # Cosine similarity + dot_product = sum(vec1[char] * vec2[char] for char in vocab) + norm1 = math.sqrt(sum(v * v for v in vec1.values())) + norm2 = math.sqrt(sum(v * v for v in vec2.values())) + + if norm1 == 0 or norm2 == 0: + return 0.0 + + return dot_product / (norm1 * norm2) + + @staticmethod + def _is_text_highly_similar_optimized( + candidate_idx: int, + candidate_text: str, + selected_global: list[int], + similarity_matrix, + flat: list, + threshold: float = 0.92, + ) -> bool: + """ + Multi-algorithm text similarity check with embedding pre-filtering. + + Strategy: + 1. Only compare with the single highest embedding similarity item (not all 25) + 2. Only perform text comparison if embedding similarity > 0.60 + 3. Use weighted combination of three algorithms: + - Dice (40%): Fastest, character-level set similarity + - TF-IDF (35%): Considers character frequency weighting + - 2-gram (25%): Considers local character order + + Combined formula: + combined_score = 0.40 * dice + 0.35 * tfidf + 0.25 * bigram + + This reduces comparisons from O(N) to O(1) per candidate, with embedding pre-filtering. + Expected speedup: 100-200x compared to LCS approach. + + Args: + candidate_idx: Index of candidate memory in flat list + candidate_text: Text content of candidate memory + selected_global: List of already selected memory indices + similarity_matrix: Precomputed embedding similarity matrix + flat: Flat list of all memories + threshold: Combined similarity threshold (default 0.75) + + Returns: + True if candidate is highly similar to any selected memory + """ + if not selected_global: + return False + + # Find the already-selected memory with highest embedding similarity + max_sim_idx = max(selected_global, key=lambda j: similarity_matrix[candidate_idx][j]) + max_sim = similarity_matrix[candidate_idx][max_sim_idx] + + # If highest embedding similarity < 0.60, skip text comparison entirely + if max_sim <= 0.9: + return False + + # Get text of most similar memory + most_similar_mem = flat[max_sim_idx][2] + most_similar_text = most_similar_mem.get("memory", "").strip() + + # Calculate three similarity scores + dice_sim = SearchHandler._dice_similarity(candidate_text, most_similar_text) + tfidf_sim = SearchHandler._tfidf_similarity(candidate_text, most_similar_text) + bigram_sim = SearchHandler._bigram_similarity(candidate_text, most_similar_text) + + # Weighted combination: Dice (40%) + TF-IDF (35%) + 2-gram (25%) + # Dice has highest weight (fastest and most reliable) + # TF-IDF considers frequency (handles repeated characters well) + # 2-gram considers order (catches local pattern similarity) + combined_score = 0.40 * dice_sim + 0.35 * tfidf_sim + 0.25 * bigram_sim + + return combined_score >= threshold + def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]: """ Normalize target cube ids from search_req. diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index d8fa784a3..5e871a448 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -319,6 +319,16 @@ class APISearchRequest(BaseRequest): description="Number of textual memories to retrieve (top-K). Default: 10.", ) + relativity: float = Field( + 0.57, + ge=0, + description=( + "Relevance threshold for recalled memories. " + "Only memories with metadata.relativity >= relativity will be returned. " + "Use 0 to disable threshold filtering. Default: 0.3." + ), + ) + dedup: Literal["no", "sim", "mmr"] | None = Field( "mmr", description=( diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index b556db5d7..bc31afa06 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -164,6 +164,7 @@ def search( include_skill_memory: bool = False, skill_mem_top_k: int = 3, dedup: str | None = None, + include_embedding: bool | None = None, **kwargs, ) -> list[TextualMemoryItem]: """Search for memories based on a query. @@ -187,6 +188,9 @@ def search( Returns: list[TextualMemoryItem]: List of matching memories. """ + # Use parameter if provided, otherwise fall back to instance attribute + include_emb = include_embedding if include_embedding is not None else self.include_embedding + searcher = Searcher( self.dispatcher_llm, self.graph_store, @@ -197,7 +201,7 @@ def search( search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, tokenizer=self.tokenizer, - include_embedding=self.include_embedding, + include_embedding=include_emb, ) return searcher.search( query, diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index bd026a51d..15e3b1bb9 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -480,6 +480,7 @@ def _fast_search( include_skill_memory=search_req.include_skill_memory, skill_mem_top_k=search_req.skill_mem_top_k, dedup=search_req.dedup, + include_embedding=(search_req.dedup == "mmr"), ) formatted_memories = [