Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a18af8b
test: add mmr dedup
hijzy Jan 25, 2026
a2029d8
test: decrease lambda
hijzy Jan 25, 2026
cb345c7
test: add tag penalty
hijzy Jan 26, 2026
a0a6a35
test: increase lambda_relevance
hijzy Jan 26, 2026
a40fa5a
test: fix top 5 candidates
hijzy Jan 27, 2026
692d1f1
test: adjust alpha_tag
hijzy Jan 27, 2026
5ef4942
test: increase lambda_relevance
hijzy Jan 27, 2026
2885d57
test: decrease lambda_relevance
hijzy Jan 27, 2026
0819a46
test: adjust params
hijzy Jan 27, 2026
767a0cc
test: delete tag penalty
hijzy Jan 27, 2026
69143b7
test: delete fix top5
hijzy Jan 27, 2026
ff9b03f
test: readd fix top5
hijzy Jan 27, 2026
47ef19b
test: delete threshold-based penalties
hijzy Jan 27, 2026
0eaa06f
test: delete threshold-based penalties
hijzy Jan 27, 2026
a58777b
test: restore best score version, add resort
hijzy Jan 27, 2026
9c0f8df
test: add diversity
hijzy Jan 27, 2026
5772a0f
test: reformat
hijzy Jan 27, 2026
d6f596b
fix: fix Nan and 0 embedding
hijzy Jan 28, 2026
659231d
feat: add mmr deduplication
hijzy Jan 28, 2026
6645c3e
feat: add preference memory deduplication
hijzy Jan 28, 2026
d7bc36b
fix: recall less preference memory
hijzy Jan 28, 2026
0fc1b6e
test: memory text deduplication
hijzy Jan 28, 2026
5c52691
test: memory text deduplication
hijzy Jan 28, 2026
ce8fb96
test: restore
hijzy Jan 28, 2026
a483011
test: restore
hijzy Jan 28, 2026
f0f685c
test: add 2 gram dedup
hijzy Jan 28, 2026
55e0288
test: add more dedup
hijzy Jan 28, 2026
9020198
test: 增大 prefill
hijzy Jan 28, 2026
6060ce1
test: 调整阈值参数
hijzy Jan 28, 2026
ff2bdb5
Merge branch 'main' into mmr
hijzy Jan 28, 2026
f060028
Merge branch 'dev-20260126-v2.0.4' into mmr
hijzy Jan 28, 2026
4f96a8c
fix: reformat
hijzy Jan 28, 2026
13ce942
Merge remote-tracking branch 'origin/mmr' into mmr
hijzy Jan 28, 2026
5c9532e
fix: reformat
hijzy Jan 28, 2026
4dd692a
fix: reformat
hijzy Jan 28, 2026
8039c02
fix: reformat
hijzy Jan 28, 2026
726b097
fix: reformat
hijzy Jan 28, 2026
5009281
fix: reformat
hijzy Jan 28, 2026
23ac282
Merge branch 'dev-20260126-v2.0.4' into mmr
CaralHsi Jan 28, 2026
e2c28ba
fix: use deepcopy, add log
hijzy Jan 28, 2026
b1e67ae
Merge remote-tracking branch 'origin/mmr' into mmr
hijzy Jan 28, 2026
e4a8831
fix: reformat
hijzy Jan 28, 2026
e64eff3
fix: simplify code
hijzy Jan 29, 2026
0f2ad58
test: adjust params
hijzy Jan 29, 2026
d886c24
feat: add relativity threshold
hijzy Feb 2, 2026
a02a71f
fix: fix preference memory key
hijzy Feb 2, 2026
f717677
fix: initialize 'filtered' as list
hijzy Feb 3, 2026
98e882d
fix: adjust relativity threshold
hijzy Feb 3, 2026
dcc3d98
fix: reformat
hijzy Feb 3, 2026
fc08ccd
Merge branch 'dev-20260202-v2.0.5' into dev
hijzy Feb 4, 2026
35a4418
Merge branch 'dev-20260202-v2.0.5' into dev
hijzy Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 207 additions & 23 deletions src/memos/api/handlers/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,37 +60,25 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse

# Use deepcopy to avoid modifying the original request object
search_req_local = copy.deepcopy(search_req)
original_top_k = search_req_local.top_k

# Expand top_k for deduplication (5x to ensure enough candidates)
if search_req_local.dedup in ("sim", "mmr"):
search_req_local.top_k = original_top_k * 5

# Create new searcher with include_embedding for MMR deduplication
searcher_to_use = self.searcher
if search_req_local.dedup == "mmr":
text_mem = getattr(self.naive_mem_cube, "text_mem", None)
if text_mem is not None:
# Create new searcher instance with include_embedding=True
searcher_to_use = text_mem.get_searcher(
manual_close_internet=not getattr(self.searcher, "internet_retriever", None),
moscube=False,
process_llm=getattr(self.mem_reader, "llm", None),
)
# Override include_embedding for this searcher
if hasattr(searcher_to_use, "graph_retriever"):
searcher_to_use.graph_retriever.include_embedding = True
search_req_local.top_k = search_req_local.top_k * 5

# Search and deduplicate
cube_view = self._build_cube_view(search_req_local, searcher_to_use)
cube_view = self._build_cube_view(search_req_local)
results = cube_view.search_memories(search_req_local)
if not search_req_local.relativity:
search_req_local.relativity = 0
self.logger.info(f"[SearchHandler] Relativity filter: {search_req_local.relativity}")
results = self._apply_relativity_threshold(results, search_req_local.relativity)

if search_req_local.dedup == "sim":
results = self._dedup_text_memories(results, original_top_k)
results = self._dedup_text_memories(results, search_req.top_k)
self._strip_embeddings(results)
elif search_req_local.dedup == "mmr":
pref_top_k = getattr(search_req_local, "pref_top_k", 6)
results = self._mmr_dedup_text_memories(results, original_top_k, pref_top_k)
results = self._mmr_dedup_text_memories(results, search_req.top_k, pref_top_k)
self._strip_embeddings(results)

self.logger.info(
Expand All @@ -102,6 +90,40 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
data=results,
)

@staticmethod
def _apply_relativity_threshold(results: dict[str, Any], relativity: float) -> dict[str, Any]:
if relativity <= 0:
return results

for key in ("text_mem", "pref_mem"):
buckets = results.get(key)
if not isinstance(buckets, list):
continue

for bucket in buckets:
memories = bucket.get("memories")
if not isinstance(memories, list):
continue

filtered: list[dict[str, Any]] = []
for mem in memories:
if not isinstance(mem, dict):
continue
meta = mem.get("metadata", {})
score = meta.get("relativity", 0.0) if isinstance(meta, dict) else 0.0
try:
score_val = float(score) if score is not None else 0.0
except (TypeError, ValueError):
score_val = 0.0
if score_val >= relativity:
filtered.append(mem)

bucket["memories"] = filtered
if "total_nodes" in bucket:
bucket["total_nodes"] = len(filtered)

return results

def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> dict[str, Any]:
buckets = results.get("text_mem", [])
if not buckets:
Expand Down Expand Up @@ -169,7 +191,7 @@ def _mmr_dedup_text_memories(
3. Re-sort by original relevance for better generation quality
"""
text_buckets = results.get("text_mem", [])
pref_buckets = results.get("preference", [])
pref_buckets = results.get("pref_mem", [])

# Early return if no memories to deduplicate
if not text_buckets and not pref_buckets:
Expand Down Expand Up @@ -238,7 +260,7 @@ def _mmr_dedup_text_memories(

# Skip if highly similar (Dice + TF-IDF + 2-gram combined, with embedding filter)
if SearchHandler._is_text_highly_similar_optimized(
idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.9
idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.92
):
continue

Expand Down Expand Up @@ -281,7 +303,7 @@ def _mmr_dedup_text_memories(

# Skip if highly similar (Dice + TF-IDF + 2-gram combined, with embedding filter)
if SearchHandler._is_text_highly_similar_optimized(
idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.9
idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.92
):
continue # Skip highly similar text, don't participate in MMR competition

Expand Down Expand Up @@ -547,6 +569,168 @@ def _is_text_highly_similar_optimized(

return combined_score >= threshold

@staticmethod
def _dice_similarity(text1: str, text2: str) -> float:
"""
Calculate Dice coefficient (character-level, fastest).

Dice = 2 * |A ∩ B| / (|A| + |B|)
Speed: O(n + m), ~0.05-0.1ms per comparison

Args:
text1: First text string
text2: Second text string

Returns:
Dice similarity score between 0.0 and 1.0
"""
if not text1 or not text2:
return 0.0

chars1 = set(text1)
chars2 = set(text2)

intersection = len(chars1 & chars2)
return 2 * intersection / (len(chars1) + len(chars2))

@staticmethod
def _bigram_similarity(text1: str, text2: str) -> float:
"""
Calculate character-level 2-gram Jaccard similarity.

Speed: O(n + m), ~0.1-0.2ms per comparison
Considers local order (more strict than Dice).

Args:
text1: First text string
text2: Second text string

Returns:
Jaccard similarity score between 0.0 and 1.0
"""
if not text1 or not text2:
return 0.0

# Generate 2-grams
bigrams1 = {text1[i : i + 2] for i in range(len(text1) - 1)} if len(text1) >= 2 else {text1}
bigrams2 = {text2[i : i + 2] for i in range(len(text2) - 1)} if len(text2) >= 2 else {text2}

intersection = len(bigrams1 & bigrams2)
union = len(bigrams1 | bigrams2)

return intersection / union if union > 0 else 0.0

@staticmethod
def _tfidf_similarity(text1: str, text2: str) -> float:
"""
Calculate TF-IDF cosine similarity (character-level, no sklearn).

Speed: O(n + m), ~0.3-0.5ms per comparison
Considers character frequency weighting.

Args:
text1: First text string
text2: Second text string

Returns:
Cosine similarity score between 0.0 and 1.0
"""
if not text1 or not text2:
return 0.0

from collections import Counter

# Character frequency (TF)
tf1 = Counter(text1)
tf2 = Counter(text2)

# All unique characters (vocabulary)
vocab = set(tf1.keys()) | set(tf2.keys())

# Simple IDF: log(2 / df) where df is document frequency
# For two documents, IDF is log(2/1)=0.693 if char appears in one doc,
# or log(2/2)=0 if appears in both (we use log(2/1) for simplicity)
idf = {char: (1.0 if char in tf1 and char in tf2 else 1.5) for char in vocab}

# TF-IDF vectors
vec1 = {char: tf1.get(char, 0) * idf[char] for char in vocab}
vec2 = {char: tf2.get(char, 0) * idf[char] for char in vocab}

# Cosine similarity
dot_product = sum(vec1[char] * vec2[char] for char in vocab)
norm1 = math.sqrt(sum(v * v for v in vec1.values()))
norm2 = math.sqrt(sum(v * v for v in vec2.values()))

if norm1 == 0 or norm2 == 0:
return 0.0

return dot_product / (norm1 * norm2)

@staticmethod
def _is_text_highly_similar_optimized(
candidate_idx: int,
candidate_text: str,
selected_global: list[int],
similarity_matrix,
flat: list,
threshold: float = 0.92,
) -> bool:
"""
Multi-algorithm text similarity check with embedding pre-filtering.

Strategy:
1. Only compare with the single highest embedding similarity item (not all 25)
2. Only perform text comparison if embedding similarity > 0.60
3. Use weighted combination of three algorithms:
- Dice (40%): Fastest, character-level set similarity
- TF-IDF (35%): Considers character frequency weighting
- 2-gram (25%): Considers local character order

Combined formula:
combined_score = 0.40 * dice + 0.35 * tfidf + 0.25 * bigram

This reduces comparisons from O(N) to O(1) per candidate, with embedding pre-filtering.
Expected speedup: 100-200x compared to LCS approach.

Args:
candidate_idx: Index of candidate memory in flat list
candidate_text: Text content of candidate memory
selected_global: List of already selected memory indices
similarity_matrix: Precomputed embedding similarity matrix
flat: Flat list of all memories
threshold: Combined similarity threshold (default 0.75)

Returns:
True if candidate is highly similar to any selected memory
"""
if not selected_global:
return False

# Find the already-selected memory with highest embedding similarity
max_sim_idx = max(selected_global, key=lambda j: similarity_matrix[candidate_idx][j])
max_sim = similarity_matrix[candidate_idx][max_sim_idx]

# If highest embedding similarity < 0.60, skip text comparison entirely
if max_sim <= 0.9:
return False

# Get text of most similar memory
most_similar_mem = flat[max_sim_idx][2]
most_similar_text = most_similar_mem.get("memory", "").strip()

# Calculate three similarity scores
dice_sim = SearchHandler._dice_similarity(candidate_text, most_similar_text)
tfidf_sim = SearchHandler._tfidf_similarity(candidate_text, most_similar_text)
bigram_sim = SearchHandler._bigram_similarity(candidate_text, most_similar_text)

# Weighted combination: Dice (40%) + TF-IDF (35%) + 2-gram (25%)
# Dice has highest weight (fastest and most reliable)
# TF-IDF considers frequency (handles repeated characters well)
# 2-gram considers order (catches local pattern similarity)
combined_score = 0.40 * dice_sim + 0.35 * tfidf_sim + 0.25 * bigram_sim

return combined_score >= threshold

def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]:
"""
Normalize target cube ids from search_req.
Expand Down
10 changes: 10 additions & 0 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,16 @@ class APISearchRequest(BaseRequest):
description="Number of textual memories to retrieve (top-K). Default: 10.",
)

