Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions src/memos/graph_dbs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,85 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N
- metadata: dict[str, Any] - Node metadata
user_name: Optional user name (will use config default if not provided)
"""

@abstractmethod
def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]:
"""
Get edges connected to a node, with optional type and direction filter.
Args:
id: Node ID to retrieve edges for.
type: Relationship type to match, or 'ANY' to match all.
direction: 'OUTGOING', 'INCOMING', or 'ANY'.
Returns:
List of edge dicts with 'from', 'to', and 'type' keys.
"""

@abstractmethod
def search_by_fulltext(self, query_words: list[str], top_k: int = 10, **kwargs) -> list[dict]:
"""
Full-text search for memory nodes.
Args:
query_words: List of words to search for.
top_k: Maximum number of results.
Returns:
List of dicts with 'id' and 'score'.
"""

@abstractmethod
def get_neighbors_by_tag(
self,
tags: list[str],
exclude_ids: list[str],
top_k: int = 5,
min_overlap: int = 1,
**kwargs,
) -> list[dict[str, Any]]:
"""
Find top-K neighbor nodes with maximum tag overlap.
Args:
tags: Tags to match.
exclude_ids: Node IDs to exclude.
top_k: Max neighbors to return.
min_overlap: Minimum overlapping tags required.
Returns:
List of node dicts.
"""

@abstractmethod
def delete_node_by_prams(
self,
memory_ids: list[str] | None = None,
writable_cube_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
**kwargs,
) -> int:
"""
Delete nodes matching given parameters.
Returns:
Number of deleted nodes.
"""

@abstractmethod
def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> list[str]:
"""
Get distinct user names that own the given memory IDs.
"""

@abstractmethod
def exist_user_name(self, user_name: str) -> bool:
"""
Check if a user_name exists in the graph.
"""

@abstractmethod
def search_by_keywords_like(self, query_word: str, **kwargs) -> list[dict]:
"""
Search memories using SQL LIKE pattern matching.
"""

@abstractmethod
def search_by_keywords_tfidf(self, query_words: list[str], **kwargs) -> list[dict]:
"""
Search memories using TF-IDF fulltext scoring.
"""
9 changes: 1 addition & 8 deletions src/memos/graph_dbs/nebular.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from memos.configs.graph_db import NebulaGraphDBConfig
from memos.dependency import require_python_package
from memos.graph_dbs.base import BaseGraphDB
from memos.graph_dbs.utils import compose_node as _compose_node
from memos.log import get_logger
from memos.utils import timed

Expand Down Expand Up @@ -44,14 +45,6 @@ def _normalize(vec: list[float]) -> list[float]:
return (v / (norm if norm else 1.0)).tolist()


@timed
def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
node_id = item["id"]
memory = item["memory"]
metadata = item.get("metadata", {})
return node_id, memory, metadata


@timed
def _escape_str(value: str) -> str:
out = []
Expand Down
197 changes: 186 additions & 11 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,13 @@
from memos.configs.graph_db import Neo4jGraphDBConfig
from memos.dependency import require_python_package
from memos.graph_dbs.base import BaseGraphDB
from memos.graph_dbs.utils import compose_node as _compose_node
from memos.log import get_logger


logger = get_logger(__name__)


def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
node_id = item["id"]
memory = item["memory"]
metadata = item.get("metadata", {})
return node_id, memory, metadata


def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
"""
Ensure metadata has proper datetime fields and normalized types.
Expand Down Expand Up @@ -502,26 +496,36 @@ def edge_exists(
return result.single() is not None

# Graph Query & Reasoning
def get_node(self, id: str, **kwargs) -> dict[str, Any] | None:
def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None:
"""
Retrieve the metadata and memory of a node.
Args:
id: Node identifier.
Returns:
Dictionary of node fields, or None if not found.
"""
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
logger.info(f"[get_node] id: {id}")
user_name = kwargs.get("user_name")
where_user = ""
params = {"id": id}
if not self.config.use_multi_db and (self.config.user_name or user_name):
if user_name is not None:
where_user = " AND n.user_name = $user_name"
params["user_name"] = user_name

query = f"MATCH (n:Memory) WHERE n.id = $id {where_user} RETURN n"
logger.info(f"[get_node] query: {query}")

with self.driver.session(database=self.db_name) as session:
record = session.run(query, params).single()
return self._parse_node(dict(record["n"])) if record else None
if not record:
return None

node_dict = dict(record["n"])
if include_embedding is False:
for key in ("embedding", "embedding_1024", "embedding_3072", "embedding_768"):
node_dict.pop(key, None)

return self._parse_node(node_dict)

def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]:
"""
Expand Down Expand Up @@ -1940,3 +1944,174 @@ def exist_user_name(self, user_name: str) -> dict[str, bool]:
f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True
)
raise

def delete_node_by_mem_cube_id(
self,
mem_kube_id: dict | None = None,
delete_record_id: dict | None = None,
deleted_type: bool = False,
) -> int:
"""
Delete nodes by mem_kube_id (user_name) and delete_record_id.

Args:
mem_kube_id: The mem_kube_id which corresponds to user_name in the table.
Can be dict or str. If dict, will extract the value.
delete_record_id: The delete_record_id to match.
Can be dict or str. If dict, will extract the value.
deleted_type: If True, performs hard delete (directly deletes records).
If False, performs soft delete (updates status to 'deleted' and sets delete_record_id and delete_time).

Returns:
int: Number of nodes deleted or updated.
"""
# Handle dict type parameters (extract value if dict)
if isinstance(mem_kube_id, dict):
# Try to get a value from dict, use first value if multiple
mem_kube_id = next(iter(mem_kube_id.values())) if mem_kube_id else None

if isinstance(delete_record_id, dict):
delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None

# Validate required parameters
if not mem_kube_id:
logger.warning("[delete_node_by_mem_cube_id] mem_kube_id is required but not provided")
return 0

if not delete_record_id:
logger.warning(
"[delete_node_by_mem_cube_id] delete_record_id is required but not provided"
)
return 0

# Convert to string if needed
mem_kube_id = str(mem_kube_id) if mem_kube_id else None
delete_record_id = str(delete_record_id) if delete_record_id else None

logger.info(
f"[delete_node_by_mem_cube_id] mem_kube_id={mem_kube_id}, "
f"delete_record_id={delete_record_id}, deleted_type={deleted_type}"
)

try:
with self.driver.session(database=self.db_name) as session:
if deleted_type:
# Hard delete: WHERE user_name = mem_kube_id AND delete_record_id = $delete_record_id
query = """
MATCH (n:Memory)
WHERE n.user_name = $mem_kube_id AND n.delete_record_id = $delete_record_id
DETACH DELETE n
"""
logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {query}")

result = session.run(
query, mem_kube_id=mem_kube_id, delete_record_id=delete_record_id
)
summary = result.consume()
deleted_count = summary.counters.nodes_deleted if summary.counters else 0

logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes")
return deleted_count
else:
# Soft delete: WHERE user_name = mem_kube_id (only user_name condition)
current_time = datetime.utcnow().isoformat()

query = """
MATCH (n:Memory)
WHERE n.user_name = $mem_kube_id
SET n.status = $status,
n.delete_record_id = $delete_record_id,
n.delete_time = $delete_time
RETURN count(n) AS updated_count
"""
logger.info(f"[delete_node_by_mem_cube_id] Soft delete query: {query}")

result = session.run(
query,
mem_kube_id=mem_kube_id,
status="deleted",
delete_record_id=delete_record_id,
delete_time=current_time,
)
record = result.single()
updated_count = record["updated_count"] if record else 0

logger.info(
f"[delete_node_by_mem_cube_id] Soft deleted (updated) {updated_count} nodes"
)
return updated_count

except Exception as e:
logger.error(
f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True
)
raise

def recover_memory_by_mem_kube_id(
self,
mem_kube_id: str | None = None,
delete_record_id: str | None = None,
) -> int:
"""
Recover memory nodes by mem_kube_id (user_name) and delete_record_id.

This function updates the status to 'activated', and clears delete_record_id and delete_time.

Args:
mem_kube_id: The mem_kube_id which corresponds to user_name in the table.
delete_record_id: The delete_record_id to match.

Returns:
int: Number of nodes recovered (updated).
"""
# Validate required parameters
if not mem_kube_id:
logger.warning(
"[recover_memory_by_mem_kube_id] mem_kube_id is required but not provided"
)
return 0

if not delete_record_id:
logger.warning(
"[recover_memory_by_mem_kube_id] delete_record_id is required but not provided"
)
return 0

logger.info(
f"[recover_memory_by_mem_kube_id] mem_kube_id={mem_kube_id}, "
f"delete_record_id={delete_record_id}"
)

try:
with self.driver.session(database=self.db_name) as session:
query = """
MATCH (n:Memory)
WHERE n.user_name = $mem_kube_id AND n.delete_record_id = $delete_record_id
SET n.status = $status,
n.delete_record_id = $delete_record_id_empty,
n.delete_time = $delete_time_empty
RETURN count(n) AS updated_count
"""
logger.info(f"[recover_memory_by_mem_kube_id] Update query: {query}")

result = session.run(
query,
mem_kube_id=mem_kube_id,
delete_record_id=delete_record_id,
status="activated",
delete_record_id_empty="",
delete_time_empty="",
)
record = result.single()
updated_count = record["updated_count"] if record else 0

logger.info(
f"[recover_memory_by_mem_kube_id] Recovered (updated) {updated_count} nodes"
)
return updated_count

except Exception as e:
logger.error(
f"[recover_memory_by_mem_kube_id] Failed to recover nodes: {e}", exc_info=True
)
raise
Loading