diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 56f8ac195..76af6decf 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -183,7 +183,8 @@ def init_server() -> dict[str, Any]: else None ) embedder = EmbedderFactory.from_config(embedder_config) - mem_reader = MemReaderFactory.from_config(mem_reader_config) + # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) + mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( diff --git a/src/memos/graph_dbs/base.py b/src/memos/graph_dbs/base.py index b76ed9d08..130b66a3d 100644 --- a/src/memos/graph_dbs/base.py +++ b/src/memos/graph_dbs/base.py @@ -19,12 +19,13 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: """ @abstractmethod - def update_node(self, id: str, fields: dict[str, Any]) -> None: + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: """ Update attributes of an existing node. Args: id: Node identifier to be updated. fields: Dictionary of fields to update. + user_name: given user_name """ @abstractmethod @@ -70,7 +71,7 @@ def edge_exists(self, source_id: str, target_id: str, type: str) -> bool: # Graph Query & Reasoning @abstractmethod - def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | None: + def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None: """ Retrieve the metadata and content of a node. Args: @@ -82,7 +83,7 @@ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | @abstractmethod def get_nodes( - self, id: str, include_embedding: bool = False, **kwargs + self, ids: list, include_embedding: bool = False, **kwargs ) -> dict[str, Any] | None: """ Retrieve the metadata and memory of a list of nodes. @@ -160,13 +161,17 @@ def search_by_embedding(self, vector: list[float], top_k: int = 5, **kwargs) -> """ @abstractmethod - def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: + def get_by_metadata( + self, filters: list[dict[str, Any]], status: str | None = None + ) -> list[str]: """ Retrieve node IDs that match given metadata filters. Args: filters (dict[str, Any]): A dictionary of attribute-value filters. Example: {"topic": "psychology", "importance": 2} + status (str, optional): Filter by status (e.g., 'activated', 'archived'). + If None, no status filter is applied. Returns: list[str]: Node IDs whose metadata match the filter conditions. @@ -239,13 +244,17 @@ def import_graph(self, data: dict[str, Any]) -> None: """ @abstractmethod - def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> list[dict]: + def get_all_memory_items( + self, scope: str, include_embedding: bool = False, status: str | None = None + ) -> list[dict]: """ Retrieve all memory items of a specific memory_type. Args: scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. include_embedding: with/without embedding + status (str, optional): Filter by status (e.g., 'activated', 'archived'). + If None, no status filter is applied. Returns: list[dict]: Full list of memory items under this scope. diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 64aedc8f4..8698b6f73 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -916,6 +916,7 @@ def get_by_metadata( filter: dict | None = None, knowledgebase_ids: list[str] | None = None, user_name_flag: bool = True, + status: str | None = None, ) -> list[str]: """ TODO: @@ -933,6 +934,8 @@ def get_by_metadata( {"field": "tags", "op": "contains", "value": "AI"}, ... ] + status (str, optional): Filter by status (e.g., 'activated', 'archived'). + If None, no status filter is applied. Returns: list[str]: Node IDs whose metadata match the filter conditions. (AND logic). @@ -942,15 +945,20 @@ def get_by_metadata( - Can be used for faceted recall or prefiltering before embedding rerank. """ logger.info( - f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}" + f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}" ) print( - f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}" + f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}" ) user_name = user_name if user_name else self.config.user_name where_clauses = [] params = {} + # Add status filter if provided + if status: + where_clauses.append("n.status = $status") + params["status"] = status + for i, f in enumerate(filters): field = f["field"] op = f.get("op", "=") @@ -1272,8 +1280,10 @@ def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> No def get_all_memory_items( self, scope: str, + include_embedding: bool = False, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + status: str | None = None, **kwargs, ) -> list[dict]: """ @@ -1281,18 +1291,21 @@ def get_all_memory_items( Args: scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. + include_embedding (bool): Whether to include embedding in results. filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results. Example: {"and": [{"id": "xxx"}, {"A": "yyy"}]} or {"or": [{"id": "xxx"}, {"A": "yyy"}]} - Returns: + knowledgebase_ids (list[str], optional): List of knowledgebase IDs to filter by. + status (str, optional): Filter by status (e.g., 'activated', 'archived'). + If None, no status filter is applied. Returns: list[dict]: Full list of memory items under this scope. """ logger.info( - f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids}" + f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}" ) print( - f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids}" + f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}" ) user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name @@ -1302,6 +1315,11 @@ def get_all_memory_items( where_clauses = ["n.memory_type = $scope"] params = {"scope": scope} + # Add status filter if provided + if status: + where_clauses.append("n.status = $status") + params["status"] = status + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher( user_name=user_name, diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index e67f866ac..4b739bb0f 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2823,6 +2823,7 @@ def get_all_memory_items( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list | None = None, + status: str | None = None, ) -> list[dict]: """ Retrieve all memory items of a specific memory_type. @@ -2831,12 +2832,16 @@ def get_all_memory_items( scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. include_embedding: with/without embedding user_name (str, optional): User name for filtering in non-multi-db mode + filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results. + knowledgebase_ids (list, optional): List of knowledgebase IDs to filter by. + status (str, optional): Filter by status (e.g., 'activated', 'archived'). + If None, no status filter is applied. Returns: list[dict]: Full list of memory items under this scope. """ logger.info( - f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" + f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status}" ) user_name = user_name if user_name else self._get_config_value("user_name") @@ -2867,6 +2872,8 @@ def get_all_memory_items( if include_embedding: # Build WHERE clause with user_name/knowledgebase_ids and filter where_parts = [f"n.memory_type = '{scope}'"] + if status: + where_parts.append(f"n.status = '{status}'") if user_name_where: # user_name_where already contains parentheses if it's an OR condition where_parts.append(user_name_where) @@ -2927,6 +2934,8 @@ def get_all_memory_items( else: # Build WHERE clause with user_name/knowledgebase_ids and filter where_parts = [f"n.memory_type = '{scope}'"] + if status: + where_parts.append(f"n.status = '{status}'") if user_name_where: # user_name_where already contains parentheses if it's an OR condition where_parts.append(user_name_where) diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 15d7c336a..1d199c6cb 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -76,7 +76,8 @@ def __init__(self, config: MemFeedbackConfig): self.llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config(config.extractor_llm) self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder) self.graph_store: PolarDBGraphDB = GraphStoreFactory.from_config(config.graph_db) - self.mem_reader = MemReaderFactory.from_config(config.mem_reader) + # Pass graph_store to mem_reader for recall operations (deduplication, conflict detection) + self.mem_reader = MemReaderFactory.from_config(config.mem_reader, graph_db=self.graph_store) self.is_reorganize = config.reorganize self.memory_manager: MemoryManager = MemoryManager( diff --git a/src/memos/mem_reader/base.py b/src/memos/mem_reader/base.py index 391270bcf..87bf43b0f 100644 --- a/src/memos/mem_reader/base.py +++ b/src/memos/mem_reader/base.py @@ -1,17 +1,38 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import TYPE_CHECKING, Any from memos.configs.mem_reader import BaseMemReaderConfig from memos.memories.textual.item import TextualMemoryItem +if TYPE_CHECKING: + from memos.graph_dbs.base import BaseGraphDB + + class BaseMemReader(ABC): """MemReader interface class for reading information.""" + # Optional graph database for recall operations (for deduplication, conflict + # detection .etc) + graph_db: "BaseGraphDB | None" = None + @abstractmethod def __init__(self, config: BaseMemReaderConfig): """Initialize the MemReader with the given configuration.""" + @abstractmethod + def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: + """ + Set the graph database instance for recall operations. + + This enables the mem-reader to perform: + - Semantic deduplication: avoid storing duplicate memories + - Conflict detection: detect contradictions with existing memories + + Args: + graph_db: The graph database instance, or None to disable recall operations. + """ + @abstractmethod def get_memory( self, scene_data: list, type: str, info: dict[str, Any], mode: str = "fast" diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py index ff24e5c77..2749327bf 100644 --- a/src/memos/mem_reader/factory.py +++ b/src/memos/mem_reader/factory.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, Optional from memos.configs.mem_reader import MemReaderConfigFactory from memos.mem_reader.base import BaseMemReader @@ -8,6 +8,10 @@ from memos.memos_tools.singleton import singleton_factory +if TYPE_CHECKING: + from memos.graph_dbs.base import BaseGraphDB + + class MemReaderFactory(BaseMemReader): """Factory class for creating MemReader instances.""" @@ -19,9 +23,31 @@ class MemReaderFactory(BaseMemReader): @classmethod @singleton_factory() - def from_config(cls, config_factory: MemReaderConfigFactory) -> BaseMemReader: + def from_config( + cls, + config_factory: MemReaderConfigFactory, + graph_db: Optional["BaseGraphDB | None"] = None, + ) -> BaseMemReader: + """ + Create a MemReader instance from configuration. + + Args: + config_factory: Configuration factory for the MemReader. + graph_db: Optional graph database instance for recall operations + (deduplication, conflict detection). Can also be set later + via reader.set_graph_db(). + + Returns: + Configured MemReader instance. + """ backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") reader_class = cls.backend_to_class[backend] - return reader_class(config_factory.config) + reader = reader_class(config_factory.config) + + # Set graph_db if provided (for recall operations) + if graph_db is not None: + reader.set_graph_db(graph_db) + + return reader diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 3bf6d4927..9edcd0a55 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -13,6 +13,7 @@ from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.templates.mem_reader_prompts import MEMORY_MERGE_PROMPT_EN, MEMORY_MERGE_PROMPT_ZH from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH from memos.types import MessagesType from memos.utils import timed @@ -393,6 +394,7 @@ def _get_llm_response( ], "summary": mem_str, } + logger.info(f"[MultiModalFine] Task {messages}, Result {response_json}") return response_json def _determine_prompt_type(self, sources: list) -> str: @@ -413,11 +415,184 @@ def _determine_prompt_type(self, sources: list) -> str: return prompt_type + def _get_maybe_merged_memory( + self, + extracted_memory_dict: dict, + mem_text: str, + sources: list, + **kwargs, + ) -> dict: + """ + Check if extracted memory should be merged with similar existing memories. + If merge is needed, return merged memory dict with merged_from field. + Otherwise, return original memory dict. + + Args: + extracted_memory_dict: The extracted memory dict from LLM response + mem_text: The memory text content + sources: Source messages for language detection + **kwargs: Additional parameters (merge_similarity_threshold, etc.) + + Returns: + Memory dict (possibly merged) with merged_from field if merged + """ + # If no graph_db or user_name, return original + if not self.graph_db or "user_name" not in kwargs: + return extracted_memory_dict + user_name = kwargs.get("user_name") + + # Detect language + lang = "en" + if sources: + for source in sources: + if hasattr(source, "lang") and source.lang: + lang = source.lang + break + elif isinstance(source, dict) and source.get("lang"): + lang = source.get("lang") + break + if lang is None: + lang = detect_lang(mem_text) + + # Search for similar memories + merge_threshold = kwargs.get("merge_similarity_threshold", 0.3) + + try: + search_results = self.graph_db.search_by_embedding( + vector=self.embedder.embed(mem_text)[0], + top_k=20, + status="activated", + threshold=merge_threshold, + user_name=user_name, + filter={ + "or": [ + {"memory_type": "LongTermMemory"}, + {"memory_type": "UserMemory"}, + {"memory_type": "WorkingMemory"}, + ] + }, + ) + + if not search_results: + return extracted_memory_dict + + # Get full memory details + similar_memory_ids = [r["id"] for r in search_results if r.get("id")] + similar_memories_list = [ + self.graph_db.get_node(mem_id, include_embedding=False, user_name=user_name) + for mem_id in similar_memory_ids + ] + + # Filter out None and mode:fast memories + filtered_similar = [] + for mem in similar_memories_list: + if not mem: + continue + mem_metadata = mem.get("metadata", {}) + tags = mem_metadata.get("tags", []) + if isinstance(tags, list) and "mode:fast" in tags: + continue + filtered_similar.append( + { + "id": mem.get("id"), + "memory": mem.get("memory", ""), + } + ) + logger.info( + f"Valid similar memories for {mem_text} is " + f"{len(filtered_similar)}: {filtered_similar}" + ) + + if not filtered_similar: + return extracted_memory_dict + + # Create a temporary TextualMemoryItem for merge check + temp_memory_item = TextualMemoryItem( + memory=mem_text, + metadata=TreeNodeTextualMemoryMetadata( + user_id="", + session_id="", + memory_type=extracted_memory_dict.get("memory_type", "LongTermMemory"), + status="activated", + tags=extracted_memory_dict.get("tags", []), + key=extracted_memory_dict.get("key", ""), + ), + ) + + # Try to merge with LLM + merge_result = self._merge_memories_with_llm( + temp_memory_item, filtered_similar, lang=lang + ) + + if merge_result: + # Return merged memory dict + merged_dict = extracted_memory_dict.copy() + merged_content = merge_result.get("value", mem_text) + merged_dict["value"] = merged_content + merged_from_ids = merge_result.get("merged_from", []) + merged_dict["merged_from"] = merged_from_ids + return merged_dict + else: + return extracted_memory_dict + + except Exception as e: + logger.error(f"[MultiModalFine] Error in get_maybe_merged_memory: {e}") + # On error, return original + return extracted_memory_dict + + def _merge_memories_with_llm( + self, + new_memory: TextualMemoryItem, + similar_memories: list[dict], + lang: str = "en", + ) -> dict | None: + """ + Use LLM to merge new memory with similar existing memories. + + Args: + new_memory: The newly extracted memory item + similar_memories: List of similar memories from graph_db (with id and memory fields) + lang: Language code ("en" or "zh") + + Returns: + Merged memory dict with merged_from field, or None if no merge needed + """ + if not similar_memories: + return None + + # Build merge prompt using template + similar_memories_text = "\n".join( + [f"[{mem['id']}]: {mem['memory']}" for mem in similar_memories] + ) + + merge_prompt_template = MEMORY_MERGE_PROMPT_ZH if lang == "zh" else MEMORY_MERGE_PROMPT_EN + merge_prompt = merge_prompt_template.format( + new_memory=new_memory.memory, + similar_memories=similar_memories_text, + ) + + try: + response_text = self.llm.generate([{"role": "user", "content": merge_prompt}]) + merge_result = parse_json_result(response_text) + + if merge_result.get("should_merge", False): + return { + "value": merge_result.get("value", new_memory.memory), + "merged_from": merge_result.get( + "merged_from", [mem["id"] for mem in similar_memories] + ), + } + except Exception as e: + logger.error(f"[MultiModalFine] Error in merge LLM call: {e}") + + return None + def _process_string_fine( self, fast_memory_items: list[TextualMemoryItem], info: dict[str, Any], custom_tags: list[str] | None = None, + **kwargs, ) -> list[TextualMemoryItem]: """ Process fast mode memory items through LLM to generate fine mode memories. @@ -454,6 +629,7 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]: # Determine prompt type based on sources prompt_type = self._determine_prompt_type(sources) + # ========== Stage 1: Normal extraction (without reference) ========== try: resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type) except Exception as e: @@ -463,39 +639,61 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]: if resp.get("memory list", []): for m in resp.get("memory list", []): try: + # Check and merge with similar memories if needed + m_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=m, + mem_text=m.get("value", ""), + sources=sources, + original_query=mem_str, + **kwargs, + ) # Normalize memory_type (same as simple_struct) memory_type = ( - m.get("memory_type", "LongTermMemory") + m_maybe_merged.get("memory_type", "LongTermMemory") .replace("长期记忆", "LongTermMemory") .replace("用户记忆", "UserMemory") ) - # Create fine mode memory item (same as simple_struct) node = self._make_memory_item( - value=m.get("value", ""), + value=m_maybe_merged.get("value", ""), info=info_per_item, memory_type=memory_type, - tags=m.get("tags", []), - key=m.get("key", ""), + tags=m_maybe_merged.get("tags", []), + key=m_maybe_merged.get("key", ""), sources=sources, # Preserve sources from fast item background=resp.get("summary", ""), **extra_kwargs, ) + # Add merged_from to info if present + if "merged_from" in m_maybe_merged: + node.metadata.info = node.metadata.info or {} + node.metadata.info["merged_from"] = m_maybe_merged["merged_from"] fine_items.append(node) except Exception as e: logger.error(f"[MultiModalFine] parse error: {e}") elif resp.get("value") and resp.get("key"): try: - # Create fine mode memory item (same as simple_struct) + # Check and merge with similar memories if needed + resp_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=resp, + mem_text=resp.get("value", "").strip(), + sources=sources, + original_query=mem_str, + **kwargs, + ) node = self._make_memory_item( - value=resp.get("value", "").strip(), + value=resp_maybe_merged.get("value", "").strip(), info=info_per_item, memory_type="LongTermMemory", - tags=resp.get("tags", []), - key=resp.get("key", None), + tags=resp_maybe_merged.get("tags", []), + key=resp_maybe_merged.get("key", None), sources=sources, # Preserve sources from fast item background=resp.get("summary", ""), **extra_kwargs, ) + # Add merged_from to info if present + if "merged_from" in resp_maybe_merged: + node.metadata.info = node.metadata.info or {} + node.metadata.info["merged_from"] = resp_maybe_merged["merged_from"] fine_items.append(node) except Exception as e: logger.error(f"[MultiModalFine] parse error: {e}") @@ -533,9 +731,7 @@ def _get_llm_tool_trajectory_response(self, mem_str: str) -> dict: return [] def _process_tool_trajectory_fine( - self, - fast_memory_items: list[TextualMemoryItem], - info: dict[str, Any], + self, fast_memory_items: list[TextualMemoryItem], info: dict[str, Any], **kwargs ) -> list[TextualMemoryItem]: """ Process tool trajectory memory items through LLM to generate fine mode memories. @@ -618,10 +814,10 @@ def _process_multi_modal_data( with ContextThreadPoolExecutor(max_workers=2) as executor: future_string = executor.submit( - self._process_string_fine, fast_memory_items, info, custom_tags + self._process_string_fine, fast_memory_items, info, custom_tags, **kwargs ) future_tool = executor.submit( - self._process_tool_trajectory_fine, fast_memory_items, info + self._process_tool_trajectory_fine, fast_memory_items, info, **kwargs ) # Collect results @@ -648,9 +844,7 @@ def _process_multi_modal_data( @timed def _process_transfer_multi_modal_data( - self, - raw_node: TextualMemoryItem, - custom_tags: list[str] | None = None, + self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None, **kwargs ) -> list[TextualMemoryItem]: """ Process transfer for multimodal data. @@ -674,9 +868,11 @@ def _process_transfer_multi_modal_data( # Part A: call llm in parallel using thread pool with ContextThreadPoolExecutor(max_workers=2) as executor: future_string = executor.submit( - self._process_string_fine, [raw_node], info, custom_tags + self._process_string_fine, [raw_node], info, custom_tags, **kwargs + ) + future_tool = executor.submit( + self._process_tool_trajectory_fine, [raw_node], info, **kwargs ) - future_tool = executor.submit(self._process_tool_trajectory_fine, [raw_node], info) # Collect results fine_memory_items_string_parser = future_string.result() @@ -710,7 +906,12 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[list[Any]]: return scene_data def _read_memory( - self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine" + self, + messages: list[MessagesType], + type: str, + info: dict[str, Any], + mode: str = "fine", + **kwargs, ) -> list[list[TextualMemoryItem]]: list_scene_data_info = self.get_scene_data_info(messages, type) @@ -718,7 +919,9 @@ def _read_memory( # Process Q&A pairs concurrently with context propagation with ContextThreadPoolExecutor() as executor: futures = [ - executor.submit(self._process_multi_modal_data, scene_data_info, info, mode=mode) + executor.submit( + self._process_multi_modal_data, scene_data_info, info, mode=mode, **kwargs + ) for scene_data_info in list_scene_data_info ] for future in concurrent.futures.as_completed(futures): @@ -736,6 +939,7 @@ def fine_transfer_simple_mem( input_memories: list[TextualMemoryItem], type: str, custom_tags: list[str] | None = None, + **kwargs, ) -> list[list[TextualMemoryItem]]: if not input_memories: return [] @@ -746,7 +950,7 @@ def fine_transfer_simple_mem( with ContextThreadPoolExecutor() as executor: futures = [ executor.submit( - self._process_transfer_multi_modal_data, scene_data_info, custom_tags + self._process_transfer_multi_modal_data, scene_data_info, custom_tags, **kwargs ) for scene_data_info in input_memories ] diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index fa72bd063..3e33538e0 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -5,7 +5,7 @@ import traceback from abc import ABC -from typing import Any, TypeAlias +from typing import TYPE_CHECKING, Any, TypeAlias from tqdm import tqdm @@ -16,6 +16,10 @@ from memos.embedders.factory import EmbedderFactory from memos.llms.factory import LLMFactory from memos.mem_reader.base import BaseMemReader + + +if TYPE_CHECKING: + from memos.graph_dbs.base import BaseGraphDB from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang from memos.mem_reader.utils import ( count_tokens_text, @@ -176,6 +180,12 @@ def __init__(self, config: SimpleStructMemReaderConfig): self.chat_window_max_tokens = getattr(self.config, "chat_window_max_tokens", 1024) self._count_tokens = count_tokens_text self.searcher = None + # Initialize graph_db as None, can be set later via set_graph_db for + # recall operations + self.graph_db = None + + def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: + self.graph_db = graph_db def _make_memory_item( self, @@ -351,7 +361,7 @@ def _build_fast_node(w): return chat_read_nodes def _process_transfer_chat_data( - self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None + self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None, **kwargs ): raw_memory = raw_node.memory response_json = self._get_llm_response(raw_memory, custom_tags) @@ -390,7 +400,12 @@ def _process_transfer_chat_data( return chat_read_nodes def get_memory( - self, scene_data: SceneDataInput, type: str, info: dict[str, Any], mode: str = "fine" + self, + scene_data: SceneDataInput, + type: str, + info: dict[str, Any], + mode: str = "fine", + user_name: str | None = None, ) -> list[list[TextualMemoryItem]]: """ Extract and classify memory content from scene_data. @@ -409,6 +424,8 @@ def get_memory( - chunk_overlap: Overlap for small chunks (default: 50) mode: mem-reader mode, fast for quick process while fine for better understanding via calling llm + user_name: tha user_name would be inserted later into the + database, may be used in recall. Returns: list[list[TextualMemoryItem]] containing memory content with summaries as keys and original text as values Raises: @@ -432,7 +449,7 @@ def get_memory( # Backward compatibility, after coercing scene_data, we only tackle # with standard scene_data type: MessagesType standard_scene_data = coerce_scene_data(scene_data, type) - return self._read_memory(standard_scene_data, type, info, mode) + return self._read_memory(standard_scene_data, type, info, mode, user_name=user_name) def rewrite_memories( self, messages: list[dict], memory_list: list[TextualMemoryItem], user_only: bool = True @@ -558,7 +575,12 @@ def filter_hallucination_in_memories( return memory_list def _read_memory( - self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine" + self, + messages: list[MessagesType], + type: str, + info: dict[str, Any], + mode: str = "fine", + **kwargs, ) -> list[list[TextualMemoryItem]]: """ 1. raw file: @@ -647,6 +669,7 @@ def fine_transfer_simple_mem( input_memories: list[TextualMemoryItem], type: str, custom_tags: list[str] | None = None, + **kwargs, ) -> list[list[TextualMemoryItem]]: if not input_memories: return [] @@ -663,7 +686,7 @@ def fine_transfer_simple_mem( # Process Q&A pairs concurrently with context propagation with ContextThreadPoolExecutor() as executor: futures = [ - executor.submit(processing_func, scene_data_info, custom_tags) + executor.submit(processing_func, scene_data_info, custom_tags, **kwargs) for scene_data_info in input_memories ] for future in concurrent.futures.as_completed(futures): @@ -867,6 +890,6 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): return doc_nodes def _process_transfer_doc_data( - self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None + self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None, **kwargs ): raise NotImplementedError diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index 8fd60153d..3a12a9c79 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -305,7 +305,8 @@ def init_components() -> dict[str, Any]: ) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) - mem_reader = MemReaderFactory.from_config(mem_reader_config) + # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) + mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 9b19e9ecb..d4ac09cc3 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -877,6 +877,7 @@ def _process_memories_with_reader( memory_items, type="chat", custom_tags=custom_tags, + user_name=user_name, ) except Exception as e: logger.warning(f"{e}: Fail to transfer mem: {memory_items}") @@ -897,6 +898,38 @@ def _process_memories_with_reader( f"Added {len(enhanced_mem_ids)} enhanced memories: {enhanced_mem_ids}" ) + # Mark merged_from memories as archived when provided in memory metadata + if self.mem_reader.graph_db: + for memory in flattened_memories: + merged_from = (memory.metadata.info or {}).get("merged_from") + if merged_from: + old_ids = ( + merged_from + if isinstance(merged_from, (list | tuple | set)) + else [merged_from] + ) + for old_id in old_ids: + try: + self.mem_reader.graph_db.update_node( + str(old_id), {"status": "archived"}, user_name=user_name + ) + logger.info( + f"[Scheduler] Archived merged_from memory: {old_id}" + ) + except Exception as e: + logger.warning( + f"[Scheduler] Failed to archive merged_from memory {old_id}: {e}" + ) + else: + # Check if any memory has merged_from but graph_db is unavailable + has_merged_from = any( + (m.metadata.info or {}).get("merged_from") for m in flattened_memories + ) + if has_merged_from: + logger.warning( + "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + ) + # LOGGING BLOCK START # This block is replicated from _add_message_consumer to ensure consistent logging cloud_env = is_cloud_env() diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 2db2fd08b..4541b118b 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -77,6 +77,7 @@ def retrieve( include_embedding=self.include_embedding, user_name=user_name, filter=search_filter, + status="activated", ) return [TextualMemoryItem.from_dict(record) for record in working_memories[:top_k]] @@ -247,7 +248,7 @@ def process_node(node): # Load nodes and post-filter node_dicts = self.graph_store.get_nodes( - list(candidate_ids), include_embedding=self.include_embedding + list(candidate_ids), include_embedding=self.include_embedding, user_name=user_name ) final_nodes = [] @@ -277,7 +278,9 @@ def process_node(node): {"field": "key", "op": "in", "value": parsed_goal.keys}, {"field": "memory_type", "op": "=", "value": memory_scope}, ] - key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) + key_ids = self.graph_store.get_by_metadata( + key_filters, user_name=user_name, status="activated" + ) candidate_ids.update(key_ids) # 2) tag-based OR branch @@ -286,7 +289,9 @@ def process_node(node): {"field": "tags", "op": "contains", "value": parsed_goal.tags}, {"field": "memory_type", "op": "=", "value": memory_scope}, ] - tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_name) + tag_ids = self.graph_store.get_by_metadata( + tag_filters, user_name=user_name, status="activated" + ) candidate_ids.update(tag_ids) # No matches → return empty @@ -422,9 +427,11 @@ def _bm25_recall( value = search_filter[key] key_filters.append({"field": key, "op": "=", "value": value}) corpus_name += "".join(list(search_filter.values())) - candidate_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) + candidate_ids = self.graph_store.get_by_metadata( + key_filters, user_name=user_name, status="activated" + ) node_dicts = self.graph_store.get_nodes( - list(candidate_ids), include_embedding=self.include_embedding + list(candidate_ids), include_embedding=self.include_embedding, user_name=user_name ) bm25_query = " ".join(list({query, *parsed_goal.keys})) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 6c3cc0cc7..426cf32be 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -802,6 +802,7 @@ def _process_text_mem( "session_id": target_session_id, }, mode=extract_mode, + user_name=user_context.mem_cube_id, ) flattened_local = [mm for m in memories_local for mm in m] @@ -831,6 +832,36 @@ def _process_text_mem( sync_mode=sync_mode, ) + # Mark merged_from memories as archived when provided in add_req.info + if sync_mode == "sync" and extract_mode == "fine": + for memory in flattened_local: + merged_from = (memory.metadata.info or {}).get("merged_from") + if merged_from: + old_ids = ( + merged_from + if isinstance(merged_from, (list | tuple | set)) + else [merged_from] + ) + if self.mem_reader and self.mem_reader.graph_db: + for old_id in old_ids: + try: + self.mem_reader.graph_db.update_node( + str(old_id), + {"status": "archived"}, + user_name=user_context.mem_cube_id, + ) + self.logger.info( + f"[SingleCubeView] Archived merged_from memory: {old_id}" + ) + except Exception as e: + self.logger.warning( + f"[SingleCubeView] Failed to archive merged_from memory {old_id}: {e}" + ) + else: + self.logger.warning( + "[SingleCubeView] merged_from provided but graph_db is unavailable; skip archiving." + ) + text_memories = [ { "memory": memory.memory, diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index 20f8150b7..2a2df0e0b 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -223,6 +223,7 @@ 您的输出:""" + SIMPLE_STRUCT_DOC_READER_PROMPT = """You are an expert text analyst for a search and retrieval system. Your task is to process a document chunk and generate a single, structured JSON object. @@ -866,10 +867,174 @@ Important: Output **only** the JSON. No extra text. """ +MEMORY_MERGE_PROMPT_EN = """You are a memory consolidation expert. Given a new memory and a set of similar existing memories, determine whether they should be merged. + +Before generating the value, you must complete the following reasoning steps (done in internal reasoning, no need to output them): +1. Identify the “fact units” contained in the new memory, for example: +• Identity-type facts: name, occupation, place of residence, etc. +• Stable preference-type facts: things the user likes/dislikes long-term, frequently visited places, etc. +• Relationship-type facts: relationships with someone (friend, colleague, fixed activity partner, etc.) +• One-off event/plan-type facts: events on a specific day, temporary plans for this weekend, etc. +2. For each fact unit, determine: +• Which existing memories are expressing “the same kind of fact” +• Whether the corresponding fact in the new memory is just a “repeated confirmation” of that fact, rather than “new factual content” + +Merge rules (must be followed when generating value): +• The merged value: +• Must not repeat the same meaning (each fact should be described only once) +• Must not repeat the same fact just because it was mentioned multiple times or at different times +• Unless time itself changes the meaning (for example, “used to dislike → now likes”), do not keep specific time information +• If the new memory contains multiple different types of facts (for example: “name + hobby + plan for this weekend”): +• You may output multiple merge results; each merge result should focus on only one type of fact (for example: one about “name”, one about “hobby”) +• Do not force unrelated facts into the same value +• One-off events/plans (such as “going skiing this weekend”, “attending a party on Sunday”): +• If there is no directly related and complementary event memory in the existing memories, treat it as an independent memory and do not merge it with identity/stable preference-type memories +• Do not merge a “temporary plan” and a “long-term preference” into the same value just because they are related (e.g. a plan to ski vs. a long-term preference for skiing) + +Output format requirements: +• You must return a single JSON object. +• If a merge occurred: +• “value”: The merged memory content (only describe the final conclusion, preserving all “semantically unique” information, without repetition) +• “merged_from”: A list of IDs of the similar memories that were merged +• “should_merge”: true +• If the new memory cannot be merged with any existing memories, return: +• “should_merge”: false + +Example: +New memory: +The user’s name is Tom, the user likes skiing, and plans to go skiing this weekend. + +Similar existing memories: +xxxx-xxxx-xxxx-xxxx-01: The user’s name is Tom +xxxx-xxxx-xxxx-xxxx-10: The user likes skiing +xxxx-xxxx-xxxx-xxxx-11: The user lives by the sea + +Expected return value: +{{ +"value": "The user's name is Tom and the user likes skiing", +"merged_from": ["xxxx-xxxx-xxxx-xxxx-01", "xxxx-xxxx-xxxx-xxxx-10"], +"should_merge": true +}} + +New memory: +The user is going to attend a party on Sunday. + +Similar existing memories: +xxxx-xxxx-xxxx-xxxx-01: The user read a book yesterday. + +Expected return value: +{{ +"should_merge": false +}} + +If the new memory largely overlaps with or complements the existing memories, merge them into an integrated memory and return a JSON object: +• “value”: The merged memory content +• “merged_from”: A list of IDs of the similar memories that were merged +• “should_merge”: true + +If the new memory is unique and should remain independent, return: +{{ +"should_merge": false +}} + +You must only return a valid JSON object in the final output, and no additional content (no natural language explanations, no extra fields). + +New memory: +{new_memory} + +Similar existing memories: +{similar_memories} + +Only return a valid JSON object, and do not include any other content. +""" + +MEMORY_MERGE_PROMPT_ZH = """ +你是一个记忆整合专家。给定一个新记忆和相似的现有记忆,判断它们是否应该合并。 + +在生成 value 之前,必须先完成以下判断步骤(在内在推理中完成,不需要输出): +1. 识别新记忆中包含的「事实单元」,例如: + - 身份信息类:名字、职业、居住地等 + - 稳定偏好类:长期喜欢/不喜欢的事物、常去地点等 + - 关系类:与某人的关系(朋友、同事、固定搭子等) + - 一次性事件/计划类:某天要参加的活动、本周末的临时安排等 +2. 对每个事实单元,判断: + - 哪些 existing memories 在表达“同一类事实”, + - 新记忆中对应的事实是否只是对该事实的「重复确认」,而不是“新的事实内容” + +合并规则(生成 value 时必须遵守): +- 合并后的 value: + - 不要重复表达同一语义(同一事实只描述一次) + - 不要因为多次提及或不同时间而重复同一事实 + - 除非时间本身改变了语义(例如“从不喜欢 → 现在开始喜欢”),否则不要保留具体时间信息 +- 如果新记忆中包含多个不同类型的事实(例如“名字 + 爱好 + 本周计划”): + - 不要合并就好 + - 不要把彼此无关的事实硬塞进同一个 value 中 +- 一次性事件/计划(如“本周末去滑雪”“周天参加聚会”): + - 如果 existing memories 中没有与之直接相关、可互补的事件记忆,则视为独立记忆,不要与身份/长期偏好类记忆合并 + - 不要因为它和某个长期偏好有关(例如喜欢滑雪),就把“临时计划”和“长期偏好”合在一个 value 里 + +输出格式要求: +- 你需要返回一个 JSON 对象。 +- 若发生了合并: + - "value": 合并后的记忆内容(只描述最终结论,保留所有「语义上独特」的信息,不重复) + - "merged_from": 被合并的相似记忆 ID 列表 + - "should_merge": true +- 若新记忆无法与现有记忆合并,返回: + - "should_merge": false + +示例: +新记忆: +用户的名字是Tom,用户喜欢滑雪,并计划周末去滑雪 + +相似的现有记忆: +xxxx-xxxx-xxxx-xxxx-01: 用户的名字是Tom +xxxx-xxxx-xxxx-xxxx-10: 用户喜欢滑雪 +xxxx-xxxx-xxxx-xxxx-11: 用户住在海边 + +应该的返回值: +{{ + "value": "用户的名字是Tom,用户喜欢滑雪", + "merged_from": ["xxxx-xxxx-xxxx-xxxx-01", "xxxx-xxxx-xxxx-xxxx-10"], + "should_merge": true +}} + +新记忆: +用户周天要参加一个聚会 + +相似的现有记忆: +xxxx-xxxx-xxxx-xxxx-01: 用户昨天读了一本书 + +应该的返回值: +{{ + "should_merge": false +}} + +如果新记忆与现有记忆大量重叠或互补,将它们合并为一个整合的记忆,并返回一个JSON对象: +- "value": 合并后的记忆内容 +- "merged_from": 被合并的相似记忆ID列表 +- "should_merge": true + +如果新记忆是独特的,应该保持独立,返回: +{{ + "should_merge": false +}} + +最终只返回有效的 JSON 对象,不要任何额外内容(不要自然语言解释、不要多余字段)。 + +新记忆: +{new_memory} + +相似的现有记忆: +{similar_memories} + +只返回有效的JSON对象,不要其他内容。""" + # Prompt mapping for specialized tasks (e.g., hallucination filtering) PROMPT_MAPPING = { "hallucination_filter": SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT, "rewrite": SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT, "rewrite_user_only": SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT, "add_before_search": SIMPLE_STRUCT_ADD_BEFORE_SEARCH_PROMPT, + "memory_merge_en": MEMORY_MERGE_PROMPT_EN, + "memory_merge_zh": MEMORY_MERGE_PROMPT_ZH, }