relativity: float = Field(
0.57,
ge=0,
description=(
"Relevance threshold for recalled memories. "
"Only memories with metadata.relativity >= relativity will be returned. "
"Use 0 to disable threshold filtering. Default: 0.3."
),
)

dedup: Literal["no", "sim", "mmr"] | None = Field(
"mmr",
description=(
Expand Down
6 changes: 5 additions & 1 deletion src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def search(
include_skill_memory: bool = False,
skill_mem_top_k: int = 3,
dedup: str | None = None,
include_embedding: bool | None = None,
**kwargs,
) -> list[TextualMemoryItem]:
"""Search for memories based on a query.
Expand All @@ -187,6 +188,9 @@ def search(
Returns:
list[TextualMemoryItem]: List of matching memories.
"""
# Use parameter if provided, otherwise fall back to instance attribute
include_emb = include_embedding if include_embedding is not None else self.include_embedding

searcher = Searcher(
self.dispatcher_llm,
self.graph_store,
Expand All @@ -197,7 +201,7 @@ def search(
search_strategy=self.search_strategy,
manual_close_internet=manual_close_internet,
tokenizer=self.tokenizer,
include_embedding=self.include_embedding,
include_embedding=include_emb,
)
return searcher.search(
query,
Expand Down
1 change: 1 addition & 0 deletions src/memos/multi_mem_cube/single_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def _fast_search(
include_skill_memory=search_req.include_skill_memory,
skill_mem_top_k=search_req.skill_mem_top_k,
dedup=search_req.dedup,
include_embedding=(search_req.dedup == "mmr"),
)

formatted_memories = [
Expand Down