Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/memos/api/handlers/component_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer


Expand Down Expand Up @@ -192,8 +193,14 @@ def init_server() -> dict[str, Any]:
embedder = EmbedderFactory.from_config(embedder_config)
nli_client = NLIClient(base_url=nli_client_config["base_url"])
memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db)
pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder)
# Pass graph_db to mem_reader for recall operations (deduplication, conflict detection)
mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db)
mem_reader = MemReaderFactory.from_config(
mem_reader_config,
graph_db=graph_db,
pre_update_retriever=pre_update_retriever,
history_manager=memory_history_manager,
)
reranker = RerankerFactory.from_config(reranker_config)
feedback_reranker = RerankerFactory.from_config(feedback_reranker_config)
internet_retriever = InternetRetrieverFactory.from_config(
Expand Down
6 changes: 5 additions & 1 deletion src/memos/configs/mem_reader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, ClassVar
from typing import Any, ClassVar, Literal

from pydantic import ConfigDict, Field, field_validator, model_validator

Expand Down Expand Up @@ -65,6 +65,10 @@ class MultiModalStructMemReaderConfig(BaseMemReaderConfig):
default=None,
description="Skills directory for the MemReader",
)
memory_version_switch: Literal["on", "off"] = Field(
default="off",
description="Turn on memory version or off",
)


class StrategyStructMemReaderConfig(BaseMemReaderConfig):
Expand Down
18 changes: 18 additions & 0 deletions src/memos/mem_reader/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
from memos.mem_reader.multi_modal_struct import MultiModalStructMemReader
from memos.mem_reader.simple_struct import SimpleStructMemReader
from memos.mem_reader.strategy_struct import StrategyStructMemReader
from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever
from memos.memos_tools.singleton import singleton_factory


if TYPE_CHECKING:
from memos.graph_dbs.base import BaseGraphDB
from memos.memories.textual.tree_text_memory.organize.history_manager import (
MemoryHistoryManager,
)
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher


