diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py new file mode 100644 index 000000000..e780cf394 --- /dev/null +++ b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py @@ -0,0 +1,264 @@ +import concurrent.futures +import re + +from typing import Any + +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_reader.read_multi_modal.utils import detect_lang +from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer + + +logger = get_logger(__name__) + + +class PreUpdateRecaller: + def __init__(self, graph_db, embedder): + """ + The PreUpdateRecaller is designed for the /add phase . + It serves to recall potentially duplicate/conflict memories against the new content that's being added. + + Args: + graph_db: The graph database instance (Neo4j, PolarDB, etc.) + embedder: The embedder instance for vector search + """ + self.graph_db = graph_db + self.embedder = embedder + # Use existing tokenizer for keyword extraction + self.tokenizer = FastTokenizer(use_jieba=True, use_stopwords=True) + + def _adjust_perspective(self, text: str, role: str, lang: str) -> str: + """ + For better search result, we adjust the perspective + from 1st person to 3rd person based on role and language. + "I" -> "User" (if role is user) + "I" -> "Assistant" (if role is assistant) + """ + if not role: + return text + + role = role.lower() + replacements = [] + + # Determine replacements based on language and role + if lang == "zh": + if role == "user": + replacements = [("我", "用户")] + elif role == "assistant": + replacements = [("我", "助手")] + else: # default to en + if role == "user": + replacements = [ + (r"\bI\b", "User"), + (r"\bme\b", "User"), + (r"\bmy\b", "User's"), + (r"\bmine\b", "User's"), + (r"\bmyself\b", "User himself"), + ] + elif role == "assistant": + replacements = [ + (r"\bI\b", "Assistant"), + (r"\bme\b", "Assistant"), + (r"\bmy\b", "Assistant's"), + (r"\bmine\b", "Assistant's"), + (r"\bmyself\b", "Assistant himself"), + ] + + adjusted_text = text + for pattern, repl in replacements: + if lang == "zh": + adjusted_text = adjusted_text.replace(pattern, repl) + else: + adjusted_text = re.sub(pattern, repl, adjusted_text, flags=re.IGNORECASE) + + return adjusted_text + + def _preprocess_query(self, item: TextualMemoryItem) -> str: + """ + Preprocess the query item: + 1. Extract language and role from metadata/sources + 2. Adjust perspective (I -> User/Assistant) based on role/lang + """ + raw_text = item.memory or "" + if not raw_text.strip(): + return "" + + # Extract lang/role + lang = None + role = None + sources = item.metadata.sources + + if sources: + source_list = sources if isinstance(sources, list) else [sources] + for source in source_list: + if hasattr(source, "lang") and source.lang: + lang = source.lang + elif isinstance(source, dict) and source.get("lang"): + lang = source.get("lang") + + if hasattr(source, "role") and source.role: + role = source.role + elif isinstance(source, dict) and source.get("role"): + role = source.get("role") + + if lang and role: + break + + if lang is None: + lang = detect_lang(raw_text) + + # Adjust perspective + return self._adjust_perspective(raw_text, role, lang) + + def _get_full_memories( + self, candidate_ids: list[str], user_name: str + ) -> list[TextualMemoryItem]: + """ + Retrieve full memories for given candidate ids. + """ + full_recalled_memories = self.graph_db.get_nodes(candidate_ids, user_name=user_name) + return [TextualMemoryItem.from_dict(item) for item in full_recalled_memories] + + def vector_search( + self, + query_text: str, + query_embedding: list[float] | None, + user_name: str, + top_k: int, + search_filter: dict[str, Any] | None = None, + threshold: float = 0.5, + ) -> list[dict]: + try: + # Use pre-computed embedding if available (matches raw/clean query) + # Otherwise embed the switched query for better semantic match + q_embed = query_embedding if query_embedding else self.embedder.embed([query_text])[0] + + # Assuming graph_db.search_by_embedding returns list of dicts or items + results = self.graph_db.search_by_embedding( + vector=q_embed, + top_k=top_k, + status=None, + threshold=threshold, + user_name=user_name, + filter=search_filter, + ) + return results + except Exception as e: + logger.error(f"[PreUpdateRecaller] Vector search failed: {e}") + return [] + + def keyword_search( + self, + query_text: str, + user_name: str, + top_k: int, + search_filter: dict[str, Any] | None = None, + ) -> list[dict]: + try: + # 1. Tokenize using existing tokenizer + keywords = self.tokenizer.tokenize_mixed(query_text) + if not keywords: + return [] + + results = [] + + # 2. Try seach_by_keywords_tfidf (PolarDB specific) + if hasattr(self.graph_db, "seach_by_keywords_tfidf"): + try: + results = self.graph_db.seach_by_keywords_tfidf( + query_words=keywords, user_name=user_name, filter=search_filter + ) + except Exception as e: + logger.warning(f"[PreUpdateRecaller] seach_by_keywords_tfidf failed: {e}") + + # 3. Fallback to search_by_fulltext + if not results and hasattr(self.graph_db, "search_by_fulltext"): + try: + results = self.graph_db.search_by_fulltext( + query_words=keywords, top_k=top_k, user_name=user_name, filter=search_filter + ) + except Exception as e: + logger.warning(f"[PreUpdateRecaller] search_by_fulltext failed: {e}") + + return results[:top_k] + + except Exception as e: + logger.error(f"[PreUpdateRecaller] Keyword search failed: {e}") + return [] + + def recall( + self, item: TextualMemoryItem, user_name: str, top_k: int = 10, sim_threshold: float = 0.5 + ) -> list[TextualMemoryItem]: + """ + Recall related memories for a TextualMemoryItem using hybrid search (Vector + Keyword). + Might actually return top_k ~ 2top_k items. + Designed for low latency. + + Args: + item: The memory item to find related memories for + user_name: User identifier for scoping search + top_k: Max number of results to return + sim_threshold: minimal similarity threshold for vector search + + Returns: + List of TextualMemoryItem + """ + # 1. Preprocess + switched_query = self._preprocess_query(item) + + # 2. Recall + futures = [] + common_filter = { + "status": {"in": ["activated", "resolving"]}, + "memory_type": {"in": ["LongTermMemory", "UserMemory", "WorkingMemory"]}, + } + + with ContextThreadPoolExecutor(max_workers=3, thread_name_prefix="fast_recall") as executor: + # Task A: Vector Search (Semantic) + query_embedding = ( + item.metadata.embedding if hasattr(item.metadata, "embedding") else None + ) + futures.append( + executor.submit( + self.vector_search, + switched_query, + query_embedding, + user_name, + top_k, + common_filter, + sim_threshold, + ) + ) + + # Task B: Keyword Search + futures.append( + executor.submit( + self.keyword_search, switched_query, user_name, top_k, common_filter + ) + ) + + # 3. Collect Results + retrieved_ids = set() # for deduplicating ids + for future in concurrent.futures.as_completed(futures): + try: + res = future.result() + if not res: + continue + + for r in res: + retrieved_ids.add(r["id"]) + + except Exception as e: + logger.error(f"[PreUpdateRecaller] Search future task failed: {e}") + + retrieved_ids = list(retrieved_ids) + + if not retrieved_ids: + return [] + + # 4. Retrieve full memories to from just ids + # TODO: We should modify the db functions to support returning arbitrary fields, instead of search twice. + final_memories = self._get_full_memories(retrieved_ids, user_name) + + return final_memories diff --git a/tests/memories/textual/test_pre_update_recaller.py b/tests/memories/textual/test_pre_update_recaller.py new file mode 100644 index 000000000..c5bb3a0eb --- /dev/null +++ b/tests/memories/textual/test_pre_update_recaller.py @@ -0,0 +1,150 @@ +import unittest +import uuid + +from dotenv import load_dotenv + +from memos.api.handlers.config_builders import build_embedder_config, build_graph_db_config +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.memories.textual.item import ( + SourceMessage, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRecaller + + +# Load environment variables +load_dotenv() + + +class TestPreUpdateRecaller(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Initialize graph_db and embedder using factories + # We assume environment variables are set for these to work + try: + cls.graph_db_config = build_graph_db_config() + cls.graph_db = GraphStoreFactory.from_config(cls.graph_db_config) + + cls.embedder_config = build_embedder_config() + cls.embedder = EmbedderFactory.from_config(cls.embedder_config) + except Exception as e: + raise unittest.SkipTest( + f"Skipping test because initialization failed (likely missing env vars): {e}" + ) from e + + cls.recaller = PreUpdateRecaller(cls.graph_db, cls.embedder) + + # Use a unique user name to isolate tests + cls.user_name = "test_pre_update_recaller_user_" + str(uuid.uuid4())[:8] + + def setUp(self): + # Add some data to the db + self.added_ids = [] + + # Create a memory item to add + self.memory_text = "The user likes to eat apples." + self.embedding = self.embedder.embed([self.memory_text])[0] + + # We use dictionary for metadata to simulate what might be passed or stored + # But wait, add_node expects metadata as a dict usually. + metadata = { + "memory_type": "LongTermMemory", + "status": "activated", + "embedding": self.embedding, + "created_at": "2023-01-01T00:00:00", + "updated_at": "2023-01-01T00:00:00", + "tags": ["food", "fruit"], + "key": "user_preference", + "sources": [], + } + + node_id = str(uuid.uuid4()) + self.graph_db.add_node(node_id, self.memory_text, metadata, user_name=self.user_name) + self.added_ids.append(node_id) + + # Add another one + self.memory_text_2 = "The user has a dog named Rex." + self.embedding_2 = self.embedder.embed([self.memory_text_2])[0] + metadata_2 = { + "memory_type": "LongTermMemory", + "status": "activated", + "embedding": self.embedding_2, + "created_at": "2023-01-01T00:00:00", + "updated_at": "2023-01-01T00:00:00", + "tags": ["pet", "dog"], + "key": "user_pet", + "sources": [], + } + node_id_2 = str(uuid.uuid4()) + self.graph_db.add_node(node_id_2, self.memory_text_2, metadata_2, user_name=self.user_name) + self.added_ids.append(node_id_2) + + def tearDown(self): + """Clean up test data.""" + for node_id in self.added_ids: + try: + self.graph_db.delete_node(node_id, user_name=self.user_name) + except Exception as e: + print(f"Error deleting node {node_id}: {e}") + + def test_recall_vector_search(self): + """Test recalling using vector search (implicit in recall method).""" + # "I like apples" -> perspective adjustment should match "The user likes to eat apples" + query_text = "I like apples" + + # Create metadata with source to trigger perspective adjustment + # role="user" means "I" -> "User" + source = SourceMessage(role="user", lang="en") + metadata = TreeNodeTextualMemoryMetadata(sources=[source], memory_type="WorkingMemory") + + item = TextualMemoryItem(memory=query_text, metadata=metadata) + + # The recall method does both vector and keyword search + results = self.recaller.recall(item, self.user_name, top_k=5) + + # Verify we got results + self.assertTrue(len(results) > 0, "Should return at least one result") + found_texts = [r.memory for r in results] + + # Check if the relevant memory is found + # "The user likes to eat apples." should be found. + # We check for "apples" to be safe + self.assertTrue( + any("apples" in t for t in found_texts), + f"Expected 'apples' in results, got: {found_texts}", + ) + + def test_recall_keyword_search(self): + """Test recalling where keyword search might be more relevant.""" + # "Rex" is a specific name + query_text = "What is the name of my dog?" + source = SourceMessage(role="user", lang="en") + metadata = TreeNodeTextualMemoryMetadata(sources=[source], memory_type="WorkingMemory") + + item = TextualMemoryItem(memory=query_text, metadata=metadata) + + results = self.recaller.recall(item, self.user_name, top_k=5) + + found_texts = [r.memory for r in results] + self.assertTrue( + any("Rex" in t for t in found_texts), f"Expected 'Rex' in results, got: {found_texts}" + ) + + def test_perspective_adjustment(self): + """Unit test for the _adjust_perspective method specifically.""" + text = "I went to the store myself." + adjusted = self.recaller._adjust_perspective(text, "user", "en") + # I -> User, myself -> User himself + self.assertIn("User", adjusted) + self.assertIn("User himself", adjusted) + + text_zh = "我喜欢吃苹果" + adjusted_zh = self.recaller._adjust_perspective(text_zh, "user", "zh") + # 我 -> 用户 + self.assertIn("用户", adjusted_zh) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/memories/textual/test_pre_update_recaller_latency.py b/tests/memories/textual/test_pre_update_recaller_latency.py new file mode 100644 index 000000000..1cb9e5ecf --- /dev/null +++ b/tests/memories/textual/test_pre_update_recaller_latency.py @@ -0,0 +1,183 @@ +import time +import unittest +import uuid + +import numpy as np + +from dotenv import load_dotenv + +from memos.api.handlers.config_builders import build_embedder_config, build_graph_db_config +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.memories.textual.item import ( + SourceMessage, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRecaller + + +# Load environment variables +load_dotenv() + + +class TestPreUpdateRecallerLatency(unittest.TestCase): + """ + Performance and latency tests for PreUpdateRecaller. + These tests are designed to measure latency and might take longer to run. + """ + + @classmethod + def setUpClass(cls): + # Initialize graph_db and embedder using factories + try: + cls.graph_db_config = build_graph_db_config() + cls.graph_db = GraphStoreFactory.from_config(cls.graph_db_config) + + cls.embedder_config = build_embedder_config() + cls.embedder = EmbedderFactory.from_config(cls.embedder_config) + except Exception as e: + raise unittest.SkipTest( + f"Skipping test because initialization failed (likely missing env vars): {e}" + ) from e + + cls.recaller = PreUpdateRecaller(cls.graph_db, cls.embedder) + + # Use a unique user name to isolate tests + cls.user_name = "test_pre_update_recaller_latency_user_" + str(uuid.uuid4())[:8] + + def setUp(self): + # Add a substantial amount of data for latency testing + self.added_ids = [] + self.num_items = 20 + + print(f"\nPopulating database with {self.num_items} items for latency test...") + for i in range(self.num_items): + text = f"This is memory item number {i}. The user might enjoy topic {i % 5}." + embedding = self.embedder.embed([text])[0] + metadata = { + "memory_type": "LongTermMemory", + "status": "activated", + "embedding": embedding, + "created_at": "2023-01-01T00:00:00", + "updated_at": "2023-01-01T00:00:00", + "tags": [f"tag_{i}"], + "key": f"key_{i}", + "sources": [], + } + node_id = str(uuid.uuid4()) + self.graph_db.add_node(node_id, text, metadata, user_name=self.user_name) + self.added_ids.append(node_id) + + def tearDown(self): + """Clean up test data.""" + print("Cleaning up test data...") + for node_id in self.added_ids: + try: + self.graph_db.delete_node(node_id, user_name=self.user_name) + except Exception as e: + print(f"Error deleting node {node_id}: {e}") + + def measure_network_rtt(self, trials=10): + """Measure average network round-trip time.""" + print(f"Measuring Network RTT (using {trials} probes)...") + latencies = [] + + # Try to use raw driver for minimal overhead if available (Neo4j specific) + if hasattr(self.graph_db, "driver") and hasattr(self.graph_db, "db_name"): + print("Using Neo4j driver for direct ping...") + try: + with self.graph_db.driver.session(database=self.graph_db.db_name) as session: + # Warmup + session.run("RETURN 1").single() + + for _ in range(trials): + start = time.time() + session.run("RETURN 1").single() + latencies.append((time.time() - start) * 1000) + except Exception as e: + print(f"Direct driver ping failed: {e}. Falling back to get_node.") + latencies = [] + + if not latencies: + # Fallback to get_node with non-existent ID + print("Using get_node for ping...") + for _ in range(trials): + probe_id = str(uuid.uuid4()) + start = time.time() + self.graph_db.get_node(probe_id, user_name=self.user_name) + latencies.append((time.time() - start) * 1000) + + avg_rtt = np.mean(latencies) + print(f"Average Network RTT: {avg_rtt:.2f} ms") + return avg_rtt + + def test_recall_latency(self): + """Test and report recall latency statistics.""" + avg_rtt = self.measure_network_rtt() + + queries = [ + "I enjoy topic 1", + "What about topic 3?", + "Do I have any preferences?", + "Tell me about memory item 5", + ] + + latencies = [] + + # Warmup + print("Warming up...") + warmup_item = TextualMemoryItem( + memory="warmup query", + metadata=TreeNodeTextualMemoryMetadata( + sources=[SourceMessage(role="user", lang="en")], memory_type="WorkingMemory" + ), + ) + self.recaller.recall(warmup_item, self.user_name, top_k=5) + + print(f"Running {len(queries)} queries...") + for q in queries: + # Pre-calculate embedding to exclude from latency measurement + q_embedding = self.embedder.embed([q])[0] + + item = TextualMemoryItem( + memory=q, + metadata=TreeNodeTextualMemoryMetadata( + sources=[SourceMessage(role="user", lang="en")], + memory_type="WorkingMemory", + embedding=q_embedding, + ), + ) + + start_time = time.time() + results = self.recaller.recall(item, self.user_name, top_k=5) + end_time = time.time() + + duration_ms = (end_time - start_time) * 1000 + latencies.append(duration_ms) + print(f"Query: '{q}' -> Found {len(results)} results in {duration_ms:.2f} ms") + + # Assert that we actually found results (sanity check) + if "preferences" not in q: # The preferences query might return 0 + self.assertTrue(len(results) > 0, f"Expected results for query: {q}") + + # Report Results + avg_latency = np.mean(latencies) + p95_latency = np.percentile(latencies, 95) + min_latency = np.min(latencies) + max_latency = np.max(latencies) + internal_processing = avg_latency - avg_rtt + + print("\n--- Latency Results ---") + print(f"Average Network RTT: {avg_rtt:.2f} ms") + print(f"Average Total Latency: {avg_latency:.2f} ms") + print(f"Estimated Internal Processing: {internal_processing:.2f} ms") + print(f"95th Percentile: {p95_latency:.2f} ms") + print(f"Min Latency: {min_latency:.2f} ms") + print(f"Max Latency: {max_latency:.2f} ms") + + self.assertLess(internal_processing, 200, "Internal processing should be under 200ms") + + +if __name__ == "__main__": + unittest.main()