Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/memos/api/handlers/component_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def init_server() -> dict[str, Any]:
else None
)
embedder = EmbedderFactory.from_config(embedder_config)
mem_reader = MemReaderFactory.from_config(mem_reader_config)
# Pass graph_db to mem_reader for recall operations (deduplication, conflict detection)
mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db)
reranker = RerankerFactory.from_config(reranker_config)
feedback_reranker = RerankerFactory.from_config(feedback_reranker_config)
internet_retriever = InternetRetrieverFactory.from_config(
Expand Down
2 changes: 1 addition & 1 deletion src/memos/graph_dbs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] |

@abstractmethod
def get_nodes(
self, id: str, include_embedding: bool = False, **kwargs
self, ids: list, include_embedding: bool = False, **kwargs
) -> dict[str, Any] | None:
"""
Retrieve the metadata and memory of a list of nodes.
Expand Down
3 changes: 2 additions & 1 deletion src/memos/mem_feedback/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def __init__(self, config: MemFeedbackConfig):
self.llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config(config.extractor_llm)
self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder)
self.graph_store: PolarDBGraphDB = GraphStoreFactory.from_config(config.graph_db)
self.mem_reader = MemReaderFactory.from_config(config.mem_reader)
# Pass graph_store to mem_reader for recall operations (deduplication, conflict detection)
self.mem_reader = MemReaderFactory.from_config(config.mem_reader, graph_db=self.graph_store)

self.is_reorganize = config.reorganize
self.memory_manager: MemoryManager = MemoryManager(
Expand Down
23 changes: 22 additions & 1 deletion src/memos/mem_reader/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,38 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import TYPE_CHECKING, Any

from memos.configs.mem_reader import BaseMemReaderConfig
from memos.memories.textual.item import TextualMemoryItem


if TYPE_CHECKING:
from memos.graph_dbs.base import BaseGraphDB


class BaseMemReader(ABC):
"""MemReader interface class for reading information."""

# Optional graph database for recall operations (for deduplication, conflict
# detection .etc)
graph_db: "BaseGraphDB | None" = None

@abstractmethod
def __init__(self, config: BaseMemReaderConfig):
"""Initialize the MemReader with the given configuration."""

@abstractmethod
def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None:
"""
Set the graph database instance for recall operations.
This enables the mem-reader to perform:
- Semantic deduplication: avoid storing duplicate memories
- Conflict detection: detect contradictions with existing memories
Args:
graph_db: The graph database instance, or None to disable recall operations.
"""

@abstractmethod
def get_memory(
self, scene_data: list, type: str, info: dict[str, Any], mode: str = "fast"
Expand Down
32 changes: 29 additions & 3 deletions src/memos/mem_reader/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, Optional

from memos.configs.mem_reader import MemReaderConfigFactory
from memos.mem_reader.base import BaseMemReader
Expand All @@ -8,6 +8,10 @@
from memos.memos_tools.singleton import singleton_factory


if TYPE_CHECKING:
from memos.graph_dbs.base import BaseGraphDB


class MemReaderFactory(BaseMemReader):
"""Factory class for creating MemReader instances."""

Expand All @@ -19,9 +23,31 @@ class MemReaderFactory(BaseMemReader):

@classmethod
@singleton_factory()
def from_config(cls, config_factory: MemReaderConfigFactory) -> BaseMemReader:
def from_config(
cls,
config_factory: MemReaderConfigFactory,
graph_db: Optional["BaseGraphDB | None"] = None,
) -> BaseMemReader:
"""
Create a MemReader instance from configuration.
Args:
config_factory: Configuration factory for the MemReader.
graph_db: Optional graph database instance for recall operations
(deduplication, conflict detection). Can also be set later
via reader.set_graph_db().
Returns:
Configured MemReader instance.
"""
backend = config_factory.backend
if backend not in cls.backend_to_class:
raise ValueError(f"Invalid backend: {backend}")
reader_class = cls.backend_to_class[backend]
return reader_class(config_factory.config)
reader = reader_class(config_factory.config)

# Set graph_db if provided (for recall operations)
if graph_db is not None:
reader.set_graph_db(graph_db)

return reader
70 changes: 61 additions & 9 deletions src/memos/mem_reader/multi_modal_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def _get_llm_response(
custom_tags: list[str] | None = None,
sources: list | None = None,
prompt_type: str = "chat",
related_memories: str | None = None,
) -> dict:
"""
Override parent method to improve language detection by using actual text content
Expand All @@ -326,6 +327,7 @@ def _get_llm_response(
custom_tags: Optional custom tags
sources: Optional list of SourceMessage objects to extract text content from
prompt_type: Type of prompt to use ("chat" or "doc")
related_memories: related_memories in the graph

Returns:
LLM response dictionary
Expand Down Expand Up @@ -360,7 +362,9 @@ def _get_llm_response(
else:
template = PROMPT_DICT["chat"][lang]
examples = PROMPT_DICT["chat"][f"{lang}_example"]
prompt = template.replace("${conversation}", mem_str)
prompt = template.replace("${conversation}", mem_str).replace(
"${reference}", related_memories
)

custom_tags_prompt = (
PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags))
Expand Down Expand Up @@ -418,6 +422,7 @@ def _process_string_fine(
fast_memory_items: list[TextualMemoryItem],
info: dict[str, Any],
custom_tags: list[str] | None = None,
**kwargs,
) -> list[TextualMemoryItem]:
"""
Process fast mode memory items through LLM to generate fine mode memories.
Expand Down Expand Up @@ -454,8 +459,40 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]:
# Determine prompt type based on sources
prompt_type = self._determine_prompt_type(sources)

# recall related memories
related_memories = None
memory_ids = []
if self.graph_db:
if "user_name" in kwargs:
memory_ids = self.graph_db.search_by_embedding(
vector=self.embedder.embed(mem_str)[0],
top_k=20,
status="activated",
user_name=kwargs.get("user_name"),
filter={
"or": [
{"memory_type": "LongTermMemory"},
{"memory_type": "UserMemory"},
{"memory_type": "WorkingMemory"},
]
},
)
memory_ids = set({r["id"] for r in memory_ids if r.get("id")})
related_memories_list = self.graph_db.get_nodes(
list(memory_ids),
include_embedding=False,
user_name=kwargs.get("user_name"),
)
related_memories = "\n".join(
["{}: {}".format(mem["id"], mem["memory"]) for mem in related_memories_list]
)
else:
logger.warning("user_name is null when graph_db exists")

try:
resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type)
resp = self._get_llm_response(
mem_str, custom_tags, sources, prompt_type, related_memories
)
except Exception as e:
logger.error(f"[MultiModalFine] Error calling LLM: {e}")
return fine_items
Expand All @@ -469,6 +506,11 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]:
.replace("长期记忆", "LongTermMemory")
.replace("用户记忆", "UserMemory")
)
if "merged_from" in m:
for merged_id in m["merged_from"]:
if merged_id not in memory_ids:
logger.warning("merged id not valid!!!!!")
info_per_item["merged_from"] = m["merged_from"]
# Create fine mode memory item (same as simple_struct)
node = self._make_memory_item(
value=m.get("value", ""),
Expand All @@ -485,6 +527,11 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]:
logger.error(f"[MultiModalFine] parse error: {e}")
elif resp.get("value") and resp.get("key"):
try:
if "merged_from" in resp:
for merged_id in resp["merged_from"]:
if merged_id not in memory_ids:
logger.warning("merged id not valid!!!!!")
info_per_item["merged_from"] = resp["merged_from"]
# Create fine mode memory item (same as simple_struct)
node = self._make_memory_item(
value=resp.get("value", "").strip(),
Expand Down Expand Up @@ -533,9 +580,7 @@ def _get_llm_tool_trajectory_response(self, mem_str: str) -> dict:
return []

def _process_tool_trajectory_fine(
self,
fast_memory_items: list[TextualMemoryItem],
info: dict[str, Any],
self, fast_memory_items: list[TextualMemoryItem], info: dict[str, Any], **kwargs
) -> list[TextualMemoryItem]:
"""
Process tool trajectory memory items through LLM to generate fine mode memories.
Expand Down Expand Up @@ -618,10 +663,10 @@ def _process_multi_modal_data(

with ContextThreadPoolExecutor(max_workers=2) as executor:
future_string = executor.submit(
self._process_string_fine, fast_memory_items, info, custom_tags
self._process_string_fine, fast_memory_items, info, custom_tags, **kwargs
)
future_tool = executor.submit(
self._process_tool_trajectory_fine, fast_memory_items, info
self._process_tool_trajectory_fine, fast_memory_items, info, **kwargs
)

# Collect results
Expand Down Expand Up @@ -710,15 +755,22 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[list[Any]]:
return scene_data

def _read_memory(
self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine"
self,
messages: list[MessagesType],
type: str,
info: dict[str, Any],
mode: str = "fine",
**kwargs,
) -> list[list[TextualMemoryItem]]:
list_scene_data_info = self.get_scene_data_info(messages, type)

memory_list = []
# Process Q&A pairs concurrently with context propagation
with ContextThreadPoolExecutor() as executor:
futures = [
executor.submit(self._process_multi_modal_data, scene_data_info, info, mode=mode)
executor.submit(
self._process_multi_modal_data, scene_data_info, info, mode=mode, **kwargs
)
for scene_data_info in list_scene_data_info
]
for future in concurrent.futures.as_completed(futures):
Expand Down
32 changes: 27 additions & 5 deletions src/memos/mem_reader/simple_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import traceback

from abc import ABC
from typing import Any, TypeAlias
from typing import TYPE_CHECKING, Any, TypeAlias

from tqdm import tqdm

Expand All @@ -16,6 +16,10 @@
from memos.embedders.factory import EmbedderFactory
from memos.llms.factory import LLMFactory
from memos.mem_reader.base import BaseMemReader


if TYPE_CHECKING:
from memos.graph_dbs.base import BaseGraphDB
from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang
from memos.mem_reader.utils import (
count_tokens_text,
Expand Down Expand Up @@ -176,6 +180,12 @@ def __init__(self, config: SimpleStructMemReaderConfig):
self.chat_window_max_tokens = getattr(self.config, "chat_window_max_tokens", 1024)
self._count_tokens = count_tokens_text
self.searcher = None
# Initialize graph_db as None, can be set later via set_graph_db for
# recall operations
self.graph_db = None

def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None:
self.graph_db = graph_db

def _make_memory_item(
self,
Expand Down Expand Up @@ -218,7 +228,7 @@ def _get_llm_response(self, mem_str: str, custom_tags: list[str] | None) -> dict
lang = detect_lang(mem_str)
template = PROMPT_DICT["chat"][lang]
examples = PROMPT_DICT["chat"][f"{lang}_example"]
prompt = template.replace("${conversation}", mem_str)
prompt = template.replace("${conversation}", mem_str).replace("${reference}", "")

custom_tags_prompt = (
PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags))
Expand Down Expand Up @@ -390,7 +400,12 @@ def _process_transfer_chat_data(
return chat_read_nodes

def get_memory(
self, scene_data: SceneDataInput, type: str, info: dict[str, Any], mode: str = "fine"
self,
scene_data: SceneDataInput,
type: str,
info: dict[str, Any],
mode: str = "fine",
user_name: str | None = None,
) -> list[list[TextualMemoryItem]]:
"""
Extract and classify memory content from scene_data.
Expand All @@ -409,6 +424,8 @@ def get_memory(
- chunk_overlap: Overlap for small chunks (default: 50)
mode: mem-reader mode, fast for quick process while fine for
better understanding via calling llm
user_name: tha user_name would be inserted later into the
database, may be used in recall.
Returns:
list[list[TextualMemoryItem]] containing memory content with summaries as keys and original text as values
Raises:
Expand All @@ -432,7 +449,7 @@ def get_memory(
# Backward compatibility, after coercing scene_data, we only tackle
# with standard scene_data type: MessagesType
standard_scene_data = coerce_scene_data(scene_data, type)
return self._read_memory(standard_scene_data, type, info, mode)
return self._read_memory(standard_scene_data, type, info, mode, user_name=user_name)

def rewrite_memories(
self, messages: list[dict], memory_list: list[TextualMemoryItem], user_only: bool = True
Expand Down Expand Up @@ -558,7 +575,12 @@ def filter_hallucination_in_memories(
return memory_list

def _read_memory(
self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine"
self,
messages: list[MessagesType],
type: str,
info: dict[str, Any],
mode: str = "fine",
**kwargs,
) -> list[list[TextualMemoryItem]]:
"""
1. raw file:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def init_components() -> dict[str, Any]:
)
llm = LLMFactory.from_config(llm_config)
embedder = EmbedderFactory.from_config(embedder_config)
mem_reader = MemReaderFactory.from_config(mem_reader_config)
# Pass graph_db to mem_reader for recall operations (deduplication, conflict detection)
mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db)
reranker = RerankerFactory.from_config(reranker_config)
feedback_reranker = RerankerFactory.from_config(feedback_reranker_config)
internet_retriever = InternetRetrieverFactory.from_config(
Expand Down
Loading