Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 41 additions & 66 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5470,21 +5470,17 @@ def escape_user_name(un: str) -> str:
@timed
def delete_node_by_mem_cube_id(
self,
mem_kube_id: dict | None = None,
delete_record_id: dict | None = None,
deleted_type: bool = False,
mem_cube_id: str | None = None,
delete_record_id: str | None = None,
hard_delete: bool = False,
) -> int:
# Handle dict type parameters (extract value if dict)
if isinstance(mem_kube_id, dict):
# Try to get a value from dict, use first value if multiple
mem_kube_id = next(iter(mem_kube_id.values())) if mem_kube_id else None

if isinstance(delete_record_id, dict):
delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None
logger.info(
f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, "
f"delete_record_id:{delete_record_id}, hard_delete:{hard_delete}"
)

# Validate required parameters
if not mem_kube_id:
logger.warning("[delete_node_by_mem_cube_id] mem_kube_id is required but not provided")
if not mem_cube_id:
logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided")
return 0

if not delete_record_id:
Expand All @@ -5493,32 +5489,18 @@ def delete_node_by_mem_cube_id(
)
return 0

# Convert to string if needed
mem_kube_id = str(mem_kube_id) if mem_kube_id else None
delete_record_id = str(delete_record_id) if delete_record_id else None

logger.info(
f"[delete_node_by_mem_cube_id] mem_kube_id={mem_kube_id}, "
f"delete_record_id={delete_record_id}, deleted_type={deleted_type}"
)

conn = None
try:
conn = self._get_connection()
with conn.cursor() as cursor:
# Build WHERE clause for user_name using parameter binding
# user_name must match mem_kube_id
user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"

# Prepare parameter for user_name
user_name_param = self.format_param_value(mem_kube_id)
user_name_param = self.format_param_value(mem_cube_id)

if deleted_type:
# Hard delete: WHERE user_name = mem_kube_id AND delete_record_id = $delete_record_id
if hard_delete:
delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype"
where_clause = f"{user_name_condition} AND {delete_record_id_condition}"

# Prepare parameters for WHERE clause (user_name and delete_record_id)
where_params = [user_name_param, self.format_param_value(delete_record_id)]

delete_query = f"""
Expand All @@ -5533,40 +5515,39 @@ def delete_node_by_mem_cube_id(
logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes")
return deleted_count
else:
# Soft delete: WHERE user_name = mem_kube_id (only user_name condition)
where_clause = user_name_condition
delete_time_empty_condition = (
"(ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) IS NULL "
"OR ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) = '\"\"'::agtype)"
)
delete_record_id_empty_condition = (
"(ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) IS NULL "
"OR ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = '\"\"'::agtype)"
)
where_clause = f"{user_name_condition} AND {delete_time_empty_condition} AND {delete_record_id_empty_condition}"

current_time = datetime.utcnow().isoformat()
# Build update properties JSON with status, delete_time, and delete_record_id
# Use PostgreSQL JSONB merge operator (||) to update properties
# Convert agtype to jsonb, merge with new values, then convert back to agtype
update_query = f"""
UPDATE "{self.db_name}_graph"."Memory"
SET properties = (
properties::jsonb || %s::jsonb
)::text::agtype
)::text::agtype,
deletetime = %s
WHERE {where_clause}
"""
# Create update JSON with the three fields to update
update_properties = {
"status": "deleted",
"delete_time": current_time,
"delete_record_id": delete_record_id,
}
logger.info(
f"[delete_node_by_mem_cube_id] Soft delete update_query: {update_query}"
)
logger.info(
f"[delete_node_by_mem_cube_id] update_properties: {update_properties}"
f"delete_node_by_mem_cube_id Soft delete update_query:{update_query},update_properties:{update_properties},deletetime:{current_time}"
)

# Combine update_properties JSON with user_name parameter (only user_name, no delete_record_id)
update_params = [json.dumps(update_properties), user_name_param]
update_params = [json.dumps(update_properties), current_time, user_name_param]
cursor.execute(update_query, update_params)
updated_count = cursor.rowcount

logger.info(
f"[delete_node_by_mem_cube_id] Soft deleted (updated) {updated_count} nodes"
f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes"
)
return updated_count

Expand All @@ -5579,61 +5560,55 @@ def delete_node_by_mem_cube_id(
self._return_connection(conn)

@timed
def recover_memory_by_mem_kube_id(
def recover_memory_by_mem_cube_id(
self,
mem_kube_id: str | None = None,
mem_cube_id: str | None = None,
delete_record_id: str | None = None,
) -> int:
"""
Recover memory nodes by mem_kube_id (user_name) and delete_record_id.
Recover memory nodes by mem_cube_id (user_name) and delete_record_id.

