diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 4799542bf..f88824493 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2505,6 +2505,7 @@ def export_graph( self, include_embedding: bool = False, user_name: str | None = None, + user_id: str | None = None, page: int | None = None, page_size: int | None = None, **kwargs, @@ -2514,6 +2515,7 @@ def export_graph( Args: include_embedding (bool): Whether to include the large embedding field. user_name (str, optional): User name for filtering in non-multi-db mode + user_id (str, optional): User ID for filtering page (int, optional): Page number (starts from 1). If None, exports all data without pagination. page_size (int, optional): Number of items per page. If None, exports all data without pagination. @@ -2524,9 +2526,9 @@ def export_graph( } """ logger.info( - f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, page: {page}, page_size: {page_size}" + f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}" ) - user_name = user_name if user_name else self._get_config_value("user_name") + user_id = user_id if user_id else self._get_config_value("user_id") # Determine if pagination is needed use_pagination = page is not None and page_size is not None @@ -2550,11 +2552,26 @@ def export_graph( if use_pagination: pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + # Build WHERE conditions + where_conditions = [] + if user_name: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + ) + if user_id: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" + ) + + where_clause = "" + if where_conditions: + where_clause = f"WHERE {' AND '.join(where_conditions)}" + if include_embedding: node_query = f""" SELECT id, properties, embedding FROM "{self.db_name}_graph"."Memory" - WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype + {where_clause} ORDER BY id {pagination_clause} """ @@ -2562,7 +2579,7 @@ def export_graph( node_query = f""" SELECT id, properties FROM "{self.db_name}_graph"."Memory" - WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype + {where_clause} ORDER BY id {pagination_clause} """ @@ -2619,11 +2636,24 @@ def export_graph( if use_pagination: edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + # Build Cypher WHERE conditions for edges + cypher_where_conditions = [] + if user_name: + cypher_where_conditions.append(f"a.user_name = '{user_name}'") + cypher_where_conditions.append(f"b.user_name = '{user_name}'") + if user_id: + cypher_where_conditions.append(f"a.user_id = '{user_id}'") + cypher_where_conditions.append(f"b.user_id = '{user_id}'") + + cypher_where_clause = "" + if cypher_where_conditions: + cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}" + edge_query = f""" SELECT source, target, edge FROM ( SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (a:Memory)-[r]->(b:Memory) - WHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}' + {cypher_where_clause} RETURN a.id AS source, b.id AS target, type(r) as edge ORDER BY a.id, b.id $$) AS (source agtype, target agtype, edge agtype) @@ -3399,7 +3429,7 @@ def add_nodes_batch( logger.warning("[add_nodes_batch] Empty nodes list, skipping") return - logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes") + logger.info(f"[add_nodes_batch] Processing only first node (total nodes: {len(nodes)})") # user_name comes from parameter; fallback to config if missing effective_user_name = user_name if user_name else self.config.user_name @@ -3528,92 +3558,89 @@ def add_nodes_batch( if graph_id: node["properties"]["graph_id"] = str(graph_id) - # Batch insert using VALUES with multiple rows - # Use psycopg2.extras.execute_values for efficient batch insert - from psycopg2.extras import execute_values - - if embedding_column and any(node["embedding_vector"] for node in nodes_group): - # Prepare data tuples for batch insert with embedding - data_tuples = [] - for node in nodes_group: - # Each tuple: (id, properties_json, embedding_json) - data_tuples.append( - ( - node["id"], - json.dumps(node["properties"]), - json.dumps(node["embedding_vector"]) - if node["embedding_vector"] - else None, + # Use PREPARE/EXECUTE for efficient batch insert + # Generate unique prepare statement name to avoid conflicts + prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}" + + try: + if embedding_column and any( + node["embedding_vector"] for node in nodes_group + ): + # PREPARE statement for insert with embedding + prepare_query = f""" + PREPARE {prepare_name} AS + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), + $2::text::agtype, + $3::vector ) + """ + logger.info( + f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}" + ) + logger.info( + f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}" ) - # Build the INSERT query template - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) - VALUES %s - """ + cursor.execute(prepare_query) - # Build the VALUES template for execute_values - # Each row: (graph_id_function, agtype, vector) - # Note: properties column is agtype, not jsonb - template = f""" - ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s::text::agtype, - %s::vector - ) - """ - # Execute batch insert - execute_values( - cursor, - insert_query, - data_tuples, - template=template, - page_size=100, # Insert in batches of 100 - ) - else: - # Prepare data tuples for batch insert without embedding - data_tuples = [] - for node in nodes_group: - # Each tuple: (id, properties_json) - data_tuples.append( - ( - node["id"], - json.dumps(node["properties"]), + # Execute prepared statement for each node + for node in nodes_group: + properties_json = json.dumps(node["properties"]) + embedding_json = ( + json.dumps(node["embedding_vector"]) + if node["embedding_vector"] + else None ) + + cursor.execute( + f"EXECUTE {prepare_name}(%s, %s, %s)", + (node["id"], properties_json, embedding_json), + ) + else: + # PREPARE statement for insert without embedding + prepare_query = f""" + PREPARE {prepare_name} AS + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), + $2::text::agtype + ) + """ + logger.info( + f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}" ) + logger.info( + f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}" + ) + cursor.execute(prepare_query) - # Build the INSERT query template - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties) - VALUES %s - """ + # Execute prepared statement for each node + for node in nodes_group: + properties_json = json.dumps(node["properties"]) - # Build the VALUES template for execute_values - # Note: properties column is agtype, not jsonb - template = f""" - ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s::text::agtype + cursor.execute( + f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json) + ) + finally: + # DEALLOCATE prepared statement (always execute, even on error) + try: + cursor.execute(f"DEALLOCATE {prepare_name}") + logger.info( + f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}" + ) + except Exception as dealloc_error: + logger.warning( + f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" ) - """ - logger.info(f"[add_nodes_batch] Inserting insert_query:{insert_query}") - logger.info(f"[add_nodes_batch] Inserting data_tuples:{data_tuples}") - # Execute batch insert - execute_values( - cursor, - insert_query, - data_tuples, - template=template, - page_size=100, # Insert in batches of 100 - ) logger.info( f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" ) elapsed_time = time.time() - batch_start_time logger.info( - f"[add_nodes_batch] execute_values completed successfully in {elapsed_time:.2f}s" + f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s" ) except Exception as e: @@ -4887,7 +4914,6 @@ def delete_node_by_prams( memory_ids: list[str] | None = None, file_ids: list[str] | None = None, filter: dict | None = None, - batch_size: int = 100, ) -> int: """ Delete nodes by memory_ids, file_ids, or filter. @@ -4956,133 +4982,74 @@ def delete_node_by_prams( try: conn = self._get_connection() with conn.cursor() as cursor: - # Process memory_ids and filter_ids in batches + # Process memory_ids and filter_ids (all at once, no batching) if all_memory_ids: memory_ids_list = list(all_memory_ids) - total_batches = (len(memory_ids_list) + batch_size - 1) // batch_size logger.info( - f"[delete_node_by_prams] memoryids Processing {len(memory_ids_list)} memory_ids in {total_batches} batches (batch_size={batch_size})" + f"[delete_node_by_prams] Processing {len(memory_ids_list)} memory_ids" ) - for batch_idx in range(total_batches): - batch_start = batch_idx * batch_size - batch_end = min(batch_start + batch_size, len(memory_ids_list)) - batch_ids = memory_ids_list[batch_start:batch_end] - - # Build conditions for this batch - batch_conditions = [] - for node_id in batch_ids: - batch_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" - ) - batch_where = f"({' OR '.join(batch_conditions)})" - - # Add user_name filter if provided - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_clause = f"({user_name_where}) AND ({batch_where})" - else: - where_clause = batch_where - - # Count before deletion - count_query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info( - f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: count_query: {count_query}" + # Build conditions for all memory_ids + id_conditions = [] + for node_id in memory_ids_list: + id_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" ) + id_where = f"({' OR '.join(id_conditions)})" - cursor.execute(count_query) - count_result = cursor.fetchone() - expected_count = count_result[0] if count_result else 0 - - if expected_count == 0: - logger.info( - f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: No nodes found, skipping" - ) - continue + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_clause = f"({user_name_where}) AND ({id_where})" + else: + where_clause = id_where - # Delete batch - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info( - f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: delete_query: {delete_query}" - ) + # Delete directly without counting + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_prams] memory_ids delete_query: {delete_query}") - logger.info( - f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Executing delete query for {len(batch_ids)} nodes" - ) - cursor.execute(delete_query) - batch_deleted = cursor.rowcount - total_deleted_count += batch_deleted + cursor.execute(delete_query) + deleted_count = cursor.rowcount + total_deleted_count += deleted_count - logger.info( - f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Deleted {batch_deleted} nodes (batch size: {len(batch_ids)})" - ) - - # Process file_ids in batches - if file_ids: - total_file_batches = (len(file_ids) + batch_size - 1) // batch_size logger.info( - f"[delete_node_by_prams] Processing {len(file_ids)} file_ids in {total_file_batches} batches (batch_size={batch_size})" + f"[delete_node_by_prams] Deleted {deleted_count} nodes by memory_ids" ) - for batch_idx in range(total_file_batches): - batch_start = batch_idx * batch_size - batch_end = min(batch_start + batch_size, len(file_ids)) - batch_file_ids = file_ids[batch_start:batch_end] - - # Build conditions for this batch - batch_conditions = [] - for file_id in batch_file_ids: - batch_conditions.append( - f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" - ) - batch_where = f"({' OR '.join(batch_conditions)})" - - # Add user_name filter if provided - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_clause = f"({user_name_where}) AND ({batch_where})" - else: - where_clause = batch_where - - # Count before deletion - count_query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ + # Process file_ids (all at once, no batching) + if file_ids: + logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") - logger.info( - f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: count_query: {count_query}" + # Build conditions for all file_ids + file_id_conditions = [] + for file_id in file_ids: + file_id_conditions.append( + f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" ) - cursor.execute(count_query) - count_result = cursor.fetchone() - expected_count = count_result[0] if count_result else 0 + file_id_where = f"({' OR '.join(file_id_conditions)})" - if expected_count == 0: - logger.info( - f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: No nodes found, skipping" - ) - continue + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_clause = f"({user_name_where}) AND ({file_id_where})" + else: + where_clause = file_id_where - # Delete batch - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - cursor.execute(delete_query) - batch_deleted = cursor.rowcount - total_deleted_count += batch_deleted + # Delete directly without counting + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_prams] file_ids delete_query: {delete_query}") - logger.info( - f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: delete_query: {delete_query}" - ) + cursor.execute(delete_query) + deleted_count = cursor.rowcount + total_deleted_count += deleted_count + + logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes by file_ids") elapsed_time = time.time() - batch_start_time logger.info( @@ -5109,6 +5076,7 @@ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[ - 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing) - 'exist_user_names': List of distinct user names (if all memory_ids exist) """ + logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}") if not memory_ids: return {"exist_user_names": []}