Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/memos/api/handlers/component_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def init_server() -> dict[str, Any]:
else None
)
embedder = EmbedderFactory.from_config(embedder_config)
mem_reader = MemReaderFactory.from_config(mem_reader_config)
# Pass graph_db to mem_reader for recall operations (deduplication, conflict detection)
mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db)
reranker = RerankerFactory.from_config(reranker_config)
feedback_reranker = RerankerFactory.from_config(feedback_reranker_config)
internet_retriever = InternetRetrieverFactory.from_config(
Expand Down
19 changes: 14 additions & 5 deletions src/memos/graph_dbs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
"""

@abstractmethod
def update_node(self, id: str, fields: dict[str, Any]) -> None:
def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None:
"""
Update attributes of an existing node.
Args:
id: Node identifier to be updated.
fields: Dictionary of fields to update.
user_name: given user_name
"""

@abstractmethod
Expand Down Expand Up @@ -70,7 +71,7 @@ def edge_exists(self, source_id: str, target_id: str, type: str) -> bool:

# Graph Query & Reasoning
@abstractmethod
def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | None:
def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None:
"""
Retrieve the metadata and content of a node.
Args:
Expand All @@ -82,7 +83,7 @@ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] |

@abstractmethod
def get_nodes(
self, id: str, include_embedding: bool = False, **kwargs
self, ids: list, include_embedding: bool = False, **kwargs
) -> dict[str, Any] | None:
"""
Retrieve the metadata and memory of a list of nodes.
Expand Down Expand Up @@ -160,13 +161,17 @@ def search_by_embedding(self, vector: list[float], top_k: int = 5, **kwargs) ->
"""

@abstractmethod
def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
def get_by_metadata(
self, filters: list[dict[str, Any]], status: str | None = None
) -> list[str]:
"""
Retrieve node IDs that match given metadata filters.

Args:
filters (dict[str, Any]): A dictionary of attribute-value filters.
Example: {"topic": "psychology", "importance": 2}
status (str, optional): Filter by status (e.g., 'activated', 'archived').
If None, no status filter is applied.

Returns:
list[str]: Node IDs whose metadata match the filter conditions.
Expand Down Expand Up @@ -239,13 +244,17 @@ def import_graph(self, data: dict[str, Any]) -> None:
"""

@abstractmethod
def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> list[dict]:
def get_all_memory_items(
self, scope: str, include_embedding: bool = False, status: str | None = None
) -> list[dict]:
"""
Retrieve all memory items of a specific memory_type.

Args:
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
include_embedding: with/without embedding
status (str, optional): Filter by status (e.g., 'activated', 'archived').
If None, no status filter is applied.

Returns:
list[dict]: Full list of memory items under this scope.
Expand Down
28 changes: 23 additions & 5 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,7 @@ def get_by_metadata(
filter: dict | None = None,
knowledgebase_ids: list[str] | None = None,
user_name_flag: bool = True,
status: str | None = None,
) -> list[str]:
"""
TODO:
Expand All @@ -933,6 +934,8 @@ def get_by_metadata(
{"field": "tags", "op": "contains", "value": "AI"},
...
]
status (str, optional): Filter by status (e.g., 'activated', 'archived').
If None, no status filter is applied.

Returns:
list[str]: Node IDs whose metadata match the filter conditions. (AND logic).
Expand All @@ -942,15 +945,20 @@ def get_by_metadata(
- Can be used for faceted recall or prefiltering before embedding rerank.
"""
logger.info(
f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}"
f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}"
)
print(
f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}"
f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}"
)
user_name = user_name if user_name else self.config.user_name
where_clauses = []
params = {}

# Add status filter if provided
if status:
where_clauses.append("n.status = $status")
params["status"] = status

for i, f in enumerate(filters):
field = f["field"]
op = f.get("op", "=")
Expand Down Expand Up @@ -1272,27 +1280,32 @@ def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> No
def get_all_memory_items(
self,
scope: str,
include_embedding: bool = False,
filter: dict | None = None,
knowledgebase_ids: list[str] | None = None,
status: str | None = None,
**kwargs,
) -> list[dict]:
"""
Retrieve all memory items of a specific memory_type.

Args:
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
include_embedding (bool): Whether to include embedding in results.
filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results.
Example: {"and": [{"id": "xxx"}, {"A": "yyy"}]} or {"or": [{"id": "xxx"}, {"A": "yyy"}]}
Returns:
knowledgebase_ids (list[str], optional): List of knowledgebase IDs to filter by.
status (str, optional): Filter by status (e.g., 'activated', 'archived').
If None, no status filter is applied.

Returns:
list[dict]: Full list of memory items under this scope.
"""
logger.info(
f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids}"
f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}"
)
print(
f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids}"
f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}"
)

user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
Expand All @@ -1302,6 +1315,11 @@ def get_all_memory_items(
where_clauses = ["n.memory_type = $scope"]
params = {"scope": scope}

# Add status filter if provided
if status:
where_clauses.append("n.status = $status")
params["status"] = status

# Build user_name filter with knowledgebase_ids support (OR relationship) using common method
user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher(
user_name=user_name,
Expand Down
11 changes: 10 additions & 1 deletion src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2823,6 +2823,7 @@ def get_all_memory_items(
user_name: str | None = None,
filter: dict | None = None,
knowledgebase_ids: list | None = None,
status: str | None = None,
) -> list[dict]:
"""
Retrieve all memory items of a specific memory_type.
Expand All @@ -2831,12 +2832,16 @@ def get_all_memory_items(
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
include_embedding: with/without embedding
user_name (str, optional): User name for filtering in non-multi-db mode
filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results.
knowledgebase_ids (list, optional): List of knowledgebase IDs to filter by.
status (str, optional): Filter by status (e.g., 'activated', 'archived').
If None, no status filter is applied.

Returns:
list[dict]: Full list of memory items under this scope.
"""
logger.info(
f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}"
f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status}"
)

user_name = user_name if user_name else self._get_config_value("user_name")
Expand Down Expand Up @@ -2867,6 +2872,8 @@ def get_all_memory_items(
if include_embedding:
# Build WHERE clause with user_name/knowledgebase_ids and filter
where_parts = [f"n.memory_type = '{scope}'"]
if status:
where_parts.append(f"n.status = '{status}'")
if user_name_where:
# user_name_where already contains parentheses if it's an OR condition
where_parts.append(user_name_where)
Expand Down Expand Up @@ -2927,6 +2934,8 @@ def get_all_memory_items(
else:
# Build WHERE clause with user_name/knowledgebase_ids and filter
where_parts = [f"n.memory_type = '{scope}'"]
if status:
where_parts.append(f"n.status = '{status}'")
if user_name_where:
# user_name_where already contains parentheses if it's an OR condition
where_parts.append(user_name_where)
Expand Down
3 changes: 2 additions & 1 deletion src/memos/mem_feedback/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def __init__(self, config: MemFeedbackConfig):
self.llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config(config.extractor_llm)
self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder)
self.graph_store: PolarDBGraphDB = GraphStoreFactory.from_config(config.graph_db)
self.mem_reader = MemReaderFactory.from_config(config.mem_reader)
# Pass graph_store to mem_reader for recall operations (deduplication, conflict detection)
self.mem_reader = MemReaderFactory.from_config(config.mem_reader, graph_db=self.graph_store)

self.is_reorganize = config.reorganize
self.memory_manager: MemoryManager = MemoryManager(
Expand Down
23 changes: 22 additions & 1 deletion src/memos/mem_reader/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,38 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import TYPE_CHECKING, Any

from memos.configs.mem_reader import BaseMemReaderConfig
from memos.memories.textual.item import TextualMemoryItem


if TYPE_CHECKING:
from memos.graph_dbs.base import BaseGraphDB


class BaseMemReader(ABC):
"""MemReader interface class for reading information."""

# Optional graph database for recall operations (for deduplication, conflict
# detection .etc)
graph_db: "BaseGraphDB | None" = None

@abstractmethod
def __init__(self, config: BaseMemReaderConfig):
"""Initialize the MemReader with the given configuration."""

@abstractmethod
def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None:
"""
Set the graph database instance for recall operations.

This enables the mem-reader to perform:
- Semantic deduplication: avoid storing duplicate memories
- Conflict detection: detect contradictions with existing memories

Args:
graph_db: The graph database instance, or None to disable recall operations.
"""

@abstractmethod
def get_memory(
self, scene_data: list, type: str, info: dict[str, Any], mode: str = "fast"
Expand Down
32 changes: 29 additions & 3 deletions src/memos/mem_reader/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, Optional

from memos.configs.mem_reader import MemReaderConfigFactory
from memos.mem_reader.base import BaseMemReader
Expand All @@ -8,6 +8,10 @@
from memos.memos_tools.singleton import singleton_factory


if TYPE_CHECKING:
from memos.graph_dbs.base import BaseGraphDB


class MemReaderFactory(BaseMemReader):
"""Factory class for creating MemReader instances."""

Expand All @@ -19,9 +23,31 @@ class MemReaderFactory(BaseMemReader):

@classmethod
@singleton_factory()
def from_config(cls, config_factory: MemReaderConfigFactory) -> BaseMemReader:
def from_config(
cls,
config_factory: MemReaderConfigFactory,
graph_db: Optional["BaseGraphDB | None"] = None,
) -> BaseMemReader:
"""
Create a MemReader instance from configuration.

Args:
config_factory: Configuration factory for the MemReader.
graph_db: Optional graph database instance for recall operations
(deduplication, conflict detection). Can also be set later
via reader.set_graph_db().

Returns:
Configured MemReader instance.
"""
backend = config_factory.backend
if backend not in cls.backend_to_class:
raise ValueError(f"Invalid backend: {backend}")
reader_class = cls.backend_to_class[backend]
return reader_class(config_factory.config)
reader = reader_class(config_factory.config)

# Set graph_db if provided (for recall operations)
if graph_db is not None:
reader.set_graph_db(graph_db)

return reader
Loading