This function updates the status to 'activated', and clears delete_record_id and delete_time.

Args:
mem_kube_id: The mem_kube_id which corresponds to user_name in the table.
mem_cube_id: The mem_cube_id which corresponds to user_name in the table.
delete_record_id: The delete_record_id to match.

Returns:
int: Number of nodes recovered (updated).
"""
logger.info(
f"recover_memory_by_mem_kube_id mem_kube_id:{mem_kube_id},delete_record_id:{delete_record_id}"
f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}"
)
# Validate required parameters
if not mem_kube_id:
logger.warning(
"[recover_memory_by_mem_kube_id] mem_kube_id is required but not provided"
)
if not mem_cube_id:
logger.warning("recover_memory_by_mem_cube_id mem_cube_id is required but not provided")
return 0

if not delete_record_id:
logger.warning(
"[recover_memory_by_mem_kube_id] delete_record_id is required but not provided"
"recover_memory_by_mem_cube_id delete_record_id is required but not provided"
)
return 0

logger.info(
f"[recover_memory_by_mem_kube_id] mem_kube_id={mem_kube_id}, "
f"recover_memory_by_mem_cube_id mem_cube_id={mem_cube_id}, "
f"delete_record_id={delete_record_id}"
)

conn = None
try:
conn = self._get_connection()
with conn.cursor() as cursor:
# Build WHERE clause for user_name and delete_record_id using parameter binding
user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype"
where_clause = f"{user_name_condition} AND {delete_record_id_condition}"

# Prepare parameters for WHERE clause
where_params = [
self.format_param_value(mem_kube_id),
self.format_param_value(mem_cube_id),
self.format_param_value(delete_record_id),
]

# Build update properties: status='activated', delete_record_id='', delete_time=''
# Use PostgreSQL JSONB merge operator (||) to update properties
update_properties = {
"status": "activated",
"delete_record_id": "",
Expand All @@ -5644,28 +5619,28 @@ def recover_memory_by_mem_kube_id(
UPDATE "{self.db_name}_graph"."Memory"
SET properties = (
properties::jsonb || %s::jsonb
)::text::agtype
)::text::agtype,
deletetime = NULL
WHERE {where_clause}
"""

logger.info(f"[recover_memory_by_mem_kube_id] Update query: {update_query}")
logger.info(f"[recover_memory_by_mem_cube_id] Update query: {update_query}")
logger.info(
f"[recover_memory_by_mem_kube_id] update_properties: {update_properties}"
f"[recover_memory_by_mem_cube_id] update_properties: {update_properties}"
)

# Combine update_properties JSON with where_params
update_params = [json.dumps(update_properties), *where_params]
cursor.execute(update_query, update_params)
updated_count = cursor.rowcount

logger.info(
f"[recover_memory_by_mem_kube_id] Recovered (updated) {updated_count} nodes"
f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes"
)
return updated_count

except Exception as e:
logger.error(
f"[recover_memory_by_mem_kube_id] Failed to recover nodes: {e}", exc_info=True
f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True
)
raise
finally:
Expand Down