diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index e943616da..f0be3d858 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -5,7 +5,7 @@ from typing import Any from memos.configs.graph_db import Neo4jGraphDBConfig -from memos.graph_dbs.neo4j import Neo4jGraphDB, _prepare_node_metadata +from memos.graph_dbs.neo4j import Neo4jGraphDB, _flatten_info_fields, _prepare_node_metadata from memos.log import get_logger from memos.vec_dbs.factory import VecDBFactory from memos.vec_dbs.item import VecDBItem @@ -104,6 +104,108 @@ def add_node( metadata=metadata, ) + def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = None) -> None: + if not nodes: + logger.warning("[add_nodes_batch] Empty nodes list, skipping") + return + + effective_user_name = user_name if user_name else self.config.user_name + + vec_items: list[VecDBItem] = [] + prepared_nodes: list[dict[str, Any]] = [] + + for node_data in nodes: + try: + node_id = node_data.get("id") + memory = node_data.get("memory") + metadata = node_data.get("metadata", {}) + + if node_id is None or memory is None: + logger.warning("[add_nodes_batch] Skip invalid node: missing id/memory") + continue + + if not self.config.use_multi_db and (self.config.user_name or effective_user_name): + metadata["user_name"] = effective_user_name + + metadata = _prepare_node_metadata(metadata) + metadata = _flatten_info_fields(metadata) + + embedding = metadata.pop("embedding", None) + if embedding is None: + raise ValueError(f"Missing 'embedding' in metadata for node {node_id}") + + vector_sync_status = "success" + vec_items.append( + VecDBItem( + id=node_id, + vector=embedding, + payload={ + "memory": memory, + "vector_sync": vector_sync_status, + **metadata, + }, + ) + ) + + created_at = metadata.pop("created_at") + updated_at = metadata.pop("updated_at") + metadata["vector_sync"] = vector_sync_status + + prepared_nodes.append( + { + "id": node_id, + "memory": memory, + "created_at": created_at, + "updated_at": updated_at, + "metadata": metadata, + } + ) + except Exception as e: + logger.error( + f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}", + exc_info=True, + ) + continue + + if not prepared_nodes: + logger.warning("[add_nodes_batch] No valid nodes to insert after preparation") + return + + try: + self.vec_db.add(vec_items) + except Exception as e: + logger.warning(f"[VecDB] batch insert failed: {e}") + for node in prepared_nodes: + node["metadata"]["vector_sync"] = "failed" + + query = """ + UNWIND $nodes AS node + MERGE (n:Memory {id: node.id}) + SET n.memory = node.memory, + n.created_at = datetime(node.created_at), + n.updated_at = datetime(node.updated_at), + n += node.metadata + """ + + nodes_data = [ + { + "id": node["id"], + "memory": node["memory"], + "created_at": node["created_at"], + "updated_at": node["updated_at"], + "metadata": node["metadata"], + } + for node in prepared_nodes + ] + + try: + with self.driver.session(database=self.db_name) as session: + session.run(query, nodes=nodes_data) + logger.info(f"[add_nodes_batch] Successfully inserted {len(prepared_nodes)} nodes") + except Exception as e: + logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True) + raise + def get_children_with_embeddings( self, id: str, user_name: str | None = None ) -> list[dict[str, Any]]: