Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 137 additions & 82 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4869,6 +4869,7 @@ def delete_node_by_prams(
memory_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
batch_size: int = 100,
) -> int:
"""
Delete nodes by memory_ids, file_ids, or filter.
Expand Down Expand Up @@ -4898,31 +4899,6 @@ def delete_node_by_prams(
f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype"
)

# Build WHERE conditions separately for memory_ids and file_ids
where_conditions = []

# Handle memory_ids: query properties.id
if memory_ids and len(memory_ids) > 0:
memory_id_conditions = []
for node_id in memory_ids:
memory_id_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
)
if memory_id_conditions:
where_conditions.append(f"({' OR '.join(memory_id_conditions)})")

# Check if any file_id is in the file_ids array field (OR relationship)
if file_ids and len(file_ids) > 0:
file_id_conditions = []
for file_id in file_ids:
# Format: agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '"file_ids"'::agtype]), '"file_id"'::agtype)
file_id_conditions.append(
f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)"
)
if file_id_conditions:
# Use OR to match any file_id in the array
where_conditions.append(f"({' OR '.join(file_id_conditions)})")

# Query nodes by filter if provided
filter_ids = set()
if filter:
Expand All @@ -4943,86 +4919,165 @@ def delete_node_by_prams(
"[delete_node_by_prams] Filter parsed to None, skipping filter query"
)

# If filter returned IDs, add condition for them
# Combine all IDs that need to be deleted
all_memory_ids = set()
if memory_ids:
all_memory_ids.update(memory_ids)
if filter_ids:
filter_id_conditions = []
for node_id in filter_ids:
filter_id_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
)
if filter_id_conditions:
where_conditions.append(f"({' OR '.join(filter_id_conditions)})")
all_memory_ids.update(filter_ids)

# If no conditions (except user_name), return 0
if not where_conditions:
# If no conditions to delete, return 0
if not all_memory_ids and not file_ids:
logger.warning(
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
)
return 0

# Build WHERE clause
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
data_conditions = " OR ".join([f"({cond})" for cond in where_conditions])
conn = None
total_deleted_count = 0
try:
conn = self._get_connection()
with conn.cursor() as cursor:
# Process memory_ids and filter_ids in batches
if all_memory_ids:
memory_ids_list = list(all_memory_ids)
total_batches = (len(memory_ids_list) + batch_size - 1) // batch_size
logger.info(
f"[delete_node_by_prams] memoryids Processing {len(memory_ids_list)} memory_ids in {total_batches} batches (batch_size={batch_size})"
)

# Build final WHERE clause
# If user_name_conditions exist, combine with data_conditions using AND
# Otherwise, use only data_conditions
if user_name_conditions:
user_name_where = " OR ".join(user_name_conditions)
where_clause = f"({user_name_where}) AND ({data_conditions})"
else:
where_clause = f"({data_conditions})"
for batch_idx in range(total_batches):
batch_start = batch_idx * batch_size
batch_end = min(batch_start + batch_size, len(memory_ids_list))
batch_ids = memory_ids_list[batch_start:batch_end]

# Use SQL DELETE query for better performance
# First count matching nodes to get accurate count
count_query = f"""
SELECT COUNT(*)
FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""
logger.info(f"[delete_node_by_prams] count_query: {count_query}")
# Build conditions for this batch
batch_conditions = []
for node_id in batch_ids:
batch_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
)
batch_where = f"({' OR '.join(batch_conditions)})"

# Then delete nodes
delete_query = f"""
DELETE FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""
# Add user_name filter if provided
if user_name_conditions:
user_name_where = " OR ".join(user_name_conditions)
where_clause = f"({user_name_where}) AND ({batch_where})"
else:
where_clause = batch_where

logger.info(
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
)
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
# Count before deletion
count_query = f"""
SELECT COUNT(*)
FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""
logger.info(
f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: count_query: {count_query}"
)

conn = None
deleted_count = 0
try:
conn = self._get_connection()
with conn.cursor() as cursor:
# Count nodes before deletion
cursor.execute(count_query)
count_result = cursor.fetchone()
expected_count = count_result[0] if count_result else 0
cursor.execute(count_query)
count_result = cursor.fetchone()
expected_count = count_result[0] if count_result else 0

logger.info(
f"[delete_node_by_prams] Found {expected_count} nodes matching the criteria"
)
if expected_count == 0:
logger.info(
f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: No nodes found, skipping"
)
continue

# Delete batch
delete_query = f"""
DELETE FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""
logger.info(
f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: delete_query: {delete_query}"
)

logger.info(
f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Executing delete query for {len(batch_ids)} nodes"
)
cursor.execute(delete_query)
batch_deleted = cursor.rowcount
total_deleted_count += batch_deleted

logger.info(
f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Deleted {batch_deleted} nodes (batch size: {len(batch_ids)})"
)

# Process file_ids in batches
if file_ids:
total_file_batches = (len(file_ids) + batch_size - 1) // batch_size
logger.info(
f"[delete_node_by_prams] Processing {len(file_ids)} file_ids in {total_file_batches} batches (batch_size={batch_size})"
)

for batch_idx in range(total_file_batches):
batch_start = batch_idx * batch_size
batch_end = min(batch_start + batch_size, len(file_ids))
batch_file_ids = file_ids[batch_start:batch_end]

# Build conditions for this batch
batch_conditions = []
for file_id in batch_file_ids:
batch_conditions.append(
f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)"
)
batch_where = f"({' OR '.join(batch_conditions)})"

# Add user_name filter if provided
if user_name_conditions:
user_name_where = " OR ".join(user_name_conditions)
where_clause = f"({user_name_where}) AND ({batch_where})"
else:
where_clause = batch_where

# Count before deletion
count_query = f"""
SELECT COUNT(*)
FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""

logger.info(
f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: count_query: {count_query}"
)
cursor.execute(count_query)
count_result = cursor.fetchone()
expected_count = count_result[0] if count_result else 0

if expected_count == 0:
logger.info(
f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: No nodes found, skipping"
)
continue

# Delete batch
delete_query = f"""
DELETE FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""
cursor.execute(delete_query)
batch_deleted = cursor.rowcount
total_deleted_count += batch_deleted

logger.info(
f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: delete_query: {delete_query}"
)

# Delete nodes
cursor.execute(delete_query)
# Use rowcount to get actual deleted count
deleted_count = cursor.rowcount
elapsed_time = time.time() - batch_start_time
logger.info(
f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, deleted {deleted_count} nodes"
f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes"
)
except Exception as e:
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
raise
finally:
self._return_connection(conn)

logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
return deleted_count
logger.info(f"[delete_node_by_prams] Successfully deleted {total_deleted_count} nodes")
return total_deleted_count

@timed
def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]:
Expand Down