Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 122 additions & 18 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,17 +1132,40 @@ def clear(self, user_name: str | None = None) -> None:
logger.error(f"[ERROR] Failed to clear database '{self.db_name}': {e}")
raise

def export_graph(self, **kwargs) -> dict[str, Any]:
def export_graph(
self,
page: int | None = None,
page_size: int | None = None,
**kwargs,
) -> dict[str, Any]:
"""
Export all graph nodes and edges in a structured form.

Args:
page (int, optional): Page number (starts from 1). If None, exports all data without pagination.
page_size (int, optional): Number of items per page. If None, exports all data without pagination.
**kwargs: Additional keyword arguments, including:
- user_name (str, optional): User name for filtering in non-multi-db mode

Returns:
{
"nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
"edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
}
"""
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name

# Determine if pagination is needed
use_pagination = page is not None and page_size is not None

# Validate pagination parameters if pagination is enabled
if use_pagination:
if page < 1:
page = 1
if page_size < 1:
page_size = 10
skip = (page - 1) * page_size

with self.driver.session(database=self.db_name) as session:
# Export nodes
node_query = "MATCH (n:Memory)"
Expand All @@ -1154,13 +1177,23 @@ def export_graph(self, **kwargs) -> dict[str, Any]:
edge_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name"
params["user_name"] = user_name

node_result = session.run(f"{node_query} RETURN n", params)
# Add ORDER BY and pagination for nodes
node_query += " RETURN n ORDER BY n.id"
if use_pagination:
node_query += f" SKIP {skip} LIMIT {page_size}"

node_result = session.run(node_query, params)
nodes = [self._parse_node(dict(record["n"])) for record in node_result]

# Export edges
edge_result = session.run(
f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) AS type", params
# Add ORDER BY and pagination for edges
edge_query += (
" RETURN a.id AS source, b.id AS target, type(r) AS type ORDER BY a.id, b.id"
)
if use_pagination:
edge_query += f" SKIP {skip} LIMIT {page_size}"

