From 132b56d661496936ee7ecb50622be1c10b27ce25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Thu, 5 Feb 2026 10:12:43 +0800 Subject: [PATCH 1/4] feat: optimize polardb && neo4j --- src/memos/graph_dbs/neo4j.py | 102 +++------ src/memos/graph_dbs/neo4j_community.py | 298 +++++++++++++++++++++---- src/memos/graph_dbs/polardb.py | 12 - 3 files changed, 287 insertions(+), 125 deletions(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 2bd2e5a46..ef9673c2b 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1953,35 +1953,17 @@ def exist_user_name(self, user_name: str) -> dict[str, bool]: def delete_node_by_mem_cube_id( self, - mem_kube_id: dict | None = None, - delete_record_id: dict | None = None, - deleted_type: bool = False, + mem_cube_id: str | None = None, + delete_record_id: str | None = None, + hard_delete: bool = False, ) -> int: - """ - Delete nodes by mem_kube_id (user_name) and delete_record_id. - - Args: - mem_kube_id: The mem_kube_id which corresponds to user_name in the table. - Can be dict or str. If dict, will extract the value. - delete_record_id: The delete_record_id to match. - Can be dict or str. If dict, will extract the value. - deleted_type: If True, performs hard delete (directly deletes records). - If False, performs soft delete (updates status to 'deleted' and sets delete_record_id and delete_time). - - Returns: - int: Number of nodes deleted or updated. - """ - # Handle dict type parameters (extract value if dict) - if isinstance(mem_kube_id, dict): - # Try to get a value from dict, use first value if multiple - mem_kube_id = next(iter(mem_kube_id.values())) if mem_kube_id else None - - if isinstance(delete_record_id, dict): - delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None + logger.info( + f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, " + f"delete_record_id:{delete_record_id}, hard_delete:{hard_delete}" + ) - # Validate required parameters - if not mem_kube_id: - logger.warning("[delete_node_by_mem_cube_id] mem_kube_id is required but not provided") + if not mem_cube_id: + logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided") return 0 if not delete_record_id: @@ -1990,28 +1972,18 @@ def delete_node_by_mem_cube_id( ) return 0 - # Convert to string if needed - mem_kube_id = str(mem_kube_id) if mem_kube_id else None - delete_record_id = str(delete_record_id) if delete_record_id else None - - logger.info( - f"[delete_node_by_mem_cube_id] mem_kube_id={mem_kube_id}, " - f"delete_record_id={delete_record_id}, deleted_type={deleted_type}" - ) - try: with self.driver.session(database=self.db_name) as session: - if deleted_type: - # Hard delete: WHERE user_name = mem_kube_id AND delete_record_id = $delete_record_id + if hard_delete: query = """ MATCH (n:Memory) - WHERE n.user_name = $mem_kube_id AND n.delete_record_id = $delete_record_id + WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id DETACH DELETE n """ logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {query}") result = session.run( - query, mem_kube_id=mem_kube_id, delete_record_id=delete_record_id + query, mem_cube_id=mem_cube_id, delete_record_id=delete_record_id ) summary = result.consume() deleted_count = summary.counters.nodes_deleted if summary.counters else 0 @@ -2019,12 +1991,13 @@ def delete_node_by_mem_cube_id( logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") return deleted_count else: - # Soft delete: WHERE user_name = mem_kube_id (only user_name condition) current_time = datetime.utcnow().isoformat() query = """ MATCH (n:Memory) - WHERE n.user_name = $mem_kube_id + WHERE n.user_name = $mem_cube_id + AND (n.delete_time IS NULL OR n.delete_time = "") + AND (n.delete_record_id IS NULL OR n.delete_record_id = "") SET n.status = $status, n.delete_record_id = $delete_record_id, n.delete_time = $delete_time @@ -2034,7 +2007,7 @@ def delete_node_by_mem_cube_id( result = session.run( query, - mem_kube_id=mem_kube_id, + mem_cube_id=mem_cube_id, status="deleted", delete_record_id=delete_record_id, delete_time=current_time, @@ -2043,7 +2016,7 @@ def delete_node_by_mem_cube_id( updated_count = record["updated_count"] if record else 0 logger.info( - f"[delete_node_by_mem_cube_id] Soft deleted (updated) {updated_count} nodes" + f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes" ) return updated_count @@ -2053,38 +2026,27 @@ def delete_node_by_mem_cube_id( ) raise - def recover_memory_by_mem_kube_id( + def recover_memory_by_mem_cube_id( self, - mem_kube_id: str | None = None, + mem_cube_id: str | None = None, delete_record_id: str | None = None, ) -> int: - """ - Recover memory nodes by mem_kube_id (user_name) and delete_record_id. - - This function updates the status to 'activated', and clears delete_record_id and delete_time. - - Args: - mem_kube_id: The mem_kube_id which corresponds to user_name in the table. - delete_record_id: The delete_record_id to match. - - Returns: - int: Number of nodes recovered (updated). - """ + logger.info( + f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}" + ) # Validate required parameters - if not mem_kube_id: - logger.warning( - "[recover_memory_by_mem_kube_id] mem_kube_id is required but not provided" - ) + if not mem_cube_id: + logger.warning("recover_memory_by_mem_cube_id mem_cube_id is required but not provided") return 0 if not delete_record_id: logger.warning( - "[recover_memory_by_mem_kube_id] delete_record_id is required but not provided" + "recover_memory_by_mem_cube_id delete_record_id is required but not provided" ) return 0 logger.info( - f"[recover_memory_by_mem_kube_id] mem_kube_id={mem_kube_id}, " + f"recover_memory_by_mem_cube_id mem_cube_id={mem_cube_id}, " f"delete_record_id={delete_record_id}" ) @@ -2092,17 +2054,17 @@ def recover_memory_by_mem_kube_id( with self.driver.session(database=self.db_name) as session: query = """ MATCH (n:Memory) - WHERE n.user_name = $mem_kube_id AND n.delete_record_id = $delete_record_id + WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id SET n.status = $status, n.delete_record_id = $delete_record_id_empty, n.delete_time = $delete_time_empty RETURN count(n) AS updated_count """ - logger.info(f"[recover_memory_by_mem_kube_id] Update query: {query}") + logger.info(f"[recover_memory_by_mem_cube_id] Update query: {query}") result = session.run( query, - mem_kube_id=mem_kube_id, + mem_cube_id=mem_cube_id, delete_record_id=delete_record_id, status="activated", delete_record_id_empty="", @@ -2112,12 +2074,12 @@ def recover_memory_by_mem_kube_id( updated_count = record["updated_count"] if record else 0 logger.info( - f"[recover_memory_by_mem_kube_id] Recovered (updated) {updated_count} nodes" + f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes" ) return updated_count except Exception as e: logger.error( - f"[recover_memory_by_mem_kube_id] Failed to recover nodes: {e}", exc_info=True + f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True ) - raise + raise \ No newline at end of file diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 411dbffe5..63fed3d65 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -10,7 +10,6 @@ from memos.vec_dbs.factory import VecDBFactory from memos.vec_dbs.item import VecDBItem - logger = get_logger(__name__) @@ -34,11 +33,11 @@ def __init__(self, config: Neo4jGraphDBConfig): super().__init__(config) def create_index( - self, - label: str = "Memory", - vector_property: str = "embedding", - dimensions: int = 1536, - index_name: str = "memory_vector_index", + self, + label: str = "Memory", + vector_property: str = "embedding", + dimensions: int = 1536, + index_name: str = "memory_vector_index", ) -> None: """ Create the vector index for embedding and datetime indexes for created_at and updated_at fields. @@ -47,7 +46,7 @@ def create_index( self._create_basic_property_indexes() def add_node( - self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None ) -> None: user_name = user_name if user_name else self.config.user_name if not self.config.use_multi_db and (self.config.user_name or user_name): @@ -216,7 +215,7 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N raise def get_children_with_embeddings( - self, id: str, user_name: str | None = None + self, id: str, user_name: str | None = None ) -> list[dict[str, Any]]: user_name = user_name if user_name else self.config.user_name where_user = "" @@ -248,17 +247,17 @@ def get_children_with_embeddings( # Search / recall operations def search_by_embedding( - self, - vector: list[float], - top_k: int = 5, - scope: str | None = None, - status: str | None = None, - threshold: float | None = None, - search_filter: dict | None = None, - user_name: str | None = None, - filter: dict | None = None, - knowledgebase_ids: list[str] | None = None, - **kwargs, + self, + vector: list[float], + top_k: int = 5, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, ) -> list[dict]: """ Retrieve node IDs based on vector similarity using external vector DB. @@ -420,10 +419,10 @@ def _normalize_date_string(self, date_str: str) -> str: return date_str def _build_filter_conditions_cypher( - self, - filter: dict | None, - param_counter_start: int = 0, - node_alias: str = "node", + self, + filter: dict | None, + param_counter_start: int = 0, + node_alias: str = "node", ) -> tuple[list[str], dict[str, Any]]: """ Build filter conditions for Cypher queries with date normalization. @@ -546,11 +545,11 @@ def _normalize_condition_dates(self, condition: dict) -> dict: return normalized def get_all_memory_items( - self, - scope: str, - filter: dict | None = None, - knowledgebase_ids: list[str] | None = None, - **kwargs, + self, + scope: str, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, ) -> list[dict]: """ Retrieve all memory items of a specific memory_type. @@ -620,14 +619,16 @@ def get_all_memory_items( with self.driver.session(database=self.db_name) as session: results = session.run(query, params) - return [self._parse_node(dict(record["n"])) for record in results] + nodes_data = [dict(record["n"]) for record in results] + # Use batch parsing to fetch all embeddings at once + return self._parse_nodes(nodes_data) def get_by_metadata( - self, - filters: list[dict[str, Any]], - user_name: str | None = None, - filter: dict | None = None, - knowledgebase_ids: list[str] | None = None, + self, + filters: list[dict[str, Any]], + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, ) -> list[str]: """ Retrieve node IDs that match given metadata filters. @@ -713,7 +714,7 @@ def get_by_metadata( if filter: # Helper function to build a single filter condition def build_filter_condition( - condition_dict: dict, param_counter: list + condition_dict: dict, param_counter: list ) -> tuple[str, dict]: """Build a WHERE condition for a single filter item. @@ -818,11 +819,11 @@ def build_filter_condition( return [record["id"] for record in result] def delete_node_by_prams( - self, - writable_cube_ids: list[str], - memory_ids: list[str] | None = None, - file_ids: list[str] | None = None, - filter: dict | None = None, + self, + writable_cube_ids: list[str], + memory_ids: list[str] | None = None, + file_ids: list[str] | None = None, + filter: dict | None = None, ) -> int: """ Delete nodes by memory_ids, file_ids, or filter. @@ -1041,9 +1042,9 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: if node["sources"]: for idx in range(len(node["sources"])): if not ( - isinstance(node["sources"][idx], str) - and node["sources"][idx][0] == "{" - and node["sources"][idx][0] == "}" + isinstance(node["sources"][idx], str) + and node["sources"][idx][0] == "{" + and node["sources"][idx][0] == "}" ): break node["sources"][idx] = json.loads(node["sources"][idx]) @@ -1057,6 +1058,60 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: new_node["metadata"]["embedding"] = None return new_node + def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Parse multiple Neo4j nodes and batch fetch embeddings from vector DB.""" + if not nodes_data: + return [] + + # First, parse all nodes without embeddings + parsed_nodes = [] + node_ids = [] + for node_data in nodes_data: + node = node_data.copy() + + # Convert Neo4j datetime to string + for time_field in ("created_at", "updated_at"): + if time_field in node and hasattr(node[time_field], "isoformat"): + node[time_field] = node[time_field].isoformat() + node.pop("user_name", None) + # serialization + if node.get("sources"): + for idx in range(len(node["sources"])): + if not ( + isinstance(node["sources"][idx], str) + and node["sources"][idx][0] == "{" + and node["sources"][idx][0] == "}" + ): + break + node["sources"][idx] = json.loads(node["sources"][idx]) + + node_id = node.pop("id") + node_ids.append(node_id) + parsed_nodes.append({ + "id": node_id, + "memory": node.pop("memory", ""), + "metadata": node + }) + + # Batch fetch all embeddings at once + vec_items_map = {} + if node_ids: + try: + vec_items = self.vec_db.get_by_ids(node_ids) + vec_items_map = {v.id: v.vector for v in vec_items if v and v.vector} + except Exception as e: + logger.warning(f"Failed to batch fetch vectors for {len(node_ids)} nodes: {e}") + + # Merge embeddings into parsed nodes + for parsed_node in parsed_nodes: + node_id = parsed_node["id"] + if node_id in vec_items_map: + parsed_node["metadata"]["embedding"] = vec_items_map[node_id] + else: + parsed_node["metadata"]["embedding"] = None + + return parsed_nodes + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]: """Get user names by memory ids. @@ -1111,3 +1166,160 @@ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True ) raise + + def delete_node_by_mem_cube_id( + self, + mem_cube_id: str | None = None, + delete_record_id: str | None = None, + hard_delete: bool = False, + ) -> int: + logger.info( + f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, " + f"delete_record_id:{delete_record_id}, hard_delete:{hard_delete}" + ) + + if not mem_cube_id: + logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided") + return 0 + + if not delete_record_id: + logger.warning( + "[delete_node_by_mem_cube_id] delete_record_id is required but not provided" + ) + return 0 + + try: + with self.driver.session(database=self.db_name) as session: + if hard_delete: + query_get_ids = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id + RETURN n.id AS id + """ + result = session.run( + query_get_ids, mem_cube_id=mem_cube_id, delete_record_id=delete_record_id + ) + node_ids = [record["id"] for record in result] + + # Delete from Neo4j + query = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id + DETACH DELETE n + """ + logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {query}") + + result = session.run( + query, mem_cube_id=mem_cube_id, delete_record_id=delete_record_id + ) + summary = result.consume() + deleted_count = summary.counters.nodes_deleted if summary.counters else 0 + + # Delete from vector DB + if node_ids and self.vec_db: + try: + self.vec_db.delete(node_ids) + logger.info( + f"[delete_node_by_mem_cube_id] Deleted {len(node_ids)} vectors from VecDB" + ) + except Exception as e: + logger.warning( + f"[delete_node_by_mem_cube_id] Failed to delete vectors from VecDB: {e}" + ) + + logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") + return deleted_count + else: + current_time = datetime.utcnow().isoformat() + + query = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_cube_id + AND (n.delete_time IS NULL OR n.delete_time = "") + AND (n.delete_record_id IS NULL OR n.delete_record_id = "") + SET n.status = $status, + n.delete_record_id = $delete_record_id, + n.delete_time = $delete_time + RETURN count(n) AS updated_count + """ + logger.info(f"[delete_node_by_mem_cube_id] Soft delete query: {query}") + + result = session.run( + query, + mem_cube_id=mem_cube_id, + status="deleted", + delete_record_id=delete_record_id, + delete_time=current_time, + ) + record = result.single() + updated_count = record["updated_count"] if record else 0 + + logger.info( + f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes" + ) + return updated_count + + except Exception as e: + logger.error( + f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True + ) + raise + + def recover_memory_by_mem_cube_id( + self, + mem_cube_id: str | None = None, + delete_record_id: str | None = None, + ) -> int: + + logger.info( + f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}" + ) + # Validate required parameters + if not mem_cube_id: + logger.warning("recover_memory_by_mem_cube_id mem_cube_id is required but not provided") + return 0 + + if not delete_record_id: + logger.warning( + "recover_memory_by_mem_cube_id delete_record_id is required but not provided" + ) + return 0 + + logger.info( + f"recover_memory_by_mem_cube_id mem_cube_id={mem_cube_id}, " + f"delete_record_id={delete_record_id}" + ) + + try: + with self.driver.session(database=self.db_name) as session: + query = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id + SET n.status = $status, + n.delete_record_id = $delete_record_id_empty, + n.delete_time = $delete_time_empty + RETURN count(n) AS updated_count + """ + logger.info(f"[recover_memory_by_mem_cube_id] Update query: {query}") + + result = session.run( + query, + mem_cube_id=mem_cube_id, + delete_record_id=delete_record_id, + status="activated", + delete_record_id_empty="", + delete_time_empty="", + ) + record = result.single() + updated_count = record["updated_count"] if record else 0 + + logger.info( + f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes" + ) + return updated_count + + except Exception as e: + logger.error( + f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True + ) + raise diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 0cdd089e4..fe6f0a024 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -5565,18 +5565,6 @@ def recover_memory_by_mem_cube_id( mem_cube_id: str | None = None, delete_record_id: str | None = None, ) -> int: - """ - Recover memory nodes by mem_cube_id (user_name) and delete_record_id. - - This function updates the status to 'activated', and clears delete_record_id and delete_time. - - Args: - mem_cube_id: The mem_cube_id which corresponds to user_name in the table. - delete_record_id: The delete_record_id to match. - - Returns: - int: Number of nodes recovered (updated). - """ logger.info( f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}" ) From 9add6a34cb54b276f2bab382667976e760ce672e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Thu, 5 Feb 2026 21:32:06 +0800 Subject: [PATCH 2/4] fix: full_fields --- src/memos/graph_dbs/polardb.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index fe6f0a024..409b3a967 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -5149,6 +5149,9 @@ def parse_filter( "info", "source", "file_ids", + "project_id", + "manager_user_id", + "delete_time", } def process_condition(condition): From 118a73bc4265f6aaff11cdc5220957bbc85b5f09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Thu, 5 Feb 2026 23:49:29 +0800 Subject: [PATCH 3/4] fix: optimize --- src/memos/graph_dbs/neo4j_community.py | 115 ++++++++++++------------- 1 file changed, 54 insertions(+), 61 deletions(-) diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 63fed3d65..e34313fa2 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -10,6 +10,7 @@ from memos.vec_dbs.factory import VecDBFactory from memos.vec_dbs.item import VecDBItem + logger = get_logger(__name__) @@ -33,11 +34,11 @@ def __init__(self, config: Neo4jGraphDBConfig): super().__init__(config) def create_index( - self, - label: str = "Memory", - vector_property: str = "embedding", - dimensions: int = 1536, - index_name: str = "memory_vector_index", + self, + label: str = "Memory", + vector_property: str = "embedding", + dimensions: int = 1536, + index_name: str = "memory_vector_index", ) -> None: """ Create the vector index for embedding and datetime indexes for created_at and updated_at fields. @@ -46,7 +47,7 @@ def create_index( self._create_basic_property_indexes() def add_node( - self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None ) -> None: user_name = user_name if user_name else self.config.user_name if not self.config.use_multi_db and (self.config.user_name or user_name): @@ -215,7 +216,7 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N raise def get_children_with_embeddings( - self, id: str, user_name: str | None = None + self, id: str, user_name: str | None = None ) -> list[dict[str, Any]]: user_name = user_name if user_name else self.config.user_name where_user = "" @@ -247,17 +248,17 @@ def get_children_with_embeddings( # Search / recall operations def search_by_embedding( - self, - vector: list[float], - top_k: int = 5, - scope: str | None = None, - status: str | None = None, - threshold: float | None = None, - search_filter: dict | None = None, - user_name: str | None = None, - filter: dict | None = None, - knowledgebase_ids: list[str] | None = None, - **kwargs, + self, + vector: list[float], + top_k: int = 5, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, ) -> list[dict]: """ Retrieve node IDs based on vector similarity using external vector DB. @@ -419,10 +420,10 @@ def _normalize_date_string(self, date_str: str) -> str: return date_str def _build_filter_conditions_cypher( - self, - filter: dict | None, - param_counter_start: int = 0, - node_alias: str = "node", + self, + filter: dict | None, + param_counter_start: int = 0, + node_alias: str = "node", ) -> tuple[list[str], dict[str, Any]]: """ Build filter conditions for Cypher queries with date normalization. @@ -545,11 +546,11 @@ def _normalize_condition_dates(self, condition: dict) -> dict: return normalized def get_all_memory_items( - self, - scope: str, - filter: dict | None = None, - knowledgebase_ids: list[str] | None = None, - **kwargs, + self, + scope: str, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, ) -> list[dict]: """ Retrieve all memory items of a specific memory_type. @@ -624,11 +625,11 @@ def get_all_memory_items( return self._parse_nodes(nodes_data) def get_by_metadata( - self, - filters: list[dict[str, Any]], - user_name: str | None = None, - filter: dict | None = None, - knowledgebase_ids: list[str] | None = None, + self, + filters: list[dict[str, Any]], + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, ) -> list[str]: """ Retrieve node IDs that match given metadata filters. @@ -714,7 +715,7 @@ def get_by_metadata( if filter: # Helper function to build a single filter condition def build_filter_condition( - condition_dict: dict, param_counter: list + condition_dict: dict, param_counter: list ) -> tuple[str, dict]: """Build a WHERE condition for a single filter item. @@ -819,11 +820,11 @@ def build_filter_condition( return [record["id"] for record in result] def delete_node_by_prams( - self, - writable_cube_ids: list[str], - memory_ids: list[str] | None = None, - file_ids: list[str] | None = None, - filter: dict | None = None, + self, + writable_cube_ids: list[str], + memory_ids: list[str] | None = None, + file_ids: list[str] | None = None, + filter: dict | None = None, ) -> int: """ Delete nodes by memory_ids, file_ids, or filter. @@ -1042,9 +1043,9 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: if node["sources"]: for idx in range(len(node["sources"])): if not ( - isinstance(node["sources"][idx], str) - and node["sources"][idx][0] == "{" - and node["sources"][idx][0] == "}" + isinstance(node["sources"][idx], str) + and node["sources"][idx][0] == "{" + and node["sources"][idx][0] == "}" ): break node["sources"][idx] = json.loads(node["sources"][idx]) @@ -1078,20 +1079,16 @@ def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]] if node.get("sources"): for idx in range(len(node["sources"])): if not ( - isinstance(node["sources"][idx], str) - and node["sources"][idx][0] == "{" - and node["sources"][idx][0] == "}" + isinstance(node["sources"][idx], str) + and node["sources"][idx][0] == "{" + and node["sources"][idx][0] == "}" ): break node["sources"][idx] = json.loads(node["sources"][idx]) node_id = node.pop("id") node_ids.append(node_id) - parsed_nodes.append({ - "id": node_id, - "memory": node.pop("memory", ""), - "metadata": node - }) + parsed_nodes.append({"id": node_id, "memory": node.pop("memory", ""), "metadata": node}) # Batch fetch all embeddings at once vec_items_map = {} @@ -1105,10 +1102,7 @@ def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]] # Merge embeddings into parsed nodes for parsed_node in parsed_nodes: node_id = parsed_node["id"] - if node_id in vec_items_map: - parsed_node["metadata"]["embedding"] = vec_items_map[node_id] - else: - parsed_node["metadata"]["embedding"] = None + parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id, None) return parsed_nodes @@ -1168,10 +1162,10 @@ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | raise def delete_node_by_mem_cube_id( - self, - mem_cube_id: str | None = None, - delete_record_id: str | None = None, - hard_delete: bool = False, + self, + mem_cube_id: str | None = None, + delete_record_id: str | None = None, + hard_delete: bool = False, ) -> int: logger.info( f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, " @@ -1266,11 +1260,10 @@ def delete_node_by_mem_cube_id( raise def recover_memory_by_mem_cube_id( - self, - mem_cube_id: str | None = None, - delete_record_id: str | None = None, + self, + mem_cube_id: str | None = None, + delete_record_id: str | None = None, ) -> int: - logger.info( f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}" ) From 5aee334650951e33ba88c437614d04685318f525 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Fri, 6 Feb 2026 00:03:05 +0800 Subject: [PATCH 4/4] fix: format --- src/memos/graph_dbs/neo4j.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index ef9673c2b..054c7a050 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -2082,4 +2082,4 @@ def recover_memory_by_mem_cube_id( logger.error( f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True ) - raise \ No newline at end of file + raise