diff --git a/src/memos/graph_dbs/base.py b/src/memos/graph_dbs/base.py index 130b66a3d..73fdd4015 100644 --- a/src/memos/graph_dbs/base.py +++ b/src/memos/graph_dbs/base.py @@ -272,3 +272,85 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N - metadata: dict[str, Any] - Node metadata user_name: Optional user name (will use config default if not provided) """ + + @abstractmethod + def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]: + """ + Get edges connected to a node, with optional type and direction filter. + Args: + id: Node ID to retrieve edges for. + type: Relationship type to match, or 'ANY' to match all. + direction: 'OUTGOING', 'INCOMING', or 'ANY'. + Returns: + List of edge dicts with 'from', 'to', and 'type' keys. + """ + + @abstractmethod + def search_by_fulltext(self, query_words: list[str], top_k: int = 10, **kwargs) -> list[dict]: + """ + Full-text search for memory nodes. + Args: + query_words: List of words to search for. + top_k: Maximum number of results. + Returns: + List of dicts with 'id' and 'score'. + """ + + @abstractmethod + def get_neighbors_by_tag( + self, + tags: list[str], + exclude_ids: list[str], + top_k: int = 5, + min_overlap: int = 1, + **kwargs, + ) -> list[dict[str, Any]]: + """ + Find top-K neighbor nodes with maximum tag overlap. + Args: + tags: Tags to match. + exclude_ids: Node IDs to exclude. + top_k: Max neighbors to return. + min_overlap: Minimum overlapping tags required. + Returns: + List of node dicts. + """ + + @abstractmethod + def delete_node_by_prams( + self, + memory_ids: list[str] | None = None, + writable_cube_ids: list[str] | None = None, + file_ids: list[str] | None = None, + filter: dict | None = None, + **kwargs, + ) -> int: + """ + Delete nodes matching given parameters. + Returns: + Number of deleted nodes. + """ + + @abstractmethod + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> list[str]: + """ + Get distinct user names that own the given memory IDs. + """ + + @abstractmethod + def exist_user_name(self, user_name: str) -> bool: + """ + Check if a user_name exists in the graph. + """ + + @abstractmethod + def search_by_keywords_like(self, query_word: str, **kwargs) -> list[dict]: + """ + Search memories using SQL LIKE pattern matching. + """ + + @abstractmethod + def search_by_keywords_tfidf(self, query_words: list[str], **kwargs) -> list[dict]: + """ + Search memories using TF-IDF fulltext scoring. + """ diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 428d6d09e..289d3cab3 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -11,6 +11,7 @@ from memos.configs.graph_db import NebulaGraphDBConfig from memos.dependency import require_python_package from memos.graph_dbs.base import BaseGraphDB +from memos.graph_dbs.utils import compose_node as _compose_node from memos.log import get_logger from memos.utils import timed @@ -44,14 +45,6 @@ def _normalize(vec: list[float]) -> list[float]: return (v / (norm if norm else 1.0)).tolist() -@timed -def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: - node_id = item["id"] - memory = item["memory"] - metadata = item.get("metadata", {}) - return node_id, memory, metadata - - @timed def _escape_str(value: str) -> str: out = [] diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 70d40f13c..d716a9cce 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -7,19 +7,13 @@ from memos.configs.graph_db import Neo4jGraphDBConfig from memos.dependency import require_python_package from memos.graph_dbs.base import BaseGraphDB +from memos.graph_dbs.utils import compose_node as _compose_node from memos.log import get_logger logger = get_logger(__name__) -def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: - node_id = item["id"] - memory = item["memory"] - metadata = item.get("metadata", {}) - return node_id, memory, metadata - - def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: """ Ensure metadata has proper datetime fields and normalized types. @@ -502,7 +496,7 @@ def edge_exists( return result.single() is not None # Graph Query & Reasoning - def get_node(self, id: str, **kwargs) -> dict[str, Any] | None: + def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None: """ Retrieve the metadata and memory of a node. Args: @@ -510,18 +504,28 @@ def get_node(self, id: str, **kwargs) -> dict[str, Any] | None: Returns: Dictionary of node fields, or None if not found. """ - user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name + logger.info(f"[get_node] id: {id}") + user_name = kwargs.get("user_name") where_user = "" params = {"id": id} - if not self.config.use_multi_db and (self.config.user_name or user_name): + if user_name is not None: where_user = " AND n.user_name = $user_name" params["user_name"] = user_name query = f"MATCH (n:Memory) WHERE n.id = $id {where_user} RETURN n" + logger.info(f"[get_node] query: {query}") with self.driver.session(database=self.db_name) as session: record = session.run(query, params).single() - return self._parse_node(dict(record["n"])) if record else None + if not record: + return None + + node_dict = dict(record["n"]) + if include_embedding is False: + for key in ("embedding", "embedding_1024", "embedding_3072", "embedding_768"): + node_dict.pop(key, None) + + return self._parse_node(node_dict) def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]: """ @@ -1940,3 +1944,174 @@ def exist_user_name(self, user_name: str) -> dict[str, bool]: f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True ) raise + + def delete_node_by_mem_cube_id( + self, + mem_kube_id: dict | None = None, + delete_record_id: dict | None = None, + deleted_type: bool = False, + ) -> int: + """ + Delete nodes by mem_kube_id (user_name) and delete_record_id. + + Args: + mem_kube_id: The mem_kube_id which corresponds to user_name in the table. + Can be dict or str. If dict, will extract the value. + delete_record_id: The delete_record_id to match. + Can be dict or str. If dict, will extract the value. + deleted_type: If True, performs hard delete (directly deletes records). + If False, performs soft delete (updates status to 'deleted' and sets delete_record_id and delete_time). + + Returns: + int: Number of nodes deleted or updated. + """ + # Handle dict type parameters (extract value if dict) + if isinstance(mem_kube_id, dict): + # Try to get a value from dict, use first value if multiple + mem_kube_id = next(iter(mem_kube_id.values())) if mem_kube_id else None + + if isinstance(delete_record_id, dict): + delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None + + # Validate required parameters + if not mem_kube_id: + logger.warning("[delete_node_by_mem_cube_id] mem_kube_id is required but not provided") + return 0 + + if not delete_record_id: + logger.warning( + "[delete_node_by_mem_cube_id] delete_record_id is required but not provided" + ) + return 0 + + # Convert to string if needed + mem_kube_id = str(mem_kube_id) if mem_kube_id else None + delete_record_id = str(delete_record_id) if delete_record_id else None + + logger.info( + f"[delete_node_by_mem_cube_id] mem_kube_id={mem_kube_id}, " + f"delete_record_id={delete_record_id}, deleted_type={deleted_type}" + ) + + try: + with self.driver.session(database=self.db_name) as session: + if deleted_type: + # Hard delete: WHERE user_name = mem_kube_id AND delete_record_id = $delete_record_id + query = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_kube_id AND n.delete_record_id = $delete_record_id + DETACH DELETE n + """ + logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {query}") + + result = session.run( + query, mem_kube_id=mem_kube_id, delete_record_id=delete_record_id + ) + summary = result.consume() + deleted_count = summary.counters.nodes_deleted if summary.counters else 0 + + logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") + return deleted_count + else: + # Soft delete: WHERE user_name = mem_kube_id (only user_name condition) + current_time = datetime.utcnow().isoformat() + + query = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_kube_id + SET n.status = $status, + n.delete_record_id = $delete_record_id, + n.delete_time = $delete_time + RETURN count(n) AS updated_count + """ + logger.info(f"[delete_node_by_mem_cube_id] Soft delete query: {query}") + + result = session.run( + query, + mem_kube_id=mem_kube_id, + status="deleted", + delete_record_id=delete_record_id, + delete_time=current_time, + ) + record = result.single() + updated_count = record["updated_count"] if record else 0 + + logger.info( + f"[delete_node_by_mem_cube_id] Soft deleted (updated) {updated_count} nodes" + ) + return updated_count + + except Exception as e: + logger.error( + f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True + ) + raise + + def recover_memory_by_mem_kube_id( + self, + mem_kube_id: str | None = None, + delete_record_id: str | None = None, + ) -> int: + """ + Recover memory nodes by mem_kube_id (user_name) and delete_record_id. + + This function updates the status to 'activated', and clears delete_record_id and delete_time. + + Args: + mem_kube_id: The mem_kube_id which corresponds to user_name in the table. + delete_record_id: The delete_record_id to match. + + Returns: + int: Number of nodes recovered (updated). + """ + # Validate required parameters + if not mem_kube_id: + logger.warning( + "[recover_memory_by_mem_kube_id] mem_kube_id is required but not provided" + ) + return 0 + + if not delete_record_id: + logger.warning( + "[recover_memory_by_mem_kube_id] delete_record_id is required but not provided" + ) + return 0 + + logger.info( + f"[recover_memory_by_mem_kube_id] mem_kube_id={mem_kube_id}, " + f"delete_record_id={delete_record_id}" + ) + + try: + with self.driver.session(database=self.db_name) as session: + query = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_kube_id AND n.delete_record_id = $delete_record_id + SET n.status = $status, + n.delete_record_id = $delete_record_id_empty, + n.delete_time = $delete_time_empty + RETURN count(n) AS updated_count + """ + logger.info(f"[recover_memory_by_mem_kube_id] Update query: {query}") + + result = session.run( + query, + mem_kube_id=mem_kube_id, + delete_record_id=delete_record_id, + status="activated", + delete_record_id_empty="", + delete_time_empty="", + ) + record = result.single() + updated_count = record["updated_count"] if record else 0 + + logger.info( + f"[recover_memory_by_mem_kube_id] Recovered (updated) {updated_count} nodes" + ) + return updated_count + + except Exception as e: + logger.error( + f"[recover_memory_by_mem_kube_id] Failed to recover nodes: {e}", exc_info=True + ) + raise diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index f2182f6cd..411dbffe5 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -1056,3 +1056,58 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: logger.warning(f"Failed to fetch vector for node {new_node['id']}: {e}") new_node["metadata"]["embedding"] = None return new_node + + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]: + """Get user names by memory ids. + + Args: + memory_ids: List of memory node IDs to query. + + Returns: + dict[str, str | None]: Dictionary mapping memory_id to user_name. + - Key: memory_id + - Value: user_name if exists, None if memory_id does not exist + Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None} + """ + if not memory_ids: + return {} + + logger.info( + f"[ neo4j_community get_user_names_by_memory_ids] Querying memory_ids {memory_ids}" + ) + + try: + with self.driver.session(database=self.db_name) as session: + # Query to get memory_id and user_name pairs + query = """ + MATCH (n:Memory) + WHERE n.id IN $memory_ids + RETURN n.id AS memory_id, n.user_name AS user_name + """ + logger.info(f"[get_user_names_by_memory_ids] query: {query}") + + result = session.run(query, memory_ids=memory_ids) + result_dict = {} + + # Build result dictionary from query results + for record in result: + memory_id = record["memory_id"] + user_name = record["user_name"] + result_dict[memory_id] = user_name if user_name else None + + # Set None for memory_ids that were not found + for mid in memory_ids: + if mid not in result_dict: + result_dict[mid] = None + + logger.info( + f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " + f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" + ) + + return result_dict + except Exception as e: + logger.error( + f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True + ) + raise diff --git a/src/memos/graph_dbs/polardb/__init__.py b/src/memos/graph_dbs/polardb/__init__.py new file mode 100644 index 000000000..69d91c9d7 --- /dev/null +++ b/src/memos/graph_dbs/polardb/__init__.py @@ -0,0 +1,27 @@ +"""PolarDB graph database package using Apache AGE extension.""" + +from memos.graph_dbs.base import BaseGraphDB +from memos.graph_dbs.polardb.connection import ConnectionMixin +from memos.graph_dbs.polardb.edges import EdgeMixin +from memos.graph_dbs.polardb.filters import FilterMixin +from memos.graph_dbs.polardb.maintenance import MaintenanceMixin +from memos.graph_dbs.polardb.nodes import NodeMixin +from memos.graph_dbs.polardb.queries import QueryMixin +from memos.graph_dbs.polardb.schema import SchemaMixin +from memos.graph_dbs.polardb.search import SearchMixin +from memos.graph_dbs.polardb.traversal import TraversalMixin + + +class PolarDBGraphDB( + ConnectionMixin, + SchemaMixin, + NodeMixin, + EdgeMixin, + TraversalMixin, + SearchMixin, + FilterMixin, + QueryMixin, + MaintenanceMixin, + BaseGraphDB, +): + """PolarDB-based graph database using Apache AGE extension.""" diff --git a/src/memos/graph_dbs/polardb/connection.py b/src/memos/graph_dbs/polardb/connection.py new file mode 100644 index 000000000..42e5f082a --- /dev/null +++ b/src/memos/graph_dbs/polardb/connection.py @@ -0,0 +1,333 @@ +import time + +from contextlib import suppress + +from memos.configs.graph_db import PolarDBGraphDBConfig +from memos.dependency import require_python_package +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class ConnectionMixin: + """Mixin class providing PolarDB connection pool management.""" + + @require_python_package( + import_name="psycopg2", + install_command="pip install psycopg2-binary", + install_link="https://pypi.org/project/psycopg2-binary/", + ) + def __init__(self, config: PolarDBGraphDBConfig): + """PolarDB-based implementation using Apache AGE. + + Tenant Modes: + - use_multi_db = True: + Dedicated Database Mode (Multi-Database Multi-Tenant). + Each tenant or logical scope uses a separate PolarDB database. + `db_name` is the specific tenant database. + `user_name` can be None (optional). + + - use_multi_db = False: + Shared Database Multi-Tenant Mode. + All tenants share a single PolarDB database. + `db_name` is the shared database. + `user_name` is required to isolate each tenant's data at the node level. + All node queries will enforce `user_name` in WHERE conditions and store it in metadata, + but it will be removed automatically before returning to external consumers. + """ + import psycopg2 + import psycopg2.pool + + self.config = config + + # Handle both dict and object config + if isinstance(config, dict): + self.db_name = config.get("db_name") + self.user_name = config.get("user_name") + host = config.get("host") + port = config.get("port") + user = config.get("user") + password = config.get("password") + maxconn = config.get("maxconn", 100) # De + else: + self.db_name = config.db_name + self.user_name = config.user_name + host = config.host + port = config.port + user = config.user + password = config.password + maxconn = config.maxconn if hasattr(config, "maxconn") else 100 + """ + # Create connection + self.connection = psycopg2.connect( + host=host, port=port, user=user, password=password, dbname=self.db_name,minconn=10, maxconn=2000 + ) + """ + logger.info(f" db_name: {self.db_name} current maxconn is:'{maxconn}'") + + # Create connection pool + self.connection_pool = psycopg2.pool.ThreadedConnectionPool( + minconn=5, + maxconn=maxconn, + host=host, + port=port, + user=user, + password=password, + dbname=self.db_name, + connect_timeout=60, # Connection timeout in seconds + keepalives_idle=40, # Seconds of inactivity before sending keepalive (should be < server idle timeout) + keepalives_interval=15, # Seconds between keepalive retries + keepalives_count=5, # Number of keepalive retries before considering connection dead + ) + + # Keep a reference to the pool for cleanup + self._pool_closed = False + + """ + # Handle auto_create + # auto_create = config.get("auto_create", False) if isinstance(config, dict) else config.auto_create + # if auto_create: + # self._ensure_database_exists() + + # Create graph and tables + # self.create_graph() + # self.create_edge() + # self._create_graph() + + # Handle embedding_dimension + # embedding_dim = config.get("embedding_dimension", 1024) if isinstance(config,dict) else config.embedding_dimension + # self.create_index(dimensions=embedding_dim) + """ + + def _get_config_value(self, key: str, default=None): + """Safely get config value from either dict or object.""" + if isinstance(self.config, dict): + return self.config.get(key, default) + else: + return getattr(self.config, key, default) + + def _get_connection(self): + """ + Get a connection from the pool. + + This function: + 1. Gets a connection from ThreadedConnectionPool + 2. Checks if connection is closed or unhealthy + 3. Returns healthy connection or retries (max 3 times) + 4. Handles connection pool exhaustion gracefully + + Returns: + psycopg2 connection object + + Raises: + RuntimeError: If connection pool is closed or exhausted after retries + """ + logger.info(f" db_name: {self.db_name} pool maxconn is:'{self.connection_pool.maxconn}'") + if self._pool_closed: + raise RuntimeError("Connection pool has been closed") + + max_retries = 500 + import psycopg2.pool + + for attempt in range(max_retries): + conn = None + try: + # Try to get connection from pool + # This may raise PoolError if pool is exhausted + conn = self.connection_pool.getconn() + + # Check if connection is closed + if conn.closed != 0: + # Connection is closed, return it to pool with close flag and try again + logger.warning( + f"[_get_connection] Got closed connection, attempt {attempt + 1}/{max_retries}" + ) + try: + self.connection_pool.putconn(conn, close=True) + except Exception as e: + logger.warning( + f"[_get_connection] Failed to return closed connection to pool: {e}" + ) + with suppress(Exception): + conn.close() + + conn = None + if attempt < max_retries - 1: + # Exponential backoff: 0.1s, 0.2s, 0.4s + """time.sleep(0.1 * (2**attempt))""" + time.sleep(0.003) + continue + else: + raise RuntimeError("Pool returned a closed connection after all retries") + + # Set autocommit for PolarDB compatibility + conn.autocommit = True + + # Test connection health with SELECT 1 + try: + cursor = conn.cursor() + cursor.execute("SELECT 1") + cursor.fetchone() + cursor.close() + except Exception as health_check_error: + # Connection is not usable, return it to pool with close flag and try again + logger.warning( + f"[_get_connection] Connection health check failed (attempt {attempt + 1}/{max_retries}): {health_check_error}" + ) + try: + self.connection_pool.putconn(conn, close=True) + except Exception as putconn_error: + logger.warning( + f"[_get_connection] Failed to return unhealthy connection to pool: {putconn_error}" + ) + with suppress(Exception): + conn.close() + + conn = None + if attempt < max_retries - 1: + # Exponential backoff: 0.1s, 0.2s, 0.4s + """time.sleep(0.1 * (2**attempt))""" + time.sleep(0.003) + continue + else: + raise RuntimeError( + f"Failed to get a healthy connection from pool after {max_retries} attempts: {health_check_error}" + ) from health_check_error + + # Connection is healthy, return it + return conn + + except psycopg2.pool.PoolError as pool_error: + # Pool exhausted or other pool-related error + # Don't retry immediately for pool exhaustion - it's unlikely to resolve quickly + error_msg = str(pool_error).lower() + if "exhausted" in error_msg or "pool" in error_msg: + # Log pool status for debugging + try: + # Try to get pool stats if available + pool_info = f"Pool config: minconn={self.connection_pool.minconn}, maxconn={self.connection_pool.maxconn}" + logger.error( + f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}" + ) + except Exception: + logger.error( + f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries})" + ) + + # For pool exhaustion, wait longer before retry (connections may be returned) + if attempt < max_retries - 1: + # Longer backoff for pool exhaustion: 0.5s, 1.0s, 2.0s + wait_time = 0.5 * (2**attempt) + logger.info(f"[_get_connection] Waiting {wait_time}s before retry...") + """time.sleep(wait_time)""" + time.sleep(0.003) + continue + else: + raise RuntimeError( + f"Connection pool exhausted after {max_retries} attempts. " + f"This usually means connections are not being returned to the pool. " + f"Check for connection leaks in your code." + ) from pool_error + else: + # Other pool errors - retry with normal backoff + if attempt < max_retries - 1: + """time.sleep(0.1 * (2**attempt))""" + time.sleep(0.003) + continue + else: + raise RuntimeError( + f"Failed to get connection from pool: {pool_error}" + ) from pool_error + + except Exception as e: + # Other exceptions (not pool-related) + # Only try to return connection if we actually got one + # If getconn() failed (e.g., pool exhausted), conn will be None + if conn is not None: + try: + # Return connection to pool if it's valid + self.connection_pool.putconn(conn, close=True) + except Exception as putconn_error: + logger.warning( + f"[_get_connection] Failed to return connection after error: {putconn_error}" + ) + with suppress(Exception): + conn.close() + + if attempt >= max_retries - 1: + raise RuntimeError(f"Failed to get a valid connection from pool: {e}") from e + else: + # Exponential backoff: 0.1s, 0.2s, 0.4s + """time.sleep(0.1 * (2**attempt))""" + time.sleep(0.003) + continue + + # Should never reach here, but just in case + raise RuntimeError("Failed to get connection after all retries") + + def _return_connection(self, connection): + """ + Return a connection to the pool. + + This function safely returns a connection to the pool, handling: + - Closed connections (close them instead of returning) + - Pool closed state (close connection directly) + - None connections (no-op) + - putconn() failures (close connection as fallback) + + Args: + connection: psycopg2 connection object or None + """ + if self._pool_closed: + # Pool is closed, just close the connection if it exists + if connection: + try: + connection.close() + logger.debug("[_return_connection] Closed connection (pool is closed)") + except Exception as e: + logger.warning( + f"[_return_connection] Failed to close connection after pool closed: {e}" + ) + return + + if not connection: + # No connection to return - this is normal if _get_connection() failed + return + + try: + # Check if connection is closed + if hasattr(connection, "closed") and connection.closed != 0: + # Connection is closed, just close it explicitly and don't return to pool + logger.debug( + "[_return_connection] Connection is closed, closing it instead of returning to pool" + ) + try: + connection.close() + except Exception as e: + logger.warning(f"[_return_connection] Failed to close closed connection: {e}") + return + + # Connection is valid, return to pool + self.connection_pool.putconn(connection) + logger.debug("[_return_connection] Successfully returned connection to pool") + except Exception as e: + # If putconn fails, try to close the connection + # This prevents connection leaks if putconn() fails + logger.error( + f"[_return_connection] Failed to return connection to pool: {e}", exc_info=True + ) + try: + connection.close() + logger.debug( + "[_return_connection] Closed connection as fallback after putconn failure" + ) + except Exception as close_error: + logger.warning( + f"[_return_connection] Failed to close connection after putconn error: {close_error}" + ) + + def __del__(self): + """Close database connection when object is destroyed.""" + if hasattr(self, "connection") and self.connection: + self.connection.close() diff --git a/src/memos/graph_dbs/polardb/edges.py b/src/memos/graph_dbs/polardb/edges.py new file mode 100644 index 000000000..356685ca0 --- /dev/null +++ b/src/memos/graph_dbs/polardb/edges.py @@ -0,0 +1,267 @@ +import json +import time + +from memos.log import get_logger +from memos.utils import timed + + +logger = get_logger(__name__) + + +class EdgeMixin: + """Mixin for edge (relationship) operations.""" + + @timed + def create_edge(self): + """Create all valid edge types if they do not exist""" + + valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} + + for label_name in valid_rel_types: + conn = None + logger.info(f"Creating elabel: {label_name}") + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") + logger.info(f"Successfully created elabel: {label_name}") + except Exception as e: + if "already exists" in str(e): + logger.info(f"Label '{label_name}' already exists, skipping.") + else: + logger.warning(f"Failed to create label {label_name}: {e}") + logger.error(f"Failed to create elabel '{label_name}': {e}", exc_info=True) + finally: + self._return_connection(conn) + + @timed + def add_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + logger.info( + f"polardb [add_edge] source_id: {source_id}, target_id: {target_id}, type: {type},user_name:{user_name}" + ) + + start_time = time.time() + if not source_id or not target_id: + logger.warning(f"Edge '{source_id}' and '{target_id}' are both None") + raise ValueError("[add_edge] source_id and target_id must be provided") + + source_exists = self.get_node(source_id) is not None + target_exists = self.get_node(target_id) is not None + + if not source_exists or not target_exists: + logger.warning( + "[add_edge] Source %s or target %s does not exist.", source_exists, target_exists + ) + raise ValueError("[add_edge] source_id and target_id must be provided") + + properties = {} + if user_name is not None: + properties["user_name"] = user_name + query = f""" + INSERT INTO {self.db_name}_graph."Edges"(source_id, target_id, edge_type, properties) + SELECT + '{source_id}', + '{target_id}', + '{type}', + jsonb_build_object('user_name', '{user_name}') + WHERE NOT EXISTS ( + SELECT 1 FROM {self.db_name}_graph."Edges" + WHERE source_id = '{source_id}' + AND target_id = '{target_id}' + AND edge_type = '{type}' + ); + """ + logger.info(f"polardb [add_edge] query: {query}, properties: {json.dumps(properties)}") + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) + logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") + + elapsed_time = time.time() - start_time + logger.info(f" polardb [add_edge] insert completed time in {elapsed_time:.2f}s") + except Exception as e: + logger.error(f"Failed to insert edge: {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + @timed + def delete_edge(self, source_id: str, target_id: str, type: str) -> None: + """ + Delete a specific edge between two nodes. + Args: + source_id: ID of the source node. + target_id: ID of the target node. + type: Relationship type to remove. + """ + query = f""" + DELETE FROM "{self.db_name}_graph"."Edges" + WHERE source_id = %s AND target_id = %s AND edge_type = %s + """ + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type)) + logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") + finally: + self._return_connection(conn) + + @timed + def edge_exists( + self, + source_id: str, + target_id: str, + type: str = "ANY", + direction: str = "OUTGOING", + user_name: str | None = None, + ) -> bool: + """ + Check if an edge exists between two nodes. + Args: + source_id: ID of the source node. + target_id: ID of the target node. + type: Relationship type. Use "ANY" to match any relationship type. + direction: Direction of the edge. + Use "OUTGOING" (default), "INCOMING", or "ANY". + user_name (str, optional): User name for filtering in non-multi-db mode + Returns: + True if the edge exists, otherwise False. + """ + + # Prepare the relationship pattern + user_name = user_name if user_name else self.config.user_name + + # Prepare the match pattern with direction + if direction == "OUTGOING": + pattern = "(a:Memory)-[r]->(b:Memory)" + elif direction == "INCOMING": + pattern = "(a:Memory)<-[r]-(b:Memory)" + elif direction == "ANY": + pattern = "(a:Memory)-[r]-(b:Memory)" + else: + raise ValueError( + f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." + ) + query = f"SELECT * FROM cypher('{self.db_name}_graph', $$" + query += f"\nMATCH {pattern}" + query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + query += f"\nAND a.id = '{source_id}' AND b.id = '{target_id}'" + if type != "ANY": + query += f"\n AND type(r) = '{type}'" + + query += "\nRETURN r" + query += "\n$$) AS (r agtype)" + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + result = cursor.fetchone() + return result is not None and result[0] is not None + finally: + self._return_connection(conn) + + @timed + def get_edges( + self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None + ) -> list[dict[str, str]]: + """ + Get edges connected to a node, with optional type and direction filter. + + Args: + id: Node ID to retrieve edges for. + type: Relationship type to match, or 'ANY' to match all. + direction: 'OUTGOING', 'INCOMING', or 'ANY'. + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + List of edges: + [ + {"from": "source_id", "to": "target_id", "type": "RELATE"}, + ... + ] + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + if direction == "OUTGOING": + pattern = "(a:Memory)-[r]->(b:Memory)" + where_clause = f"a.id = '{id}'" + elif direction == "INCOMING": + pattern = "(a:Memory)<-[r]-(b:Memory)" + where_clause = f"a.id = '{id}'" + elif direction == "ANY": + pattern = "(a:Memory)-[r]-(b:Memory)" + where_clause = f"a.id = '{id}' OR b.id = '{id}'" + else: + raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") + + # Add type filter + if type != "ANY": + where_clause += f" AND type(r) = '{type}'" + + # Add user filter + where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH {pattern} + WHERE {where_clause} + RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + $$) AS (from_id agtype, to_id agtype, edge_type agtype) + """ + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() + + edges = [] + for row in results: + # Extract and clean from_id + from_id_raw = row[0].value if hasattr(row[0], "value") else row[0] + if ( + isinstance(from_id_raw, str) + and from_id_raw.startswith('"') + and from_id_raw.endswith('"') + ): + from_id = from_id_raw[1:-1] + else: + from_id = str(from_id_raw) + + # Extract and clean to_id + to_id_raw = row[1].value if hasattr(row[1], "value") else row[1] + if ( + isinstance(to_id_raw, str) + and to_id_raw.startswith('"') + and to_id_raw.endswith('"') + ): + to_id = to_id_raw[1:-1] + else: + to_id = str(to_id_raw) + + # Extract and clean edge_type + edge_type_raw = row[2].value if hasattr(row[2], "value") else row[2] + if ( + isinstance(edge_type_raw, str) + and edge_type_raw.startswith('"') + and edge_type_raw.endswith('"') + ): + edge_type = edge_type_raw[1:-1] + else: + edge_type = str(edge_type_raw) + + edges.append({"from": from_id, "to": to_id, "type": edge_type}) + return edges + + except Exception as e: + logger.error(f"Failed to get edges: {e}", exc_info=True) + return [] + finally: + self._return_connection(conn) diff --git a/src/memos/graph_dbs/polardb/filters.py b/src/memos/graph_dbs/polardb/filters.py new file mode 100644 index 000000000..a3f566202 --- /dev/null +++ b/src/memos/graph_dbs/polardb/filters.py @@ -0,0 +1,578 @@ +import json + +from typing import Any, Literal + +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class FilterMixin: + """Mixin for filter condition building (WHERE clause builders).""" + + def _build_user_name_and_kb_ids_conditions( + self, + user_name: str | None, + knowledgebase_ids: list | None, + default_user_name: str | None = None, + mode: Literal["cypher", "sql"] = "sql", + ) -> list[str]: + """ + Build user_name and knowledgebase_ids conditions. + + Args: + user_name: User name for filtering + knowledgebase_ids: List of knowledgebase IDs + default_user_name: Default user name from config if user_name is None + mode: 'cypher' for Cypher property access, 'sql' for AgType SQL access + + Returns: + List of condition strings (will be joined with OR) + """ + user_name_conditions = [] + effective_user_name = user_name if user_name else default_user_name + + def _fmt(value: str) -> str: + if mode == "cypher": + escaped = value.replace("'", "''") + return f"n.user_name = '{escaped}'" + return ( + f"ag_catalog.agtype_access_operator(properties::text::agtype, " + f"'\"user_name\"'::agtype) = '\"{value}\"'::agtype" + ) + + if effective_user_name: + user_name_conditions.append(_fmt(effective_user_name)) + + if knowledgebase_ids and isinstance(knowledgebase_ids, list) and len(knowledgebase_ids) > 0: + for kb_id in knowledgebase_ids: + if isinstance(kb_id, str): + user_name_conditions.append(_fmt(kb_id)) + + return user_name_conditions + + def _build_user_name_and_kb_ids_conditions_cypher( + self, user_name, knowledgebase_ids, default_user_name=None + ): + return self._build_user_name_and_kb_ids_conditions( + user_name, knowledgebase_ids, default_user_name, mode="cypher" + ) + + def _build_user_name_and_kb_ids_conditions_sql( + self, user_name, knowledgebase_ids, default_user_name=None + ): + return self._build_user_name_and_kb_ids_conditions( + user_name, knowledgebase_ids, default_user_name, mode="sql" + ) + + def _build_filter_conditions( + self, + filter: dict | None, + mode: Literal["cypher", "sql"] = "sql", + ) -> str | list[str]: + """ + Build filter conditions for Cypher or SQL queries. + + Args: + filter: Filter dictionary with "or" or "and" logic + mode: "cypher" for Cypher queries, "sql" for SQL queries + + Returns: + For mode="cypher": Filter WHERE clause string with " AND " prefix (empty string if no filter) + For mode="sql": List of filter WHERE clause strings (empty list if no filter) + """ + is_cypher = mode == "cypher" + filter = self.parse_filter(filter) + + if not filter: + return "" if is_cypher else [] + + # --- Dialect helpers --- + + def escape_string(value: str) -> str: + if is_cypher: + # Backslash escape for single quotes inside $$ dollar-quoted strings + return value.replace("'", "\\'") + else: + return value.replace("'", "''") + + def prop_direct(key: str) -> str: + """Property access expression for a direct (top-level) key.""" + if is_cypher: + return f"n.{key}" + else: + return f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"{key}\"'::agtype)" + + def prop_nested(info_field: str) -> str: + """Property access expression for a nested info.field key.""" + if is_cypher: + return f"n.info.{info_field}" + else: + return f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties::text::agtype, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype])" + + def prop_ref(key: str) -> str: + """Return the appropriate property access expression for a key (direct or nested).""" + if key.startswith("info."): + return prop_nested(key[5:]) + return prop_direct(key) + + def fmt_str_val(escaped_value: str) -> str: + """Format an escaped string value as a literal.""" + if is_cypher: + return f"'{escaped_value}'" + else: + return f"'\"{escaped_value}\"'::agtype" + + def fmt_non_str_val(value: Any) -> str: + """Format a non-string value as a literal.""" + if is_cypher: + return str(value) + else: + value_json = json.dumps(value) + return f"ag_catalog.agtype_in('{value_json}')" + + def fmt_array_eq_single_str(escaped_value: str) -> str: + """Format an array-equality check for a single string value: field = ['val'].""" + if is_cypher: + return f"['{escaped_value}']" + else: + return f"'[\"{escaped_value}\"]'::agtype" + + def fmt_array_eq_list(items: list, escape_fn) -> str: + """Format an array-equality check for a list of values.""" + if is_cypher: + escaped_items = [f"'{escape_fn(str(item))}'" for item in items] + return "[" + ", ".join(escaped_items) + "]" + else: + escaped_items = [escape_fn(str(item)) for item in items] + json_array = json.dumps(escaped_items) + return f"'{json_array}'::agtype" + + def fmt_array_eq_non_str(value: Any) -> str: + """Format an array-equality check for a single non-string value: field = [val].""" + if is_cypher: + return f"[{value}]" + else: + return f"'[{value}]'::agtype" + + def fmt_contains_str(escaped_value: str, prop_expr: str) -> str: + """Format a 'contains' check: array field contains a string value.""" + if is_cypher: + return f"'{escaped_value}' IN {prop_expr}" + else: + return f"{prop_expr} @> '[\"{escaped_value}\"]'::agtype" + + def fmt_contains_non_str(value: Any, prop_expr: str) -> str: + """Format a 'contains' check: array field contains a non-string value.""" + if is_cypher: + return f"{value} IN {prop_expr}" + else: + escaped_value = str(value).replace("'", "''") + return f"{prop_expr} @> '[\"{escaped_value}\"]'::agtype" + + def fmt_like(escaped_value: str, prop_expr: str) -> str: + """Format a 'like' (fuzzy match) check.""" + if is_cypher: + return f"{prop_expr} CONTAINS '{escaped_value}'" + else: + return f"{prop_expr}::text LIKE '%{escaped_value}%'" + + def fmt_datetime_cmp(prop_expr: str, cmp_op: str, escaped_value: str) -> str: + """Format a datetime comparison.""" + if is_cypher: + return f"{prop_expr}::timestamp {cmp_op} '{escaped_value}'::timestamp" + else: + return f"TRIM(BOTH '\"' FROM {prop_expr}::text)::timestamp {cmp_op} '{escaped_value}'::timestamp" + + def fmt_in_scalar_eq_str(escaped_value: str, prop_expr: str) -> str: + """Format scalar equality for 'in' operator with a string item.""" + return f"{prop_expr} = {fmt_str_val(escaped_value)}" + + def fmt_in_scalar_eq_non_str(item: Any, prop_expr: str) -> str: + """Format scalar equality for 'in' operator with a non-string item.""" + if is_cypher: + return f"{prop_expr} = {item}" + else: + return f"{prop_expr} = {item}::agtype" + + def fmt_in_array_contains_str(escaped_value: str, prop_expr: str) -> str: + """Format array-contains for 'in' operator with a string item.""" + if is_cypher: + return f"'{escaped_value}' IN {prop_expr}" + else: + return f"{prop_expr} @> '[\"{escaped_value}\"]'::agtype" + + def fmt_in_array_contains_non_str(item: Any, prop_expr: str) -> str: + """Format array-contains for 'in' operator with a non-string item.""" + if is_cypher: + return f"{item} IN {prop_expr}" + else: + escaped_value = str(item).replace("'", "''") + return f"{prop_expr} @> '[\"{escaped_value}\"]'::agtype" + + def escape_like_value(value: str) -> str: + """Escape a value for use in like/CONTAINS. SQL needs extra LIKE-char escaping.""" + escaped = escape_string(value) + if not is_cypher: + escaped = escaped.replace("%", "\\%").replace("_", "\\_") + return escaped + + def fmt_scalar_in_clause(items: list, prop_expr: str) -> str: + """Format a scalar IN clause for multiple values (cypher only has this path).""" + if is_cypher: + escaped_items = [ + f"'{escape_string(str(item))}'" if isinstance(item, str) else str(item) + for item in items + ] + array_str = "[" + ", ".join(escaped_items) + "]" + return f"{prop_expr} IN {array_str}" + else: + # SQL mode: use OR equality conditions + or_parts = [] + for item in items: + if isinstance(item, str): + escaped_value = escape_string(item) + or_parts.append(f"{prop_expr} = {fmt_str_val(escaped_value)}") + else: + or_parts.append(f"{prop_expr} = {item}::agtype") + return f"({' OR '.join(or_parts)})" + + # --- Main condition builder --- + + def build_filter_condition(condition_dict: dict) -> str: + """Build a WHERE condition for a single filter item.""" + condition_parts = [] + for key, value in condition_dict.items(): + is_info = key.startswith("info.") + info_field = key[5:] if is_info else None + prop_expr = prop_ref(key) + + # Check if value is a dict with comparison operators + if isinstance(value, dict): + for op, op_value in value.items(): + if op in ("gt", "lt", "gte", "lte"): + cmp_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="} + cmp_op = cmp_op_map[op] + + # Determine if this is a datetime field + field_name = info_field if is_info else key + is_dt = field_name in ( + "created_at", + "updated_at", + ) or field_name.endswith("_at") + + if isinstance(op_value, str): + escaped_value = escape_string(op_value) + if is_dt: + condition_parts.append( + fmt_datetime_cmp(prop_expr, cmp_op, escaped_value) + ) + else: + condition_parts.append( + f"{prop_expr} {cmp_op} {fmt_str_val(escaped_value)}" + ) + else: + condition_parts.append( + f"{prop_expr} {cmp_op} {fmt_non_str_val(op_value)}" + ) + + elif op == "=": + # Equality operator + field_name = info_field if is_info else key + is_array_field = field_name in ("tags", "sources") + + if isinstance(op_value, str): + escaped_value = escape_string(op_value) + if is_array_field: + condition_parts.append( + f"{prop_expr} = {fmt_array_eq_single_str(escaped_value)}" + ) + else: + condition_parts.append( + f"{prop_expr} = {fmt_str_val(escaped_value)}" + ) + elif isinstance(op_value, list): + if is_array_field: + condition_parts.append( + f"{prop_expr} = {fmt_array_eq_list(op_value, escape_string)}" + ) + else: + if is_cypher: + condition_parts.append(f"{prop_expr} = {op_value}") + elif is_info: + # Info nested field: use ::agtype cast + condition_parts.append(f"{prop_expr} = {op_value}::agtype") + else: + # Direct field: convert to JSON string and then to agtype + value_json = json.dumps(op_value) + condition_parts.append( + f"{prop_expr} = ag_catalog.agtype_in('{value_json}')" + ) + else: + if is_array_field: + condition_parts.append( + f"{prop_expr} = {fmt_array_eq_non_str(op_value)}" + ) + else: + condition_parts.append( + f"{prop_expr} = {fmt_non_str_val(op_value)}" + ) + + elif op == "contains": + if isinstance(op_value, str): + escaped_value = escape_string(str(op_value)) + condition_parts.append(fmt_contains_str(escaped_value, prop_expr)) + else: + condition_parts.append(fmt_contains_non_str(op_value, prop_expr)) + + elif op == "in": + if not isinstance(op_value, list): + raise ValueError( + f"in operator only supports array format. " + f"Use {{'{key}': {{'in': ['{op_value}']}}}} instead of {{'{key}': {{'in': '{op_value}'}}}}" + ) + + field_name = info_field if is_info else key + is_arr = field_name in ("file_ids", "tags", "sources") + + if len(op_value) == 0: + condition_parts.append("false") + elif len(op_value) == 1: + item = op_value[0] + if is_arr: + if isinstance(item, str): + escaped_value = escape_string(str(item)) + condition_parts.append( + fmt_in_array_contains_str(escaped_value, prop_expr) + ) + else: + condition_parts.append( + fmt_in_array_contains_non_str(item, prop_expr) + ) + else: + if isinstance(item, str): + escaped_value = escape_string(item) + condition_parts.append( + fmt_in_scalar_eq_str(escaped_value, prop_expr) + ) + else: + condition_parts.append( + fmt_in_scalar_eq_non_str(item, prop_expr) + ) + else: + if is_arr: + # For array fields, use OR conditions with contains + or_conditions = [] + for item in op_value: + if isinstance(item, str): + escaped_value = escape_string(str(item)) + or_conditions.append( + fmt_in_array_contains_str(escaped_value, prop_expr) + ) + else: + or_conditions.append( + fmt_in_array_contains_non_str(item, prop_expr) + ) + if or_conditions: + condition_parts.append(f"({' OR '.join(or_conditions)})") + else: + # For scalar fields + if is_cypher: + # Cypher uses IN clause with array literal + condition_parts.append( + fmt_scalar_in_clause(op_value, prop_expr) + ) + else: + # SQL uses OR equality conditions + or_conditions = [] + for item in op_value: + if isinstance(item, str): + escaped_value = escape_string(item) + or_conditions.append( + fmt_in_scalar_eq_str(escaped_value, prop_expr) + ) + else: + or_conditions.append( + fmt_in_scalar_eq_non_str(item, prop_expr) + ) + if or_conditions: + condition_parts.append( + f"({' OR '.join(or_conditions)})" + ) + + elif op == "like": + if isinstance(op_value, str): + escaped_value = escape_like_value(op_value) + condition_parts.append(fmt_like(escaped_value, prop_expr)) + else: + if is_cypher: + condition_parts.append(f"{prop_expr} CONTAINS {op_value}") + else: + condition_parts.append(f"{prop_expr}::text LIKE '%{op_value}%'") + + # Simple equality (value is not a dict) + elif is_info: + if isinstance(value, str): + escaped_value = escape_string(value) + condition_parts.append(f"{prop_expr} = {fmt_str_val(escaped_value)}") + else: + condition_parts.append(f"{prop_expr} = {fmt_non_str_val(value)}") + else: + if isinstance(value, str): + escaped_value = escape_string(value) + condition_parts.append(f"{prop_expr} = {fmt_str_val(escaped_value)}") + else: + condition_parts.append(f"{prop_expr} = {fmt_non_str_val(value)}") + return " AND ".join(condition_parts) + + # --- Assemble final result based on filter structure and mode --- + + if is_cypher: + filter_where_clause = "" + if isinstance(filter, dict): + if "or" in filter: + or_conditions = [] + for condition in filter["or"]: + if isinstance(condition, dict): + condition_str = build_filter_condition(condition) + if condition_str: + or_conditions.append(f"({condition_str})") + if or_conditions: + filter_where_clause = " AND " + f"({' OR '.join(or_conditions)})" + elif "and" in filter: + and_conditions = [] + for condition in filter["and"]: + if isinstance(condition, dict): + condition_str = build_filter_condition(condition) + if condition_str: + and_conditions.append(f"({condition_str})") + if and_conditions: + filter_where_clause = " AND " + " AND ".join(and_conditions) + else: + condition_str = build_filter_condition(filter) + if condition_str: + filter_where_clause = " AND " + condition_str + return filter_where_clause + else: + filter_conditions: list[str] = [] + if isinstance(filter, dict): + if "or" in filter: + or_conditions = [] + for condition in filter["or"]: + if isinstance(condition, dict): + condition_str = build_filter_condition(condition) + if condition_str: + or_conditions.append(f"({condition_str})") + if or_conditions: + filter_conditions.append(f"({' OR '.join(or_conditions)})") + elif "and" in filter: + for condition in filter["and"]: + if isinstance(condition, dict): + condition_str = build_filter_condition(condition) + if condition_str: + filter_conditions.append(f"({condition_str})") + else: + condition_str = build_filter_condition(filter) + if condition_str: + filter_conditions.append(condition_str) + return filter_conditions + + def _build_filter_conditions_cypher( + self, + filter: dict | None, + ) -> str: + """ + Build filter conditions for Cypher queries. + + Args: + filter: Filter dictionary with "or" or "and" logic + + Returns: + Filter WHERE clause string (empty string if no filter) + """ + return self._build_filter_conditions(filter, mode="cypher") + + def _build_filter_conditions_sql( + self, + filter: dict | None, + ) -> list[str]: + """ + Build filter conditions for SQL queries. + + Args: + filter: Filter dictionary with "or" or "and" logic + + Returns: + List of filter WHERE clause strings (empty list if no filter) + """ + return self._build_filter_conditions(filter, mode="sql") + + def parse_filter( + self, + filter_dict: dict | None = None, + ): + if filter_dict is None: + return None + full_fields = { + "id", + "key", + "tags", + "type", + "usage", + "memory", + "status", + "sources", + "user_id", + "graph_id", + "user_name", + "background", + "confidence", + "created_at", + "session_id", + "updated_at", + "memory_type", + "node_type", + "info", + "source", + "file_ids", + } + + def process_condition(condition): + if not isinstance(condition, dict): + return condition + + new_condition = {} + + for key, value in condition.items(): + if key.lower() in ["or", "and"]: + if isinstance(value, list): + processed_items = [] + for item in value: + if isinstance(item, dict): + processed_item = {} + for item_key, item_value in item.items(): + if item_key not in full_fields and not item_key.startswith( + "info." + ): + new_item_key = f"info.{item_key}" + else: + new_item_key = item_key + processed_item[new_item_key] = item_value + processed_items.append(processed_item) + else: + processed_items.append(item) + new_condition[key] = processed_items + else: + new_condition[key] = value + else: + if key not in full_fields and not key.startswith("info."): + new_key = f"info.{key}" + else: + new_key = key + + new_condition[new_key] = value + + return new_condition + + return process_condition(filter_dict) diff --git a/src/memos/graph_dbs/polardb/helpers.py b/src/memos/graph_dbs/polardb/helpers.py new file mode 100644 index 000000000..c8dd2b844 --- /dev/null +++ b/src/memos/graph_dbs/polardb/helpers.py @@ -0,0 +1,13 @@ +"""Module-level utility functions for PolarDB graph database.""" + +import random + + +def generate_vector(dim=1024, low=-0.2, high=0.2): + """Generate a random vector for testing purposes.""" + return [round(random.uniform(low, high), 6) for _ in range(dim)] + + +def escape_sql_string(value: str) -> str: + """Escape single quotes in SQL string.""" + return value.replace("'", "''") diff --git a/src/memos/graph_dbs/polardb/maintenance.py b/src/memos/graph_dbs/polardb/maintenance.py new file mode 100644 index 000000000..1bfa468d9 --- /dev/null +++ b/src/memos/graph_dbs/polardb/maintenance.py @@ -0,0 +1,769 @@ +import copy +import json +import time + +from typing import Any + +from memos.graph_dbs.utils import compose_node as _compose_node +from memos.graph_dbs.utils import prepare_node_metadata as _prepare_node_metadata +from memos.log import get_logger +from memos.utils import timed + + +logger = get_logger(__name__) + + +class MaintenanceMixin: + """Mixin for maintenance operations (import/export, clear, cleanup).""" + + @timed + def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: + """ + Import the entire graph from a serialized dictionary. + + Args: + data: A dictionary containing all nodes and edges to be loaded. + user_name (str, optional): User name for filtering in non-multi-db mode + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + # Import nodes + for node in data.get("nodes", []): + try: + id, memory, metadata = _compose_node(node) + metadata["user_name"] = user_name + metadata = _prepare_node_metadata(metadata) + metadata.update({"id": id, "memory": memory}) + + # Use add_node to insert node + self.add_node(id, memory, metadata) + + except Exception as e: + logger.error(f"Fail to load node: {node}, error: {e}") + + # Import edges + for edge in data.get("edges", []): + try: + source_id, target_id = edge["source"], edge["target"] + edge_type = edge["type"] + + # Use add_edge to insert edge + self.add_edge(source_id, target_id, edge_type, user_name) + + except Exception as e: + logger.error(f"Fail to load edge: {edge}, error: {e}") + + @timed + def export_graph( + self, + include_embedding: bool = False, + user_name: str | None = None, + user_id: str | None = None, + page: int | None = None, + page_size: int | None = None, + filter: dict | None = None, + **kwargs, + ) -> dict[str, Any]: + """ + Export all graph nodes and edges in a structured form. + Args: + include_embedding (bool): Whether to include the large embedding field. + user_name (str, optional): User name for filtering in non-multi-db mode + user_id (str, optional): User ID for filtering + page (int, optional): Page number (starts from 1). If None, exports all data without pagination. + page_size (int, optional): Number of items per page. If None, exports all data without pagination. + filter (dict, optional): Filter dictionary for metadata filtering. Supports "and", "or" logic and operators: + - "=": equality + - "in": value in list + - "contains": array contains value + - "gt", "lt", "gte", "lte": comparison operators + - "like": fuzzy matching + Example: {"and": [{"created_at": {"gte": "2025-01-01"}}, {"tags": {"contains": "AI"}}]} + + Returns: + { + "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], + "edges": [ { "source": ..., "target": ..., "type": ... }, ... ], + "total_nodes": int, # Total number of nodes matching the filter criteria + "total_edges": int, # Total number of edges matching the filter criteria + } + """ + logger.info( + f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}" + ) + user_id = user_id if user_id else self._get_config_value("user_id") + + # Initialize total counts + total_nodes = 0 + total_edges = 0 + + # Determine if pagination is needed + use_pagination = page is not None and page_size is not None + + # Validate pagination parameters if pagination is enabled + if use_pagination: + if page < 1: + page = 1 + if page_size < 1: + page_size = 10 + offset = (page - 1) * page_size + else: + offset = None + + conn = None + try: + conn = self._get_connection() + # Build WHERE conditions + where_conditions = [] + if user_name: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + ) + if user_id: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" + ) + + # Build filter conditions using common method + filter_conditions = self._build_filter_conditions_sql(filter) + logger.info(f"[export_graph] filter_conditions: {filter_conditions}") + if filter_conditions: + where_conditions.extend(filter_conditions) + + where_clause = "" + if where_conditions: + where_clause = f"WHERE {' AND '.join(where_conditions)}" + + # Get total count of nodes before pagination + count_node_query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + {where_clause} + """ + logger.info(f"[export_graph nodes count] Query: {count_node_query}") + with conn.cursor() as cursor: + cursor.execute(count_node_query) + total_nodes = cursor.fetchone()[0] + + # Export nodes + # Build pagination clause if needed + pagination_clause = "" + if use_pagination: + pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + + if include_embedding: + node_query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY ag_catalog.agtype_access_operator(properties::text::agtype, '"created_at"'::agtype) DESC NULLS LAST, + id DESC + {pagination_clause} + """ + else: + node_query = f""" + SELECT id, properties + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY ag_catalog.agtype_access_operator(properties::text::agtype, '"created_at"'::agtype) DESC NULLS LAST, + id DESC + {pagination_clause} + """ + logger.info(f"[export_graph nodes] Query: {node_query}") + with conn.cursor() as cursor: + cursor.execute(node_query) + node_results = cursor.fetchall() + nodes = [] + + for row in node_results: + if include_embedding: + """row is (id, properties, embedding)""" + _, properties_json, embedding_json = row + else: + """row is (id, properties)""" + _, properties_json = row + embedding_json = None + + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except json.JSONDecodeError: + properties = {} + else: + properties = properties_json if properties_json else {} + + # Remove embedding field if include_embedding is False + if not include_embedding: + properties.pop("embedding", None) + elif include_embedding and embedding_json is not None: + properties["embedding"] = embedding_json + + nodes.append(self._parse_node(properties)) + + except Exception as e: + logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) + raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e + finally: + self._return_connection(conn) + + conn = None + try: + conn = self._get_connection() + # Build Cypher WHERE conditions for edges + cypher_where_conditions = [] + if user_name: + cypher_where_conditions.append(f"a.user_name = '{user_name}'") + cypher_where_conditions.append(f"b.user_name = '{user_name}'") + if user_id: + cypher_where_conditions.append(f"a.user_id = '{user_id}'") + cypher_where_conditions.append(f"b.user_id = '{user_id}'") + + # Build filter conditions for edges (apply to both source and target nodes) + filter_where_clause = self._build_filter_conditions_cypher(filter) + logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}") + if filter_where_clause: + # _build_filter_conditions_cypher returns a string that starts with " AND " if filter exists + # Remove the leading " AND " and replace n. with a. for source node and b. for target node + filter_clause = filter_where_clause.strip() + if filter_clause.startswith("AND "): + filter_clause = filter_clause[4:].strip() + # Replace n. with a. for source node and create a copy for target node + source_filter = filter_clause.replace("n.", "a.") + target_filter = filter_clause.replace("n.", "b.") + # Combine source and target filters with AND + combined_filter = f"({source_filter}) AND ({target_filter})" + cypher_where_conditions.append(combined_filter) + + cypher_where_clause = "" + if cypher_where_conditions: + cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}" + + # Get total count of edges before pagination + count_edge_query = f""" + SELECT COUNT(*) + FROM ( + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (a:Memory)-[r]->(b:Memory) + {cypher_where_clause} + RETURN a.id AS source, b.id AS target, type(r) as edge + $$) AS (source agtype, target agtype, edge agtype) + ) AS edges + """ + logger.info(f"[export_graph edges count] Query: {count_edge_query}") + with conn.cursor() as cursor: + cursor.execute(count_edge_query) + total_edges = cursor.fetchone()[0] + + # Export edges using cypher query + # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery + # Build pagination clause if needed + edge_pagination_clause = "" + if use_pagination: + edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + + edge_query = f""" + SELECT source, target, edge FROM ( + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (a:Memory)-[r]->(b:Memory) + {cypher_where_clause} + RETURN a.id AS source, b.id AS target, type(r) as edge + ORDER BY COALESCE(a.created_at, '1970-01-01T00:00:00') DESC, + COALESCE(b.created_at, '1970-01-01T00:00:00') DESC, + a.id DESC, b.id DESC + $$) AS (source agtype, target agtype, edge agtype) + ) AS edges + {edge_pagination_clause} + """ + logger.info(f"[export_graph edges] Query: {edge_query}") + with conn.cursor() as cursor: + cursor.execute(edge_query) + edge_results = cursor.fetchall() + edges = [] + + for row in edge_results: + source_agtype, target_agtype, edge_agtype = row + + # Extract and clean source + source_raw = ( + source_agtype.value + if hasattr(source_agtype, "value") + else str(source_agtype) + ) + if ( + isinstance(source_raw, str) + and source_raw.startswith('"') + and source_raw.endswith('"') + ): + source = source_raw[1:-1] + else: + source = str(source_raw) + + # Extract and clean target + target_raw = ( + target_agtype.value + if hasattr(target_agtype, "value") + else str(target_agtype) + ) + if ( + isinstance(target_raw, str) + and target_raw.startswith('"') + and target_raw.endswith('"') + ): + target = target_raw[1:-1] + else: + target = str(target_raw) + + # Extract and clean edge type + type_raw = ( + edge_agtype.value if hasattr(edge_agtype, "value") else str(edge_agtype) + ) + if ( + isinstance(type_raw, str) + and type_raw.startswith('"') + and type_raw.endswith('"') + ): + edge_type = type_raw[1:-1] + else: + edge_type = str(type_raw) + + edges.append( + { + "source": source, + "target": target, + "type": edge_type, + } + ) + + except Exception as e: + logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True) + raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e + finally: + self._return_connection(conn) + + return { + "nodes": nodes, + "edges": edges, + "total_nodes": total_nodes, + "total_edges": total_edges, + } + + @timed + def clear(self, user_name: str | None = None) -> None: + """ + Clear the entire graph if the target database exists. + + Args: + user_name (str, optional): User name for filtering in non-multi-db mode + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + try: + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.user_name = '{user_name}' + DETACH DELETE n + $$) AS (result agtype) + """ + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + logger.info("Cleared all nodes from database.") + finally: + self._return_connection(conn) + + except Exception as e: + logger.error(f"[ERROR] Failed to clear database: {e}") + + def drop_database(self) -> None: + """Permanently delete the entire graph this instance is using.""" + return + if self._get_config_value("use_multi_db", True): + with self.connection.cursor() as cursor: + cursor.execute(f"SELECT drop_graph('{self.db_name}_graph', true)") + logger.info(f"Graph '{self.db_name}_graph' has been dropped.") + else: + raise ValueError( + f"Refusing to drop graph '{self.db_name}_graph' in " + f"Shared Database Multi-Tenant mode" + ) + + @timed + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str | None = None + ) -> None: + """ + Remove all WorkingMemory nodes except the latest `keep_latest` entries. + + Args: + memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). + keep_latest (int): Number of latest WorkingMemory entries to keep. + user_name (str, optional): User name for filtering in non-multi-db mode + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + # Use actual OFFSET logic, consistent with nebular.py + # First find IDs to delete, then delete them + select_query = f""" + SELECT id FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"memory_type"'::agtype) = %s::agtype + AND ag_catalog.agtype_access_operator(properties::text::agtype, '"user_name"'::agtype) = %s::agtype + ORDER BY ag_catalog.agtype_access_operator(properties::text::agtype, '"updated_at"'::agtype) DESC + OFFSET %s + """ + select_params = [ + self.format_param_value(memory_type), + self.format_param_value(user_name), + keep_latest, + ] + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Execute query to get IDs to delete + cursor.execute(select_query, select_params) + ids_to_delete = [row[0] for row in cursor.fetchall()] + + if not ids_to_delete: + logger.info(f"No {memory_type} memories to remove for user {user_name}") + return + + # Build delete query + placeholders = ",".join(["%s"] * len(ids_to_delete)) + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE id IN ({placeholders}) + """ + delete_params = ids_to_delete + + # Execute deletion + cursor.execute(delete_query, delete_params) + deleted_count = cursor.rowcount + logger.info( + f"Removed {deleted_count} oldest {memory_type} memories, " + f"keeping {keep_latest} latest for user {user_name}, " + f"removed ids: {ids_to_delete}" + ) + except Exception as e: + logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + def merge_nodes(self, id1: str, id2: str) -> str: + """Merge two similar or duplicate nodes into one.""" + raise NotImplementedError + + def deduplicate_nodes(self) -> None: + """Deduplicate redundant or semantically similar nodes.""" + raise NotImplementedError + + def detect_conflicts(self) -> list[tuple[str, str]]: + """Detect conflicting nodes based on logical or semantic inconsistency.""" + raise NotImplementedError + + def _convert_graph_edges(self, core_node: dict) -> dict: + data = copy.deepcopy(core_node) + id_map = {} + core_node = data.get("core_node", {}) + if not core_node: + return { + "core_node": None, + "neighbors": data.get("neighbors", []), + "edges": data.get("edges", []), + } + core_meta = core_node.get("metadata", {}) + if "graph_id" in core_meta and "id" in core_node: + id_map[core_meta["graph_id"]] = core_node["id"] + for neighbor in data.get("neighbors", []): + n_meta = neighbor.get("metadata", {}) + if "graph_id" in n_meta and "id" in neighbor: + id_map[n_meta["graph_id"]] = neighbor["id"] + for edge in data.get("edges", []): + src = edge.get("source") + tgt = edge.get("target") + if src in id_map: + edge["source"] = id_map[src] + if tgt in id_map: + edge["target"] = id_map[tgt] + return data + + @timed + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]: + """Get user names by memory ids. + + Args: + memory_ids: List of memory node IDs to query. + + Returns: + dict[str, str | None]: Dictionary mapping memory_id to user_name. + - Key: memory_id + - Value: user_name if exists, None if memory_id does not exist + Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None} + """ + logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}") + if not memory_ids: + return {} + + # Validate and normalize memory_ids + # Ensure all items are strings + normalized_memory_ids = [] + for mid in memory_ids: + if not isinstance(mid, str): + mid = str(mid) + # Remove any whitespace + mid = mid.strip() + if mid: + normalized_memory_ids.append(mid) + + if not normalized_memory_ids: + return {} + + # Escape special characters for JSON string format in agtype + def escape_memory_id(mid: str) -> str: + """Escape special characters in memory_id for JSON string format.""" + # Escape backslashes first, then double quotes + mid_str = mid.replace("\\", "\\\\") + mid_str = mid_str.replace('"', '\\"') + return mid_str + + # Build OR conditions for each memory_id + id_conditions = [] + for mid in normalized_memory_ids: + # Escape special characters + escaped_mid = escape_memory_id(mid) + id_conditions.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"id\"'::agtype) = '\"{escaped_mid}\"'::agtype" + ) + + where_clause = f"({' OR '.join(id_conditions)})" + + # Query to get memory_id and user_name pairs + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties::text::agtype, '\"id\"'::agtype)::text AS memory_id, + ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype)::text AS user_name + FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + + logger.info(f"[get_user_names_by_memory_ids] query: {query}") + conn = None + result_dict = {} + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() + + # Build result dictionary from query results + for row in results: + memory_id_raw = row[0] + user_name_raw = row[1] + + # Remove quotes if present + if isinstance(memory_id_raw, str): + memory_id = memory_id_raw.strip('"').strip("'") + else: + memory_id = str(memory_id_raw).strip('"').strip("'") + + if isinstance(user_name_raw, str): + user_name = user_name_raw.strip('"').strip("'") + else: + user_name = ( + str(user_name_raw).strip('"').strip("'") if user_name_raw else None + ) + + result_dict[memory_id] = user_name if user_name else None + + # Set None for memory_ids that were not found + for mid in normalized_memory_ids: + if mid not in result_dict: + result_dict[mid] = None + + logger.info( + f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " + f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" + ) + + return result_dict + except Exception as e: + logger.error( + f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True + ) + raise + finally: + self._return_connection(conn) + + def exist_user_name(self, user_name: str) -> dict[str, bool]: + """Check if user name exists in the graph. + + Args: + user_name: User name to check. + + Returns: + dict[str, bool]: Dictionary with user_name as key and bool as value indicating existence. + """ + logger.info(f"[exist_user_name] Querying user_name {user_name}") + if not user_name: + return {user_name: False} + + # Escape special characters for JSON string format in agtype + def escape_user_name(un: str) -> str: + """Escape special characters in user_name for JSON string format.""" + # Escape backslashes first, then double quotes + un_str = un.replace("\\", "\\\\") + un_str = un_str.replace('"', '\\"') + return un_str + + # Escape special characters + escaped_un = escape_user_name(user_name) + + # Query to check if user_name exists + query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = '\"{escaped_un}\"'::agtype + """ + logger.info(f"[exist_user_name] query: {query}") + result_dict = {} + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + count = cursor.fetchone()[0] + result = count > 0 + result_dict[user_name] = result + return result_dict + except Exception as e: + logger.error( + f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True + ) + raise + finally: + self._return_connection(conn) + + @timed + def delete_node_by_prams( + self, + writable_cube_ids: list[str] | None = None, + memory_ids: list[str] | None = None, + file_ids: list[str] | None = None, + filter: dict | None = None, + ) -> int: + """ + Delete nodes by memory_ids, file_ids, or filter. + + Args: + writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes. + If not provided, no user_name filter will be applied. + memory_ids (list[str], optional): List of memory node IDs to delete. + file_ids (list[str], optional): List of file node IDs to delete. + filter (dict, optional): Filter dictionary for metadata filtering. + Filter conditions are directly used in DELETE WHERE clause without pre-querying. + + Returns: + int: Number of nodes deleted. + """ + batch_start_time = time.time() + logger.info( + f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" + ) + + # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) + # Only add user_name filter if writable_cube_ids is provided + user_name_conditions = [] + if writable_cube_ids and len(writable_cube_ids) > 0: + for cube_id in writable_cube_ids: + # Use agtype_access_operator with VARIADIC ARRAY format for consistency + user_name_conditions.append( + f"agtype_access_operator(VARIADIC ARRAY[properties::text::agtype, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype" + ) + + # Build filter conditions using common method (no query, direct use in WHERE clause) + filter_conditions = [] + if filter: + filter_conditions = self._build_filter_conditions_sql(filter) + logger.info(f"[delete_node_by_prams] filter_conditions: {filter_conditions}") + + # If no conditions to delete, return 0 + if not memory_ids and not file_ids and not filter_conditions: + logger.warning( + "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" + ) + return 0 + + conn = None + total_deleted_count = 0 + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Build WHERE conditions list + where_conditions = [] + + # Add memory_ids conditions + if memory_ids: + logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids") + id_conditions = [] + for node_id in memory_ids: + id_conditions.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" + ) + where_conditions.append(f"({' OR '.join(id_conditions)})") + + # Add file_ids conditions + if file_ids: + logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") + file_id_conditions = [] + for file_id in file_ids: + file_id_conditions.append( + f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties::text::agtype, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" + ) + where_conditions.append(f"({' OR '.join(file_id_conditions)})") + + # Add filter conditions + if filter_conditions: + logger.info("[delete_node_by_prams] Processing filter conditions") + where_conditions.extend(filter_conditions) + + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_conditions.append(f"({user_name_where})") + + # Build final WHERE clause + if not where_conditions: + logger.warning("[delete_node_by_prams] No WHERE conditions to delete") + return 0 + + where_clause = " AND ".join(where_conditions) + + # Delete directly without counting + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") + + cursor.execute(delete_query) + deleted_count = cursor.rowcount + total_deleted_count = deleted_count + + logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes") + + elapsed_time = time.time() - batch_start_time + logger.info( + f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes" + ) + except Exception as e: + logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + logger.info(f"[delete_node_by_prams] Successfully deleted {total_deleted_count} nodes") + return total_deleted_count diff --git a/src/memos/graph_dbs/polardb/nodes.py b/src/memos/graph_dbs/polardb/nodes.py new file mode 100644 index 000000000..8f0064efc --- /dev/null +++ b/src/memos/graph_dbs/polardb/nodes.py @@ -0,0 +1,711 @@ +import json +import time + +from datetime import datetime +from typing import Any + +from memos.graph_dbs.polardb.helpers import generate_vector +from memos.graph_dbs.utils import prepare_node_metadata as _prepare_node_metadata +from memos.log import get_logger +from memos.utils import timed + + +logger = get_logger(__name__) + + +class NodeMixin: + """Mixin for node (memory) CRUD operations.""" + + @timed + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + ) -> None: + """Add a memory node to the graph.""" + logger.info(f"[add_node] id: {id}, memory: {memory}, metadata: {metadata}") + + # user_name comes from metadata; fallback to config if missing + metadata["user_name"] = user_name if user_name else self.config.user_name + + metadata = _prepare_node_metadata(metadata) + + # Merge node and set metadata + created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) + updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) + + # Prepare properties + properties = { + "id": id, + "memory": memory, + "created_at": created_at, + "updated_at": updated_at, + "delete_time": "", + "delete_record_id": "", + **metadata, + } + + # Generate embedding if not provided + if "embedding" not in properties or not properties["embedding"]: + properties["embedding"] = generate_vector( + self._get_config_value("embedding_dimension", 1024) + ) + + # serialization - JSON-serialize sources and usage fields + for field_name in ["sources", "usage"]: + if properties.get(field_name): + if isinstance(properties[field_name], list): + for idx in range(len(properties[field_name])): + # Serialize only when element is not a string + if not isinstance(properties[field_name][idx], str): + properties[field_name][idx] = json.dumps(properties[field_name][idx]) + elif isinstance(properties[field_name], str): + # If already a string, leave as-is + pass + + # Extract embedding for separate column + embedding_vector = properties.pop("embedding", []) + if not isinstance(embedding_vector, list): + embedding_vector = [] + + # Select column name based on embedding dimension + embedding_column = "embedding" # default column + if len(embedding_vector) == 3072: + embedding_column = "embedding_3072" + elif len(embedding_vector) == 1024: + embedding_column = "embedding" + elif len(embedding_vector) == 768: + embedding_column = "embedding_768" + + conn = None + insert_query = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Delete existing record first (if any) + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id = %s + """ + cursor.execute(delete_query, (id,)) + properties["graph_id"] = str(id) + + # Then insert new record + if embedding_vector: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + %s, + %s, + %s + ) + """ + cursor.execute( + insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) + ) + logger.info( + f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + else: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + %s, + %s + ) + """ + cursor.execute(insert_query, (id, json.dumps(properties))) + logger.info( + f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + except Exception as e: + logger.error(f"[add_node] Failed to add node: {e}", exc_info=True) + raise + finally: + if insert_query: + logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") + self._return_connection(conn) + + @timed + def add_nodes_batch( + self, + nodes: list[dict[str, Any]], + user_name: str | None = None, + ) -> None: + """ + Batch add multiple memory nodes to the graph. + + Args: + nodes: List of node dictionaries, each containing: + - id: str - Node ID + - memory: str - Memory content + - metadata: dict[str, Any] - Node metadata + user_name: Optional user name (will use config default if not provided) + """ + batch_start_time = time.time() + if not nodes: + logger.warning("[add_nodes_batch] Empty nodes list, skipping") + return + + logger.info(f"[add_nodes_batch] Processing only first node (total nodes: {len(nodes)})") + + # user_name comes from parameter; fallback to config if missing + effective_user_name = user_name if user_name else self.config.user_name + + # Prepare all nodes + prepared_nodes = [] + for node_data in nodes: + try: + id = node_data["id"] + memory = node_data["memory"] + metadata = node_data.get("metadata", {}) + + logger.debug(f"[add_nodes_batch] Processing node id: {id}") + + # Set user_name in metadata + metadata["user_name"] = effective_user_name + + metadata = _prepare_node_metadata(metadata) + + # Merge node and set metadata + created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) + updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) + + # Prepare properties + properties = { + "id": id, + "memory": memory, + "created_at": created_at, + "updated_at": updated_at, + "delete_time": "", + "delete_record_id": "", + **metadata, + } + + # Generate embedding if not provided + if "embedding" not in properties or not properties["embedding"]: + properties["embedding"] = generate_vector( + self._get_config_value("embedding_dimension", 1024) + ) + + # Serialization - JSON-serialize sources and usage fields + for field_name in ["sources", "usage"]: + if properties.get(field_name): + if isinstance(properties[field_name], list): + for idx in range(len(properties[field_name])): + # Serialize only when element is not a string + if not isinstance(properties[field_name][idx], str): + properties[field_name][idx] = json.dumps( + properties[field_name][idx] + ) + elif isinstance(properties[field_name], str): + # If already a string, leave as-is + pass + + # Extract embedding for separate column + embedding_vector = properties.pop("embedding", []) + if not isinstance(embedding_vector, list): + embedding_vector = [] + + # Select column name based on embedding dimension + embedding_column = "embedding" # default column + if len(embedding_vector) == 3072: + embedding_column = "embedding_3072" + elif len(embedding_vector) == 1024: + embedding_column = "embedding" + elif len(embedding_vector) == 768: + embedding_column = "embedding_768" + + prepared_nodes.append( + { + "id": id, + "memory": memory, + "properties": properties, + "embedding_vector": embedding_vector, + "embedding_column": embedding_column, + } + ) + except Exception as e: + logger.error( + f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}", + exc_info=True, + ) + # Continue with other nodes + continue + + if not prepared_nodes: + logger.warning("[add_nodes_batch] No valid nodes to insert after preparation") + return + + # Group nodes by embedding column to optimize batch inserts + nodes_by_embedding_column = {} + for node in prepared_nodes: + col = node["embedding_column"] + if col not in nodes_by_embedding_column: + nodes_by_embedding_column[col] = [] + nodes_by_embedding_column[col].append(node) + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Process each group separately + for embedding_column, nodes_group in nodes_by_embedding_column.items(): + # Batch delete existing records using IN clause + ids_to_delete = [node["id"] for node in nodes_group] + if ids_to_delete: + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id = ANY(%s::text[]) + """ + cursor.execute(delete_query, (ids_to_delete,)) + + # Set graph_id in properties (using text ID directly) + for node in nodes_group: + node["properties"]["graph_id"] = str(node["id"]) + + # Use PREPARE/EXECUTE for efficient batch insert + # Generate unique prepare statement name to avoid conflicts + prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}" + + try: + if embedding_column and any( + node["embedding_vector"] for node in nodes_group + ): + # PREPARE statement for insert with embedding + prepare_query = f""" + PREPARE {prepare_name} AS + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + $1, + $2::jsonb, + $3::vector + ) + """ + logger.info( + f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}" + ) + logger.info( + f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}" + ) + + cursor.execute(prepare_query) + + # Execute prepared statement for each node + for node in nodes_group: + properties_json = json.dumps(node["properties"]) + embedding_json = ( + json.dumps(node["embedding_vector"]) + if node["embedding_vector"] + else None + ) + + cursor.execute( + f"EXECUTE {prepare_name}(%s, %s, %s)", + (node["id"], properties_json, embedding_json), + ) + else: + # PREPARE statement for insert without embedding + prepare_query = f""" + PREPARE {prepare_name} AS + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + $1, + $2::jsonb + ) + """ + logger.info( + f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}" + ) + logger.info( + f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}" + ) + cursor.execute(prepare_query) + + # Execute prepared statement for each node + for node in nodes_group: + properties_json = json.dumps(node["properties"]) + + cursor.execute( + f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json) + ) + finally: + # DEALLOCATE prepared statement (always execute, even on error) + try: + cursor.execute(f"DEALLOCATE {prepare_name}") + logger.info( + f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}" + ) + except Exception as dealloc_error: + logger.warning( + f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" + ) + + logger.info( + f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" + ) + elapsed_time = time.time() - batch_start_time + logger.info( + f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s" + ) + + except Exception as e: + logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + @timed + def get_node( + self, id: str, include_embedding: bool = False, user_name: str | None = None + ) -> dict[str, Any] | None: + """ + Retrieve a Memory node by its unique ID. + + Args: + id (str): Node ID (Memory.id) + include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + dict: Node properties as key-value pairs, or None if not found. + """ + logger.info( + f"polardb [get_node] id: {id}, include_embedding: {include_embedding}, user_name: {user_name}" + ) + start_time = time.time() + select_fields = "id, properties, embedding" if include_embedding else "id, properties" + + query = f""" + SELECT {select_fields} + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) = %s::agtype + """ + params = [self.format_param_value(id)] + + # Only add user filter when user_name is provided + if user_name is not None: + query += "\nAND ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + params.append(self.format_param_value(user_name)) + + logger.info(f"polardb [get_node] query: {query},params: {params}") + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + + if result: + if include_embedding: + _, properties_json, embedding_json = result + else: + _, properties_json = result + embedding_json = None + + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {id}") + properties = {} + else: + properties = properties_json if properties_json else {} + + # Parse embedding from JSONB if it exists and include_embedding is True + if include_embedding and embedding_json is not None: + try: + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {id}") + + elapsed_time = time.time() - start_time + logger.info( + f" polardb [get_node] get_node completed time in {elapsed_time:.2f}s" + ) + return self._parse_node( + { + "id": id, + "memory": properties.get("memory", ""), + **properties, + } + ) + return None + + except Exception as e: + logger.error(f"[get_node] Failed to retrieve node '{id}': {e}", exc_info=True) + return None + finally: + self._return_connection(conn) + + @timed + def get_nodes( + self, ids: list[str], user_name: str | None = None, **kwargs + ) -> list[dict[str, Any]]: + """ + Retrieve the metadata and memory of a list of nodes. + Args: + ids: List of Node identifier. + Returns: + list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'. + + Notes: + - Assumes all provided IDs are valid and exist. + - Returns empty list if input is empty. + """ + logger.info(f"get_nodes ids:{ids},user_name:{user_name}") + if not ids: + return [] + + # Build WHERE clause using IN operator with agtype array + # Use ANY operator with array for better performance + placeholders = ",".join(["%s"] * len(ids)) + params = [self.format_param_value(id_val) for id_val in ids] + + query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '\"id\"'::agtype) = ANY(ARRAY[{placeholders}]::agtype[]) + """ + + # Only add user_name filter if provided + if user_name is not None: + query += " AND ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + params.append(self.format_param_value(user_name)) + + logger.info(f"get_nodes query:{query},params:{params}") + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + + nodes = [] + for row in results: + node_id, properties_json, embedding_json = row + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {node_id}") + properties = {} + else: + properties = properties_json if properties_json else {} + + # Parse embedding from JSONB if it exists + if embedding_json is not None and kwargs.get("include_embedding"): + try: + # remove embedding + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + nodes.append( + self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } + ) + ) + return nodes + finally: + self._return_connection(conn) + + @timed + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: + """ + Update node fields in PolarDB, auto-converting `created_at` and `updated_at` to datetime type if present. + """ + if not fields: + return + + user_name = user_name if user_name else self.config.user_name + + # Get the current node + current_node = self.get_node(id, user_name=user_name) + if not current_node: + return + + # Update properties but keep original id and memory fields + properties = current_node["metadata"].copy() + original_id = properties.get("id", id) # Preserve original ID + original_memory = current_node.get("memory", "") # Preserve original memory + + # If fields include memory, use it; otherwise keep original memory + if "memory" in fields: + original_memory = fields.pop("memory") + + properties.update(fields) + properties["id"] = original_id # Ensure ID is not overwritten + properties["memory"] = original_memory # Ensure memory is not overwritten + + # Handle embedding field + embedding_vector = None + if "embedding" in fields: + embedding_vector = fields.pop("embedding") + if not isinstance(embedding_vector, list): + embedding_vector = None + + # Build update query + if embedding_vector is not None: + query = f""" + UPDATE "{self.db_name}_graph"."Memory" + SET properties = %s, embedding = %s + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) = %s::agtype + """ + params = [ + json.dumps(properties), + json.dumps(embedding_vector), + self.format_param_value(id), + ] + else: + query = f""" + UPDATE "{self.db_name}_graph"."Memory" + SET properties = %s + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) = %s::agtype + """ + params = [json.dumps(properties), self.format_param_value(id)] + + # Only add user filter when user_name is provided + if user_name is not None: + query += "\nAND ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + params.append(self.format_param_value(user_name)) + + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + except Exception as e: + logger.error(f"[update_node] Failed to update node '{id}': {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + @timed + def delete_node(self, id: str, user_name: str | None = None) -> None: + """ + Delete a node from the graph. + Args: + id: Node identifier to delete. + user_name (str, optional): User name for filtering in non-multi-db mode + """ + query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) = %s::agtype + """ + params = [self.format_param_value(id)] + + # Only add user filter when user_name is provided + if user_name is not None: + query += "\nAND ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + params.append(self.format_param_value(user_name)) + + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + except Exception as e: + logger.error(f"[delete_node] Failed to delete node '{id}': {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: + """Parse node data from database format to standard format.""" + node = node_data.copy() + + # Strip wrapping quotes from agtype string values (idempotent) + for k, v in list(node.items()): + if isinstance(v, str) and len(v) >= 2 and v[0] == v[-1] and v[0] in ("'", '"'): + node[k] = v[1:-1] + + # Convert datetime to string + for time_field in ("created_at", "updated_at"): + if time_field in node and hasattr(node[time_field], "isoformat"): + node[time_field] = node[time_field].isoformat() + + # Deserialize sources from JSON strings back to dict objects + if "sources" in node and node.get("sources"): + sources = node["sources"] + if isinstance(sources, list): + deserialized_sources = [] + for source_item in sources: + if isinstance(source_item, str): + try: + parsed = json.loads(source_item) + deserialized_sources.append(parsed) + except (json.JSONDecodeError, TypeError): + deserialized_sources.append({"type": "doc", "content": source_item}) + elif isinstance(source_item, dict): + deserialized_sources.append(source_item) + else: + deserialized_sources.append({"type": "doc", "content": str(source_item)}) + node["sources"] = deserialized_sources + + return {"id": node.pop("id", None), "memory": node.pop("memory", ""), "metadata": node} + + def _build_node_from_agtype(self, node_agtype, embedding=None): + """ + Parse the cypher-returned column `n` (agtype or JSON string) + into a standard node and merge embedding into properties. + """ + try: + # String case: '{"id":...,"label":[...],"properties":{...}}::vertex' + if isinstance(node_agtype, str): + json_str = node_agtype.replace("::vertex", "") + obj = json.loads(json_str) + if not (isinstance(obj, dict) and "properties" in obj): + return None + props = obj["properties"] + # agtype case: has `value` attribute + elif node_agtype and hasattr(node_agtype, "value"): + val = node_agtype.value + if not (isinstance(val, dict) and "properties" in val): + return None + props = val["properties"] + else: + return None + + if embedding is not None: + if isinstance(embedding, str): + try: + embedding = json.loads(embedding) + except (json.JSONDecodeError, TypeError): + logger.warning("Failed to parse embedding for node") + props["embedding"] = embedding + + return self._parse_node(props) + except Exception: + return None + + def format_param_value(self, value: str | None) -> str: + """Format parameter value to handle both quoted and unquoted formats""" + # Handle None value + if value is None: + logger.warning("format_param_value: value is None") + return "null" + + # Remove outer quotes if they exist + if value.startswith('"') and value.endswith('"'): + # Already has double quotes, return as is + return value + else: + # Add double quotes + return f'"{value}"' diff --git a/src/memos/graph_dbs/polardb/queries.py b/src/memos/graph_dbs/polardb/queries.py new file mode 100644 index 000000000..a6b4b56a5 --- /dev/null +++ b/src/memos/graph_dbs/polardb/queries.py @@ -0,0 +1,659 @@ +import json + +from typing import Any + +from memos.log import get_logger +from memos.utils import timed + + +logger = get_logger(__name__) + + +class QueryMixin: + """Mixin for query operations (metadata, counts, grouped queries).""" + + @timed + def get_by_metadata( + self, + filters: list[dict[str, Any]], + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list | None = None, + user_name_flag: bool = True, + ) -> list[str]: + """ + Retrieve node IDs that match given metadata filters. + Supports exact match. + + Args: + filters: List of filter dicts like: + [ + {"field": "key", "op": "in", "value": ["A", "B"]}, + {"field": "confidence", "op": ">=", "value": 80}, + {"field": "tags", "op": "contains", "value": "AI"}, + ... + ] + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + list[str]: Node IDs whose metadata match the filter conditions. (AND logic). + """ + logger.info(f"[get_by_metadata] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") + + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build WHERE conditions for cypher query + where_conditions = [] + + for f in filters: + field = f["field"] + op = f.get("op", "=") + value = f["value"] + + # Format value + if isinstance(value, str): + # Escape single quotes using backslash when inside $$ dollar-quoted strings + # In $$ delimiters, Cypher string literals can use \' to escape single quotes + escaped_str = value.replace("'", "\\'") + escaped_value = f"'{escaped_str}'" + elif isinstance(value, list): + # Handle list values - use double quotes for Cypher arrays + list_items = [] + for v in value: + if isinstance(v, str): + # Escape double quotes in string values for Cypher + escaped_str = v.replace('"', '\\"') + list_items.append(f'"{escaped_str}"') + else: + list_items.append(str(v)) + escaped_value = f"[{', '.join(list_items)}]" + else: + escaped_value = f"'{value}'" if isinstance(value, str) else str(value) + # Build WHERE conditions + if op == "=": + where_conditions.append(f"n.{field} = {escaped_value}") + elif op == "in": + where_conditions.append(f"n.{field} IN {escaped_value}") + """ + # where_conditions.append(f"{escaped_value} IN n.{field}") + """ + elif op == "contains": + where_conditions.append(f"{escaped_value} IN n.{field}") + """ + # where_conditions.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0") + """ + elif op == "starts_with": + where_conditions.append(f"n.{field} STARTS WITH {escaped_value}") + elif op == "ends_with": + where_conditions.append(f"n.{field} ENDS WITH {escaped_value}") + elif op == "like": + where_conditions.append(f"n.{field} CONTAINS {escaped_value}") + elif op in [">", ">=", "<", "<="]: + where_conditions.append(f"n.{field} {op} {escaped_value}") + else: + raise ValueError(f"Unsupported operator: {op}") + + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self._get_config_value("user_name"), + ) + logger.info(f"[get_by_metadata] user_name_conditions: {user_name_conditions}") + + # Add user_name WHERE clause + if user_name_conditions: + if len(user_name_conditions) == 1: + where_conditions.append(user_name_conditions[0]) + else: + where_conditions.append(f"({' OR '.join(user_name_conditions)})") + + # Build filter conditions using common method + filter_where_clause = self._build_filter_conditions_cypher(filter) + logger.info(f"[get_by_metadata] filter_where_clause: {filter_where_clause}") + + where_str = " AND ".join(where_conditions) + filter_where_clause + + # Use cypher query + cypher_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE {where_str} + RETURN n.id AS id + $$) AS (id agtype) + """ + + ids = [] + conn = None + logger.info(f"[get_by_metadata] cypher_query: {cypher_query}") + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + ids = [str(item[0]).strip('"') for item in results] + except Exception as e: + logger.error(f"Failed to get metadata: {e}, query is {cypher_query}") + finally: + self._return_connection(conn) + + return ids + + @timed + def get_grouped_counts( + self, + group_fields: list[str], + where_clause: str = "", + params: dict[str, Any] | None = None, + user_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Count nodes grouped by any fields. + + Args: + group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"] + where_clause (str, optional): Extra WHERE condition. E.g., + "WHERE n.status = 'activated'" + params (dict, optional): Parameters for WHERE clause. + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] + """ + if not group_fields: + raise ValueError("group_fields cannot be empty") + + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build user clause + user_clause = f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + if where_clause: + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause += f" AND {user_clause}" + else: + where_clause = f"WHERE {where_clause} AND {user_clause}" + else: + where_clause = f"WHERE {user_clause}" + + # Inline parameters if provided + if params and isinstance(params, dict): + for key, value in params.items(): + # Handle different value types appropriately + if isinstance(value, str): + value = f"'{value}'" + where_clause = where_clause.replace(f"${key}", str(value)) + + # Handle user_name parameter in where_clause + if "user_name = %s" in where_clause: + where_clause = where_clause.replace( + "user_name = %s", + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype", + ) + + # Build return fields and group by fields + return_fields = [] + group_by_fields = [] + + for field in group_fields: + alias = field.replace(".", "_") + return_fields.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"{field}\"'::agtype)::text AS {alias}" + ) + group_by_fields.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"{field}\"'::agtype)::text" + ) + + # Full SQL query construction + query = f""" + SELECT {", ".join(return_fields)}, COUNT(*) AS count + FROM "{self.db_name}_graph"."Memory" + {where_clause} + GROUP BY {", ".join(group_by_fields)} + """ + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Handle parameterized query + if params and isinstance(params, list): + cursor.execute(query, params) + else: + cursor.execute(query) + results = cursor.fetchall() + + output = [] + for row in results: + group_values = {} + for i, field in enumerate(group_fields): + value = row[i] + if hasattr(value, "value"): + group_values[field] = value.value + else: + group_values[field] = str(value) + count_value = row[-1] # Last column is count + output.append({**group_values, "count": int(count_value)}) + + return output + + except Exception as e: + logger.error(f"Failed to get grouped counts: {e}", exc_info=True) + return [] + finally: + self._return_connection(conn) + + def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: + """Get count of memory nodes by type.""" + user_name = user_name if user_name else self._get_config_value("user_name") + query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"memory_type"'::agtype) = %s::agtype + """ + query += "\nAND ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + params = [self.format_param_value(memory_type), self.format_param_value(user_name)] + + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return result[0] if result else 0 + except Exception as e: + logger.error(f"[get_memory_count] Failed: {e}") + return -1 + finally: + self._return_connection(conn) + + @timed + def node_not_exist(self, scope: str, user_name: str | None = None) -> int: + """Check if a node with given scope exists.""" + user_name = user_name if user_name else self._get_config_value("user_name") + query = f""" + SELECT id + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"memory_type"'::agtype) = %s::agtype + """ + query += "\nAND ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + query += "\nLIMIT 1" + params = [self.format_param_value(scope), self.format_param_value(user_name)] + + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return 1 if result else 0 + except Exception as e: + logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + @timed + def count_nodes(self, scope: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name + + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.memory_type = '{scope}' + AND n.user_name = '{user_name}' + RETURN count(n) + $$) AS (count agtype) + """ + conn = None + try: + conn = self._get_connection() + cursor = conn.cursor() + cursor.execute(query) + row = cursor.fetchone() + cursor.close() + conn.commit() + return int(row[0]) if row else 0 + except Exception: + if conn: + conn.rollback() + raise + finally: + self._return_connection(conn) + + @timed + def get_all_memory_items( + self, + scope: str, + include_embedding: bool = False, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list | None = None, + status: str | None = None, + ) -> list[dict]: + """ + Retrieve all memory items of a specific memory_type. + + Args: + scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. + include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode + filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results. + knowledgebase_ids (list, optional): List of knowledgebase IDs to filter by. + status (str, optional): Filter by status (e.g., 'activated', 'archived'). + If None, no status filter is applied. + + Returns: + list[dict]: Full list of memory items under this scope. + """ + logger.info( + f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status}" + ) + + user_name = user_name if user_name else self._get_config_value("user_name") + if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: + raise ValueError(f"Unsupported memory type scope: {scope}") + + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self._get_config_value("user_name"), + ) + + # Build user_name WHERE clause + if user_name_conditions: + if len(user_name_conditions) == 1: + user_name_where = user_name_conditions[0] + else: + user_name_where = f"({' OR '.join(user_name_conditions)})" + else: + user_name_where = "" + + # Build filter conditions using common method + filter_where_clause = self._build_filter_conditions_cypher(filter) + logger.info(f"[get_all_memory_items] filter_where_clause: {filter_where_clause}") + + # Use cypher query to retrieve memory items + if include_embedding: + # Build WHERE clause with user_name/knowledgebase_ids and filter + where_parts = [f"n.memory_type = '{scope}'"] + if status: + where_parts.append(f"n.status = '{status}'") + if user_name_where: + # user_name_where already contains parentheses if it's an OR condition + where_parts.append(user_name_where) + if filter_where_clause: + # filter_where_clause already contains " AND " prefix, so we just append it + where_clause = " AND ".join(where_parts) + filter_where_clause + else: + where_clause = " AND ".join(where_parts) + + cypher_query = f""" + WITH t as ( + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE {where_clause} + RETURN id(n) as id1,n + LIMIT 100 + $$) AS (id1 agtype,n agtype) + ) + SELECT + m.embedding, + t.n + FROM t, + {self.db_name}_graph."Memory" m + WHERE t.id1 = m.id; + """ + nodes = [] + node_ids = set() + conn = None + logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + + for row in results: + if isinstance(row, list | tuple) and len(row) >= 2: + embedding_val, node_val = row[0], row[1] + else: + embedding_val, node_val = None, row[0] + + node = self._build_node_from_agtype(node_val, embedding_val) + if node: + node_id = node["id"] + if node_id not in node_ids: + nodes.append(node) + node_ids.add(node_id) + + except Exception as e: + logger.error(f"Failed to get memories: {e}", exc_info=True) + finally: + self._return_connection(conn) + + return nodes + else: + # Build WHERE clause with user_name/knowledgebase_ids and filter + where_parts = [f"n.memory_type = '{scope}'"] + if status: + where_parts.append(f"n.status = '{status}'") + if user_name_where: + # user_name_where already contains parentheses if it's an OR condition + where_parts.append(user_name_where) + if filter_where_clause: + # filter_where_clause already contains " AND " prefix, so we just append it + where_clause = " AND ".join(where_parts) + filter_where_clause + else: + where_clause = " AND ".join(where_parts) + + cypher_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE {where_clause} + RETURN properties(n) as props + LIMIT 100 + $$) AS (nprops agtype) + """ + + nodes = [] + conn = None + logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + + for row in results: + memory_data = json.loads(row[0]) if isinstance(row[0], str) else row[0] + nodes.append(self._parse_node(memory_data)) + + except Exception as e: + logger.error(f"Failed to get memories: {e}", exc_info=True) + finally: + self._return_connection(conn) + + return nodes + + @timed + def get_structure_optimization_candidates( + self, scope: str, include_embedding: bool = False, user_name: str | None = None + ) -> list[dict]: + """ + Find nodes that are likely candidates for structure optimization: + - Isolated nodes, nodes with empty background, or nodes with exactly one child. + - Plus: the child of any parent node that has exactly one child. + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build return fields based on include_embedding flag + if include_embedding: + return_fields = "id(n) as id1,n" + return_fields_agtype = " id1 agtype,n agtype" + else: + # Build field list without embedding + return_fields = ",".join( + [ + "n.id AS id", + "n.memory AS memory", + "n.user_name AS user_name", + "n.user_id AS user_id", + "n.session_id AS session_id", + "n.status AS status", + "n.key AS key", + "n.confidence AS confidence", + "n.tags AS tags", + "n.created_at AS created_at", + "n.updated_at AS updated_at", + "n.memory_type AS memory_type", + "n.sources AS sources", + "n.source AS source", + "n.node_type AS node_type", + "n.visibility AS visibility", + "n.usage AS usage", + "n.background AS background", + "n.graph_id as graph_id", + ] + ) + fields = [ + "id", + "memory", + "user_name", + "user_id", + "session_id", + "status", + "key", + "confidence", + "tags", + "created_at", + "updated_at", + "memory_type", + "sources", + "source", + "node_type", + "visibility", + "usage", + "background", + "graph_id", + ] + return_fields_agtype = ", ".join([f"{field} agtype" for field in fields]) + + # Use OPTIONAL MATCH to find isolated nodes (no parents or children) + cypher_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.memory_type = '{scope}' + AND n.status = 'activated' + AND n.user_name = '{user_name}' + OPTIONAL MATCH (n)-[:PARENT]->(c:Memory) + OPTIONAL MATCH (p:Memory)-[:PARENT]->(n) + WITH n, c, p + WHERE c IS NULL AND p IS NULL + RETURN {return_fields} + $$) AS ({return_fields_agtype}) + """ + if include_embedding: + cypher_query = f""" + WITH t as ( + {cypher_query} + ) + SELECT + m.embedding, + t.n + FROM t, + {self.db_name}_graph."Memory" m + WHERE t.id1 = m.id + """ + logger.info(f"[get_structure_optimization_candidates] query: {cypher_query}") + + candidates = [] + node_ids = set() + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + logger.info(f"Found {len(results)} structure optimization candidates") + for row in results: + if include_embedding: + # When include_embedding=True, return full node object + """ + if isinstance(row, (list, tuple)) and len(row) >= 2: + """ + if isinstance(row, list | tuple) and len(row) >= 2: + embedding_val, node_val = row[0], row[1] + else: + embedding_val, node_val = None, row[0] + + node = self._build_node_from_agtype(node_val, embedding_val) + if node: + node_id = node["id"] + if node_id not in node_ids: + candidates.append(node) + node_ids.add(node_id) + else: + # When include_embedding=False, return field dictionary + # Define field names matching the RETURN clause + field_names = [ + "id", + "memory", + "user_name", + "user_id", + "session_id", + "status", + "key", + "confidence", + "tags", + "created_at", + "updated_at", + "memory_type", + "sources", + "source", + "node_type", + "visibility", + "usage", + "background", + "graph_id", + ] + + # Convert row to dictionary + node_data = {} + for i, field_name in enumerate(field_names): + if i < len(row): + value = row[i] + # Handle special fields + if field_name in ["tags", "sources", "usage"] and isinstance( + value, str + ): + try: + # Try parsing JSON string + node_data[field_name] = json.loads(value) + except (json.JSONDecodeError, TypeError): + node_data[field_name] = value + else: + node_data[field_name] = value + + # Parse node using _parse_node_new + try: + node = self._parse_node(node_data) + node_id = node["id"] + + if node_id not in node_ids: + candidates.append(node) + node_ids.add(node_id) + logger.debug(f"Parsed node successfully: {node_id}") + except Exception as e: + logger.error(f"Failed to parse node: {e}") + + except Exception as e: + logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) + finally: + self._return_connection(conn) + + return candidates diff --git a/src/memos/graph_dbs/polardb/schema.py b/src/memos/graph_dbs/polardb/schema.py new file mode 100644 index 000000000..179119de8 --- /dev/null +++ b/src/memos/graph_dbs/polardb/schema.py @@ -0,0 +1,172 @@ +from memos.log import get_logger +from memos.utils import timed + + +logger = get_logger(__name__) + + +class SchemaMixin: + """Mixin for schema and extension management.""" + + def _ensure_database_exists(self): + """Create database if it doesn't exist.""" + try: + # For PostgreSQL/PolarDB, we need to connect to a default database first + # This is a simplified implementation - in production you might want to handle this differently + logger.info(f"Using database '{self.db_name}'") + except Exception as e: + logger.error(f"Failed to access database '{self.db_name}': {e}") + raise + + @timed + def _create_graph(self): + """Create PostgreSQL schema and table for graph storage.""" + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Create schema if it doesn't exist + cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') + logger.info(f"Schema '{self.db_name}_graph' ensured.") + + # Create Memory table if it doesn't exist + cursor.execute(f""" + CREATE TABLE IF NOT EXISTS "{self.db_name}_graph"."Memory" ( + id TEXT PRIMARY KEY, + properties JSONB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """) + logger.info(f"Memory table created in schema '{self.db_name}_graph'.") + + # Add embedding column if it doesn't exist (using JSONB for compatibility) + try: + cursor.execute(f""" + ALTER TABLE "{self.db_name}_graph"."Memory" + ADD COLUMN IF NOT EXISTS embedding JSONB; + """) + logger.info("Embedding column added to Memory table.") + except Exception as e: + logger.warning(f"Failed to add embedding column: {e}") + + # Create indexes + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_properties + ON "{self.db_name}_graph"."Memory" USING GIN (properties); + """) + + # Create vector index for embedding field + try: + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_embedding + ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 100); + """) + logger.info("Vector index created for Memory table.") + except Exception as e: + logger.warning(f"Vector index creation failed (might not be supported): {e}") + + logger.info("Indexes created for Memory table.") + + except Exception as e: + logger.error(f"Failed to create graph schema: {e}") + raise e + finally: + self._return_connection(conn) + + def create_index( + self, + label: str = "Memory", + vector_property: str = "embedding", + dimensions: int = 1024, + index_name: str = "memory_vector_index", + ) -> None: + """ + Create indexes for embedding and other fields. + Note: This creates PostgreSQL indexes on the underlying tables. + """ + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Create indexes on the underlying PostgreSQL tables + # Apache AGE stores data in regular PostgreSQL tables + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_properties + ON "{self.db_name}_graph"."Memory" USING GIN (properties); + """) + + # Try to create vector index, but don't fail if it doesn't work + try: + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_embedding + ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops); + """) + except Exception as ve: + logger.warning(f"Vector index creation failed (might not be supported): {ve}") + + logger.debug("Indexes created successfully.") + except Exception as e: + logger.warning(f"Failed to create indexes: {e}") + finally: + self._return_connection(conn) + + @timed + def create_extension(self): + extensions = [("polar_age", "Graph engine"), ("vector", "Vector engine")] + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Ensure in the correct database context + cursor.execute("SELECT current_database();") + current_db = cursor.fetchone()[0] + logger.info(f"Current database context: {current_db}") + + for ext_name, ext_desc in extensions: + try: + cursor.execute(f"create extension if not exists {ext_name};") + logger.info(f"Extension '{ext_name}' ({ext_desc}) ensured.") + except Exception as e: + if "already exists" in str(e): + logger.info(f"Extension '{ext_name}' ({ext_desc}) already exists.") + else: + logger.warning( + f"Failed to create extension '{ext_name}' ({ext_desc}): {e}" + ) + logger.error( + f"Failed to create extension '{ext_name}': {e}", exc_info=True + ) + except Exception as e: + logger.warning(f"Failed to access database context: {e}") + logger.error(f"Failed to access database context: {e}", exc_info=True) + finally: + self._return_connection(conn) + + @timed + def create_graph(self): + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(f""" + SELECT COUNT(*) FROM ag_catalog.ag_graph + WHERE name = '{self.db_name}_graph'; + """) + graph_exists = cursor.fetchone()[0] > 0 + + if graph_exists: + logger.info(f"Graph '{self.db_name}_graph' already exists.") + else: + cursor.execute(f"select create_graph('{self.db_name}_graph');") + logger.info(f"Graph database '{self.db_name}_graph' created.") + except Exception as e: + logger.warning(f"Failed to create graph '{self.db_name}_graph': {e}") + logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) + finally: + self._return_connection(conn) diff --git a/src/memos/graph_dbs/polardb/search.py b/src/memos/graph_dbs/polardb/search.py new file mode 100644 index 000000000..3fb9bca8c --- /dev/null +++ b/src/memos/graph_dbs/polardb/search.py @@ -0,0 +1,377 @@ +import time + +from memos.graph_dbs.utils import convert_to_vector +from memos.log import get_logger +from memos.utils import timed + + +logger = get_logger(__name__) + + +class SearchMixin: + """Mixin for search operations (keyword, fulltext, embedding).""" + + def _build_search_where_clauses_sql( + self, + scope: str | None = None, + status: str | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + ) -> list[str]: + """Build common WHERE clauses for SQL-based search methods.""" + where_clauses = [] + + if scope: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" + ) + if status: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"status\"'::agtype) = '\"{status}\"'::agtype" + ) + else: + where_clauses.append( + "ag_catalog.agtype_access_operator(properties::text::agtype, '\"status\"'::agtype) = '\"activated\"'::agtype" + ) + + # Build user_name filter with knowledgebase_ids support (OR relationship) + user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + ) + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" + ) + else: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"{key}\"'::agtype) = {value}::agtype" + ) + + # Build filter conditions + filter_conditions = self._build_filter_conditions_sql(filter) + where_clauses.extend(filter_conditions) + + return where_clauses + + @timed + def search_by_keywords_like( + self, + query_word: str, + scope: str | None = None, + status: str | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + where_clauses = self._build_search_where_clauses_sql( + scope=scope, + status=status, + search_filter=search_filter, + user_name=user_name, + filter=filter, + knowledgebase_ids=knowledgebase_ids, + ) + + # Method-specific: LIKE pattern match + where_clauses.append("""(properties -> '"memory"')::text LIKE %s""") + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) AS old_id, + agtype_object_field_text(properties, 'memory') as memory_text + FROM "{self.db_name}_graph"."Memory" + {where_clause} + """ + + params = (query_word,) + logger.info( + f"[search_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}" + ) + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] + id_val = str(oldid).strip('"') + output.append({"id": id_val}) + logger.info( + f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) + return output + finally: + self._return_connection(conn) + + @timed + def search_by_keywords_tfidf( + self, + query_words: list[str], + scope: str | None = None, + status: str | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + tsvector_field: str = "properties_tsvector_zh", + tsquery_config: str = "jiebaqry", + **kwargs, + ) -> list[dict]: + where_clauses = self._build_search_where_clauses_sql( + scope=scope, + status=status, + search_filter=search_filter, + user_name=user_name, + filter=filter, + knowledgebase_ids=knowledgebase_ids, + ) + + # Method-specific: TF-IDF fulltext search condition + tsquery_string = " | ".join(query_words) + where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") + + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) AS old_id, + agtype_object_field_text(properties, 'memory') as memory_text + FROM "{self.db_name}_graph"."Memory" + {where_clause} + """ + + params = (tsquery_string,) + logger.info( + f"[search_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}" + ) + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] + id_val = str(oldid).strip('"') + output.append({"id": id_val}) + + logger.info( + f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) + return output + finally: + self._return_connection(conn) + + @timed + def search_by_fulltext( + self, + query_words: list[str], + top_k: int = 10, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + tsvector_field: str = "properties_tsvector_zh", + tsquery_config: str = "jiebacfg", + **kwargs, + ) -> list[dict]: + """ + Full-text search functionality using PostgreSQL's full-text search capabilities. + + Args: + query_text: query text + top_k: maximum number of results to return + scope: memory type filter (memory_type) + status: status filter, defaults to "activated" + threshold: similarity threshold filter + search_filter: additional property filter conditions + user_name: username filter + knowledgebase_ids: knowledgebase ids filter + filter: filter conditions with 'and' or 'or' logic for search results. + tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 + tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) + **kwargs: other parameters (e.g. cube_name) + + Returns: + list[dict]: result list containing id and score + """ + logger.info( + f"[search_by_fulltext] query_words: {query_words}, top_k: {top_k}, scope: {scope}, filter: {filter}" + ) + start_time = time.time() + where_clauses = self._build_search_where_clauses_sql( + scope=scope, + status=status, + search_filter=search_filter, + user_name=user_name, + filter=filter, + knowledgebase_ids=knowledgebase_ids, + ) + + # Method-specific: fulltext search condition + tsquery_string = " | ".join(query_words) + + where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") + + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + logger.info(f"[search_by_fulltext] where_clause: {where_clause}") + + # Build fulltext search query + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) AS old_id, + agtype_object_field_text(properties, 'memory') as memory_text, + ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY rank DESC + LIMIT {top_k}; + """ + + params = [tsquery_string, tsquery_string] + logger.info(f"[search_by_fulltext] query: {query}, params: {params}") + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] # old_id + rank = row[2] # rank score + + id_val = str(oldid).strip('"') + score_val = float(rank) + + # Apply threshold filter if specified + if threshold is None or score_val >= threshold: + output.append({"id": id_val, "score": score_val}) + elapsed_time = time.time() - start_time + logger.info( + f" polardb [search_by_fulltext] query completed time in {elapsed_time:.2f}s" + ) + return output[:top_k] + finally: + self._return_connection(conn) + + @timed + def search_by_embedding( + self, + vector: list[float], + top_k: int = 5, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + """ + Retrieve node IDs based on vector similarity using PostgreSQL vector operations. + """ + logger.info( + f"[search_by_embedding] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" + ) + where_clauses = self._build_search_where_clauses_sql( + scope=scope, + status=status, + search_filter=search_filter, + user_name=user_name, + filter=filter, + knowledgebase_ids=knowledgebase_ids, + ) + # Method-specific: require embedding column + where_clauses.append("embedding is not null") + + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + # Keep original simple query structure but add dynamic WHERE clause + query = f""" + WITH t AS ( + SELECT id, + properties, + timeline, + ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) AS old_id, + (1 - (embedding <=> %s::vector(1024))) AS scope + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY scope DESC + LIMIT {top_k} + ) + SELECT * + FROM t + WHERE scope > 0.1; + """ + # Convert vector to string format for PostgreSQL vector type + # PostgreSQL vector type expects a string format like '[1,2,3]' + vector_str = convert_to_vector(vector) + # Use string format directly in query instead of parameterized query + # Replace %s with the vector string, but need to quote it properly + # PostgreSQL vector type needs the string to be quoted + query = query.replace("%s::vector(1024)", f"'{vector_str}'::vector(1024)") + params = [] + + logger.info(f"[search_by_embedding] query: {query}, params: {params}") + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + try: + # If params is empty, execute query directly without parameters + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + except Exception as e: + logger.error(f"[search_by_embedding] Error executing query: {e}") + raise + results = cursor.fetchall() + output = [] + for row in results: + if len(row) < 5: + logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") + continue + oldid = row[3] # old_id + score = row[4] # scope + id_val = str(oldid).strip('"') + score_val = float(score) + score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score + if threshold is None or score_val >= threshold: + output.append({"id": id_val, "score": score_val}) + return output[:top_k] + except Exception as e: + logger.error(f"[search_by_embedding] Error: {type(e).__name__}: {e}", exc_info=True) + return [] + finally: + self._return_connection(conn) + + # Backward-compatible aliases for renamed methods (typo -> correct) + seach_by_keywords_like = search_by_keywords_like + seach_by_keywords_tfidf = search_by_keywords_tfidf diff --git a/src/memos/graph_dbs/polardb/traversal.py b/src/memos/graph_dbs/polardb/traversal.py new file mode 100644 index 000000000..29c67d477 --- /dev/null +++ b/src/memos/graph_dbs/polardb/traversal.py @@ -0,0 +1,433 @@ +import json + +from typing import Any, Literal + +from memos.log import get_logger +from memos.utils import timed + + +logger = get_logger(__name__) + + +class TraversalMixin: + """Mixin for graph traversal operations.""" + + def get_neighbors( + self, id: str, type: str, direction: Literal["in", "out", "both"] = "out" + ) -> list[str]: + """Get connected node IDs in a specific direction and relationship type.""" + raise NotImplementedError + + @timed + def get_children_with_embeddings( + self, id: str, user_name: str | None = None + ) -> list[dict[str, Any]]: + """Get children nodes with their embeddings.""" + user_name = user_name if user_name else self._get_config_value("user_name") + where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" + + query = f""" + WITH t as ( + SELECT * + FROM cypher('{self.db_name}_graph', $$ + MATCH (p:Memory)-[r:PARENT]->(c:Memory) + WHERE p.id = '{id}' {where_user} + RETURN id(c) as cid, c.id AS id, c.memory AS memory + $$) as (cid agtype, id agtype, memory agtype) + ) + SELECT t.id, m.embedding, t.memory FROM t, + "{self.db_name}_graph"."Memory" m + WHERE t.cid::graphid = m.id; + """ + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() + + children = [] + for row in results: + # Handle child_id - remove possible quotes + child_id_raw = row[0].value if hasattr(row[0], "value") else str(row[0]) + if isinstance(child_id_raw, str): + # If string starts and ends with quotes, remove quotes + if child_id_raw.startswith('"') and child_id_raw.endswith('"'): + child_id = child_id_raw[1:-1] + else: + child_id = child_id_raw + else: + child_id = str(child_id_raw) + + # Handle embedding - get from database embedding column + embedding_raw = row[1] + embedding = [] + if embedding_raw is not None: + try: + if isinstance(embedding_raw, str): + # If it is a JSON string, parse it + embedding = json.loads(embedding_raw) + elif isinstance(embedding_raw, list): + # If already a list, use directly + embedding = embedding_raw + else: + # Try converting to list + embedding = list(embedding_raw) + except (json.JSONDecodeError, TypeError, ValueError) as e: + logger.warning( + f"Failed to parse embedding for child node {child_id}: {e}" + ) + embedding = [] + + # Handle memory - remove possible quotes + memory_raw = row[2].value if hasattr(row[2], "value") else str(row[2]) + if isinstance(memory_raw, str): + # If string starts and ends with quotes, remove quotes + if memory_raw.startswith('"') and memory_raw.endswith('"'): + memory = memory_raw[1:-1] + else: + memory = memory_raw + else: + memory = str(memory_raw) + + children.append({"id": child_id, "embedding": embedding, "memory": memory}) + + return children + + except Exception as e: + logger.error(f"[get_children_with_embeddings] Failed: {e}", exc_info=True) + return [] + finally: + self._return_connection(conn) + + def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: + """Get the path of nodes from source to target within a limited depth.""" + raise NotImplementedError + + @timed + def get_subgraph( + self, + center_id: str, + depth: int = 2, + center_status: str = "activated", + user_name: str | None = None, + ) -> dict[str, Any]: + """ + Retrieve a local subgraph centered at a given node. + Args: + center_id: The ID of the center node. + depth: The hop distance for neighbors. + center_status: Required status for center node. + user_name (str, optional): User name for filtering in non-multi-db mode + Returns: + { + "core_node": {...}, + "neighbors": [...], + "edges": [...] + } + """ + logger.info(f"[get_subgraph] center_id: {center_id}") + if not 1 <= depth <= 5: + raise ValueError("depth must be 1-5") + + user_name = user_name if user_name else self._get_config_value("user_name") + + if center_id.startswith('"') and center_id.endswith('"'): + center_id = center_id[1:-1] + # Use UNION ALL for better performance: separate queries for depth 1 and depth 2 + if depth == 1: + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH(center: Memory)-[r]->(neighbor:Memory) + WHERE + center.id = '{center_id}' + AND center.status = '{center_status}' + AND center.user_name = '{user_name}' + RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r) + $$ ) as (centers agtype, neighbors agtype, rels agtype); + """ + else: + # For depth >= 2, use UNION ALL to combine depth 1 and depth 2 queries + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH(center: Memory)-[r]->(neighbor:Memory) + WHERE + center.id = '{center_id}' + AND center.status = '{center_status}' + AND center.user_name = '{user_name}' + RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r) + UNION ALL + MATCH(center: Memory)-[r]->(n:Memory)-[r1]->(neighbor:Memory) + WHERE + center.id = '{center_id}' + AND center.status = '{center_status}' + AND center.user_name = '{user_name}' + RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r1) + $$ ) as (centers agtype, neighbors agtype, rels agtype); + """ + conn = None + logger.info(f"[get_subgraph] Query: {query}") + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() + + if not results: + return {"core_node": None, "neighbors": [], "edges": []} + + # Merge results from all UNION ALL rows + all_centers_list = [] + all_neighbors_list = [] + all_edges_list = [] + + for result in results: + if not result or not result[0]: + continue + + centers_data = result[0] if result[0] else "[]" + neighbors_data = result[1] if result[1] else "[]" + edges_data = result[2] if result[2] else "[]" + + # Parse JSON data + try: + # Clean ::vertex and ::edge suffixes in data + if isinstance(centers_data, str): + centers_data = centers_data.replace("::vertex", "") + if isinstance(neighbors_data, str): + neighbors_data = neighbors_data.replace("::vertex", "") + if isinstance(edges_data, str): + edges_data = edges_data.replace("::edge", "") + + centers_list = ( + json.loads(centers_data) + if isinstance(centers_data, str) + else centers_data + ) + neighbors_list = ( + json.loads(neighbors_data) + if isinstance(neighbors_data, str) + else neighbors_data + ) + edges_list = ( + json.loads(edges_data) if isinstance(edges_data, str) else edges_data + ) + + # Collect data from this row + if isinstance(centers_list, list): + all_centers_list.extend(centers_list) + if isinstance(neighbors_list, list): + all_neighbors_list.extend(neighbors_list) + if isinstance(edges_list, list): + all_edges_list.extend(edges_list) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON data: {e}") + continue + + # Deduplicate centers by ID + centers_dict = {} + for center_data in all_centers_list: + if isinstance(center_data, dict) and "properties" in center_data: + center_id_key = center_data["properties"].get("id") + if center_id_key and center_id_key not in centers_dict: + centers_dict[center_id_key] = center_data + + # Parse center node (use first center) + core_node = None + if centers_dict: + center_data = next(iter(centers_dict.values())) + if isinstance(center_data, dict) and "properties" in center_data: + core_node = self._parse_node(center_data["properties"]) + + # Deduplicate neighbors by ID + neighbors_dict = {} + for neighbor_data in all_neighbors_list: + if isinstance(neighbor_data, dict) and "properties" in neighbor_data: + neighbor_id = neighbor_data["properties"].get("id") + if neighbor_id and neighbor_id not in neighbors_dict: + neighbors_dict[neighbor_id] = neighbor_data + + # Parse neighbor nodes + neighbors = [] + for neighbor_data in neighbors_dict.values(): + if isinstance(neighbor_data, dict) and "properties" in neighbor_data: + neighbor_parsed = self._parse_node(neighbor_data["properties"]) + neighbors.append(neighbor_parsed) + + # Deduplicate edges by (source, target, type) + edges_dict = {} + for edge_group in all_edges_list: + if isinstance(edge_group, list): + for edge_data in edge_group: + if isinstance(edge_data, dict): + edge_key = ( + edge_data.get("start_id", ""), + edge_data.get("end_id", ""), + edge_data.get("label", ""), + ) + if edge_key not in edges_dict: + edges_dict[edge_key] = { + "type": edge_data.get("label", ""), + "source": edge_data.get("start_id", ""), + "target": edge_data.get("end_id", ""), + } + elif isinstance(edge_group, dict): + # Handle single edge (not in a list) + edge_key = ( + edge_group.get("start_id", ""), + edge_group.get("end_id", ""), + edge_group.get("label", ""), + ) + if edge_key not in edges_dict: + edges_dict[edge_key] = { + "type": edge_group.get("label", ""), + "source": edge_group.get("start_id", ""), + "target": edge_group.get("end_id", ""), + } + + edges = list(edges_dict.values()) + + return self._convert_graph_edges( + {"core_node": core_node, "neighbors": neighbors, "edges": edges} + ) + + except Exception as e: + logger.error(f"Failed to get subgraph: {e}", exc_info=True) + return {"core_node": None, "neighbors": [], "edges": []} + finally: + self._return_connection(conn) + + def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: + """Get the ordered context chain starting from a node.""" + raise NotImplementedError + + @timed + def get_neighbors_by_tag( + self, + tags: list[str], + exclude_ids: list[str], + top_k: int = 5, + min_overlap: int = 1, + include_embedding: bool = False, + user_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Find top-K neighbor nodes with maximum tag overlap. + + Args: + tags: The list of tags to match. + exclude_ids: Node IDs to exclude (e.g., local cluster). + top_k: Max number of neighbors to return. + min_overlap: Minimum number of overlapping tags required. + include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + List of dicts with node details and overlap count. + """ + if not tags: + return [] + + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build query conditions - more relaxed filters + where_clauses = [] + params = [] + + # Exclude specified IDs - use id in properties + if exclude_ids: + exclude_conditions = [] + for exclude_id in exclude_ids: + exclude_conditions.append( + "ag_catalog.agtype_access_operator(properties::text::agtype, '\"id\"'::agtype) != %s::agtype" + ) + params.append(self.format_param_value(exclude_id)) + where_clauses.append(f"({' AND '.join(exclude_conditions)})") + + # Status filter - keep only 'activated' + where_clauses.append( + "ag_catalog.agtype_access_operator(properties::text::agtype, '\"status\"'::agtype) = '\"activated\"'::agtype" + ) + + # Type filter - exclude 'reasoning' type + where_clauses.append( + "ag_catalog.agtype_access_operator(properties::text::agtype, '\"node_type\"'::agtype) != '\"reasoning\"'::agtype" + ) + + # User filter + where_clauses.append( + "ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + ) + params.append(self.format_param_value(user_name)) + + # Testing showed no data; annotate. + where_clauses.append( + "ag_catalog.agtype_access_operator(properties::text::agtype, '\"memory_type\"'::agtype) != '\"WorkingMemory\"'::agtype" + ) + + where_clause = " AND ".join(where_clauses) + + # Fetch all candidate nodes + query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + + logger.debug(f"[get_neighbors_by_tag] query: {query}, params: {params}") + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + + nodes_with_overlap = [] + for row in results: + node_id, properties_json, embedding_json = row + properties = properties_json if properties_json else {} + + # Parse embedding + if include_embedding and embedding_json is not None: + try: + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + + # Compute tag overlap + node_tags = properties.get("tags", []) + if isinstance(node_tags, str): + try: + node_tags = json.loads(node_tags) + except (json.JSONDecodeError, TypeError): + node_tags = [] + + overlap_tags = [tag for tag in tags if tag in node_tags] + overlap_count = len(overlap_tags) + + if overlap_count >= min_overlap: + node_data = self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } + ) + nodes_with_overlap.append((node_data, overlap_count)) + + # Sort by overlap count and return top_k items + nodes_with_overlap.sort(key=lambda x: x[1], reverse=True) + return [node for node, _ in nodes_with_overlap[:top_k]] + + except Exception as e: + logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) + return [] + finally: + self._return_connection(conn) diff --git a/src/memos/graph_dbs/utils.py b/src/memos/graph_dbs/utils.py new file mode 100644 index 000000000..d4975075c --- /dev/null +++ b/src/memos/graph_dbs/utils.py @@ -0,0 +1,62 @@ +"""Shared utilities for graph database backends.""" + +from datetime import datetime +from typing import Any + +import numpy as np + + +def compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: + """Extract id, memory, and metadata from a node dict.""" + node_id = item["id"] + memory = item["memory"] + metadata = item.get("metadata", {}) + return node_id, memory, metadata + + +def prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: + """ + Ensure metadata has proper datetime fields and normalized types. + + - Fill `created_at` and `updated_at` if missing (in ISO 8601 format). + - Convert embedding to list of float if present. + """ + now = datetime.utcnow().isoformat() + + # Fill timestamps if missing + metadata.setdefault("created_at", now) + metadata.setdefault("updated_at", now) + + # Normalize embedding type + embedding = metadata.get("embedding") + if embedding and isinstance(embedding, list): + metadata["embedding"] = [float(x) for x in embedding] + + return metadata + + +def convert_to_vector(embedding_list): + """Convert an embedding list to PostgreSQL vector string format.""" + if not embedding_list: + return None + if isinstance(embedding_list, np.ndarray): + embedding_list = embedding_list.tolist() + return "[" + ",".join(str(float(x)) for x in embedding_list) + "]" + + +def detect_embedding_field(embedding_list): + """Detect the embedding field name based on vector dimension.""" + if not embedding_list: + return None + dim = len(embedding_list) + if dim == 1024: + return "embedding" + return None + + +def clean_properties(props): + """Remove vector fields from properties dict.""" + vector_keys = {"embedding", "embedding_1024", "embedding_3072", "embedding_768"} + if not isinstance(props, dict): + return {} + return {k: v for k, v in props.items() if k not in vector_keys}