edge_result = session.run(edge_query, params)
edges = [
{"source": record["source"], "target": record["target"], "type": record["type"]}
for record in edge_result
Expand Down Expand Up @@ -1646,7 +1679,7 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:

def delete_node_by_prams(
self,
writable_cube_ids: list[str],
writable_cube_ids: list[str] | None = None,
memory_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
Expand All @@ -1655,7 +1688,8 @@ def delete_node_by_prams(
Delete nodes by memory_ids, file_ids, or filter.

Args:
writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes.
If not provided, no user_name filter will be applied.
memory_ids (list[str], optional): List of memory node IDs to delete.
file_ids (list[str], optional): List of file node IDs to delete.
filter (dict, optional): Filter dictionary to query matching nodes for deletion.
Expand All @@ -1670,20 +1704,18 @@ def delete_node_by_prams(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)

# Validate writable_cube_ids
if not writable_cube_ids or len(writable_cube_ids) == 0:
raise ValueError("writable_cube_ids is required and cannot be empty")

# Build WHERE conditions separately for memory_ids and file_ids
where_clauses = []
params = {}

# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
# Only add user_name filter if writable_cube_ids is provided
user_name_conditions = []
for idx, cube_id in enumerate(writable_cube_ids):
param_name = f"cube_id_{idx}"
user_name_conditions.append(f"n.user_name = ${param_name}")
params[param_name] = cube_id
if writable_cube_ids and len(writable_cube_ids) > 0:
for idx, cube_id in enumerate(writable_cube_ids):
param_name = f"cube_id_{idx}"
user_name_conditions.append(f"n.user_name = ${param_name}")
params[param_name] = cube_id

# Handle memory_ids: query n.id
if memory_ids and len(memory_ids) > 0:
Expand Down Expand Up @@ -1711,7 +1743,7 @@ def delete_node_by_prams(
filters=[],
user_name=None,
filter=filter,
knowledgebase_ids=writable_cube_ids,
knowledgebase_ids=writable_cube_ids if writable_cube_ids else None,
)

# If filter returned IDs, add condition for them
Expand All @@ -1730,9 +1762,14 @@ def delete_node_by_prams(
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
data_conditions = " OR ".join([f"({clause})" for clause in where_clauses])

# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
user_name_where = " OR ".join(user_name_conditions)
ids_where = f"({user_name_where}) AND ({data_conditions})"
# Build final WHERE clause
# If user_name_conditions exist, combine with data_conditions using AND
# Otherwise, use only data_conditions
if user_name_conditions:
user_name_where = " OR ".join(user_name_conditions)
ids_where = f"({user_name_where}) AND ({data_conditions})"
else:
ids_where = f"({data_conditions})"

logger.info(
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
Expand Down Expand Up @@ -1773,3 +1810,70 @@ def delete_node_by_prams(

logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
return deleted_count

def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]:
"""Get user names by memory ids.

Args:
memory_ids: List of memory node IDs to query.

Returns:
dict[str, list[str]]: Dictionary with one key:
- 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing)
- 'exist_user_names': List of distinct user names (if all memory_ids exist)
"""
if not memory_ids:
return {"exist_user_names": []}

logger.info(f"[get_user_names_by_memory_ids] Checking {len(memory_ids)} memory_ids")

try:
with self.driver.session(database=self.db_name) as session:
# Query to check which memory_ids exist
check_query = """
MATCH (n:Memory)
WHERE n.id IN $memory_ids
RETURN n.id AS id
"""

check_result = session.run(check_query, memory_ids=memory_ids)
existing_ids = set()
for record in check_result:
node_id = record["id"]
existing_ids.add(node_id)

# Check if any memory_ids are missing
no_exist_list = [mid for mid in memory_ids if mid not in existing_ids]

# If any memory_ids are missing, return no_exist_memory_ids
if no_exist_list:
logger.info(
f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}"
)
return {"no_exist_memory_ids": no_exist_list}

# All memory_ids exist, query user_names
user_names_query = """
MATCH (n:Memory)
WHERE n.id IN $memory_ids
RETURN DISTINCT n.user_name AS user_name
"""
logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}")

user_names_result = session.run(user_names_query, memory_ids=memory_ids)
user_names = []
for record in user_names_result:
user_name = record["user_name"]
if user_name:
user_names.append(user_name)

logger.info(
f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names"
)

return {"exist_user_names": user_names}
except Exception as e:
logger.error(
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
)
raise
41 changes: 29 additions & 12 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2505,16 +2505,16 @@ def export_graph(
self,
include_embedding: bool = False,
user_name: str | None = None,
page: int = 1,
page_size: int = 10,
page: int | None = None,
page_size: int | None = None,
) -> dict[str, Any]:
"""
Export all graph nodes and edges in a structured form.
Args:
include_embedding (bool): Whether to include the large embedding field.
user_name (str, optional): User name for filtering in non-multi-db mode
page (int): Page number (starts from 1). Default is 1.
page_size (int): Number of items per page. Default is 1000.
page (int, optional): Page number (starts from 1). If None, exports all data without pagination.
page_size (int, optional): Number of items per page. If None, exports all data without pagination.

Returns:
{
Expand All @@ -2527,31 +2527,43 @@ def export_graph(
)
user_name = user_name if user_name else self._get_config_value("user_name")

# Validate pagination parameters
if page < 1:
page = 1
if page_size < 1:
page_size = 10
# Determine if pagination is needed
use_pagination = page is not None and page_size is not None

# Validate pagination parameters if pagination is enabled
if use_pagination:
if page < 1:
page = 1
if page_size < 1:
page_size = 10
offset = (page - 1) * page_size
else:
offset = None

conn = None
try:
conn = self._get_connection()
# Export nodes
# Build pagination clause if needed
pagination_clause = ""
if use_pagination:
pagination_clause = f"LIMIT {page_size} OFFSET {offset}"

if include_embedding:
node_query = f"""
SELECT id, properties, embedding
FROM "{self.db_name}_graph"."Memory"
WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype
ORDER BY id
LIMIT {page_size} OFFSET {(page - 1) * page_size}
{pagination_clause}
"""
else:
node_query = f"""
SELECT id, properties
FROM "{self.db_name}_graph"."Memory"
WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype
ORDER BY id
LIMIT {page_size} OFFSET {(page - 1) * page_size}
{pagination_clause}
"""
logger.info(f"[export_graph nodes] Query: {node_query}")
with conn.cursor() as cursor:
Expand Down Expand Up @@ -2601,6 +2613,11 @@ def export_graph(
conn = self._get_connection()
# Export edges using cypher query
# Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery
# Build pagination clause if needed
edge_pagination_clause = ""
if use_pagination:
edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}"

edge_query = f"""
SELECT source, target, edge FROM (
SELECT * FROM cypher('{self.db_name}_graph', $$
Expand All @@ -2610,7 +2627,7 @@ def export_graph(
ORDER BY a.id, b.id
$$) AS (source agtype, target agtype, edge agtype)
) AS edges
LIMIT {page_size} OFFSET {(page - 1) * page_size}
{edge_pagination_clause}
"""
logger.info(f"[export_graph edges] Query: {edge_query}")
with conn.cursor() as cursor:
Expand Down