diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 8f1955d16..054c7a050 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1953,33 +1953,15 @@ def exist_user_name(self, user_name: str) -> dict[str, bool]: def delete_node_by_mem_cube_id( self, - mem_cube_id: dict | None = None, - delete_record_id: dict | None = None, + mem_cube_id: str | None = None, + delete_record_id: str | None = None, hard_delete: bool = False, ) -> int: - """ - Delete nodes by mem_cube_id (user_name) and delete_record_id. - - Args: - mem_cube_id: The mem_cube_id which corresponds to user_name in the table. - Can be dict or str. If dict, will extract the value. - delete_record_id: The delete_record_id to match. - Can be dict or str. If dict, will extract the value. - hard_delete: If True, performs hard delete (directly deletes records). - If False, performs soft delete (updates status to 'deleted' and sets delete_record_id and delete_time). - - Returns: - int: Number of nodes deleted or updated. - """ - # Handle dict type parameters (extract value if dict) - if isinstance(mem_cube_id, dict): - # Try to get a value from dict, use first value if multiple - mem_cube_id = next(iter(mem_cube_id.values())) if mem_cube_id else None - - if isinstance(delete_record_id, dict): - delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None + logger.info( + f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, " + f"delete_record_id:{delete_record_id}, hard_delete:{hard_delete}" + ) - # Validate required parameters if not mem_cube_id: logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided") return 0 @@ -1990,19 +1972,9 @@ def delete_node_by_mem_cube_id( ) return 0 - # Convert to string if needed - mem_cube_id = str(mem_cube_id) if mem_cube_id else None - delete_record_id = str(delete_record_id) if delete_record_id else None - - logger.info( - f"[delete_node_by_mem_cube_id] mem_cube_id={mem_cube_id}, " - f"delete_record_id={delete_record_id}, hard_delete={hard_delete}" - ) - try: with self.driver.session(database=self.db_name) as session: if hard_delete: - # Hard delete: WHERE user_name = mem_cube_id AND delete_record_id = $delete_record_id query = """ MATCH (n:Memory) WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id @@ -2019,12 +1991,13 @@ def delete_node_by_mem_cube_id( logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") return deleted_count else: - # Soft delete: WHERE user_name = mem_cube_id (only user_name condition) current_time = datetime.utcnow().isoformat() query = """ MATCH (n:Memory) WHERE n.user_name = $mem_cube_id + AND (n.delete_time IS NULL OR n.delete_time = "") + AND (n.delete_record_id IS NULL OR n.delete_record_id = "") SET n.status = $status, n.delete_record_id = $delete_record_id, n.delete_time = $delete_time @@ -2043,7 +2016,7 @@ def delete_node_by_mem_cube_id( updated_count = record["updated_count"] if record else 0 logger.info( - f"[delete_node_by_mem_cube_id] Soft deleted (updated) {updated_count} nodes" + f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes" ) return updated_count @@ -2058,33 +2031,22 @@ def recover_memory_by_mem_cube_id( mem_cube_id: str | None = None, delete_record_id: str | None = None, ) -> int: - """ - Recover memory nodes by mem_cube_id (user_name) and delete_record_id. - - This function updates the status to 'activated', and clears delete_record_id and delete_time. - - Args: - mem_cube_id: The mem_cube_id which corresponds to user_name in the table. - delete_record_id: The delete_record_id to match. - - Returns: - int: Number of nodes recovered (updated). - """ + logger.info( + f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}" + ) # Validate required parameters if not mem_cube_id: - logger.warning( - "[recover_memory_by_mem_cube_id] mem_cube_id is required but not provided" - ) + logger.warning("recover_memory_by_mem_cube_id mem_cube_id is required but not provided") return 0 if not delete_record_id: logger.warning( - "[recover_memory_by_mem_cube_id] delete_record_id is required but not provided" + "recover_memory_by_mem_cube_id delete_record_id is required but not provided" ) return 0 logger.info( - f"[recover_memory_by_mem_cube_id] mem_cube_id={mem_cube_id}, " + f"recover_memory_by_mem_cube_id mem_cube_id={mem_cube_id}, " f"delete_record_id={delete_record_id}" ) diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 411dbffe5..e34313fa2 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -620,7 +620,9 @@ def get_all_memory_items( with self.driver.session(database=self.db_name) as session: results = session.run(query, params) - return [self._parse_node(dict(record["n"])) for record in results] + nodes_data = [dict(record["n"]) for record in results] + # Use batch parsing to fetch all embeddings at once + return self._parse_nodes(nodes_data) def get_by_metadata( self, @@ -1057,6 +1059,53 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: new_node["metadata"]["embedding"] = None return new_node + def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Parse multiple Neo4j nodes and batch fetch embeddings from vector DB.""" + if not nodes_data: + return [] + + # First, parse all nodes without embeddings + parsed_nodes = [] + node_ids = [] + for node_data in nodes_data: + node = node_data.copy() + + # Convert Neo4j datetime to string + for time_field in ("created_at", "updated_at"): + if time_field in node and hasattr(node[time_field], "isoformat"): + node[time_field] = node[time_field].isoformat() + node.pop("user_name", None) + # serialization + if node.get("sources"): + for idx in range(len(node["sources"])): + if not ( + isinstance(node["sources"][idx], str) + and node["sources"][idx][0] == "{" + and node["sources"][idx][0] == "}" + ): + break + node["sources"][idx] = json.loads(node["sources"][idx]) + + node_id = node.pop("id") + node_ids.append(node_id) + parsed_nodes.append({"id": node_id, "memory": node.pop("memory", ""), "metadata": node}) + + # Batch fetch all embeddings at once + vec_items_map = {} + if node_ids: + try: + vec_items = self.vec_db.get_by_ids(node_ids) + vec_items_map = {v.id: v.vector for v in vec_items if v and v.vector} + except Exception as e: + logger.warning(f"Failed to batch fetch vectors for {len(node_ids)} nodes: {e}") + + # Merge embeddings into parsed nodes + for parsed_node in parsed_nodes: + node_id = parsed_node["id"] + parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id, None) + + return parsed_nodes + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]: """Get user names by memory ids. @@ -1111,3 +1160,159 @@ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True ) raise + + def delete_node_by_mem_cube_id( + self, + mem_cube_id: str | None = None, + delete_record_id: str | None = None, + hard_delete: bool = False, + ) -> int: + logger.info( + f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, " + f"delete_record_id:{delete_record_id}, hard_delete:{hard_delete}" + ) + + if not mem_cube_id: + logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided") + return 0 + + if not delete_record_id: + logger.warning( + "[delete_node_by_mem_cube_id] delete_record_id is required but not provided" + ) + return 0 + + try: + with self.driver.session(database=self.db_name) as session: + if hard_delete: + query_get_ids = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id + RETURN n.id AS id + """ + result = session.run( + query_get_ids, mem_cube_id=mem_cube_id, delete_record_id=delete_record_id + ) + node_ids = [record["id"] for record in result] + + # Delete from Neo4j + query = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id + DETACH DELETE n + """ + logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {query}") + + result = session.run( + query, mem_cube_id=mem_cube_id, delete_record_id=delete_record_id + ) + summary = result.consume() + deleted_count = summary.counters.nodes_deleted if summary.counters else 0 + + # Delete from vector DB + if node_ids and self.vec_db: + try: + self.vec_db.delete(node_ids) + logger.info( + f"[delete_node_by_mem_cube_id] Deleted {len(node_ids)} vectors from VecDB" + ) + except Exception as e: + logger.warning( + f"[delete_node_by_mem_cube_id] Failed to delete vectors from VecDB: {e}" + ) + + logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") + return deleted_count + else: + current_time = datetime.utcnow().isoformat() + + query = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_cube_id + AND (n.delete_time IS NULL OR n.delete_time = "") + AND (n.delete_record_id IS NULL OR n.delete_record_id = "") + SET n.status = $status, + n.delete_record_id = $delete_record_id, + n.delete_time = $delete_time + RETURN count(n) AS updated_count + """ + logger.info(f"[delete_node_by_mem_cube_id] Soft delete query: {query}") + + result = session.run( + query, + mem_cube_id=mem_cube_id, + status="deleted", + delete_record_id=delete_record_id, + delete_time=current_time, + ) + record = result.single() + updated_count = record["updated_count"] if record else 0 + + logger.info( + f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes" + ) + return updated_count + + except Exception as e: + logger.error( + f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True + ) + raise + + def recover_memory_by_mem_cube_id( + self, + mem_cube_id: str | None = None, + delete_record_id: str | None = None, + ) -> int: + logger.info( + f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}" + ) + # Validate required parameters + if not mem_cube_id: + logger.warning("recover_memory_by_mem_cube_id mem_cube_id is required but not provided") + return 0 + + if not delete_record_id: + logger.warning( + "recover_memory_by_mem_cube_id delete_record_id is required but not provided" + ) + return 0 + + logger.info( + f"recover_memory_by_mem_cube_id mem_cube_id={mem_cube_id}, " + f"delete_record_id={delete_record_id}" + ) + + try: + with self.driver.session(database=self.db_name) as session: + query = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id + SET n.status = $status, + n.delete_record_id = $delete_record_id_empty, + n.delete_time = $delete_time_empty + RETURN count(n) AS updated_count + """ + logger.info(f"[recover_memory_by_mem_cube_id] Update query: {query}") + + result = session.run( + query, + mem_cube_id=mem_cube_id, + delete_record_id=delete_record_id, + status="activated", + delete_record_id_empty="", + delete_time_empty="", + ) + record = result.single() + updated_count = record["updated_count"] if record else 0 + + logger.info( + f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes" + ) + return updated_count + + except Exception as e: + logger.error( + f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True + ) + raise diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 18778532f..409b3a967 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -5149,6 +5149,9 @@ def parse_filter( "info", "source", "file_ids", + "project_id", + "manager_user_id", + "delete_time", } def process_condition(condition): @@ -5470,27 +5473,15 @@ def escape_user_name(un: str) -> str: @timed def delete_node_by_mem_cube_id( self, - mem_cube_id: dict | None = None, - delete_record_id: dict | None = None, + mem_cube_id: str | None = None, + delete_record_id: str | None = None, hard_delete: bool = False, ) -> int: - """ - (inner) Delete memory nodes by mem_cube_id (user_name) and delete_record_id. Record id is inner field, just for delete and recover memory, not for user to set. - - Args: - mem_cube_id: The mem_cube_id which corresponds to user_name in the table. - delete_record_id: The delete_record_id to match. - hard_delete: Whether to hard delete the nodes. - """ - # Handle dict type parameters (extract value if dict) - if isinstance(mem_cube_id, dict): - # Try to get a value from dict, use first value if multiple - mem_cube_id = next(iter(mem_cube_id.values())) if mem_cube_id else None - - if isinstance(delete_record_id, dict): - delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None + logger.info( + f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, " + f"delete_record_id:{delete_record_id}, hard_delete:{hard_delete}" + ) - # Validate required parameters if not mem_cube_id: logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided") return 0 @@ -5501,32 +5492,18 @@ def delete_node_by_mem_cube_id( ) return 0 - # Convert to string if needed - mem_cube_id = str(mem_cube_id) if mem_cube_id else None - delete_record_id = str(delete_record_id) if delete_record_id else None - - logger.info( - f"[delete_node_by_mem_cube_id] mem_cube_id={mem_cube_id}, " - f"delete_record_id={delete_record_id}, hard_delete={hard_delete}" - ) - conn = None try: conn = self._get_connection() with conn.cursor() as cursor: - # Build WHERE clause for user_name using parameter binding - # user_name must match mem_cube_id user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - # Prepare parameter for user_name user_name_param = self.format_param_value(mem_cube_id) if hard_delete: - # Hard delete: WHERE user_name = mem_cube_id AND delete_record_id = $delete_record_id delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" where_clause = f"{user_name_condition} AND {delete_record_id_condition}" - # Prepare parameters for WHERE clause (user_name and delete_record_id) where_params = [user_name_param, self.format_param_value(delete_record_id)] delete_query = f""" @@ -5541,40 +5518,39 @@ def delete_node_by_mem_cube_id( logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") return deleted_count else: - # Soft delete: WHERE user_name = mem_cube_id (only user_name condition) - where_clause = user_name_condition + delete_time_empty_condition = ( + "(ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) IS NULL " + "OR ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) = '\"\"'::agtype)" + ) + delete_record_id_empty_condition = ( + "(ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) IS NULL " + "OR ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = '\"\"'::agtype)" + ) + where_clause = f"{user_name_condition} AND {delete_time_empty_condition} AND {delete_record_id_empty_condition}" current_time = datetime.utcnow().isoformat() - # Build update properties JSON with status, delete_time, and delete_record_id - # Use PostgreSQL JSONB merge operator (||) to update properties - # Convert agtype to jsonb, merge with new values, then convert back to agtype update_query = f""" UPDATE "{self.db_name}_graph"."Memory" SET properties = ( properties::jsonb || %s::jsonb - )::text::agtype + )::text::agtype, + deletetime = %s WHERE {where_clause} """ - # Create update JSON with the three fields to update update_properties = { "status": "deleted", "delete_time": current_time, "delete_record_id": delete_record_id, } logger.info( - f"[delete_node_by_mem_cube_id] Soft delete update_query: {update_query}" - ) - logger.info( - f"[delete_node_by_mem_cube_id] update_properties: {update_properties}" + f"delete_node_by_mem_cube_id Soft delete update_query:{update_query},update_properties:{update_properties},deletetime:{current_time}" ) - - # Combine update_properties JSON with user_name parameter (only user_name, no delete_record_id) - update_params = [json.dumps(update_properties), user_name_param] + update_params = [json.dumps(update_properties), current_time, user_name_param] cursor.execute(update_query, update_params) updated_count = cursor.rowcount logger.info( - f"[delete_node_by_mem_cube_id] Soft deleted (updated) {updated_count} nodes" + f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes" ) return updated_count @@ -5592,36 +5568,22 @@ def recover_memory_by_mem_cube_id( mem_cube_id: str | None = None, delete_record_id: str | None = None, ) -> int: - """ - (inner) Recover memory nodes by mem_cube_id (user_name) and delete_record_id. Record id is inner field, just for delete and recover memory, not for user to set. - - This function updates the status to 'activated', and clears delete_record_id and delete_time. - - Args: - mem_cube_id: The mem_cube_id which corresponds to user_name in the table. - delete_record_id: The delete_record_id to match. - - Returns: - int: Number of nodes recovered (updated). - """ logger.info( f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}" ) # Validate required parameters if not mem_cube_id: - logger.warning( - "[recover_memory_by_mem_cube_id] mem_cube_id is required but not provided" - ) + logger.warning("recover_memory_by_mem_cube_id mem_cube_id is required but not provided") return 0 if not delete_record_id: logger.warning( - "[recover_memory_by_mem_cube_id] delete_record_id is required but not provided" + "recover_memory_by_mem_cube_id delete_record_id is required but not provided" ) return 0 logger.info( - f"[recover_memory_by_mem_cube_id] mem_cube_id={mem_cube_id}, " + f"recover_memory_by_mem_cube_id mem_cube_id={mem_cube_id}, " f"delete_record_id={delete_record_id}" ) @@ -5629,19 +5591,15 @@ def recover_memory_by_mem_cube_id( try: conn = self._get_connection() with conn.cursor() as cursor: - # Build WHERE clause for user_name and delete_record_id using parameter binding user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" where_clause = f"{user_name_condition} AND {delete_record_id_condition}" - # Prepare parameters for WHERE clause where_params = [ self.format_param_value(mem_cube_id), self.format_param_value(delete_record_id), ] - # Build update properties: status='activated', delete_record_id='', delete_time='' - # Use PostgreSQL JSONB merge operator (||) to update properties update_properties = { "status": "activated", "delete_record_id": "", @@ -5652,7 +5610,8 @@ def recover_memory_by_mem_cube_id( UPDATE "{self.db_name}_graph"."Memory" SET properties = ( properties::jsonb || %s::jsonb - )::text::agtype + )::text::agtype, + deletetime = NULL WHERE {where_clause} """ @@ -5661,7 +5620,6 @@ def recover_memory_by_mem_cube_id( f"[recover_memory_by_mem_cube_id] update_properties: {update_properties}" ) - # Combine update_properties JSON with where_params update_params = [json.dumps(update_properties), *where_params] cursor.execute(update_query, update_params) updated_count = cursor.rowcount