From 0f59cd8e04c7af541b3204ceb465a10bdbabc291 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Thu, 25 Dec 2025 17:12:17 +0800 Subject: [PATCH] feat: add batch delete --- src/memos/graph_dbs/polardb.py | 219 +++++++++++++++++++++------------ 1 file changed, 137 insertions(+), 82 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 1d19dc98d..b29dd26ce 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -4869,6 +4869,7 @@ def delete_node_by_prams( memory_ids: list[str] | None = None, file_ids: list[str] | None = None, filter: dict | None = None, + batch_size: int = 100, ) -> int: """ Delete nodes by memory_ids, file_ids, or filter. @@ -4898,31 +4899,6 @@ def delete_node_by_prams( f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype" ) - # Build WHERE conditions separately for memory_ids and file_ids - where_conditions = [] - - # Handle memory_ids: query properties.id - if memory_ids and len(memory_ids) > 0: - memory_id_conditions = [] - for node_id in memory_ids: - memory_id_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" - ) - if memory_id_conditions: - where_conditions.append(f"({' OR '.join(memory_id_conditions)})") - - # Check if any file_id is in the file_ids array field (OR relationship) - if file_ids and len(file_ids) > 0: - file_id_conditions = [] - for file_id in file_ids: - # Format: agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '"file_ids"'::agtype]), '"file_id"'::agtype) - file_id_conditions.append( - f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" - ) - if file_id_conditions: - # Use OR to match any file_id in the array - where_conditions.append(f"({' OR '.join(file_id_conditions)})") - # Query nodes by filter if provided filter_ids = set() if filter: @@ -4943,77 +4919,156 @@ def delete_node_by_prams( "[delete_node_by_prams] Filter parsed to None, skipping filter query" ) - # If filter returned IDs, add condition for them + # Combine all IDs that need to be deleted + all_memory_ids = set() + if memory_ids: + all_memory_ids.update(memory_ids) if filter_ids: - filter_id_conditions = [] - for node_id in filter_ids: - filter_id_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" - ) - if filter_id_conditions: - where_conditions.append(f"({' OR '.join(filter_id_conditions)})") + all_memory_ids.update(filter_ids) - # If no conditions (except user_name), return 0 - if not where_conditions: + # If no conditions to delete, return 0 + if not all_memory_ids and not file_ids: logger.warning( "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" ) return 0 - # Build WHERE clause - # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match) - data_conditions = " OR ".join([f"({cond})" for cond in where_conditions]) + conn = None + total_deleted_count = 0 + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Process memory_ids and filter_ids in batches + if all_memory_ids: + memory_ids_list = list(all_memory_ids) + total_batches = (len(memory_ids_list) + batch_size - 1) // batch_size + logger.info( + f"[delete_node_by_prams] memoryids Processing {len(memory_ids_list)} memory_ids in {total_batches} batches (batch_size={batch_size})" + ) - # Build final WHERE clause - # If user_name_conditions exist, combine with data_conditions using AND - # Otherwise, use only data_conditions - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_clause = f"({user_name_where}) AND ({data_conditions})" - else: - where_clause = f"({data_conditions})" + for batch_idx in range(total_batches): + batch_start = batch_idx * batch_size + batch_end = min(batch_start + batch_size, len(memory_ids_list)) + batch_ids = memory_ids_list[batch_start:batch_end] - # Use SQL DELETE query for better performance - # First count matching nodes to get accurate count - count_query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[delete_node_by_prams] count_query: {count_query}") + # Build conditions for this batch + batch_conditions = [] + for node_id in batch_ids: + batch_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" + ) + batch_where = f"({' OR '.join(batch_conditions)})" - # Then delete nodes - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_clause = f"({user_name_where}) AND ({batch_where})" + else: + where_clause = batch_where - logger.info( - f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" - ) - logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") + # Count before deletion + count_query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info( + f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: count_query: {count_query}" + ) - conn = None - deleted_count = 0 - try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Count nodes before deletion - cursor.execute(count_query) - count_result = cursor.fetchone() - expected_count = count_result[0] if count_result else 0 + cursor.execute(count_query) + count_result = cursor.fetchone() + expected_count = count_result[0] if count_result else 0 - logger.info( - f"[delete_node_by_prams] Found {expected_count} nodes matching the criteria" - ) + if expected_count == 0: + logger.info( + f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: No nodes found, skipping" + ) + continue + + # Delete batch + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info( + f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: delete_query: {delete_query}" + ) + + logger.info( + f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Executing delete query for {len(batch_ids)} nodes" + ) + cursor.execute(delete_query) + batch_deleted = cursor.rowcount + total_deleted_count += batch_deleted + + logger.info( + f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Deleted {batch_deleted} nodes (batch size: {len(batch_ids)})" + ) + + # Process file_ids in batches + if file_ids: + total_file_batches = (len(file_ids) + batch_size - 1) // batch_size + logger.info( + f"[delete_node_by_prams] Processing {len(file_ids)} file_ids in {total_file_batches} batches (batch_size={batch_size})" + ) + + for batch_idx in range(total_file_batches): + batch_start = batch_idx * batch_size + batch_end = min(batch_start + batch_size, len(file_ids)) + batch_file_ids = file_ids[batch_start:batch_end] + + # Build conditions for this batch + batch_conditions = [] + for file_id in batch_file_ids: + batch_conditions.append( + f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" + ) + batch_where = f"({' OR '.join(batch_conditions)})" + + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_clause = f"({user_name_where}) AND ({batch_where})" + else: + where_clause = batch_where + + # Count before deletion + count_query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + + logger.info( + f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: count_query: {count_query}" + ) + cursor.execute(count_query) + count_result = cursor.fetchone() + expected_count = count_result[0] if count_result else 0 + + if expected_count == 0: + logger.info( + f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: No nodes found, skipping" + ) + continue + + # Delete batch + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + cursor.execute(delete_query) + batch_deleted = cursor.rowcount + total_deleted_count += batch_deleted + + logger.info( + f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: delete_query: {delete_query}" + ) - # Delete nodes - cursor.execute(delete_query) - # Use rowcount to get actual deleted count - deleted_count = cursor.rowcount elapsed_time = time.time() - batch_start_time logger.info( - f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, deleted {deleted_count} nodes" + f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes" ) except Exception as e: logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) @@ -5021,8 +5076,8 @@ def delete_node_by_prams( finally: self._return_connection(conn) - logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes") - return deleted_count + logger.info(f"[delete_node_by_prams] Successfully deleted {total_deleted_count} nodes") + return total_deleted_count @timed def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]: