From fc40de9e8c80d96cac87bff6266d010129090605 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 4 Feb 2026 10:39:13 +0800 Subject: [PATCH 1/3] feat: Building fast-add related functions for memory versions. --- .../organize/history_manager.py | 156 ++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 1afdc9281..97bbe7483 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -1,4 +1,5 @@ import logging +import time from typing import Literal @@ -67,6 +68,36 @@ def _detach_related_content(new_item: TextualMemoryItem) -> None: return +def _rebuild_fast_node_history( + item: TextualMemoryItem, + replacements: dict[int, list[ArchivedTextualMemory]], +) -> None: + """ + Reconstruct the history list of a fast node: + 1. Replace resolved items with their evolved versions. + 2. Deduplicate by ID while preserving newest versions. + """ + new_history = {} + + def _add(history_item): + item_id = history_item.archived_memory_id + current = new_history.get(item_id) + + if current is None or history_item.version > current.version: + new_history[item_id] = history_item + + # Apply replacements and filter superseded items + for i, h in enumerate(item.metadata.history): + if i in replacements: + # This item is resolved, insert its replacements + for replacement_item in replacements[i]: + _add(replacement_item) + else: + _add(h) + + item.metadata.history = list(new_history.values()) + + class MemoryHistoryManager: def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: """ @@ -79,6 +110,131 @@ def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: self.nli_client = nli_client self.graph_db = graph_db + def _check_and_fetch_replacements( + self, item: TextualMemoryItem, pending_indices: list[int] + ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: + """ + Check DB status for pending items. If 'deleted', fetch evolved nodes. + + Returns: + replacements: Dict mapping original history index to list of new ArchivedTextualMemory items. + """ + pending_ids = [item.metadata.history[i].archived_memory_id for i in pending_indices] + + # Batch fetch pending nodes to check status + nodes_data = self.graph_db.get_nodes(ids=pending_ids) or [] + nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + + replacements = {} + + for i in pending_indices: + h_item = item.metadata.history[i] + node_data = nodes_map.get(h_item.archived_memory_id) + + if not node_data: + continue + + metadata = node_data.get("metadata", {}) + status = metadata.get("status") + + # Condition: Fast node is processed when it is marked as 'deleted' + if status == "deleted": + evolve_to_ids = metadata.get("evolve_to", []) + + new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type) + replacements[i] = new_items + + logger.info( + f"[MemoryHistoryManager] Resolved fast history item {h_item.archived_memory_id} -> {evolve_to_ids}" + ) + + return replacements + + def _fetch_evolved_nodes( + self, evolve_to_ids: list[str], update_type: str + ) -> list[ArchivedTextualMemory]: + """Fetch the actual nodes that the fast node evolved into and convert to archive format.""" + if not evolve_to_ids: + return [] + + evolved_nodes = self.graph_db.get_nodes(ids=evolve_to_ids) or [] + results = [] + + for enode in evolved_nodes: + if not enode or "id" not in enode: + continue + + enode_meta = enode.get("metadata", {}) + + # Create new archived memory inheriting the update_type (conflict/duplicate) + new_archived = ArchivedTextualMemory( + version=enode_meta.get("version", 1), + is_fast=enode_meta.get("is_fast", False), + memory=enode.get("memory", ""), + update_type=update_type, + archived_memory_id=enode.get("id"), + created_at=enode_meta.get("created_at"), + ) + results.append(new_archived) + + return results + + def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int = 30) -> None: + """ + Scan the item's history. If any history item is marked as `is_fast`, + wait for it to be resolved (i.e., status becomes 'deleted' in the DB). + When resolved, replace the fast item with the nodes referenced in its `evolve_to` field. + Finally, deduplicate the history. + + Args: + item: The memory item containing the history to check. + timeout_sec: Maximum time to wait for resolution in seconds. + """ + start_time = time.time() + + # 1. Identify pending items (fast nodes) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + while True: + if not pending_indices: + # All fast nodes resolved or none existed + break + + if time.time() - start_time > timeout_sec: + logger.warning( + f"[MemoryHistoryManager] Timeout waiting for fast history resolution for item {item.id}" + ) + # Remove pending fast nodes from history + item.metadata.history = [ + h + for h in item.metadata.history + if not (getattr(h, "is_fast", False) and h.archived_memory_id) + ] + break + + # 2. Check status of the fast nodes and fetch replacements for evolved ones + replacements = self._check_and_fetch_replacements(item, pending_indices) + + # 3. If we have any resolved items, rebuild the history + if replacements: + _rebuild_fast_node_history(item, replacements) + + # Check if we are done (no pending items left) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + if pending_indices: + time.sleep(1) # This avoids visiting the DB too frequently + + return + def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] ) -> list[TextualMemoryItem]: From 23f23975d62377d9126745ef48bc8ea20bfcac52 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 4 Feb 2026 10:46:39 +0800 Subject: [PATCH 2/3] feat: Building fast-add related functions for memory versions. --- .../organize/history_manager.py | 112 +++++++++--------- 1 file changed, 56 insertions(+), 56 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 97bbe7483..4cdfac985 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -179,62 +179,6 @@ def _fetch_evolved_nodes( return results - def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int = 30) -> None: - """ - Scan the item's history. If any history item is marked as `is_fast`, - wait for it to be resolved (i.e., status becomes 'deleted' in the DB). - When resolved, replace the fast item with the nodes referenced in its `evolve_to` field. - Finally, deduplicate the history. - - Args: - item: The memory item containing the history to check. - timeout_sec: Maximum time to wait for resolution in seconds. - """ - start_time = time.time() - - # 1. Identify pending items (fast nodes) - pending_indices = [ - i - for i, h in enumerate(item.metadata.history) - if getattr(h, "is_fast", False) and h.archived_memory_id - ] - - while True: - if not pending_indices: - # All fast nodes resolved or none existed - break - - if time.time() - start_time > timeout_sec: - logger.warning( - f"[MemoryHistoryManager] Timeout waiting for fast history resolution for item {item.id}" - ) - # Remove pending fast nodes from history - item.metadata.history = [ - h - for h in item.metadata.history - if not (getattr(h, "is_fast", False) and h.archived_memory_id) - ] - break - - # 2. Check status of the fast nodes and fetch replacements for evolved ones - replacements = self._check_and_fetch_replacements(item, pending_indices) - - # 3. If we have any resolved items, rebuild the history - if replacements: - _rebuild_fast_node_history(item, replacements) - - # Check if we are done (no pending items left) - pending_indices = [ - i - for i, h in enumerate(item.metadata.history) - if getattr(h, "is_fast", False) and h.archived_memory_id - ] - - if pending_indices: - time.sleep(1) # This avoids visiting the DB too frequently - - return - def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] ) -> list[TextualMemoryItem]: @@ -293,6 +237,62 @@ def resolve_history_via_nli( return duplicate_memories + conflict_memories + def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int = 30) -> None: + """ + Scan the item's history. If any history item is marked as `is_fast`, + wait for it to be resolved (i.e., status becomes 'deleted' in the DB). + When resolved, replace the fast item with the nodes referenced in its `evolve_to` field. + Finally, deduplicate the history. + + Args: + item: The memory item containing the history to check. + timeout_sec: Maximum time to wait for resolution in seconds. + """ + start_time = time.time() + + # 1. Identify pending items (fast nodes) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + while True: + if not pending_indices: + # All fast nodes resolved or none existed + break + + if time.time() - start_time > timeout_sec: + logger.warning( + f"[MemoryHistoryManager] Timeout waiting for fast history resolution for item {item.id}" + ) + # Remove pending fast nodes from history + item.metadata.history = [ + h + for h in item.metadata.history + if not (getattr(h, "is_fast", False) and h.archived_memory_id) + ] + break + + # 2. Check status of the fast nodes and fetch replacements for evolved ones + replacements = self._check_and_fetch_replacements(item, pending_indices) + + # 3. If we have any resolved items, rebuild the history + if replacements: + _rebuild_fast_node_history(item, replacements) + + # Check if we are done (no pending items left) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + if pending_indices: + time.sleep(1) # This avoids visiting the DB too frequently + + return + def mark_memory_status( self, memory_items: list[TextualMemoryItem], From 0b6d9a357c31e19cea0a949f43a956ff34bb1bbb Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 4 Feb 2026 11:30:26 +0800 Subject: [PATCH 3/3] feat: supporting memory versions in fast-add process --- src/memos/api/handlers/component_init.py | 9 +++- src/memos/configs/mem_reader.py | 6 ++- src/memos/mem_reader/factory.py | 18 ++++++++ src/memos/mem_reader/multi_modal_struct.py | 44 +++++++++++++++++++ src/memos/mem_reader/simple_struct.py | 12 +++++ .../init_components_for_scheduler.py | 14 +++++- 6 files changed, 100 insertions(+), 3 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index ba527d602..2e27e9da5 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -45,6 +45,7 @@ from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer @@ -192,8 +193,14 @@ def init_server() -> dict[str, Any]: embedder = EmbedderFactory.from_config(embedder_config) nli_client = NLIClient(base_url=nli_client_config["base_url"]) memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) + pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) - mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) + mem_reader = MemReaderFactory.from_config( + mem_reader_config, + graph_db=graph_db, + pre_update_retriever=pre_update_retriever, + history_manager=memory_history_manager, + ) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index 4bd7953c0..98aff981f 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, ClassVar +from typing import Any, ClassVar, Literal from pydantic import ConfigDict, Field, field_validator, model_validator @@ -65,6 +65,10 @@ class MultiModalStructMemReaderConfig(BaseMemReaderConfig): default=None, description="Skills directory for the MemReader", ) + memory_version_switch: Literal["on", "off"] = Field( + default="off", + description="Turn on memory version or off", + ) class StrategyStructMemReaderConfig(BaseMemReaderConfig): diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py index 7bd551fb8..0907168a7 100644 --- a/src/memos/mem_reader/factory.py +++ b/src/memos/mem_reader/factory.py @@ -5,11 +5,15 @@ from memos.mem_reader.multi_modal_struct import MultiModalStructMemReader from memos.mem_reader.simple_struct import SimpleStructMemReader from memos.mem_reader.strategy_struct import StrategyStructMemReader +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memos_tools.singleton import singleton_factory if TYPE_CHECKING: from memos.graph_dbs.base import BaseGraphDB + from memos.memories.textual.tree_text_memory.organize.history_manager import ( + MemoryHistoryManager, + ) from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -29,6 +33,8 @@ def from_config( config_factory: MemReaderConfigFactory, graph_db: Optional["BaseGraphDB | None"] = None, searcher: Optional["Searcher | None"] = None, + pre_update_retriever: PreUpdateRetriever | None = None, + history_manager: Optional["MemoryHistoryManager | None"] = None, ) -> BaseMemReader: """ Create a MemReader instance from configuration. @@ -55,4 +61,16 @@ def from_config( if searcher is not None: reader.set_searcher(searcher) + if pre_update_retriever is not None: + if hasattr(reader, "set_pre_update_retriever"): + reader.set_pre_update_retriever(pre_update_retriever) + else: + reader.pre_update_retriever = pre_update_retriever + + if history_manager is not None: + if hasattr(reader, "set_history_manager"): + reader.set_history_manager(history_manager) + else: + reader.history_manager = history_manager + return reader diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index f6a016556..cdfa14a69 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -52,6 +52,10 @@ def __init__(self, config: MultiModalStructMemReaderConfig): simple_config = SimpleStructMemReaderConfig(**config_dict) super().__init__(simple_config) + self.pre_update_retriever = None + self.history_manager = None + self.memory_version_switch = getattr(config, "memory_version_switch", "off") + # Initialize MultiModalParser for routing to different parsers self.multi_modal_parser = MultiModalParser( embedder=self.embedder, @@ -808,6 +812,39 @@ def _process_tool_trajectory_fine( return fine_memory_items + def _fast_resolve_memory_duplicates_and_conflicts( + self, fast_memory_items: list[TextualMemoryItem], user_name: str + ) -> None: + """ + 1. Recall related memories + 2. Fast conflict/duplication check with NLI model + 3. Attach conflicting/duplicate old memory contents onto fast memory items + 4. Mark conflicting/duplicate old memory nodes as "resolving", making them invisible to /search, + but still visible for other conflict/duplication checks' recalls. + """ + if not self.pre_update_retriever or not self.history_manager: + logger.warning( + "[MultiModalStruct] PreUpdateRetriever or HistoryManager is not initialized." + ) + return + + for item in fast_memory_items: + try: + # recall related memories + related = self.pre_update_retriever.retrieve( + item=item, + user_name=user_name, + ) + # NLI check & attaching contents + conflicting_or_duplicate_items = self.history_manager.resolve_history_via_nli( + item, related + ) + # mark delete + self.history_manager.mark_memory_status(conflicting_or_duplicate_items, "resolving") + + except Exception as e: + logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") + @timed def _process_multi_modal_data( self, scene_data_info: MessagesType, info, mode: str = "fine", **kwargs @@ -856,6 +893,13 @@ def _process_multi_modal_data( scene_data_info, info, mode="fast", need_emb=False, **kwargs ) fast_memory_items = self._concat_multi_modal_memories(all_memory_items) + + # Perform conflict/duplicate check with old memories + # TODO: find a better way to pass in the user_name + user_name = kwargs.get("user_name") + if self.memory_version_switch == "on": + self._fast_resolve_memory_duplicates_and_conflicts(fast_memory_items, user_name) + if mode == "fast": return fast_memory_items else: diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 2c4fee853..0f3fda1df 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -20,6 +20,10 @@ if TYPE_CHECKING: from memos.graph_dbs.base import BaseGraphDB + from memos.memories.textual.tree_text_memory.organize.history_manager import ( + MemoryHistoryManager, + ) + from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang from memos.mem_reader.utils import ( @@ -184,6 +188,8 @@ def __init__(self, config: SimpleStructMemReaderConfig): # Initialize graph_db as None, can be set later via set_graph_db for # recall operations self.graph_db = None + self.pre_update_retriever = None + self.history_manager = None def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: self.graph_db = graph_db @@ -191,6 +197,12 @@ def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: def set_searcher(self, searcher: "Searcher | None") -> None: self.searcher = searcher + def set_pre_update_retriever(self, pre_update_retriever: "PreUpdateRetriever | None") -> None: + self.pre_update_retriever = pre_update_retriever + + def set_history_manager(self, history_manager: "MemoryHistoryManager | None") -> None: + self.history_manager = history_manager + def _make_memory_item( self, value: str, diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index b103acf3a..883c191fe 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -12,6 +12,7 @@ from memos.configs.reranker import RerankerConfigFactory from memos.configs.vec_db import VectorDBConfigFactory from memos.embedders.factory import EmbedderFactory +from memos.extras.nli_model.client import NLIClient from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory from memos.log import get_logger @@ -30,10 +31,12 @@ ) from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer @@ -287,6 +290,7 @@ def init_components() -> dict[str, Any]: graph_db_config = build_graph_db_config() llm_config = build_llm_config() embedder_config = build_embedder_config() + nli_client_config = APIConfig.get_nli_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() feedback_reranker_config = build_feedback_reranker_config() @@ -307,8 +311,16 @@ def init_components() -> dict[str, Any]: ) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) + nli_client = NLIClient(base_url=nli_client_config["base_url"]) + memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) + pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) - mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) + mem_reader = MemReaderFactory.from_config( + mem_reader_config, + graph_db=graph_db, + pre_update_retriever=pre_update_retriever, + history_manager=memory_history_manager, + ) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config(