diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index debbb4e3c..2b3859252 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1132,10 +1132,21 @@ def clear(self, user_name: str | None = None) -> None: logger.error(f"[ERROR] Failed to clear database '{self.db_name}': {e}") raise - def export_graph(self, **kwargs) -> dict[str, Any]: + def export_graph( + self, + page: int | None = None, + page_size: int | None = None, + **kwargs, + ) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. + Args: + page (int, optional): Page number (starts from 1). If None, exports all data without pagination. + page_size (int, optional): Number of items per page. If None, exports all data without pagination. + **kwargs: Additional keyword arguments, including: + - user_name (str, optional): User name for filtering in non-multi-db mode + Returns: { "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], @@ -1143,6 +1154,18 @@ def export_graph(self, **kwargs) -> dict[str, Any]: } """ user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name + + # Determine if pagination is needed + use_pagination = page is not None and page_size is not None + + # Validate pagination parameters if pagination is enabled + if use_pagination: + if page < 1: + page = 1 + if page_size < 1: + page_size = 10 + skip = (page - 1) * page_size + with self.driver.session(database=self.db_name) as session: # Export nodes node_query = "MATCH (n:Memory)" @@ -1154,13 +1177,23 @@ def export_graph(self, **kwargs) -> dict[str, Any]: edge_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name" params["user_name"] = user_name - node_result = session.run(f"{node_query} RETURN n", params) + # Add ORDER BY and pagination for nodes + node_query += " RETURN n ORDER BY n.id" + if use_pagination: + node_query += f" SKIP {skip} LIMIT {page_size}" + + node_result = session.run(node_query, params) nodes = [self._parse_node(dict(record["n"])) for record in node_result] # Export edges - edge_result = session.run( - f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) AS type", params + # Add ORDER BY and pagination for edges + edge_query += ( + " RETURN a.id AS source, b.id AS target, type(r) AS type ORDER BY a.id, b.id" ) + if use_pagination: + edge_query += f" SKIP {skip} LIMIT {page_size}" + + edge_result = session.run(edge_query, params) edges = [ {"source": record["source"], "target": record["target"], "type": record["type"]} for record in edge_result @@ -1646,7 +1679,7 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: def delete_node_by_prams( self, - writable_cube_ids: list[str], + writable_cube_ids: list[str] | None = None, memory_ids: list[str] | None = None, file_ids: list[str] | None = None, filter: dict | None = None, @@ -1655,7 +1688,8 @@ def delete_node_by_prams( Delete nodes by memory_ids, file_ids, or filter. Args: - writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter. + writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes. + If not provided, no user_name filter will be applied. memory_ids (list[str], optional): List of memory node IDs to delete. file_ids (list[str], optional): List of file node IDs to delete. filter (dict, optional): Filter dictionary to query matching nodes for deletion. @@ -1670,20 +1704,18 @@ def delete_node_by_prams( f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" ) - # Validate writable_cube_ids - if not writable_cube_ids or len(writable_cube_ids) == 0: - raise ValueError("writable_cube_ids is required and cannot be empty") - # Build WHERE conditions separately for memory_ids and file_ids where_clauses = [] params = {} # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) + # Only add user_name filter if writable_cube_ids is provided user_name_conditions = [] - for idx, cube_id in enumerate(writable_cube_ids): - param_name = f"cube_id_{idx}" - user_name_conditions.append(f"n.user_name = ${param_name}") - params[param_name] = cube_id + if writable_cube_ids and len(writable_cube_ids) > 0: + for idx, cube_id in enumerate(writable_cube_ids): + param_name = f"cube_id_{idx}" + user_name_conditions.append(f"n.user_name = ${param_name}") + params[param_name] = cube_id # Handle memory_ids: query n.id if memory_ids and len(memory_ids) > 0: @@ -1711,7 +1743,7 @@ def delete_node_by_prams( filters=[], user_name=None, filter=filter, - knowledgebase_ids=writable_cube_ids, + knowledgebase_ids=writable_cube_ids if writable_cube_ids else None, ) # If filter returned IDs, add condition for them @@ -1730,9 +1762,14 @@ def delete_node_by_prams( # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match) data_conditions = " OR ".join([f"({clause})" for clause in where_clauses]) - # Then, combine with user_name condition using AND (must match user_name AND one of the data conditions) - user_name_where = " OR ".join(user_name_conditions) - ids_where = f"({user_name_where}) AND ({data_conditions})" + # Build final WHERE clause + # If user_name_conditions exist, combine with data_conditions using AND + # Otherwise, use only data_conditions + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + ids_where = f"({user_name_where}) AND ({data_conditions})" + else: + ids_where = f"({data_conditions})" logger.info( f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" @@ -1773,3 +1810,70 @@ def delete_node_by_prams( logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes") return deleted_count + + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]: + """Get user names by memory ids. + + Args: + memory_ids: List of memory node IDs to query. + + Returns: + dict[str, list[str]]: Dictionary with one key: + - 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing) + - 'exist_user_names': List of distinct user names (if all memory_ids exist) + """ + if not memory_ids: + return {"exist_user_names": []} + + logger.info(f"[get_user_names_by_memory_ids] Checking {len(memory_ids)} memory_ids") + + try: + with self.driver.session(database=self.db_name) as session: + # Query to check which memory_ids exist + check_query = """ + MATCH (n:Memory) + WHERE n.id IN $memory_ids + RETURN n.id AS id + """ + + check_result = session.run(check_query, memory_ids=memory_ids) + existing_ids = set() + for record in check_result: + node_id = record["id"] + existing_ids.add(node_id) + + # Check if any memory_ids are missing + no_exist_list = [mid for mid in memory_ids if mid not in existing_ids] + + # If any memory_ids are missing, return no_exist_memory_ids + if no_exist_list: + logger.info( + f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}" + ) + return {"no_exist_memory_ids": no_exist_list} + + # All memory_ids exist, query user_names + user_names_query = """ + MATCH (n:Memory) + WHERE n.id IN $memory_ids + RETURN DISTINCT n.user_name AS user_name + """ + logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}") + + user_names_result = session.run(user_names_query, memory_ids=memory_ids) + user_names = [] + for record in user_names_result: + user_name = record["user_name"] + if user_name: + user_names.append(user_name) + + logger.info( + f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names" + ) + + return {"exist_user_names": user_names} + except Exception as e: + logger.error( + f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True + ) + raise diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index b29dd26ce..fcb7e0caa 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2505,16 +2505,16 @@ def export_graph( self, include_embedding: bool = False, user_name: str | None = None, - page: int = 1, - page_size: int = 10, + page: int | None = None, + page_size: int | None = None, ) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. Args: include_embedding (bool): Whether to include the large embedding field. user_name (str, optional): User name for filtering in non-multi-db mode - page (int): Page number (starts from 1). Default is 1. - page_size (int): Number of items per page. Default is 1000. + page (int, optional): Page number (starts from 1). If None, exports all data without pagination. + page_size (int, optional): Number of items per page. If None, exports all data without pagination. Returns: { @@ -2527,23 +2527,35 @@ def export_graph( ) user_name = user_name if user_name else self._get_config_value("user_name") - # Validate pagination parameters - if page < 1: - page = 1 - if page_size < 1: - page_size = 10 + # Determine if pagination is needed + use_pagination = page is not None and page_size is not None + + # Validate pagination parameters if pagination is enabled + if use_pagination: + if page < 1: + page = 1 + if page_size < 1: + page_size = 10 + offset = (page - 1) * page_size + else: + offset = None conn = None try: conn = self._get_connection() # Export nodes + # Build pagination clause if needed + pagination_clause = "" + if use_pagination: + pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + if include_embedding: node_query = f""" SELECT id, properties, embedding FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype ORDER BY id - LIMIT {page_size} OFFSET {(page - 1) * page_size} + {pagination_clause} """ else: node_query = f""" @@ -2551,7 +2563,7 @@ def export_graph( FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype ORDER BY id - LIMIT {page_size} OFFSET {(page - 1) * page_size} + {pagination_clause} """ logger.info(f"[export_graph nodes] Query: {node_query}") with conn.cursor() as cursor: @@ -2601,6 +2613,11 @@ def export_graph( conn = self._get_connection() # Export edges using cypher query # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery + # Build pagination clause if needed + edge_pagination_clause = "" + if use_pagination: + edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + edge_query = f""" SELECT source, target, edge FROM ( SELECT * FROM cypher('{self.db_name}_graph', $$ @@ -2610,7 +2627,7 @@ def export_graph( ORDER BY a.id, b.id $$) AS (source agtype, target agtype, edge agtype) ) AS edges - LIMIT {page_size} OFFSET {(page - 1) * page_size} + {edge_pagination_clause} """ logger.info(f"[export_graph edges] Query: {edge_query}") with conn.cursor() as cursor: