Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 15 additions & 53 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,33 +1953,15 @@ def exist_user_name(self, user_name: str) -> dict[str, bool]:

def delete_node_by_mem_cube_id(
self,
mem_cube_id: dict | None = None,
delete_record_id: dict | None = None,
mem_cube_id: str | None = None,
delete_record_id: str | None = None,
hard_delete: bool = False,
) -> int:
"""
Delete nodes by mem_cube_id (user_name) and delete_record_id.

Args:
mem_cube_id: The mem_cube_id which corresponds to user_name in the table.
Can be dict or str. If dict, will extract the value.
delete_record_id: The delete_record_id to match.
Can be dict or str. If dict, will extract the value.
hard_delete: If True, performs hard delete (directly deletes records).
If False, performs soft delete (updates status to 'deleted' and sets delete_record_id and delete_time).

Returns:
int: Number of nodes deleted or updated.
"""
# Handle dict type parameters (extract value if dict)
if isinstance(mem_cube_id, dict):
# Try to get a value from dict, use first value if multiple
mem_cube_id = next(iter(mem_cube_id.values())) if mem_cube_id else None

if isinstance(delete_record_id, dict):
delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None
logger.info(
f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, "
f"delete_record_id:{delete_record_id}, hard_delete:{hard_delete}"
)

# Validate required parameters
if not mem_cube_id:
logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided")
return 0
Expand All @@ -1990,19 +1972,9 @@ def delete_node_by_mem_cube_id(
)
return 0

# Convert to string if needed
mem_cube_id = str(mem_cube_id) if mem_cube_id else None
delete_record_id = str(delete_record_id) if delete_record_id else None

logger.info(
f"[delete_node_by_mem_cube_id] mem_cube_id={mem_cube_id}, "
f"delete_record_id={delete_record_id}, hard_delete={hard_delete}"
)

try:
with self.driver.session(database=self.db_name) as session:
if hard_delete:
# Hard delete: WHERE user_name = mem_cube_id AND delete_record_id = $delete_record_id
query = """
MATCH (n:Memory)
WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id
Expand All @@ -2019,12 +1991,13 @@ def delete_node_by_mem_cube_id(
logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes")
return deleted_count
else:
# Soft delete: WHERE user_name = mem_cube_id (only user_name condition)
current_time = datetime.utcnow().isoformat()

query = """
MATCH (n:Memory)
WHERE n.user_name = $mem_cube_id
AND (n.delete_time IS NULL OR n.delete_time = "")
AND (n.delete_record_id IS NULL OR n.delete_record_id = "")
SET n.status = $status,
n.delete_record_id = $delete_record_id,
n.delete_time = $delete_time
Expand All @@ -2043,7 +2016,7 @@ def delete_node_by_mem_cube_id(
updated_count = record["updated_count"] if record else 0

logger.info(
f"[delete_node_by_mem_cube_id] Soft deleted (updated) {updated_count} nodes"
f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes"
)
return updated_count

Expand All @@ -2058,33 +2031,22 @@ def recover_memory_by_mem_cube_id(
mem_cube_id: str | None = None,
delete_record_id: str | None = None,
) -> int:
"""
Recover memory nodes by mem_cube_id (user_name) and delete_record_id.

This function updates the status to 'activated', and clears delete_record_id and delete_time.

Args:
mem_cube_id: The mem_cube_id which corresponds to user_name in the table.
delete_record_id: The delete_record_id to match.

Returns:
int: Number of nodes recovered (updated).
"""
logger.info(
f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}"
)
# Validate required parameters
if not mem_cube_id:
logger.warning(
"[recover_memory_by_mem_cube_id] mem_cube_id is required but not provided"
)
logger.warning("recover_memory_by_mem_cube_id mem_cube_id is required but not provided")
return 0

if not delete_record_id:
logger.warning(
"[recover_memory_by_mem_cube_id] delete_record_id is required but not provided"
"recover_memory_by_mem_cube_id delete_record_id is required but not provided"
)
return 0

logger.info(
f"[recover_memory_by_mem_cube_id] mem_cube_id={mem_cube_id}, "
f"recover_memory_by_mem_cube_id mem_cube_id={mem_cube_id}, "
f"delete_record_id={delete_record_id}"
)

Expand Down
207 changes: 206 additions & 1 deletion src/memos/graph_dbs/neo4j_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,9 @@ def get_all_memory_items(

with self.driver.session(database=self.db_name) as session:
results = session.run(query, params)
return [self._parse_node(dict(record["n"])) for record in results]
nodes_data = [dict(record["n"]) for record in results]
# Use batch parsing to fetch all embeddings at once
return self._parse_nodes(nodes_data)

def get_by_metadata(
self,
Expand Down Expand Up @@ -1057,6 +1059,53 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
new_node["metadata"]["embedding"] = None
return new_node

def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Parse multiple Neo4j nodes and batch fetch embeddings from vector DB."""
if not nodes_data:
return []

# First, parse all nodes without embeddings
parsed_nodes = []
node_ids = []
for node_data in nodes_data:
node = node_data.copy()

# Convert Neo4j datetime to string
for time_field in ("created_at", "updated_at"):
if time_field in node and hasattr(node[time_field], "isoformat"):
node[time_field] = node[time_field].isoformat()
node.pop("user_name", None)
# serialization
if node.get("sources"):
for idx in range(len(node["sources"])):
if not (
isinstance(node["sources"][idx], str)
and node["sources"][idx][0] == "{"
and node["sources"][idx][0] == "}"
):
break
node["sources"][idx] = json.loads(node["sources"][idx])

node_id = node.pop("id")
node_ids.append(node_id)
parsed_nodes.append({"id": node_id, "memory": node.pop("memory", ""), "metadata": node})

# Batch fetch all embeddings at once
vec_items_map = {}
if node_ids:
try:
vec_items = self.vec_db.get_by_ids(node_ids)
vec_items_map = {v.id: v.vector for v in vec_items if v and v.vector}
except Exception as e:
logger.warning(f"Failed to batch fetch vectors for {len(node_ids)} nodes: {e}")

# Merge embeddings into parsed nodes
for parsed_node in parsed_nodes:
node_id = parsed_node["id"]
parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id, None)

return parsed_nodes

def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]:
"""Get user names by memory ids.

Expand Down Expand Up @@ -1111,3 +1160,159 @@ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str |
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
)
raise

def delete_node_by_mem_cube_id(
self,
mem_cube_id: str | None = None,
delete_record_id: str | None = None,
hard_delete: bool = False,
) -> int:
logger.info(
f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, "
f"delete_record_id:{delete_record_id}, hard_delete:{hard_delete}"
)

if not mem_cube_id:
logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided")
return 0

if not delete_record_id:
logger.warning(
"[delete_node_by_mem_cube_id] delete_record_id is required but not provided"
)
return 0

try:
with self.driver.session(database=self.db_name) as session:
if hard_delete:
query_get_ids = """
MATCH (n:Memory)
WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id
RETURN n.id AS id
"""
result = session.run(
query_get_ids, mem_cube_id=mem_cube_id, delete_record_id=delete_record_id
)
node_ids = [record["id"] for record in result]

# Delete from Neo4j
query = """
MATCH (n:Memory)
WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id
DETACH DELETE n
"""
logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {query}")

result = session.run(
query, mem_cube_id=mem_cube_id, delete_record_id=delete_record_id
)
summary = result.consume()
deleted_count = summary.counters.nodes_deleted if summary.counters else 0

# Delete from vector DB
if node_ids and self.vec_db:
try:
self.vec_db.delete(node_ids)
logger.info(
f"[delete_node_by_mem_cube_id] Deleted {len(node_ids)} vectors from VecDB"
)
except Exception as e:
logger.warning(
f"[delete_node_by_mem_cube_id] Failed to delete vectors from VecDB: {e}"
)

logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes")
return deleted_count
else:
current_time = datetime.utcnow().isoformat()

query = """
MATCH (n:Memory)
WHERE n.user_name = $mem_cube_id
AND (n.delete_time IS NULL OR n.delete_time = "")
AND (n.delete_record_id IS NULL OR n.delete_record_id = "")
SET n.status = $status,
n.delete_record_id = $delete_record_id,
n.delete_time = $delete_time
RETURN count(n) AS updated_count
"""
logger.info(f"[delete_node_by_mem_cube_id] Soft delete query: {query}")

result = session.run(
query,
mem_cube_id=mem_cube_id,
status="deleted",
delete_record_id=delete_record_id,
delete_time=current_time,
)
record = result.single()
updated_count = record["updated_count"] if record else 0

logger.info(
f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes"
)
return updated_count

except Exception as e:
logger.error(
f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True
)
raise

def recover_memory_by_mem_cube_id(
self,
mem_cube_id: str | None = None,
delete_record_id: str | None = None,
) -> int:
logger.info(
f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}"
)
# Validate required parameters
if not mem_cube_id:
logger.warning("recover_memory_by_mem_cube_id mem_cube_id is required but not provided")
return 0

if not delete_record_id:
logger.warning(
"recover_memory_by_mem_cube_id delete_record_id is required but not provided"
)
return 0

logger.info(
f"recover_memory_by_mem_cube_id mem_cube_id={mem_cube_id}, "
f"delete_record_id={delete_record_id}"
)

try:
with self.driver.session(database=self.db_name) as session:
query = """
MATCH (n:Memory)
WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id
SET n.status = $status,
n.delete_record_id = $delete_record_id_empty,
n.delete_time = $delete_time_empty
RETURN count(n) AS updated_count
"""
logger.info(f"[recover_memory_by_mem_cube_id] Update query: {query}")

result = session.run(
query,
mem_cube_id=mem_cube_id,
delete_record_id=delete_record_id,
status="activated",
delete_record_id_empty="",
delete_time_empty="",
)
record = result.single()
updated_count = record["updated_count"] if record else 0

logger.info(
f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes"
)
return updated_count

except Exception as e:
logger.error(
f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True
)
raise
Loading