Expand All @@ -29,6 +33,8 @@ def from_config(
config_factory: MemReaderConfigFactory,
graph_db: Optional["BaseGraphDB | None"] = None,
searcher: Optional["Searcher | None"] = None,
pre_update_retriever: PreUpdateRetriever | None = None,
history_manager: Optional["MemoryHistoryManager | None"] = None,
) -> BaseMemReader:
"""
Create a MemReader instance from configuration.
Expand All @@ -55,4 +61,16 @@ def from_config(
if searcher is not None:
reader.set_searcher(searcher)

if pre_update_retriever is not None:
if hasattr(reader, "set_pre_update_retriever"):
reader.set_pre_update_retriever(pre_update_retriever)
else:
reader.pre_update_retriever = pre_update_retriever

if history_manager is not None:
if hasattr(reader, "set_history_manager"):
reader.set_history_manager(history_manager)
else:
reader.history_manager = history_manager

return reader
44 changes: 44 additions & 0 deletions src/memos/mem_reader/multi_modal_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def __init__(self, config: MultiModalStructMemReaderConfig):
simple_config = SimpleStructMemReaderConfig(**config_dict)
super().__init__(simple_config)

self.pre_update_retriever = None
self.history_manager = None
self.memory_version_switch = getattr(config, "memory_version_switch", "off")

# Initialize MultiModalParser for routing to different parsers
self.multi_modal_parser = MultiModalParser(
embedder=self.embedder,
Expand Down Expand Up @@ -808,6 +812,39 @@ def _process_tool_trajectory_fine(

return fine_memory_items

def _fast_resolve_memory_duplicates_and_conflicts(
self, fast_memory_items: list[TextualMemoryItem], user_name: str
) -> None:
"""
1. Recall related memories
2. Fast conflict/duplication check with NLI model
3. Attach conflicting/duplicate old memory contents onto fast memory items
4. Mark conflicting/duplicate old memory nodes as "resolving", making them invisible to /search,
but still visible for other conflict/duplication checks' recalls.
"""
if not self.pre_update_retriever or not self.history_manager:
logger.warning(
"[MultiModalStruct] PreUpdateRetriever or HistoryManager is not initialized."
)
return

for item in fast_memory_items:
try:
# recall related memories
related = self.pre_update_retriever.retrieve(
item=item,
user_name=user_name,
)
# NLI check & attaching contents
conflicting_or_duplicate_items = self.history_manager.resolve_history_via_nli(
item, related
)
# mark delete
self.history_manager.mark_memory_status(conflicting_or_duplicate_items, "resolving")

except Exception as e:
logger.warning(f"[MultiModalStruct] Fast recall failed: {e}")

@timed
def _process_multi_modal_data(
self, scene_data_info: MessagesType, info, mode: str = "fine", **kwargs
Expand Down Expand Up @@ -856,6 +893,13 @@ def _process_multi_modal_data(
scene_data_info, info, mode="fast", need_emb=False, **kwargs
)
fast_memory_items = self._concat_multi_modal_memories(all_memory_items)

# Perform conflict/duplicate check with old memories
# TODO: find a better way to pass in the user_name
user_name = kwargs.get("user_name")
if self.memory_version_switch == "on":
self._fast_resolve_memory_duplicates_and_conflicts(fast_memory_items, user_name)

if mode == "fast":
return fast_memory_items
else:
Expand Down
12 changes: 12 additions & 0 deletions src/memos/mem_reader/simple_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

if TYPE_CHECKING:
from memos.graph_dbs.base import BaseGraphDB
from memos.memories.textual.tree_text_memory.organize.history_manager import (
MemoryHistoryManager,
)
from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang
from memos.mem_reader.utils import (
Expand Down Expand Up @@ -184,13 +188,21 @@ def __init__(self, config: SimpleStructMemReaderConfig):
# Initialize graph_db as None, can be set later via set_graph_db for
# recall operations
self.graph_db = None
self.pre_update_retriever = None
self.history_manager = None

def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None:
self.graph_db = graph_db

def set_searcher(self, searcher: "Searcher | None") -> None:
self.searcher = searcher

def set_pre_update_retriever(self, pre_update_retriever: "PreUpdateRetriever | None") -> None:
self.pre_update_retriever = pre_update_retriever

def set_history_manager(self, history_manager: "MemoryHistoryManager | None") -> None:
self.history_manager = history_manager

def _make_memory_item(
self,
value: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from memos.configs.reranker import RerankerConfigFactory
from memos.configs.vec_db import VectorDBConfigFactory
from memos.embedders.factory import EmbedderFactory
from memos.extras.nli_model.client import NLIClient
from memos.graph_dbs.factory import GraphStoreFactory
from memos.llms.factory import LLMFactory
from memos.log import get_logger
Expand All @@ -30,10 +31,12 @@
)
from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import (
InternetRetrieverFactory,
)
from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer


Expand Down Expand Up @@ -287,6 +290,7 @@ def init_components() -> dict[str, Any]:
graph_db_config = build_graph_db_config()
llm_config = build_llm_config()
embedder_config = build_embedder_config()
nli_client_config = APIConfig.get_nli_config()
mem_reader_config = build_mem_reader_config()
reranker_config = build_reranker_config()
feedback_reranker_config = build_feedback_reranker_config()
Expand All @@ -307,8 +311,16 @@ def init_components() -> dict[str, Any]:
)
llm = LLMFactory.from_config(llm_config)
embedder = EmbedderFactory.from_config(embedder_config)
nli_client = NLIClient(base_url=nli_client_config["base_url"])
memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db)
pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder)
# Pass graph_db to mem_reader for recall operations (deduplication, conflict detection)
mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db)
mem_reader = MemReaderFactory.from_config(
mem_reader_config,
graph_db=graph_db,
pre_update_retriever=pre_update_retriever,
history_manager=memory_history_manager,
)
reranker = RerankerFactory.from_config(reranker_config)
feedback_reranker = RerankerFactory.from_config(feedback_reranker_config)
internet_retriever = InternetRetrieverFactory.from_config(
Expand Down
Loading