diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md new file mode 100644 index 000000000..becc2f783 --- /dev/null +++ b/DEVELOPMENT.md @@ -0,0 +1,168 @@ +# Development Workflow + +## 🎯 Основной репозиторий для доработок + +**Используйте:** `/home/krolik/MemOSina` + +## 📋 Workflow для изменений + +### 1. Внести изменения локально +```bash +cd /home/krolik/MemOSina +git checkout -b feature/my-feature +# Делайте изменения в коде +``` + +### 2. Коммит и пуш +```bash +git add . +git commit -m "feat: описание изменений" +git push origin feature/my-feature +``` + +### 3. CI/CD автоматически запустится +GitHub Actions выполнит все проверки: +- **16 матричных билдов:** + - 4 ОС: ubuntu, windows, macos-14, macos-15 + - 4 версии Python: 3.10, 3.11, 3.12, 3.13 + +- **Проверки:** + - ✅ Установка зависимостей + - ✅ Сборка sdist и wheel + - ✅ Ruff linting (`ruff check`) + - ✅ Ruff formatting (`ruff format --check`) + - ✅ PyTest unit tests + +### 4. Обновить krolik-server +После пуша в GitHub: +```bash +cd ~/krolik-server/services/memos-core +git pull origin main # или нужную ветку +cd ../.. +docker compose build --no-cache memos-api memos-mcp +docker compose restart memos-api memos-mcp +``` + +## 🔒 Branch Protection (main ветка) + +✅ **Настроено:** +- Требуются проверки CI для Python 3.10, 3.11, 3.12, 3.13 на ubuntu-latest +- Strict mode: ветка должна быть актуальной +- Force push запрещен +- Удаление ветки запрещено + +## 🧪 Локальная проверка перед коммитом + +### Pre-commit hooks (опционально) +```bash +# Установить pre-commit +pip install --user pre-commit + +# В директории MemOSina +cd /home/krolik/MemOSina +pre-commit install + +# Запустить вручную +pre-commit run --all-files +``` + +### Ручная проверка с Ruff +```bash +# В контейнере или локально +cd /home/krolik/MemOSina + +# Проверка стиля +ruff check . + +# Автоисправление +ruff check . --fix + +# Проверка форматирования +ruff format --check . + +# Автоформатирование +ruff format . +``` + +## 📊 Проверка статуса CI + +```bash +cd /home/krolik/MemOSina + +# Список последних запусков +gh run list --limit 10 + +# Статус для конкретной ветки +gh run list --branch feature/my-feature + +# Просмотр логов последнего запуска +gh run view --log +``` + +## 🔄 Синхронизация с upstream MemOS + +```bash +cd /home/krolik/MemOSina + +# Добавить upstream remote (если еще нет) +git remote add upstream https://github.com/MemTensor/MemOS.git + +# Получить обновления +git fetch upstream + +# Слить в main +git checkout main +git merge upstream/main + +# Разрешить конфликты если есть +# git add . +# git commit + +# Пуш в форк +git push origin main +``` + +## 📁 Структура репозиториев + +``` +/home/krolik/ +├── MemOSina/ ⭐ ОСНОВНОЙ - все доработки здесь +│ ├── .github/workflows/ - CI/CD конфигурация +│ ├── src/memos/ - Исходный код с патчами +│ └── tests/ - Тесты +│ +├── memos-pr-work/ 🔧 Для создания PR в upstream +│ └── (ветки для PR: fix/*, feat/*) +│ +└── krolik-server/ + ├── services/ + │ └── memos-core/ 📦 Git submodule → MemOSina + └── docker-compose.yml +``` + +## ✅ Гарантия качества + +С этой настройкой каждый коммит в main проходит: +- ✅ 16 матричных билдов (4 ОС × 4 Python версии) +- ✅ Ruff проверки (код и форматирование) +- ✅ Unit тесты +- ✅ Проверка зависимостей + +**Ваш форк теперь такой же качественный, как upstream MemOS!** + +## 🚀 Quick Reference + +| Задача | Команда | +|--------|---------| +| Создать ветку | `git checkout -b feature/name` | +| Запушить изменения | `git push origin feature/name` | +| Проверить CI | `gh run list --branch feature/name` | +| Обновить submodule | `cd ~/krolik-server/services/memos-core && git pull` | +| Пересобрать контейнеры | `docker compose build --no-cache memos-api memos-mcp` | +| Перезапустить сервисы | `docker compose restart memos-api memos-mcp` | +| Проверить код Ruff | `ruff check . && ruff format --check .` | + +--- + +**Все изменения делайте в `/home/krolik/MemOSina`** +**CI/CD гарантирует качество перед попаданием в upstream!** diff --git a/docker/Dockerfile.krolik b/docker/Dockerfile.krolik new file mode 100644 index 000000000..c475a6d30 --- /dev/null +++ b/docker/Dockerfile.krolik @@ -0,0 +1,65 @@ +# MemOS with Krolik Security Extensions +# +# This Dockerfile builds MemOS with authentication, rate limiting, and admin API. +# It uses the overlay pattern to keep customizations separate from base code. + +FROM python:3.11-slim + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + g++ \ + build-essential \ + libffi-dev \ + python3-dev \ + curl \ + libpq-dev \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user +RUN groupadd -r memos && useradd -r -g memos -u 1000 memos + +WORKDIR /app + +# Use official Hugging Face +ENV HF_ENDPOINT=https://huggingface.co + +# Copy base MemOS source +COPY src/ ./src/ +COPY pyproject.toml ./ + +# Install base dependencies +RUN pip install --upgrade pip && \ + pip install --no-cache-dir poetry && \ + poetry config virtualenvs.create false && \ + poetry install --no-dev --extras "tree-mem mem-scheduler" + +# Install additional dependencies for Krolik +RUN pip install --no-cache-dir \ + sentence-transformers \ + torch \ + transformers \ + psycopg2-binary \ + redis + +# Apply Krolik overlay (AFTER base install to allow easy updates) +COPY overlays/krolik/ ./src/memos/ + +# Create data directory +RUN mkdir -p /data/memos && chown -R memos:memos /data/memos +RUN chown -R memos:memos /app + +# Set Python path +ENV PYTHONPATH=/app/src + +# Switch to non-root user +USER memos + +EXPOSE 8000 + +# Healthcheck +HEALTHCHECK --interval=30s --timeout=10s --retries=3 --start-period=60s \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Use extended entry point with security features +CMD ["gunicorn", "memos.api.server_api_ext:app", "--preload", "-w", "2", "-k", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000", "--timeout", "120"] diff --git a/docker/requirements.txt b/docker/requirements.txt index 340f4e140..be72527af 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -98,6 +98,7 @@ rich-toolkit==0.15.1 rignore==0.7.6 rpds-py==0.28.0 safetensors==0.6.2 +schedule==1.2.2 scikit-learn==1.7.2 scipy==1.16.3 sentry-sdk==2.44.0 diff --git a/examples/mcp/mcp_serve.py b/examples/mcp/mcp_serve.py new file mode 100644 index 000000000..901524b12 --- /dev/null +++ b/examples/mcp/mcp_serve.py @@ -0,0 +1,614 @@ +import asyncio +import os + +from typing import Any + +from dotenv import load_dotenv +from fastmcp import FastMCP + +# Assuming these are your imports +from memos.mem_os.main import MOS +from memos.mem_os.utils.default_config import get_default +from memos.mem_user.user_manager import UserRole + + +load_dotenv() + + +def load_default_config(user_id="default_user"): + """ + Load MOS configuration from environment variables. + + IMPORTANT for Neo4j Community Edition: + Community Edition does not support administrative commands like 'CREATE DATABASE'. + To avoid errors, ensure the following environment variables are set correctly: + - NEO4J_DB_NAME=neo4j (Must use the default database) + - NEO4J_AUTO_CREATE=false (Disable automatic database creation) + - NEO4J_USE_MULTI_DB=false (Disable multi-tenant database mode) + """ + # Define mapping between environment variables and configuration parameters + # We support both clean names and MOS_ prefixed names for compatibility + env_mapping = { + "OPENAI_API_KEY": "openai_api_key", + "OPENAI_API_BASE": "openai_api_base", + "MOS_TEXT_MEM_TYPE": "text_mem_type", + "NEO4J_URI": "neo4j_uri", + "NEO4J_USER": "neo4j_user", + "NEO4J_PASSWORD": "neo4j_password", + "NEO4J_DB_NAME": "neo4j_db_name", + "NEO4J_AUTO_CREATE": "neo4j_auto_create", + "NEO4J_USE_MULTI_DB": "use_multi_db", + "MOS_NEO4J_SHARED_DB": "mos_shared_db", # Special handle later + "MODEL_NAME": "model_name", + "MOS_CHAT_MODEL": "model_name", + "EMBEDDER_MODEL": "embedder_model", + "MOS_EMBEDDER_MODEL": "embedder_model", + "CHUNK_SIZE": "chunk_size", + "CHUNK_OVERLAP": "chunk_overlap", + "ENABLE_MEM_SCHEDULER": "enable_mem_scheduler", + "MOS_ENABLE_SCHEDULER": "enable_mem_scheduler", + "ENABLE_ACTIVATION_MEMORY": "enable_activation_memory", + "TEMPERATURE": "temperature", + "MOS_CHAT_TEMPERATURE": "temperature", + "MAX_TOKENS": "max_tokens", + "MOS_MAX_TOKENS": "max_tokens", + "TOP_P": "top_p", + "MOS_TOP_P": "top_p", + "TOP_K": "top_k", + "MOS_TOP_K": "top_k", + "SCHEDULER_TOP_K": "scheduler_top_k", + "MOS_SCHEDULER_TOP_K": "scheduler_top_k", + "SCHEDULER_TOP_N": "scheduler_top_n", + } + + # Fields that should always be kept as strings (not converted to numbers) + string_only_fields = { + "openai_api_key", + "openai_api_base", + "neo4j_uri", + "neo4j_user", + "neo4j_password", + "neo4j_db_name", + "text_mem_type", + "model_name", + "embedder_model", + } + + kwargs = {"user_id": user_id} + for env_key, param_key in env_mapping.items(): + val = os.getenv(env_key) + if val is not None: + # Strip quotes if they exist (sometimes happens with .env) + if (val.startswith('"') and val.endswith('"')) or ( + val.startswith("'") and val.endswith("'") + ): + val = val[1:-1] + + # Handle boolean conversions + if val.lower() in ("true", "false"): + kwargs[param_key] = val.lower() == "true" + # Keep certain fields as strings + elif param_key in string_only_fields: + kwargs[param_key] = val + else: + # Try numeric conversions (int first, then float) + try: + if "." in val: + kwargs[param_key] = float(val) + else: + kwargs[param_key] = int(val) + except ValueError: + kwargs[param_key] = val + + # Logic handle for MOS_NEO4J_SHARED_DB vs use_multi_db + if "mos_shared_db" in kwargs: + kwargs["use_multi_db"] = not kwargs.pop("mos_shared_db") + + # Extract mandatory or special params + openai_api_key = kwargs.pop("openai_api_key", os.getenv("OPENAI_API_KEY")) + openai_api_base = kwargs.pop("openai_api_base", "https://api.openai.com/v1") + text_mem_type = kwargs.pop("text_mem_type", "tree_text") + + # Ensure embedder_model has a default value if not set + if "embedder_model" not in kwargs: + kwargs["embedder_model"] = os.getenv("EMBEDDER_MODEL", "nomic-embed-text:latest") + + config, cube = get_default( + openai_api_key=openai_api_key, + openai_api_base=openai_api_base, + text_mem_type=text_mem_type, + **kwargs, + ) + return config, cube + + +class MOSMCPStdioServer: + def __init__(self): + self.mcp = FastMCP("MOS Memory System") + config, cube = load_default_config() + self.mos_core = MOS(config=config) + self.mos_core.register_mem_cube(cube) + self._setup_tools() + + +class MOSMCPServer: + """MCP Server that accepts an existing MOS instance.""" + + def __init__(self, mos_instance: MOS | None = None): + self.mcp = FastMCP("MOS Memory System") + if mos_instance is None: + # Fall back to creating from default config + config, cube = load_default_config() + self.mos_core = MOS(config=config) + self.mos_core.register_mem_cube(cube) + else: + self.mos_core = mos_instance + self._setup_tools() + + def _setup_tools(self): + """Setup MCP tools""" + + @self.mcp.tool() + async def chat(query: str, user_id: str | None = None) -> str: + """ + Chat with MOS system using memory-enhanced responses. + + This method provides intelligent responses by searching through user's memory cubes + and incorporating relevant context. It supports both standard chat mode and enhanced + Chain of Thought (CoT) mode for complex queries when PRO_MODE is enabled. + + Args: + query (str): The user's query or question to be answered + user_id (str, optional): User ID for the chat session. If not provided, uses the default user + + Returns: + str: AI-generated response incorporating relevant memories and context + """ + try: + response = self.mos_core.chat(query, user_id) + return response + except Exception as e: + import traceback + + error_details = traceback.format_exc() + return f"Chat error: {e!s}\nTraceback:\n{error_details}" + + @self.mcp.tool() + async def create_user( + user_id: str, role: str = "USER", user_name: str | None = None + ) -> str: + """ + Create a new user in the MOS system. + + This method creates a new user account with specified role and name. + Users can have different access levels and can own or access memory cubes. + + Args: + user_id (str): Unique identifier for the user + role (str): User role - "USER" for regular users, "ADMIN" for administrators + user_name (str, optional): Display name for the user. If not provided, uses user_id + + Returns: + str: Success message with the created user ID + """ + try: + user_role = UserRole.ADMIN if role.upper() == "ADMIN" else UserRole.USER + created_user_id = self.mos_core.create_user(user_id, user_role, user_name) + return f"User created successfully: {created_user_id}" + except Exception as e: + return f"Error creating user: {e!s}" + + @self.mcp.tool() + async def create_cube( + cube_name: str, owner_id: str, cube_path: str | None = None, cube_id: str | None = None + ) -> str: + """ + Create a new memory cube for a user. + + Memory cubes are containers that store different types of memories (textual, activation, parametric). + Each cube can be owned by a user and shared with other users. + + Args: + cube_name (str): Human-readable name for the memory cube + owner_id (str): User ID of the cube owner who has full control + cube_path (str, optional): File system path where cube data will be stored + cube_id (str, optional): Custom unique identifier for the cube. If not provided, one will be generated + + Returns: + str: Success message with the created cube ID + """ + try: + created_cube_id = self.mos_core.create_cube_for_user( + cube_name, owner_id, cube_path, cube_id + ) + return f"Cube created successfully: {created_cube_id}" + except Exception as e: + return f"Error creating cube: {e!s}" + + @self.mcp.tool() + async def register_cube( + cube_name_or_path: str, cube_id: str | None = None, user_id: str | None = None + ) -> str: + """ + Register an existing memory cube with the MOS system. + + This method loads and registers a memory cube from a file path or creates a new one + if the path doesn't exist. The cube becomes available for memory operations. + + Args: + cube_name_or_path (str): File path to the memory cube or name for a new cube + cube_id (str, optional): Custom identifier for the cube. If not provided, one will be generated + user_id (str, optional): User ID to associate with the cube. If not provided, uses default user + + Returns: + str: Success message with the registered cube ID + """ + try: + if not os.path.exists(cube_name_or_path): + _, cube = load_default_config(user_id=user_id) + cube_to_register = cube + else: + cube_to_register = cube_name_or_path + self.mos_core.register_mem_cube( + cube_to_register, mem_cube_id=cube_id, user_id=user_id + ) + return f"Cube registered successfully: {cube_id or cube_to_register}" + except Exception as e: + return f"Error registering cube: {e!s}" + + @self.mcp.tool() + async def unregister_cube(cube_id: str, user_id: str | None = None) -> str: + """ + Unregister a memory cube from the MOS system. + + This method removes a memory cube from the active session, making it unavailable + for memory operations. The cube data remains intact on disk. + + Args: + cube_id (str): Unique identifier of the cube to unregister + user_id (str, optional): User ID for access validation. If not provided, uses default user + + Returns: + str: Success message confirming the cube was unregistered + """ + try: + self.mos_core.unregister_mem_cube(cube_id, user_id) + return f"Cube unregistered successfully: {cube_id}" + except Exception as e: + return f"Error unregistering cube: {e!s}" + + @self.mcp.tool() + async def search_memories( + query: str, user_id: str | None = None, cube_ids: list[str] | None = None + ) -> dict[str, Any]: + """ + Search for memories across user's accessible memory cubes. + + This method performs semantic search through textual memories stored in the specified + cubes, returning relevant memories based on the query. Results are ranked by relevance. + + Args: + query (str): Search query to find relevant memories + user_id (str, optional): User ID whose cubes to search. If not provided, uses default user + cube_ids (list[str], optional): Specific cube IDs to search. If not provided, searches all user's cubes + + Returns: + dict: Search results containing text_mem, act_mem, and para_mem categories with relevant memories + """ + try: + result = self.mos_core.search(query, user_id, install_cube_ids=cube_ids) + return result + except Exception as e: + import traceback + + error_details = traceback.format_exc() + return {"error": str(e), "traceback": error_details} + + @self.mcp.tool() + async def add_memory( + memory_content: str | None = None, + doc_path: str | None = None, + messages: list[dict[str, str]] | None = None, + cube_id: str | None = None, + user_id: str | None = None, + ) -> str: + """ + Add memories to a memory cube. + + This method can add memories from different sources: direct text content, document files, + or conversation messages. The memories are processed and stored in the specified cube. + + Args: + memory_content (str, optional): Direct text content to add as memory + doc_path (str, optional): Path to a document file to process and add as memories + messages (list[dict[str, str]], optional): List of conversation messages to add as memories + cube_id (str, optional): Target cube ID. If not provided, uses user's default cube + user_id (str, optional): User ID for access validation. If not provided, uses default user + + Returns: + str: Success message confirming memories were added + """ + try: + self.mos_core.add( + messages=messages, + memory_content=memory_content, + doc_path=doc_path, + mem_cube_id=cube_id, + user_id=user_id, + ) + return "Memory added successfully" + except Exception as e: + return f"Error adding memory: {e!s}" + + @self.mcp.tool() + async def get_memory( + cube_id: str, memory_id: str, user_id: str | None = None + ) -> dict[str, Any]: + """ + Retrieve a specific memory from a memory cube. + + This method fetches a single memory item by its unique identifier from the specified cube. + + Args: + cube_id (str): Unique identifier of the cube containing the memory + memory_id (str): Unique identifier of the specific memory to retrieve + user_id (str, optional): User ID for access validation. If not provided, uses default user + + Returns: + dict: Memory content with metadata including memory text, creation time, and source + """ + try: + memory = self.mos_core.get(cube_id, memory_id, user_id) + return {"memory": str(memory)} + except Exception as e: + return {"error": str(e)} + + @self.mcp.tool() + async def update_memory( + cube_id: str, memory_id: str, memory_content: str, user_id: str | None = None + ) -> str: + """ + Update an existing memory in a memory cube. + + This method modifies the content of a specific memory while preserving its metadata. + Note: Update functionality may not be supported by all memory backends (e.g., tree_text). + + Args: + cube_id (str): Unique identifier of the cube containing the memory + memory_id (str): Unique identifier of the memory to update + memory_content (str): New content to replace the existing memory + user_id (str, optional): User ID for access validation. If not provided, uses default user + + Returns: + str: Success message confirming the memory was updated + """ + try: + from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata + + metadata = TextualMemoryMetadata( + user_id=user_id or self.mos_core.user_id, + session_id=self.mos_core.session_id, + source="mcp_update", + ) + memory_item = TextualMemoryItem(memory=memory_content, metadata=metadata) + + self.mos_core.update(cube_id, memory_id, memory_item, user_id) + return f"Memory updated successfully: {memory_id}" + except Exception as e: + return f"Error updating memory: {e!s}" + + @self.mcp.tool() + async def delete_memory(cube_id: str, memory_id: str, user_id: str | None = None) -> str: + """ + Delete a specific memory from a memory cube. + + This method permanently removes a memory item from the specified cube. + The operation cannot be undone. + + Args: + cube_id (str): Unique identifier of the cube containing the memory + memory_id (str): Unique identifier of the memory to delete + user_id (str, optional): User ID for access validation. If not provided, uses default user + + Returns: + str: Success message confirming the memory was deleted + """ + try: + self.mos_core.delete(cube_id, memory_id, user_id) + return f"Memory deleted successfully: {memory_id}" + except Exception as e: + return f"Error deleting memory: {e!s}" + + @self.mcp.tool() + async def delete_all_memories(cube_id: str, user_id: str | None = None) -> str: + """ + Delete all memories from a memory cube. + + This method permanently removes all memory items from the specified cube. + The operation cannot be undone and will clear all textual memories. + + Args: + cube_id (str): Unique identifier of the cube to clear + user_id (str, optional): User ID for access validation. If not provided, uses default user + + Returns: + str: Success message confirming all memories were deleted + """ + try: + self.mos_core.delete_all(cube_id, user_id) + return f"All memories deleted successfully from cube: {cube_id}" + except Exception as e: + return f"Error deleting all memories: {e!s}" + + @self.mcp.tool() + async def clear_chat_history(user_id: str | None = None) -> str: + """ + Clear the chat history for a user. + + This method resets the conversation history, removing all previous messages + while keeping the memory cubes and stored memories intact. + + Args: + user_id (str, optional): User ID whose chat history to clear. If not provided, uses default user + + Returns: + str: Success message confirming chat history was cleared + """ + try: + self.mos_core.clear_messages(user_id) + target_user = user_id or self.mos_core.user_id + return f"Chat history cleared for user: {target_user}" + except Exception as e: + return f"Error clearing chat history: {e!s}" + + @self.mcp.tool() + async def dump_cube( + dump_dir: str, user_id: str | None = None, cube_id: str | None = None + ) -> str: + """ + Export a memory cube to a directory. + + This method creates a backup or export of a memory cube, including all memories + and metadata, to the specified directory for backup or migration purposes. + + Args: + dump_dir (str): Directory path where the cube data will be exported + user_id (str, optional): User ID for access validation. If not provided, uses default user + cube_id (str, optional): Cube ID to export. If not provided, uses user's default cube + + Returns: + str: Success message with the export directory path + """ + try: + self.mos_core.dump(dump_dir, user_id, cube_id) + return f"Cube dumped successfully to: {dump_dir}" + except Exception as e: + return f"Error dumping cube: {e!s}" + + @self.mcp.tool() + async def share_cube(cube_id: str, target_user_id: str) -> str: + """ + Share a memory cube with another user. + + This method grants access to a memory cube to another user, allowing them + to read and search through the memories stored in that cube. + + Args: + cube_id (str): Unique identifier of the cube to share + target_user_id (str): User ID of the person to share the cube with + + Returns: + str: Success message confirming the cube was shared or error message if failed + """ + try: + success = self.mos_core.share_cube_with_user(cube_id, target_user_id) + if success: + return f"Cube {cube_id} shared successfully with user {target_user_id}" + else: + return f"Failed to share cube {cube_id} with user {target_user_id}" + except Exception as e: + return f"Error sharing cube: {e!s}" + + @self.mcp.tool() + async def get_user_info(user_id: str | None = None) -> dict[str, Any]: + """ + Get detailed information about a user and their accessible memory cubes. + + This method returns comprehensive user information including profile details, + role, creation time, and a list of all memory cubes the user can access. + + Args: + user_id (str, optional): User ID to get information for. If not provided, uses current user + + Returns: + dict: User information including user_id, user_name, role, created_at, and accessible_cubes + """ + try: + if user_id and user_id != self.mos_core.user_id: + # Temporarily switch user + original_user = self.mos_core.user_id + self.mos_core.user_id = user_id + user_info = self.mos_core.get_user_info() + self.mos_core.user_id = original_user + return user_info + else: + return self.mos_core.get_user_info() + except Exception as e: + return {"error": str(e)} + + @self.mcp.tool() + async def control_memory_scheduler(action: str) -> str: + """ + Control the memory scheduler service. + + The memory scheduler is responsible for processing and organizing memories + in the background. This method allows starting or stopping the scheduler service. + + Args: + action (str): Action to perform - "start" to enable the scheduler, "stop" to disable it + + Returns: + str: Success message confirming the scheduler action or error message if failed + """ + try: + if action.lower() == "start": + success = self.mos_core.mem_scheduler_on() + return ( + "Memory scheduler started" + if success + else "Failed to start memory scheduler" + ) + elif action.lower() == "stop": + success = self.mos_core.mem_scheduler_off() + return ( + "Memory scheduler stopped" if success else "Failed to stop memory scheduler" + ) + else: + return "Invalid action. Use 'start' or 'stop'" + except Exception as e: + return f"Error controlling memory scheduler: {e!s}" + + +def _run_mcp(self, transport: str = "stdio", **kwargs): + if transport == "stdio": + self.mcp.run(transport="stdio") + elif transport == "http": + host = kwargs.get("host", "localhost") + port = kwargs.get("port", 8000) + asyncio.run(self.mcp.run_http_async(host=host, port=port)) + elif transport == "sse": + host = kwargs.get("host", "localhost") + port = kwargs.get("port", 8000) + self.mcp.run(transport="sse", host=host, port=port) + else: + raise ValueError(f"Unsupported transport: {transport}") + + +MOSMCPStdioServer.run = _run_mcp +MOSMCPServer.run = _run_mcp + + +# Usage example +if __name__ == "__main__": + import argparse + + from dotenv import load_dotenv + + load_dotenv() + + # Parse command line arguments + parser = argparse.ArgumentParser(description="MOS MCP Server") + parser.add_argument( + "--transport", + choices=["stdio", "http", "sse"], + default="stdio", + help="Transport method (default: stdio)", + ) + parser.add_argument("--host", default="localhost", help="Host for HTTP/SSE transport") + parser.add_argument("--port", type=int, default=8000, help="Port for HTTP/SSE transport") + + args = parser.parse_args() + + # Create and run MCP server + server = MOSMCPStdioServer() + server.run(transport=args.transport, host=args.host, port=args.port) diff --git a/scripts/tools/verify_age_fix.py b/scripts/tools/verify_age_fix.py new file mode 100644 index 000000000..749f20813 --- /dev/null +++ b/scripts/tools/verify_age_fix.py @@ -0,0 +1,98 @@ +import logging +import os +import sys + + +logging.basicConfig(level=logging.INFO) + +# Ensure /app/src is in path +sys.path.append("/app/src") + +# --- Test PolarDBGraphDB --- +try: + print("\n[Test 1] Testing PolarDBGraphDB...") + # Import from graph_dbs.polardb + # Class name is PolarDBGraphDB + from memos.configs.graph_db import PolarDBConfig + from memos.graph_dbs.polardb import PolarDBGraphDB + print("Successfully imported PolarDBGraphDB") +except ImportError as e: + print(f"Failed to import PolarDBGraphDB: {e}") + sys.exit(1) + +# Credentials from docker inspect +config = PolarDBConfig( + host="postgres", + port=5432, + user="memos", + password="K2DscvW8JoBmSpEV4WIM856E6XtVl0s", + db_name="memos", + auto_create=False, + use_multi_db=False, # Shared DB mode usually + user_name="memos_default" +) + +try: + print("Initializing PolarDBGraphDB...") + db = PolarDBGraphDB(config) + print("Initialized.") + + print("Checking connection (via simple query)...") + # node_not_exist uses agtype_access_operator + count = db.node_not_exist("memo") + print(f"node_not_exist result: {count}") + + # Try get_node + node = db.get_node("dummy_id_12345") + print(f"get_node result: {node}") + + print("SUCCESS: PolarDBGraphDB test passed.") + +except Exception as e: + print(f"FAILURE PolarDBGraphDB: {e}") + import traceback + traceback.print_exc() + + +# --- Test Embedder --- +print("\n[Test 2] Testing UniversalAPIEmbedder (VoyageAI)...") +try: + from memos.configs.embedder import UniversalAPIEmbedderConfig + from memos.embedders.universal_api import UniversalAPIEmbedder + print("Successfully imported UniversalAPIEmbedder") + + # Values from our api_config.py logic + # api_config.py defaults for voyageai: + # provider="openai" + # base_url="https://api.voyageai.com/v1" + # api_key="pa-7v..." (VOYAGE_API_KEY from env) + + # We need to manually set these or load from env + # Env var VOYAGE_API_KEY should be present in container + voyage_key = os.getenv("VOYAGE_API_KEY", "missing_key") + + embedder_config = UniversalAPIEmbedderConfig( + provider="openai", + api_key=voyage_key, + base_url="https://api.voyageai.com/v1", + model_name_or_path="voyage-4-lite" + ) + + print(f"Initializing Embedder with Base URL: {embedder_config.base_url}") + embedder = UniversalAPIEmbedder(embedder_config) + + print("Generating embedding for 'Hellos World'...") + # embed method returns list[list[float]] + embeddings = embedder.embed(["Hellos World"]) + + print(f"Embeddings generated. Count: {len(embeddings)}") + if len(embeddings) > 0: + print(f"Embedding vector length: {len(embeddings[0])}") + print("SUCCESS: Embedder test passed.") + else: + print("FAILURE: No embeddings returned.") + +except Exception as e: + print(f"FAILURE Embedder: {e}") + import traceback + traceback.print_exc() diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d27c391ab..6be376b8a 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -429,17 +429,36 @@ def get_feedback_reranker_config() -> dict[str, Any]: @staticmethod def get_embedder_config() -> dict[str, Any]: """Get embedder configuration.""" + print(f"DEBUG: get_embedder_config called. BACKEND={os.getenv('MOS_EMBEDDER_BACKEND')}") embedder_backend = os.getenv("MOS_EMBEDDER_BACKEND", "ollama") - if embedder_backend == "universal_api": + # Map voyageai to universal_api + if embedder_backend in ["universal_api", "voyageai"]: + # Default provider is openai (compatible client) + provider = os.getenv("MOS_EMBEDDER_PROVIDER", "openai") + + # Handle API Key + api_key = os.getenv("MOS_EMBEDDER_API_KEY") + if not api_key and embedder_backend == "voyageai": + api_key = os.getenv("VOYAGE_API_KEY") + if not api_key: + api_key = "sk-xxxx" + + # Handle Base URL + base_url = os.getenv("MOS_EMBEDDER_API_BASE") + if not base_url and embedder_backend == "voyageai": + base_url = "https://api.voyageai.com/v1" + if not base_url: + base_url = "http://openai.com" + return { "backend": "universal_api", "config": { - "provider": os.getenv("MOS_EMBEDDER_PROVIDER", "openai"), - "api_key": os.getenv("MOS_EMBEDDER_API_KEY", "sk-xxxx"), + "provider": provider, + "api_key": api_key, "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"), "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), - "base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"), + "base_url": base_url, "backup_client": os.getenv("MOS_EMBEDDER_BACKUP_CLIENT", "false").lower() == "true", "backup_base_url": os.getenv( diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 13dd92189..ba527d602 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -43,6 +43,7 @@ ) from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer @@ -190,6 +191,7 @@ def init_server() -> dict[str, Any]: ) embedder = EmbedderFactory.from_config(embedder_config) nli_client = NLIClient(base_url=nli_client_config["base_url"]) + memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) reranker = RerankerFactory.from_config(reranker_config) @@ -393,4 +395,5 @@ def init_server() -> dict[str, Any]: "redis_client": redis_client, "deepsearch_agent": deepsearch_agent, "nli_client": nli_client, + "memory_history_manager": memory_history_manager, } diff --git a/src/memos/api/mcp_serve.py b/src/memos/api/mcp_serve.py index ce2e41390..f9f09ec4a 100644 --- a/src/memos/api/mcp_serve.py +++ b/src/memos/api/mcp_serve.py @@ -59,6 +59,16 @@ def load_default_config(user_id="default_user"): "SCHEDULER_TOP_K": "scheduler_top_k", "MOS_SCHEDULER_TOP_K": "scheduler_top_k", "SCHEDULER_TOP_N": "scheduler_top_n", + # Graph DB backend selection (neo4j, polardb, etc.) + "GRAPH_DB_BACKEND": "graph_db_backend", + "NEO4J_BACKEND": "graph_db_backend", + # PolarDB connection (Postgres + Apache AGE) + "POLAR_DB_HOST": "polar_db_host", + "POLAR_DB_PORT": "polar_db_port", + "POLAR_DB_USER": "polar_db_user", + "POLAR_DB_PASSWORD": "polar_db_password", + "POLAR_DB_DB_NAME": "polar_db_name", + "EMBEDDING_DIMENSION": "embedding_dimension", } # Fields that should always be kept as strings (not converted to numbers) @@ -72,6 +82,11 @@ def load_default_config(user_id="default_user"): "text_mem_type", "model_name", "embedder_model", + "graph_db_backend", + "polar_db_host", + "polar_db_user", + "polar_db_password", + "polar_db_name", } kwargs = {"user_id": user_id} @@ -122,15 +137,6 @@ def load_default_config(user_id="default_user"): return config, cube -class MOSMCPStdioServer: - def __init__(self): - self.mcp = FastMCP("MOS Memory System") - config, cube = load_default_config() - self.mos_core = MOS(config=config) - self.mos_core.register_mem_cube(cube) - self._setup_tools() - - class MOSMCPServer: """MCP Server that accepts an existing MOS instance.""" @@ -584,7 +590,6 @@ def _run_mcp(self, transport: str = "stdio", **kwargs): raise ValueError(f"Unsupported transport: {transport}") -MOSMCPStdioServer.run = _run_mcp MOSMCPServer.run = _run_mcp @@ -610,5 +615,5 @@ def _run_mcp(self, transport: str = "stdio", **kwargs): args = parser.parse_args() # Create and run MCP server - server = MOSMCPStdioServer() + server = MOSMCPServer() server.run(transport=args.transport, host=args.host, port=args.port) diff --git a/src/memos/chunkers/base.py b/src/memos/chunkers/base.py index c2a783baa..25d517c98 100644 --- a/src/memos/chunkers/base.py +++ b/src/memos/chunkers/base.py @@ -1,3 +1,4 @@ +import re from abc import ABC, abstractmethod from memos.configs.chunker import BaseChunkerConfig @@ -22,3 +23,28 @@ def __init__(self, config: BaseChunkerConfig): @abstractmethod def chunk(self, text: str) -> list[Chunk]: """Chunk the given text into smaller chunks.""" + + def protect_urls(self, text: str) -> tuple[str, dict[str, str]]: + """Protect URLs in text from being split during chunking. + + Returns: + tuple: (Text with URLs replaced by placeholders, URL mapping dictionary) + """ + url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+' + url_map = {} + + def replace_url(match): + url = match.group(0) + placeholder = f"__URL_{len(url_map)}__" + url_map[placeholder] = url + return placeholder + + protected_text = re.sub(url_pattern, replace_url, text) + return protected_text, url_map + + def restore_urls(self, text: str, url_map: dict[str, str]) -> str: + """Restore protected URLs in text back to their original form.""" + restored_text = text + for placeholder, url in url_map.items(): + restored_text = restored_text.replace(placeholder, url) + return restored_text diff --git a/src/memos/chunkers/charactertext_chunker.py b/src/memos/chunkers/charactertext_chunker.py index 15c0958ba..25739d96f 100644 --- a/src/memos/chunkers/charactertext_chunker.py +++ b/src/memos/chunkers/charactertext_chunker.py @@ -36,6 +36,8 @@ def __init__( def chunk(self, text: str, **kwargs) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - chunks = self.chunker.split_text(text) + protected_text, url_map = self.protect_urls(text) + chunks = self.chunker.split_text(protected_text) + chunks = [self.restore_urls(chunk, url_map) for chunk in chunks] logger.debug(f"Generated {len(chunks)} chunks from input text") return chunks diff --git a/src/memos/chunkers/markdown_chunker.py b/src/memos/chunkers/markdown_chunker.py index b7771ac35..17a6f4632 100644 --- a/src/memos/chunkers/markdown_chunker.py +++ b/src/memos/chunkers/markdown_chunker.py @@ -1,3 +1,5 @@ +import re + from memos.configs.chunker import MarkdownChunkerConfig from memos.dependency import require_python_package from memos.log import get_logger @@ -22,6 +24,7 @@ def __init__( chunk_size: int = 1000, chunk_overlap: int = 200, recursive: bool = False, + auto_fix_headers: bool = True, ): from langchain_text_splitters import ( MarkdownHeaderTextSplitter, @@ -29,6 +32,7 @@ def __init__( ) self.config = config + self.auto_fix_headers = auto_fix_headers self.chunker = MarkdownHeaderTextSplitter( headers_to_split_on=config.headers_to_split_on if config @@ -46,17 +50,88 @@ def __init__( def chunk(self, text: str, **kwargs) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - md_header_splits = self.chunker.split_text(text) + # Protect URLs first + protected_text, url_map = self.protect_urls(text) + # Auto-detect and fix malformed header hierarchy if enabled + if self.auto_fix_headers and self._detect_malformed_headers(protected_text): + logger.info("detected malformed header hierarchy, attempting to fix...") + protected_text = self._fix_header_hierarchy(protected_text) + logger.info("Header hierarchy fix completed") + + md_header_splits = self.chunker.split_text(protected_text) chunks = [] if self.chunker_recursive: md_header_splits = self.chunker_recursive.split_documents(md_header_splits) for doc in md_header_splits: try: chunk = " ".join(list(doc.metadata.values())) + "\n" + doc.page_content + chunk = self.restore_urls(chunk, url_map) chunks.append(chunk) except Exception as e: logger.warning(f"warning chunking document: {e}") - chunks.append(doc.page_content) + restored_chunk = self.restore_urls(doc.page_content, url_map) + chunks.append(restored_chunk) logger.info(f"Generated chunks: {chunks[:5]}") logger.debug(f"Generated {len(chunks)} chunks from input text") return chunks + + def _detect_malformed_headers(self, text: str) -> bool: + """Detect if markdown has improper header hierarchy usage.""" + header_levels = [] + pattern = re.compile(r'^#{1,6}\s+.+') + for line in text.split('\n'): + stripped_line = line.strip() + if pattern.match(stripped_line): + hash_match = re.match(r'^(#+)', stripped_line) + if hash_match: + level = len(hash_match.group(1)) + header_levels.append(level) + + total_headers = len(header_levels) + if total_headers == 0: + return False + + level1_count = sum(1 for level in header_levels if level == 1) + + # >90% are level-1 when total > 5, or all headers are level-1 when total <= 5 + if total_headers > 5: + level1_ratio = level1_count / total_headers + if level1_ratio > 0.9: + logger.warning( + f"Detected header hierarchy issue: {level1_count}/{total_headers} " + f"({level1_ratio:.1%}) of headers are level 1" + ) + return True + elif level1_count == total_headers: + logger.warning( + f"Detected header hierarchy issue: all {total_headers} headers are level 1" + ) + return True + return False + + def _fix_header_hierarchy(self, text: str) -> str: + """Fix markdown header hierarchy by keeping first header and incrementing the rest.""" + header_pattern = re.compile(r'^(#{1,6})\s+(.+)$') + lines = text.split('\n') + fixed_lines = [] + first_valid_header = False + + for line in lines: + stripped_line = line.strip() + header_match = header_pattern.match(stripped_line) + if header_match: + current_hashes, title_content = header_match.groups() + current_level = len(current_hashes) + + if not first_valid_header: + fixed_line = f"{current_hashes} {title_content}" + first_valid_header = True + else: + new_level = min(current_level + 1, 6) + new_hashes = '#' * new_level + fixed_line = f"{new_hashes} {title_content}" + fixed_lines.append(fixed_line) + else: + fixed_lines.append(line) + + return '\n'.join(fixed_lines) diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index f39dfb8e2..b02ef34a5 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -43,11 +43,12 @@ def __init__(self, config: SentenceChunkerConfig): def chunk(self, text: str) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - chonkie_chunks = self.chunker.chunk(text) + protected_text, url_map = self.protect_urls(text) + chonkie_chunks = self.chunker.chunk(protected_text) chunks = [] for c in chonkie_chunks: - chunk = Chunk(text=c.text, token_count=c.token_count, sentences=c.sentences) + chunk = self.restore_urls(c.text, url_map) chunks.append(chunk) logger.debug(f"Generated {len(chunks)} chunks from input text") diff --git a/src/memos/chunkers/simple_chunker.py b/src/memos/chunkers/simple_chunker.py index cc0dc40d0..1e8bc211b 100644 --- a/src/memos/chunkers/simple_chunker.py +++ b/src/memos/chunkers/simple_chunker.py @@ -20,12 +20,27 @@ def _simple_split_text(self, text: str, chunk_size: int, chunk_overlap: int) -> Returns: List of text chunks """ - if not text or len(text) <= chunk_size: - return [text] if text.strip() else [] + import re + + # Protect URLs from being split + url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+' + url_map = {} + + def replace_url(match): + url = match.group(0) + placeholder = f"__URL_{len(url_map)}__" + url_map[placeholder] = url + return placeholder + + protected_text = re.sub(url_pattern, replace_url, text) + + if not protected_text or len(protected_text) <= chunk_size: + chunks = [protected_text] if protected_text.strip() else [] + return [self._restore_urls(c, url_map) for c in chunks] chunks = [] start = 0 - text_len = len(text) + text_len = len(protected_text) while start < text_len: # Calculate end position @@ -35,16 +50,22 @@ def _simple_split_text(self, text: str, chunk_size: int, chunk_overlap: int) -> if end < text_len: # Try to break at newline, sentence end, or space for separator in ["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " "]: - last_sep = text.rfind(separator, start, end) + last_sep = protected_text.rfind(separator, start, end) if last_sep != -1: end = last_sep + len(separator) break - chunk = text[start:end].strip() + chunk = protected_text[start:end].strip() if chunk: chunks.append(chunk) # Move start position with overlap start = max(start + 1, end - chunk_overlap) - return chunks + return [self._restore_urls(c, url_map) for c in chunks] + + @staticmethod + def _restore_urls(text: str, url_map: dict[str, str]) -> str: + for placeholder, url in url_map.items(): + text = text.replace(placeholder, url) + return text diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 3b4bace0e..5ce9faad1 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -211,6 +211,58 @@ def validate_config(self): return self +class PostgresGraphDBConfig(BaseConfig): + """ + PostgreSQL + pgvector configuration for MemOS. + + Uses standard PostgreSQL with pgvector extension for vector search. + Does NOT require Apache AGE or other graph extensions. + + Schema: + - memos_memories: Main table for memory nodes (id, memory, properties JSONB, embedding vector) + - memos_edges: Edge table for relationships (source_id, target_id, type) + + Example: + --- + host = "postgres" + port = 5432 + user = "n8n" + password = "secret" + db_name = "n8n" + schema_name = "memos" + user_name = "default" + """ + + host: str = Field(..., description="Database host") + port: int = Field(default=5432, description="Database port") + user: str = Field(..., description="Database user") + password: str = Field(..., description="Database password") + db_name: str = Field(..., description="Database name") + schema_name: str = Field(default="memos", description="Schema name for MemOS tables") + user_name: str | None = Field( + default=None, + description="Logical user/tenant ID for data isolation", + ) + use_multi_db: bool = Field( + default=False, + description="If False: use single database with logical isolation by user_name", + ) + embedding_dimension: int = Field(default=768, description="Dimension of vector embedding (768 for all-mpnet-base-v2)") + maxconn: int = Field( + default=20, + description="Maximum number of connections in the connection pool", + ) + + @model_validator(mode="after") + def validate_config(self): + """Validate config.""" + if not self.db_name: + raise ValueError("`db_name` must be provided") + if not self.use_multi_db and not self.user_name: + raise ValueError("In single-database mode, `user_name` must be provided") + return self + + class GraphDBConfigFactory(BaseModel): backend: str = Field(..., description="Backend for graph database") config: dict[str, Any] = Field(..., description="Configuration for the graph database backend") @@ -220,6 +272,7 @@ class GraphDBConfigFactory(BaseModel): "neo4j-community": Neo4jCommunityGraphDBConfig, "nebular": NebulaGraphDBConfig, "polardb": PolarDBGraphDBConfig, + "postgres": PostgresGraphDBConfig, } @field_validator("backend") diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index a28f3bdce..f5f6dec33 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -250,8 +250,9 @@ def validate_partial_initialization(self) -> "AuthConfig": "All configuration components are None. This may indicate missing environment variables or configuration files." ) elif failed_components: - logger.warning( - f"Failed to initialize components: {', '.join(failed_components)}. Successfully initialized: {', '.join(initialized_components)}" + logger.info( + f"Components not configured: {', '.join(failed_components)}. " + f"Successfully initialized: {', '.join(initialized_components)}" ) return self diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index 538d913ea..0d5a7df87 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -14,8 +14,26 @@ logger = get_logger(__name__) +def _sanitize_unicode(text: str) -> str: + """ + Remove Unicode surrogates and other problematic characters. + Surrogates (U+D800-U+DFFF) cause UnicodeEncodeError with some APIs. + """ + try: + # Encode with 'surrogatepass' then decode, replacing invalid chars + cleaned = text.encode("utf-8", errors="surrogatepass").decode("utf-8", errors="replace") + # Replace replacement char with empty string for cleaner output + return cleaned.replace("\ufffd", "") + except Exception: + # Fallback: remove all non-BMP characters + return "".join(c for c in text if ord(c) < 0x10000) + + class UniversalAPIEmbedder(BaseEmbedder): def __init__(self, config: UniversalAPIEmbedderConfig): + print( + f"DEBUG: UniversalAPIEmbedder init. Config provider={config.provider}, base_url={config.base_url}" + ) self.provider = config.provider self.config = config @@ -54,6 +72,8 @@ def __init__(self, config: UniversalAPIEmbedderConfig): def embed(self, texts: list[str]) -> list[list[float]]: if isinstance(texts, str): texts = [texts] + # Sanitize Unicode to prevent encoding errors with emoji/surrogates + texts = [_sanitize_unicode(t) for t in texts] # Truncate texts if max_tokens is configured texts = self._truncate_texts(texts) logger.info(f"Embeddings request with input: {texts}") diff --git a/src/memos/graph_dbs/base.py b/src/memos/graph_dbs/base.py index 130b66a3d..bda0fbadd 100644 --- a/src/memos/graph_dbs/base.py +++ b/src/memos/graph_dbs/base.py @@ -272,3 +272,93 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N - metadata: dict[str, Any] - Node metadata user_name: Optional user name (will use config default if not provided) """ + + @abstractmethod + def get_edges( + self, id: str, type: str = "ANY", direction: str = "ANY" + ) -> list[dict[str, str]]: + """ + Get edges connected to a node, with optional type and direction filter. + Args: + id: Node ID to retrieve edges for. + type: Relationship type to match, or 'ANY' to match all. + direction: 'OUTGOING', 'INCOMING', or 'ANY'. + Returns: + List of edge dicts with 'from', 'to', and 'type' keys. + """ + + @abstractmethod + def search_by_fulltext( + self, query_words: list[str], top_k: int = 10, **kwargs + ) -> list[dict]: + """ + Full-text search for memory nodes. + Args: + query_words: List of words to search for. + top_k: Maximum number of results. + Returns: + List of dicts with 'id' and 'score'. + """ + + @abstractmethod + def get_neighbors_by_tag( + self, + tags: list[str], + exclude_ids: list[str], + top_k: int = 5, + min_overlap: int = 1, + **kwargs, + ) -> list[dict[str, Any]]: + """ + Find top-K neighbor nodes with maximum tag overlap. + Args: + tags: Tags to match. + exclude_ids: Node IDs to exclude. + top_k: Max neighbors to return. + min_overlap: Minimum overlapping tags required. + Returns: + List of node dicts. + """ + + @abstractmethod + def delete_node_by_prams( + self, + memory_ids: list[str] | None = None, + writable_cube_ids: list[str] | None = None, + file_ids: list[str] | None = None, + filter: dict | None = None, + **kwargs, + ) -> int: + """ + Delete nodes matching given parameters. + Returns: + Number of deleted nodes. + """ + + @abstractmethod + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> list[str]: + """ + Get distinct user names that own the given memory IDs. + """ + + @abstractmethod + def exist_user_name(self, user_name: str) -> bool: + """ + Check if a user_name exists in the graph. + """ + + @abstractmethod + def search_by_keywords_like( + self, query_word: str, **kwargs + ) -> list[dict]: + """ + Search memories using SQL LIKE pattern matching. + """ + + @abstractmethod + def search_by_keywords_tfidf( + self, query_words: list[str], **kwargs + ) -> list[dict]: + """ + Search memories using TF-IDF fulltext scoring. + """ diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py index ec9cbcda0..c207e3190 100644 --- a/src/memos/graph_dbs/factory.py +++ b/src/memos/graph_dbs/factory.py @@ -6,6 +6,7 @@ from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB from memos.graph_dbs.polardb import PolarDBGraphDB +from memos.graph_dbs.postgres import PostgresGraphDB class GraphStoreFactory(BaseGraphDB): @@ -16,6 +17,7 @@ class GraphStoreFactory(BaseGraphDB): "neo4j-community": Neo4jCommunityGraphDB, "nebular": NebulaGraphDB, "polardb": PolarDBGraphDB, + "postgres": PostgresGraphDB, } @classmethod diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 428d6d09e..289d3cab3 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -11,6 +11,7 @@ from memos.configs.graph_db import NebulaGraphDBConfig from memos.dependency import require_python_package from memos.graph_dbs.base import BaseGraphDB +from memos.graph_dbs.utils import compose_node as _compose_node from memos.log import get_logger from memos.utils import timed @@ -44,14 +45,6 @@ def _normalize(vec: list[float]) -> list[float]: return (v / (norm if norm else 1.0)).tolist() -@timed -def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: - node_id = item["id"] - memory = item["memory"] - metadata = item.get("metadata", {}) - return node_id, memory, metadata - - @timed def _escape_str(value: str) -> str: out = [] diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 70d40f13c..d716a9cce 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -7,19 +7,13 @@ from memos.configs.graph_db import Neo4jGraphDBConfig from memos.dependency import require_python_package from memos.graph_dbs.base import BaseGraphDB +from memos.graph_dbs.utils import compose_node as _compose_node from memos.log import get_logger logger = get_logger(__name__) -def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: - node_id = item["id"] - memory = item["memory"] - metadata = item.get("metadata", {}) - return node_id, memory, metadata - - def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: """ Ensure metadata has proper datetime fields and normalized types. @@ -502,7 +496,7 @@ def edge_exists( return result.single() is not None # Graph Query & Reasoning - def get_node(self, id: str, **kwargs) -> dict[str, Any] | None: + def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None: """ Retrieve the metadata and memory of a node. Args: @@ -510,18 +504,28 @@ def get_node(self, id: str, **kwargs) -> dict[str, Any] | None: Returns: Dictionary of node fields, or None if not found. """ - user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name + logger.info(f"[get_node] id: {id}") + user_name = kwargs.get("user_name") where_user = "" params = {"id": id} - if not self.config.use_multi_db and (self.config.user_name or user_name): + if user_name is not None: where_user = " AND n.user_name = $user_name" params["user_name"] = user_name query = f"MATCH (n:Memory) WHERE n.id = $id {where_user} RETURN n" + logger.info(f"[get_node] query: {query}") with self.driver.session(database=self.db_name) as session: record = session.run(query, params).single() - return self._parse_node(dict(record["n"])) if record else None + if not record: + return None + + node_dict = dict(record["n"]) + if include_embedding is False: + for key in ("embedding", "embedding_1024", "embedding_3072", "embedding_768"): + node_dict.pop(key, None) + + return self._parse_node(node_dict) def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]: """ @@ -1940,3 +1944,174 @@ def exist_user_name(self, user_name: str) -> dict[str, bool]: f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True ) raise + + def delete_node_by_mem_cube_id( + self, + mem_kube_id: dict | None = None, + delete_record_id: dict | None = None, + deleted_type: bool = False, + ) -> int: + """ + Delete nodes by mem_kube_id (user_name) and delete_record_id. + + Args: + mem_kube_id: The mem_kube_id which corresponds to user_name in the table. + Can be dict or str. If dict, will extract the value. + delete_record_id: The delete_record_id to match. + Can be dict or str. If dict, will extract the value. + deleted_type: If True, performs hard delete (directly deletes records). + If False, performs soft delete (updates status to 'deleted' and sets delete_record_id and delete_time). + + Returns: + int: Number of nodes deleted or updated. + """ + # Handle dict type parameters (extract value if dict) + if isinstance(mem_kube_id, dict): + # Try to get a value from dict, use first value if multiple + mem_kube_id = next(iter(mem_kube_id.values())) if mem_kube_id else None + + if isinstance(delete_record_id, dict): + delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None + + # Validate required parameters + if not mem_kube_id: + logger.warning("[delete_node_by_mem_cube_id] mem_kube_id is required but not provided") + return 0 + + if not delete_record_id: + logger.warning( + "[delete_node_by_mem_cube_id] delete_record_id is required but not provided" + ) + return 0 + + # Convert to string if needed + mem_kube_id = str(mem_kube_id) if mem_kube_id else None + delete_record_id = str(delete_record_id) if delete_record_id else None + + logger.info( + f"[delete_node_by_mem_cube_id] mem_kube_id={mem_kube_id}, " + f"delete_record_id={delete_record_id}, deleted_type={deleted_type}" + ) + + try: + with self.driver.session(database=self.db_name) as session: + if deleted_type: + # Hard delete: WHERE user_name = mem_kube_id AND delete_record_id = $delete_record_id + query = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_kube_id AND n.delete_record_id = $delete_record_id + DETACH DELETE n + """ + logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {query}") + + result = session.run( + query, mem_kube_id=mem_kube_id, delete_record_id=delete_record_id + ) + summary = result.consume() + deleted_count = summary.counters.nodes_deleted if summary.counters else 0 + + logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") + return deleted_count + else: + # Soft delete: WHERE user_name = mem_kube_id (only user_name condition) + current_time = datetime.utcnow().isoformat() + + query = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_kube_id + SET n.status = $status, + n.delete_record_id = $delete_record_id, + n.delete_time = $delete_time + RETURN count(n) AS updated_count + """ + logger.info(f"[delete_node_by_mem_cube_id] Soft delete query: {query}") + + result = session.run( + query, + mem_kube_id=mem_kube_id, + status="deleted", + delete_record_id=delete_record_id, + delete_time=current_time, + ) + record = result.single() + updated_count = record["updated_count"] if record else 0 + + logger.info( + f"[delete_node_by_mem_cube_id] Soft deleted (updated) {updated_count} nodes" + ) + return updated_count + + except Exception as e: + logger.error( + f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True + ) + raise + + def recover_memory_by_mem_kube_id( + self, + mem_kube_id: str | None = None, + delete_record_id: str | None = None, + ) -> int: + """ + Recover memory nodes by mem_kube_id (user_name) and delete_record_id. + + This function updates the status to 'activated', and clears delete_record_id and delete_time. + + Args: + mem_kube_id: The mem_kube_id which corresponds to user_name in the table. + delete_record_id: The delete_record_id to match. + + Returns: + int: Number of nodes recovered (updated). + """ + # Validate required parameters + if not mem_kube_id: + logger.warning( + "[recover_memory_by_mem_kube_id] mem_kube_id is required but not provided" + ) + return 0 + + if not delete_record_id: + logger.warning( + "[recover_memory_by_mem_kube_id] delete_record_id is required but not provided" + ) + return 0 + + logger.info( + f"[recover_memory_by_mem_kube_id] mem_kube_id={mem_kube_id}, " + f"delete_record_id={delete_record_id}" + ) + + try: + with self.driver.session(database=self.db_name) as session: + query = """ + MATCH (n:Memory) + WHERE n.user_name = $mem_kube_id AND n.delete_record_id = $delete_record_id + SET n.status = $status, + n.delete_record_id = $delete_record_id_empty, + n.delete_time = $delete_time_empty + RETURN count(n) AS updated_count + """ + logger.info(f"[recover_memory_by_mem_kube_id] Update query: {query}") + + result = session.run( + query, + mem_kube_id=mem_kube_id, + delete_record_id=delete_record_id, + status="activated", + delete_record_id_empty="", + delete_time_empty="", + ) + record = result.single() + updated_count = record["updated_count"] if record else 0 + + logger.info( + f"[recover_memory_by_mem_kube_id] Recovered (updated) {updated_count} nodes" + ) + return updated_count + + except Exception as e: + logger.error( + f"[recover_memory_by_mem_kube_id] Failed to recover nodes: {e}", exc_info=True + ) + raise diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index f2182f6cd..411dbffe5 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -1056,3 +1056,58 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: logger.warning(f"Failed to fetch vector for node {new_node['id']}: {e}") new_node["metadata"]["embedding"] = None return new_node + + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]: + """Get user names by memory ids. + + Args: + memory_ids: List of memory node IDs to query. + + Returns: + dict[str, str | None]: Dictionary mapping memory_id to user_name. + - Key: memory_id + - Value: user_name if exists, None if memory_id does not exist + Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None} + """ + if not memory_ids: + return {} + + logger.info( + f"[ neo4j_community get_user_names_by_memory_ids] Querying memory_ids {memory_ids}" + ) + + try: + with self.driver.session(database=self.db_name) as session: + # Query to get memory_id and user_name pairs + query = """ + MATCH (n:Memory) + WHERE n.id IN $memory_ids + RETURN n.id AS memory_id, n.user_name AS user_name + """ + logger.info(f"[get_user_names_by_memory_ids] query: {query}") + + result = session.run(query, memory_ids=memory_ids) + result_dict = {} + + # Build result dictionary from query results + for record in result: + memory_id = record["memory_id"] + user_name = record["user_name"] + result_dict[memory_id] = user_name if user_name else None + + # Set None for memory_ids that were not found + for mid in memory_ids: + if mid not in result_dict: + result_dict[mid] = None + + logger.info( + f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " + f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" + ) + + return result_dict + except Exception as e: + logger.error( + f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True + ) + raise diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py deleted file mode 100644 index b9c8ca84b..000000000 --- a/src/memos/graph_dbs/polardb.py +++ /dev/null @@ -1,5443 +0,0 @@ -import json -import random -import textwrap -import time - -from contextlib import suppress -from datetime import datetime -from typing import Any, Literal - -import numpy as np - -from memos.configs.graph_db import PolarDBGraphDBConfig -from memos.dependency import require_python_package -from memos.graph_dbs.base import BaseGraphDB -from memos.log import get_logger -from memos.utils import timed - - -logger = get_logger(__name__) - - -def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: - node_id = item["id"] - memory = item["memory"] - metadata = item.get("metadata", {}) - return node_id, memory, metadata - - -def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: - """ - Ensure metadata has proper datetime fields and normalized types. - - - Fill `created_at` and `updated_at` if missing (in ISO 8601 format). - - Convert embedding to list of float if present. - """ - now = datetime.utcnow().isoformat() - - # Fill timestamps if missing - metadata.setdefault("created_at", now) - metadata.setdefault("updated_at", now) - - # Normalize embedding type - embedding = metadata.get("embedding") - if embedding and isinstance(embedding, list): - metadata["embedding"] = [float(x) for x in embedding] - - return metadata - - -def generate_vector(dim=1024, low=-0.2, high=0.2): - """Generate a random vector for testing purposes.""" - return [round(random.uniform(low, high), 6) for _ in range(dim)] - - -def find_embedding(metadata): - def find_embedding(item): - """Find an embedding vector within nested structures""" - for key in ["embedding", "embedding_1024", "embedding_3072", "embedding_768"]: - if key in item and isinstance(item[key], list): - return item[key] - if "metadata" in item and key in item["metadata"]: - return item["metadata"][key] - if "properties" in item and key in item["properties"]: - return item["properties"][key] - return None - - -def detect_embedding_field(embedding_list): - if not embedding_list: - return None - dim = len(embedding_list) - if dim == 1024: - return "embedding" - else: - logger.warning(f"Unknown embedding dimension {dim}, skipping this vector") - return None - - -def convert_to_vector(embedding_list): - if not embedding_list: - return None - if isinstance(embedding_list, np.ndarray): - embedding_list = embedding_list.tolist() - return "[" + ",".join(str(float(x)) for x in embedding_list) + "]" - - -def clean_properties(props): - """Remove vector fields""" - vector_keys = {"embedding", "embedding_1024", "embedding_3072", "embedding_768"} - if not isinstance(props, dict): - return {} - return {k: v for k, v in props.items() if k not in vector_keys} - - -def escape_sql_string(value: str) -> str: - """Escape single quotes in SQL string.""" - return value.replace("'", "''") - - -class PolarDBGraphDB(BaseGraphDB): - """PolarDB-based implementation using Apache AGE graph database extension.""" - - @require_python_package( - import_name="psycopg2", - install_command="pip install psycopg2-binary", - install_link="https://pypi.org/project/psycopg2-binary/", - ) - def __init__(self, config: PolarDBGraphDBConfig): - """PolarDB-based implementation using Apache AGE. - - Tenant Modes: - - use_multi_db = True: - Dedicated Database Mode (Multi-Database Multi-Tenant). - Each tenant or logical scope uses a separate PolarDB database. - `db_name` is the specific tenant database. - `user_name` can be None (optional). - - - use_multi_db = False: - Shared Database Multi-Tenant Mode. - All tenants share a single PolarDB database. - `db_name` is the shared database. - `user_name` is required to isolate each tenant's data at the node level. - All node queries will enforce `user_name` in WHERE conditions and store it in metadata, - but it will be removed automatically before returning to external consumers. - """ - import psycopg2 - import psycopg2.pool - - self.config = config - - # Handle both dict and object config - if isinstance(config, dict): - self.db_name = config.get("db_name") - self.user_name = config.get("user_name") - host = config.get("host") - port = config.get("port") - user = config.get("user") - password = config.get("password") - maxconn = config.get("maxconn", 100) # De - else: - self.db_name = config.db_name - self.user_name = config.user_name - host = config.host - port = config.port - user = config.user - password = config.password - maxconn = config.maxconn if hasattr(config, "maxconn") else 100 - """ - # Create connection - self.connection = psycopg2.connect( - host=host, port=port, user=user, password=password, dbname=self.db_name,minconn=10, maxconn=2000 - ) - """ - logger.info(f" db_name: {self.db_name} current maxconn is:'{maxconn}'") - - # Create connection pool - self.connection_pool = psycopg2.pool.ThreadedConnectionPool( - minconn=5, - maxconn=maxconn, - host=host, - port=port, - user=user, - password=password, - dbname=self.db_name, - connect_timeout=60, # Connection timeout in seconds - keepalives_idle=40, # Seconds of inactivity before sending keepalive (should be < server idle timeout) - keepalives_interval=15, # Seconds between keepalive retries - keepalives_count=5, # Number of keepalive retries before considering connection dead - ) - - # Keep a reference to the pool for cleanup - self._pool_closed = False - - """ - # Handle auto_create - # auto_create = config.get("auto_create", False) if isinstance(config, dict) else config.auto_create - # if auto_create: - # self._ensure_database_exists() - - # Create graph and tables - # self.create_graph() - # self.create_edge() - # self._create_graph() - - # Handle embedding_dimension - # embedding_dim = config.get("embedding_dimension", 1024) if isinstance(config,dict) else config.embedding_dimension - # self.create_index(dimensions=embedding_dim) - """ - - def _get_config_value(self, key: str, default=None): - """Safely get config value from either dict or object.""" - if isinstance(self.config, dict): - return self.config.get(key, default) - else: - return getattr(self.config, key, default) - - def _get_connection_old(self): - """Get a connection from the pool.""" - if self._pool_closed: - raise RuntimeError("Connection pool has been closed") - conn = self.connection_pool.getconn() - # Set autocommit for PolarDB compatibility - conn.autocommit = True - return conn - - def _get_connection(self): - """ - Get a connection from the pool. - - This function: - 1. Gets a connection from ThreadedConnectionPool - 2. Checks if connection is closed or unhealthy - 3. Returns healthy connection or retries (max 3 times) - 4. Handles connection pool exhaustion gracefully - - Returns: - psycopg2 connection object - - Raises: - RuntimeError: If connection pool is closed or exhausted after retries - """ - logger.info(f" db_name: {self.db_name} pool maxconn is:'{self.connection_pool.maxconn}'") - if self._pool_closed: - raise RuntimeError("Connection pool has been closed") - - max_retries = 500 - import psycopg2.pool - - for attempt in range(max_retries): - conn = None - try: - # Try to get connection from pool - # This may raise PoolError if pool is exhausted - conn = self.connection_pool.getconn() - - # Check if connection is closed - if conn.closed != 0: - # Connection is closed, return it to pool with close flag and try again - logger.warning( - f"[_get_connection] Got closed connection, attempt {attempt + 1}/{max_retries}" - ) - try: - self.connection_pool.putconn(conn, close=True) - except Exception as e: - logger.warning( - f"[_get_connection] Failed to return closed connection to pool: {e}" - ) - with suppress(Exception): - conn.close() - - conn = None - if attempt < max_retries - 1: - # Exponential backoff: 0.1s, 0.2s, 0.4s - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - else: - raise RuntimeError("Pool returned a closed connection after all retries") - - # Set autocommit for PolarDB compatibility - conn.autocommit = True - - # Test connection health with SELECT 1 - try: - cursor = conn.cursor() - cursor.execute("SELECT 1") - cursor.fetchone() - cursor.close() - except Exception as health_check_error: - # Connection is not usable, return it to pool with close flag and try again - logger.warning( - f"[_get_connection] Connection health check failed (attempt {attempt + 1}/{max_retries}): {health_check_error}" - ) - try: - self.connection_pool.putconn(conn, close=True) - except Exception as putconn_error: - logger.warning( - f"[_get_connection] Failed to return unhealthy connection to pool: {putconn_error}" - ) - with suppress(Exception): - conn.close() - - conn = None - if attempt < max_retries - 1: - # Exponential backoff: 0.1s, 0.2s, 0.4s - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - else: - raise RuntimeError( - f"Failed to get a healthy connection from pool after {max_retries} attempts: {health_check_error}" - ) from health_check_error - - # Connection is healthy, return it - return conn - - except psycopg2.pool.PoolError as pool_error: - # Pool exhausted or other pool-related error - # Don't retry immediately for pool exhaustion - it's unlikely to resolve quickly - error_msg = str(pool_error).lower() - if "exhausted" in error_msg or "pool" in error_msg: - # Log pool status for debugging - try: - # Try to get pool stats if available - pool_info = f"Pool config: minconn={self.connection_pool.minconn}, maxconn={self.connection_pool.maxconn}" - logger.error( - f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}" - ) - except Exception: - logger.error( - f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries})" - ) - - # For pool exhaustion, wait longer before retry (connections may be returned) - if attempt < max_retries - 1: - # Longer backoff for pool exhaustion: 0.5s, 1.0s, 2.0s - wait_time = 0.5 * (2**attempt) - logger.info(f"[_get_connection] Waiting {wait_time}s before retry...") - """time.sleep(wait_time)""" - time.sleep(0.003) - continue - else: - raise RuntimeError( - f"Connection pool exhausted after {max_retries} attempts. " - f"This usually means connections are not being returned to the pool. " - f"Check for connection leaks in your code." - ) from pool_error - else: - # Other pool errors - retry with normal backoff - if attempt < max_retries - 1: - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - else: - raise RuntimeError( - f"Failed to get connection from pool: {pool_error}" - ) from pool_error - - except Exception as e: - # Other exceptions (not pool-related) - # Only try to return connection if we actually got one - # If getconn() failed (e.g., pool exhausted), conn will be None - if conn is not None: - try: - # Return connection to pool if it's valid - self.connection_pool.putconn(conn, close=True) - except Exception as putconn_error: - logger.warning( - f"[_get_connection] Failed to return connection after error: {putconn_error}" - ) - with suppress(Exception): - conn.close() - - if attempt >= max_retries - 1: - raise RuntimeError(f"Failed to get a valid connection from pool: {e}") from e - else: - # Exponential backoff: 0.1s, 0.2s, 0.4s - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - - # Should never reach here, but just in case - raise RuntimeError("Failed to get connection after all retries") - - def _return_connection(self, connection): - """ - Return a connection to the pool. - - This function safely returns a connection to the pool, handling: - - Closed connections (close them instead of returning) - - Pool closed state (close connection directly) - - None connections (no-op) - - putconn() failures (close connection as fallback) - - Args: - connection: psycopg2 connection object or None - """ - if self._pool_closed: - # Pool is closed, just close the connection if it exists - if connection: - try: - connection.close() - logger.debug("[_return_connection] Closed connection (pool is closed)") - except Exception as e: - logger.warning( - f"[_return_connection] Failed to close connection after pool closed: {e}" - ) - return - - if not connection: - # No connection to return - this is normal if _get_connection() failed - return - - try: - # Check if connection is closed - if hasattr(connection, "closed") and connection.closed != 0: - # Connection is closed, just close it explicitly and don't return to pool - logger.debug( - "[_return_connection] Connection is closed, closing it instead of returning to pool" - ) - try: - connection.close() - except Exception as e: - logger.warning(f"[_return_connection] Failed to close closed connection: {e}") - return - - # Connection is valid, return to pool - self.connection_pool.putconn(connection) - logger.debug("[_return_connection] Successfully returned connection to pool") - except Exception as e: - # If putconn fails, try to close the connection - # This prevents connection leaks if putconn() fails - logger.error( - f"[_return_connection] Failed to return connection to pool: {e}", exc_info=True - ) - try: - connection.close() - logger.debug( - "[_return_connection] Closed connection as fallback after putconn failure" - ) - except Exception as close_error: - logger.warning( - f"[_return_connection] Failed to close connection after putconn error: {close_error}" - ) - - def _return_connection_old(self, connection): - """Return a connection to the pool.""" - if not self._pool_closed and connection: - self.connection_pool.putconn(connection) - - def _ensure_database_exists(self): - """Create database if it doesn't exist.""" - try: - # For PostgreSQL/PolarDB, we need to connect to a default database first - # This is a simplified implementation - in production you might want to handle this differently - logger.info(f"Using database '{self.db_name}'") - except Exception as e: - logger.error(f"Failed to access database '{self.db_name}': {e}") - raise - - @timed - def _create_graph(self): - """Create PostgreSQL schema and table for graph storage.""" - # Get a connection from the pool - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Create schema if it doesn't exist - cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') - logger.info(f"Schema '{self.db_name}_graph' ensured.") - - # Create Memory table if it doesn't exist - cursor.execute(f""" - CREATE TABLE IF NOT EXISTS "{self.db_name}_graph"."Memory" ( - id TEXT PRIMARY KEY, - properties JSONB NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - """) - logger.info(f"Memory table created in schema '{self.db_name}_graph'.") - - # Add embedding column if it doesn't exist (using JSONB for compatibility) - try: - cursor.execute(f""" - ALTER TABLE "{self.db_name}_graph"."Memory" - ADD COLUMN IF NOT EXISTS embedding JSONB; - """) - logger.info("Embedding column added to Memory table.") - except Exception as e: - logger.warning(f"Failed to add embedding column: {e}") - - # Create indexes - cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_properties - ON "{self.db_name}_graph"."Memory" USING GIN (properties); - """) - - # Create vector index for embedding field - try: - cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_embedding - ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops) - WITH (lists = 100); - """) - logger.info("Vector index created for Memory table.") - except Exception as e: - logger.warning(f"Vector index creation failed (might not be supported): {e}") - - logger.info("Indexes created for Memory table.") - - except Exception as e: - logger.error(f"Failed to create graph schema: {e}") - raise e - finally: - self._return_connection(conn) - - def create_index( - self, - label: str = "Memory", - vector_property: str = "embedding", - dimensions: int = 1024, - index_name: str = "memory_vector_index", - ) -> None: - """ - Create indexes for embedding and other fields. - Note: This creates PostgreSQL indexes on the underlying tables. - """ - # Get a connection from the pool - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Create indexes on the underlying PostgreSQL tables - # Apache AGE stores data in regular PostgreSQL tables - cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_properties - ON "{self.db_name}_graph"."Memory" USING GIN (properties); - """) - - # Try to create vector index, but don't fail if it doesn't work - try: - cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_embedding - ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops); - """) - except Exception as ve: - logger.warning(f"Vector index creation failed (might not be supported): {ve}") - - logger.debug("Indexes created successfully.") - except Exception as e: - logger.warning(f"Failed to create indexes: {e}") - finally: - self._return_connection(conn) - - def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: - """Get count of memory nodes by type.""" - user_name = user_name if user_name else self._get_config_value("user_name") - query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype - """ - query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - params = [self.format_param_value(memory_type), self.format_param_value(user_name)] - - # Get a connection from the pool - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return result[0] if result else 0 - except Exception as e: - logger.error(f"[get_memory_count] Failed: {e}") - return -1 - finally: - self._return_connection(conn) - - @timed - def node_not_exist(self, scope: str, user_name: str | None = None) -> int: - """Check if a node with given scope exists.""" - user_name = user_name if user_name else self._get_config_value("user_name") - query = f""" - SELECT id - FROM "{self.db_name}_graph"."Memory" - WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype - """ - query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - query += "\nLIMIT 1" - params = [self.format_param_value(scope), self.format_param_value(user_name)] - - # Get a connection from the pool - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return 1 if result else 0 - except Exception as e: - logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) - raise - finally: - self._return_connection(conn) - - @timed - def remove_oldest_memory( - self, memory_type: str, keep_latest: int, user_name: str | None = None - ) -> None: - """ - Remove all WorkingMemory nodes except the latest `keep_latest` entries. - - Args: - memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). - keep_latest (int): Number of latest WorkingMemory entries to keep. - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self._get_config_value("user_name") - - # Use actual OFFSET logic, consistent with nebular.py - # First find IDs to delete, then delete them - select_query = f""" - SELECT id FROM "{self.db_name}_graph"."Memory" - WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype - AND ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = %s::agtype - ORDER BY ag_catalog.agtype_access_operator(properties, '"updated_at"'::agtype) DESC - OFFSET %s - """ - select_params = [ - self.format_param_value(memory_type), - self.format_param_value(user_name), - keep_latest, - ] - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Execute query to get IDs to delete - cursor.execute(select_query, select_params) - ids_to_delete = [row[0] for row in cursor.fetchall()] - - if not ids_to_delete: - logger.info(f"No {memory_type} memories to remove for user {user_name}") - return - - # Build delete query - placeholders = ",".join(["%s"] * len(ids_to_delete)) - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE id IN ({placeholders}) - """ - delete_params = ids_to_delete - - # Execute deletion - cursor.execute(delete_query, delete_params) - deleted_count = cursor.rowcount - logger.info( - f"Removed {deleted_count} oldest {memory_type} memories, " - f"keeping {keep_latest} latest for user {user_name}, " - f"removed ids: {ids_to_delete}" - ) - except Exception as e: - logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True) - raise - finally: - self._return_connection(conn) - - @timed - def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: - """ - Update node fields in PolarDB, auto-converting `created_at` and `updated_at` to datetime type if present. - """ - if not fields: - return - - user_name = user_name if user_name else self.config.user_name - - # Get the current node - current_node = self.get_node(id, user_name=user_name) - if not current_node: - return - - # Update properties but keep original id and memory fields - properties = current_node["metadata"].copy() - original_id = properties.get("id", id) # Preserve original ID - original_memory = current_node.get("memory", "") # Preserve original memory - - # If fields include memory, use it; otherwise keep original memory - if "memory" in fields: - original_memory = fields.pop("memory") - - properties.update(fields) - properties["id"] = original_id # Ensure ID is not overwritten - properties["memory"] = original_memory # Ensure memory is not overwritten - - # Handle embedding field - embedding_vector = None - if "embedding" in fields: - embedding_vector = fields.pop("embedding") - if not isinstance(embedding_vector, list): - embedding_vector = None - - # Build update query - if embedding_vector is not None: - query = f""" - UPDATE "{self.db_name}_graph"."Memory" - SET properties = %s, embedding = %s - WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype - """ - params = [ - json.dumps(properties), - json.dumps(embedding_vector), - self.format_param_value(id), - ] - else: - query = f""" - UPDATE "{self.db_name}_graph"."Memory" - SET properties = %s - WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype - """ - params = [json.dumps(properties), self.format_param_value(id)] - - # Only add user filter when user_name is provided - if user_name is not None: - query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - params.append(self.format_param_value(user_name)) - - # Get a connection from the pool - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - except Exception as e: - logger.error(f"[update_node] Failed to update node '{id}': {e}", exc_info=True) - raise - finally: - self._return_connection(conn) - - @timed - def delete_node(self, id: str, user_name: str | None = None) -> None: - """ - Delete a node from the graph. - Args: - id: Node identifier to delete. - user_name (str, optional): User name for filtering in non-multi-db mode - """ - query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype - """ - params = [self.format_param_value(id)] - - # Only add user filter when user_name is provided - if user_name is not None: - query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - params.append(self.format_param_value(user_name)) - - # Get a connection from the pool - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - except Exception as e: - logger.error(f"[delete_node] Failed to delete node '{id}': {e}", exc_info=True) - raise - finally: - self._return_connection(conn) - - @timed - def create_extension(self): - extensions = [("polar_age", "Graph engine"), ("vector", "Vector engine")] - # Get a connection from the pool - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Ensure in the correct database context - cursor.execute("SELECT current_database();") - current_db = cursor.fetchone()[0] - logger.info(f"Current database context: {current_db}") - - for ext_name, ext_desc in extensions: - try: - cursor.execute(f"create extension if not exists {ext_name};") - logger.info(f"Extension '{ext_name}' ({ext_desc}) ensured.") - except Exception as e: - if "already exists" in str(e): - logger.info(f"Extension '{ext_name}' ({ext_desc}) already exists.") - else: - logger.warning( - f"Failed to create extension '{ext_name}' ({ext_desc}): {e}" - ) - logger.error( - f"Failed to create extension '{ext_name}': {e}", exc_info=True - ) - except Exception as e: - logger.warning(f"Failed to access database context: {e}") - logger.error(f"Failed to access database context: {e}", exc_info=True) - finally: - self._return_connection(conn) - - @timed - def create_graph(self): - # Get a connection from the pool - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(f""" - SELECT COUNT(*) FROM ag_catalog.ag_graph - WHERE name = '{self.db_name}_graph'; - """) - graph_exists = cursor.fetchone()[0] > 0 - - if graph_exists: - logger.info(f"Graph '{self.db_name}_graph' already exists.") - else: - cursor.execute(f"select create_graph('{self.db_name}_graph');") - logger.info(f"Graph database '{self.db_name}_graph' created.") - except Exception as e: - logger.warning(f"Failed to create graph '{self.db_name}_graph': {e}") - logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) - finally: - self._return_connection(conn) - - @timed - def create_edge(self): - """Create all valid edge types if they do not exist""" - - valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} - - for label_name in valid_rel_types: - conn = None - logger.info(f"Creating elabel: {label_name}") - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") - logger.info(f"Successfully created elabel: {label_name}") - except Exception as e: - if "already exists" in str(e): - logger.info(f"Label '{label_name}' already exists, skipping.") - else: - logger.warning(f"Failed to create label {label_name}: {e}") - logger.error(f"Failed to create elabel '{label_name}': {e}", exc_info=True) - finally: - self._return_connection(conn) - - @timed - def add_edge( - self, source_id: str, target_id: str, type: str, user_name: str | None = None - ) -> None: - logger.info( - f"polardb [add_edge] source_id: {source_id}, target_id: {target_id}, type: {type},user_name:{user_name}" - ) - - start_time = time.time() - if not source_id or not target_id: - logger.warning(f"Edge '{source_id}' and '{target_id}' are both None") - raise ValueError("[add_edge] source_id and target_id must be provided") - - source_exists = self.get_node(source_id) is not None - target_exists = self.get_node(target_id) is not None - - if not source_exists or not target_exists: - logger.warning( - "[add_edge] Source %s or target %s does not exist.", source_exists, target_exists - ) - raise ValueError("[add_edge] source_id and target_id must be provided") - - properties = {} - if user_name is not None: - properties["user_name"] = user_name - query = f""" - INSERT INTO {self.db_name}_graph."{type}"(id, start_id, end_id, properties) - SELECT - ag_catalog._next_graph_id('{self.db_name}_graph'::name, '{type}'), - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{source_id}'::text::cstring), - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{target_id}'::text::cstring), - jsonb_build_object('user_name', '{user_name}')::text::agtype - WHERE NOT EXISTS ( - SELECT 1 FROM {self.db_name}_graph."{type}" - WHERE start_id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{source_id}'::text::cstring) - AND end_id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{target_id}'::text::cstring) - ); - """ - logger.info(f"polardb [add_edge] query: {query}, properties: {json.dumps(properties)}") - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) - logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") - - elapsed_time = time.time() - start_time - logger.info(f" polardb [add_edge] insert completed time in {elapsed_time:.2f}s") - except Exception as e: - logger.error(f"Failed to insert edge: {e}", exc_info=True) - raise - finally: - self._return_connection(conn) - - @timed - def delete_edge(self, source_id: str, target_id: str, type: str) -> None: - """ - Delete a specific edge between two nodes. - Args: - source_id: ID of the source node. - target_id: ID of the target node. - type: Relationship type to remove. - """ - query = f""" - DELETE FROM "{self.db_name}_graph"."Edges" - WHERE source_id = %s AND target_id = %s AND edge_type = %s - """ - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, (source_id, target_id, type)) - logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") - finally: - self._return_connection(conn) - - @timed - def edge_exists_old( - self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING" - ) -> bool: - """ - Check if an edge exists between two nodes. - Args: - source_id: ID of the source node. - target_id: ID of the target node. - type: Relationship type. Use "ANY" to match any relationship type. - direction: Direction of the edge. - Use "OUTGOING" (default), "INCOMING", or "ANY". - Returns: - True if the edge exists, otherwise False. - """ - where_clauses = [] - params = [] - # SELECT * FROM - # cypher('memtensor_memos_graph', $$ - # MATCH(a: Memory - # {id: "13bb9df6-0609-4442-8bed-bba77dadac92"})-[r] - (b:Memory {id: "2dd03a5b-5d5f-49c9-9e0a-9a2a2899b98d"}) - # RETURN - # r - # $$) AS(r - # agtype); - - if direction == "OUTGOING": - where_clauses.append("source_id = %s AND target_id = %s") - params.extend([source_id, target_id]) - elif direction == "INCOMING": - where_clauses.append("source_id = %s AND target_id = %s") - params.extend([target_id, source_id]) - elif direction == "ANY": - where_clauses.append( - "((source_id = %s AND target_id = %s) OR (source_id = %s AND target_id = %s))" - ) - params.extend([source_id, target_id, target_id, source_id]) - else: - raise ValueError( - f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." - ) - - if type != "ANY": - where_clauses.append("edge_type = %s") - params.append(type) - - where_clause = " AND ".join(where_clauses) - - query = f""" - SELECT 1 FROM "{self.db_name}_graph"."Edges" - WHERE {where_clause} - LIMIT 1 - """ - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return result is not None - finally: - self._return_connection(conn) - - @timed - def edge_exists( - self, - source_id: str, - target_id: str, - type: str = "ANY", - direction: str = "OUTGOING", - user_name: str | None = None, - ) -> bool: - """ - Check if an edge exists between two nodes. - Args: - source_id: ID of the source node. - target_id: ID of the target node. - type: Relationship type. Use "ANY" to match any relationship type. - direction: Direction of the edge. - Use "OUTGOING" (default), "INCOMING", or "ANY". - user_name (str, optional): User name for filtering in non-multi-db mode - Returns: - True if the edge exists, otherwise False. - """ - - # Prepare the relationship pattern - user_name = user_name if user_name else self.config.user_name - - # Prepare the match pattern with direction - if direction == "OUTGOING": - pattern = "(a:Memory)-[r]->(b:Memory)" - elif direction == "INCOMING": - pattern = "(a:Memory)<-[r]-(b:Memory)" - elif direction == "ANY": - pattern = "(a:Memory)-[r]-(b:Memory)" - else: - raise ValueError( - f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." - ) - query = f"SELECT * FROM cypher('{self.db_name}_graph', $$" - query += f"\nMATCH {pattern}" - query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" - query += f"\nAND a.id = '{source_id}' AND b.id = '{target_id}'" - if type != "ANY": - query += f"\n AND type(r) = '{type}'" - - query += "\nRETURN r" - query += "\n$$) AS (r agtype)" - - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - result = cursor.fetchone() - return result is not None and result[0] is not None - finally: - self._return_connection(conn) - - @timed - def get_node( - self, id: str, include_embedding: bool = False, user_name: str | None = None - ) -> dict[str, Any] | None: - """ - Retrieve a Memory node by its unique ID. - - Args: - id (str): Node ID (Memory.id) - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - dict: Node properties as key-value pairs, or None if not found. - """ - logger.info( - f"polardb [get_node] id: {id}, include_embedding: {include_embedding}, user_name: {user_name}" - ) - start_time = time.time() - select_fields = "id, properties, embedding" if include_embedding else "id, properties" - - query = f""" - SELECT {select_fields} - FROM "{self.db_name}_graph"."Memory" - WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype - """ - params = [self.format_param_value(id)] - - # Only add user filter when user_name is provided - if user_name is not None: - query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - params.append(self.format_param_value(user_name)) - - logger.info(f"polardb [get_node] query: {query},params: {params}") - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - - if result: - if include_embedding: - _, properties_json, embedding_json = result - else: - _, properties_json = result - embedding_json = None - - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse properties for node {id}") - properties = {} - else: - properties = properties_json if properties_json else {} - - # Parse embedding from JSONB if it exists and include_embedding is True - if include_embedding and embedding_json is not None: - try: - embedding = ( - json.loads(embedding_json) - if isinstance(embedding_json, str) - else embedding_json - ) - properties["embedding"] = embedding - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {id}") - - elapsed_time = time.time() - start_time - logger.info( - f" polardb [get_node] get_node completed time in {elapsed_time:.2f}s" - ) - return self._parse_node( - { - "id": id, - "memory": properties.get("memory", ""), - **properties, - } - ) - return None - - except Exception as e: - logger.error(f"[get_node] Failed to retrieve node '{id}': {e}", exc_info=True) - return None - finally: - self._return_connection(conn) - - @timed - def get_nodes( - self, ids: list[str], user_name: str | None = None, **kwargs - ) -> list[dict[str, Any]]: - """ - Retrieve the metadata and memory of a list of nodes. - Args: - ids: List of Node identifier. - Returns: - list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'. - - Notes: - - Assumes all provided IDs are valid and exist. - - Returns empty list if input is empty. - """ - logger.info(f"get_nodes ids:{ids},user_name:{user_name}") - if not ids: - return [] - - # Build WHERE clause using IN operator with agtype array - # Use ANY operator with array for better performance - placeholders = ",".join(["%s"] * len(ids)) - params = [self.format_param_value(id_val) for id_val in ids] - - query = f""" - SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" - WHERE ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = ANY(ARRAY[{placeholders}]::agtype[]) - """ - - # Only add user_name filter if provided - if user_name is not None: - query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - params.append(self.format_param_value(user_name)) - - logger.info(f"get_nodes query:{query},params:{params}") - - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - - nodes = [] - for row in results: - node_id, properties_json, embedding_json = row - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse properties for node {node_id}") - properties = {} - else: - properties = properties_json if properties_json else {} - - # Parse embedding from JSONB if it exists - if embedding_json is not None and kwargs.get("include_embedding"): - try: - # remove embedding - embedding = ( - json.loads(embedding_json) - if isinstance(embedding_json, str) - else embedding_json - ) - properties["embedding"] = embedding - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {node_id}") - nodes.append( - self._parse_node( - { - "id": properties.get("id", node_id), - "memory": properties.get("memory", ""), - "metadata": properties, - } - ) - ) - return nodes - finally: - self._return_connection(conn) - - @timed - def get_edges_old( - self, id: str, type: str = "ANY", direction: str = "ANY" - ) -> list[dict[str, str]]: - """ - Get edges connected to a node, with optional type and direction filter. - - Args: - id: Node ID to retrieve edges for. - type: Relationship type to match, or 'ANY' to match all. - direction: 'OUTGOING', 'INCOMING', or 'ANY'. - - Returns: - List of edges: - [ - {"from": "source_id", "to": "target_id", "type": "RELATE"}, - ... - ] - """ - - # Create a simple edge table to store relationships (if not exists) - try: - with self.connection.cursor() as cursor: - # Create edge table - cursor.execute(f""" - CREATE TABLE IF NOT EXISTS "{self.db_name}_graph"."Edges" ( - id SERIAL PRIMARY KEY, - source_id TEXT NOT NULL, - target_id TEXT NOT NULL, - edge_type TEXT NOT NULL, - properties JSONB, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (source_id) REFERENCES "{self.db_name}_graph"."Memory"(id), - FOREIGN KEY (target_id) REFERENCES "{self.db_name}_graph"."Memory"(id) - ); - """) - - # Create indexes - cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_edges_source - ON "{self.db_name}_graph"."Edges" (source_id); - """) - cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_edges_target - ON "{self.db_name}_graph"."Edges" (target_id); - """) - cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_edges_type - ON "{self.db_name}_graph"."Edges" (edge_type); - """) - except Exception as e: - logger.warning(f"Failed to create edges table: {e}") - - # Query edges - where_clauses = [] - params = [id] - - if type != "ANY": - where_clauses.append("edge_type = %s") - params.append(type) - - if direction == "OUTGOING": - where_clauses.append("source_id = %s") - elif direction == "INCOMING": - where_clauses.append("target_id = %s") - else: # ANY - where_clauses.append("(source_id = %s OR target_id = %s)") - params.append(id) # Add second parameter for ANY direction - - where_clause = " AND ".join(where_clauses) - - query = f""" - SELECT source_id, target_id, edge_type - FROM "{self.db_name}_graph"."Edges" - WHERE {where_clause} - """ - - with self.connection.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - - edges = [] - for row in results: - source_id, target_id, edge_type = row - edges.append({"from": source_id, "to": target_id, "type": edge_type}) - return edges - - def get_neighbors( - self, id: str, type: str, direction: Literal["in", "out", "both"] = "out" - ) -> list[str]: - """Get connected node IDs in a specific direction and relationship type.""" - raise NotImplementedError - - @timed - def get_neighbors_by_tag_old( - self, - tags: list[str], - exclude_ids: list[str], - top_k: int = 5, - min_overlap: int = 1, - ) -> list[dict[str, Any]]: - """ - Find top-K neighbor nodes with maximum tag overlap. - - Args: - tags: The list of tags to match. - exclude_ids: Node IDs to exclude (e.g., local cluster). - top_k: Max number of neighbors to return. - min_overlap: Minimum number of overlapping tags required. - - Returns: - List of dicts with node details and overlap count. - """ - # Build query conditions - where_clauses = [] - params = [] - - # Exclude specified IDs - if exclude_ids: - placeholders = ",".join(["%s"] * len(exclude_ids)) - where_clauses.append(f"id NOT IN ({placeholders})") - params.extend(exclude_ids) - - # Status filter - where_clauses.append("properties->>'status' = %s") - params.append("activated") - - # Type filter - where_clauses.append("properties->>'type' != %s") - params.append("reasoning") - - where_clauses.append("properties->>'memory_type' != %s") - params.append("WorkingMemory") - - # User filter - if not self._get_config_value("use_multi_db", True) and self._get_config_value("user_name"): - where_clauses.append("properties->>'user_name' = %s") - params.append(self._get_config_value("user_name")) - - where_clause = " AND ".join(where_clauses) - - # Get all candidate nodes - query = f""" - SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - - with self.connection.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - - nodes_with_overlap = [] - for row in results: - node_id, properties_json, embedding_json = row - properties = properties_json if properties_json else {} - - # Parse embedding - if embedding_json is not None: - try: - embedding = ( - json.loads(embedding_json) - if isinstance(embedding_json, str) - else embedding_json - ) - properties["embedding"] = embedding - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {node_id}") - - # Compute tag overlap - node_tags = properties.get("tags", []) - if isinstance(node_tags, str): - try: - node_tags = json.loads(node_tags) - except (json.JSONDecodeError, TypeError): - node_tags = [] - - overlap_tags = [tag for tag in tags if tag in node_tags] - overlap_count = len(overlap_tags) - - if overlap_count >= min_overlap: - node_data = self._parse_node( - { - "id": properties.get("id", node_id), - "memory": properties.get("memory", ""), - "metadata": properties, - } - ) - nodes_with_overlap.append((node_data, overlap_count)) - - # Sort by overlap count and return top_k - nodes_with_overlap.sort(key=lambda x: x[1], reverse=True) - return [node for node, _ in nodes_with_overlap[:top_k]] - - @timed - def get_children_with_embeddings( - self, id: str, user_name: str | None = None - ) -> list[dict[str, Any]]: - """Get children nodes with their embeddings.""" - user_name = user_name if user_name else self._get_config_value("user_name") - where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" - - query = f""" - WITH t as ( - SELECT * - FROM cypher('{self.db_name}_graph', $$ - MATCH (p:Memory)-[r:PARENT]->(c:Memory) - WHERE p.id = '{id}' {where_user} - RETURN id(c) as cid, c.id AS id, c.memory AS memory - $$) as (cid agtype, id agtype, memory agtype) - ) - SELECT t.id, m.embedding, t.memory FROM t, - "{self.db_name}_graph"."Memory" m - WHERE t.cid::graphid = m.id; - """ - - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() - - children = [] - for row in results: - # Handle child_id - remove possible quotes - child_id_raw = row[0].value if hasattr(row[0], "value") else str(row[0]) - if isinstance(child_id_raw, str): - # If string starts and ends with quotes, remove quotes - if child_id_raw.startswith('"') and child_id_raw.endswith('"'): - child_id = child_id_raw[1:-1] - else: - child_id = child_id_raw - else: - child_id = str(child_id_raw) - - # Handle embedding - get from database embedding column - embedding_raw = row[1] - embedding = [] - if embedding_raw is not None: - try: - if isinstance(embedding_raw, str): - # If it is a JSON string, parse it - embedding = json.loads(embedding_raw) - elif isinstance(embedding_raw, list): - # If already a list, use directly - embedding = embedding_raw - else: - # Try converting to list - embedding = list(embedding_raw) - except (json.JSONDecodeError, TypeError, ValueError) as e: - logger.warning( - f"Failed to parse embedding for child node {child_id}: {e}" - ) - embedding = [] - - # Handle memory - remove possible quotes - memory_raw = row[2].value if hasattr(row[2], "value") else str(row[2]) - if isinstance(memory_raw, str): - # If string starts and ends with quotes, remove quotes - if memory_raw.startswith('"') and memory_raw.endswith('"'): - memory = memory_raw[1:-1] - else: - memory = memory_raw - else: - memory = str(memory_raw) - - children.append({"id": child_id, "embedding": embedding, "memory": memory}) - - return children - - except Exception as e: - logger.error(f"[get_children_with_embeddings] Failed: {e}", exc_info=True) - return [] - finally: - self._return_connection(conn) - - def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: - """Get the path of nodes from source to target within a limited depth.""" - raise NotImplementedError - - @timed - def get_subgraph( - self, - center_id: str, - depth: int = 2, - center_status: str = "activated", - user_name: str | None = None, - ) -> dict[str, Any]: - """ - Retrieve a local subgraph centered at a given node. - Args: - center_id: The ID of the center node. - depth: The hop distance for neighbors. - center_status: Required status for center node. - user_name (str, optional): User name for filtering in non-multi-db mode - Returns: - { - "core_node": {...}, - "neighbors": [...], - "edges": [...] - } - """ - logger.info(f"[get_subgraph] center_id: {center_id}") - if not 1 <= depth <= 5: - raise ValueError("depth must be 1-5") - - user_name = user_name if user_name else self._get_config_value("user_name") - - if center_id.startswith('"') and center_id.endswith('"'): - center_id = center_id[1:-1] - # Use a simplified query to get the subgraph (temporarily only direct neighbors) - """ - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH(center: Memory)-[r * 1..{depth}]->(neighbor:Memory) - WHERE - center.id = '{center_id}' - AND center.status = '{center_status}' - AND center.user_name = '{user_name}' - RETURN - collect(DISTINCT - center), collect(DISTINCT - neighbor), collect(DISTINCT - r) - $$ ) as (centers agtype, neighbors agtype, rels agtype); - """ - # Use UNION ALL for better performance: separate queries for depth 1 and depth 2 - if depth == 1: - query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH(center: Memory)-[r]->(neighbor:Memory) - WHERE - center.id = '{center_id}' - AND center.status = '{center_status}' - AND center.user_name = '{user_name}' - RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r) - $$ ) as (centers agtype, neighbors agtype, rels agtype); - """ - else: - # For depth >= 2, use UNION ALL to combine depth 1 and depth 2 queries - query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH(center: Memory)-[r]->(neighbor:Memory) - WHERE - center.id = '{center_id}' - AND center.status = '{center_status}' - AND center.user_name = '{user_name}' - RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r) - UNION ALL - MATCH(center: Memory)-[r]->(n:Memory)-[r1]->(neighbor:Memory) - WHERE - center.id = '{center_id}' - AND center.status = '{center_status}' - AND center.user_name = '{user_name}' - RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r1) - $$ ) as (centers agtype, neighbors agtype, rels agtype); - """ - conn = None - logger.info(f"[get_subgraph] Query: {query}") - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() - - if not results: - return {"core_node": None, "neighbors": [], "edges": []} - - # Merge results from all UNION ALL rows - all_centers_list = [] - all_neighbors_list = [] - all_edges_list = [] - - for result in results: - if not result or not result[0]: - continue - - centers_data = result[0] if result[0] else "[]" - neighbors_data = result[1] if result[1] else "[]" - edges_data = result[2] if result[2] else "[]" - - # Parse JSON data - try: - # Clean ::vertex and ::edge suffixes in data - if isinstance(centers_data, str): - centers_data = centers_data.replace("::vertex", "") - if isinstance(neighbors_data, str): - neighbors_data = neighbors_data.replace("::vertex", "") - if isinstance(edges_data, str): - edges_data = edges_data.replace("::edge", "") - - centers_list = ( - json.loads(centers_data) - if isinstance(centers_data, str) - else centers_data - ) - neighbors_list = ( - json.loads(neighbors_data) - if isinstance(neighbors_data, str) - else neighbors_data - ) - edges_list = ( - json.loads(edges_data) if isinstance(edges_data, str) else edges_data - ) - - # Collect data from this row - if isinstance(centers_list, list): - all_centers_list.extend(centers_list) - if isinstance(neighbors_list, list): - all_neighbors_list.extend(neighbors_list) - if isinstance(edges_list, list): - all_edges_list.extend(edges_list) - except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON data: {e}") - continue - - # Deduplicate centers by ID - centers_dict = {} - for center_data in all_centers_list: - if isinstance(center_data, dict) and "properties" in center_data: - center_id_key = center_data["properties"].get("id") - if center_id_key and center_id_key not in centers_dict: - centers_dict[center_id_key] = center_data - - # Parse center node (use first center) - core_node = None - if centers_dict: - center_data = next(iter(centers_dict.values())) - if isinstance(center_data, dict) and "properties" in center_data: - core_node = self._parse_node(center_data["properties"]) - - # Deduplicate neighbors by ID - neighbors_dict = {} - for neighbor_data in all_neighbors_list: - if isinstance(neighbor_data, dict) and "properties" in neighbor_data: - neighbor_id = neighbor_data["properties"].get("id") - if neighbor_id and neighbor_id not in neighbors_dict: - neighbors_dict[neighbor_id] = neighbor_data - - # Parse neighbor nodes - neighbors = [] - for neighbor_data in neighbors_dict.values(): - if isinstance(neighbor_data, dict) and "properties" in neighbor_data: - neighbor_parsed = self._parse_node(neighbor_data["properties"]) - neighbors.append(neighbor_parsed) - - # Deduplicate edges by (source, target, type) - edges_dict = {} - for edge_group in all_edges_list: - if isinstance(edge_group, list): - for edge_data in edge_group: - if isinstance(edge_data, dict): - edge_key = ( - edge_data.get("start_id", ""), - edge_data.get("end_id", ""), - edge_data.get("label", ""), - ) - if edge_key not in edges_dict: - edges_dict[edge_key] = { - "type": edge_data.get("label", ""), - "source": edge_data.get("start_id", ""), - "target": edge_data.get("end_id", ""), - } - elif isinstance(edge_group, dict): - # Handle single edge (not in a list) - edge_key = ( - edge_group.get("start_id", ""), - edge_group.get("end_id", ""), - edge_group.get("label", ""), - ) - if edge_key not in edges_dict: - edges_dict[edge_key] = { - "type": edge_group.get("label", ""), - "source": edge_group.get("start_id", ""), - "target": edge_group.get("end_id", ""), - } - - edges = list(edges_dict.values()) - - return self._convert_graph_edges( - {"core_node": core_node, "neighbors": neighbors, "edges": edges} - ) - - except Exception as e: - logger.error(f"Failed to get subgraph: {e}", exc_info=True) - return {"core_node": None, "neighbors": [], "edges": []} - finally: - self._return_connection(conn) - - def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: - """Get the ordered context chain starting from a node.""" - raise NotImplementedError - - @timed - def seach_by_keywords_like( - self, - query_word: str, - scope: str | None = None, - status: str | None = None, - search_filter: dict | None = None, - user_name: str | None = None, - filter: dict | None = None, - knowledgebase_ids: list[str] | None = None, - **kwargs, - ) -> list[dict]: - where_clauses = [] - - if scope: - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" - ) - if status: - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype" - ) - else: - where_clauses.append( - "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" - ) - - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( - user_name=user_name, - knowledgebase_ids=knowledgebase_ids, - default_user_name=self.config.user_name, - ) - - # Add OR condition if we have any user_name conditions - if user_name_conditions: - if len(user_name_conditions) == 1: - where_clauses.append(user_name_conditions[0]) - else: - where_clauses.append(f"({' OR '.join(user_name_conditions)})") - - # Add search_filter conditions - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" - ) - else: - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" - ) - - # Build filter conditions using common method - filter_conditions = self._build_filter_conditions_sql(filter) - where_clauses.extend(filter_conditions) - - # Build key - where_clauses.append("""(properties -> '"memory"')::text LIKE %s""") - where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - - query = f""" - SELECT - ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - agtype_object_field_text(properties, 'memory') as memory_text - FROM "{self.db_name}_graph"."Memory" - {where_clause} - """ - - params = (query_word,) - logger.info( - f"[seach_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}" - ) - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - oldid = row[0] - id_val = str(oldid) - output.append({"id": id_val}) - logger.info( - f"[seach_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" - ) - return output - finally: - self._return_connection(conn) - - @timed - def seach_by_keywords_tfidf( - self, - query_words: list[str], - scope: str | None = None, - status: str | None = None, - search_filter: dict | None = None, - user_name: str | None = None, - filter: dict | None = None, - knowledgebase_ids: list[str] | None = None, - tsvector_field: str = "properties_tsvector_zh", - tsquery_config: str = "jiebaqry", - **kwargs, - ) -> list[dict]: - where_clauses = [] - - if scope: - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" - ) - if status: - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype" - ) - else: - where_clauses.append( - "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" - ) - - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( - user_name=user_name, - knowledgebase_ids=knowledgebase_ids, - default_user_name=self.config.user_name, - ) - - # Add OR condition if we have any user_name conditions - if user_name_conditions: - if len(user_name_conditions) == 1: - where_clauses.append(user_name_conditions[0]) - else: - where_clauses.append(f"({' OR '.join(user_name_conditions)})") - - # Add search_filter conditions - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" - ) - else: - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" - ) - - # Build filter conditions using common method - filter_conditions = self._build_filter_conditions_sql(filter) - where_clauses.extend(filter_conditions) - # Add fulltext search condition - # Convert query_text to OR query format: "word1 | word2 | word3" - tsquery_string = " | ".join(query_words) - - where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") - - where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - - # Build fulltext search query - query = f""" - SELECT - ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - agtype_object_field_text(properties, 'memory') as memory_text - FROM "{self.db_name}_graph"."Memory" - {where_clause} - """ - - params = (tsquery_string,) - logger.info( - f"[seach_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}" - ) - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - oldid = row[0] - id_val = str(oldid) - output.append({"id": id_val}) - - logger.info( - f"[seach_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" - ) - return output - finally: - self._return_connection(conn) - - @timed - def search_by_fulltext( - self, - query_words: list[str], - top_k: int = 10, - scope: str | None = None, - status: str | None = None, - threshold: float | None = None, - search_filter: dict | None = None, - user_name: str | None = None, - filter: dict | None = None, - knowledgebase_ids: list[str] | None = None, - tsvector_field: str = "properties_tsvector_zh", - tsquery_config: str = "jiebacfg", - **kwargs, - ) -> list[dict]: - """ - Full-text search functionality using PostgreSQL's full-text search capabilities. - - Args: - query_text: query text - top_k: maximum number of results to return - scope: memory type filter (memory_type) - status: status filter, defaults to "activated" - threshold: similarity threshold filter - search_filter: additional property filter conditions - user_name: username filter - knowledgebase_ids: knowledgebase ids filter - filter: filter conditions with 'and' or 'or' logic for search results. - tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 - tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) - **kwargs: other parameters (e.g. cube_name) - - Returns: - list[dict]: result list containing id and score - """ - logger.info( - f"[search_by_fulltext] query_words: {query_words},top_k:{top_k},scope:{scope},status:{status},threshold:{threshold},search_filter:{search_filter},user_name:{user_name},knowledgebase_ids:{knowledgebase_ids},filter:{filter}" - ) - # Build WHERE clause dynamically, same as search_by_embedding - start_time = time.time() - where_clauses = [] - - if scope: - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" - ) - if status: - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype" - ) - else: - where_clauses.append( - "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" - ) - - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( - user_name=user_name, - knowledgebase_ids=knowledgebase_ids, - default_user_name=self.config.user_name, - ) - logger.info(f"[search_by_fulltext] user_name_conditions: {user_name_conditions}") - - # Add OR condition if we have any user_name conditions - if user_name_conditions: - if len(user_name_conditions) == 1: - where_clauses.append(user_name_conditions[0]) - else: - where_clauses.append(f"({' OR '.join(user_name_conditions)})") - - # Add search_filter conditions - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" - ) - else: - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" - ) - - # Build filter conditions using common method - filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[search_by_fulltext] filter_conditions: {filter_conditions}") - - where_clauses.extend(filter_conditions) - # Add fulltext search condition - # Convert query_text to OR query format: "word1 | word2 | word3" - tsquery_string = " | ".join(query_words) - - where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") - - where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - - logger.info(f"[search_by_fulltext] where_clause: {where_clause}") - - # Build fulltext search query - query = f""" - SELECT - ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - agtype_object_field_text(properties, 'memory') as memory_text, - ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank - FROM "{self.db_name}_graph"."Memory" - {where_clause} - ORDER BY rank DESC - LIMIT {top_k}; - """ - - params = [tsquery_string, tsquery_string] - logger.info(f"[search_by_fulltext] query: {query}, params: {params}") - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - oldid = row[0] # old_id - rank = row[2] # rank score - - id_val = str(oldid) - score_val = float(rank) - - # Apply threshold filter if specified - if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) - elapsed_time = time.time() - start_time - logger.info( - f" polardb [search_by_fulltext] query completed time in {elapsed_time:.2f}s" - ) - return output[:top_k] - finally: - self._return_connection(conn) - - @timed - def search_by_embedding( - self, - vector: list[float], - top_k: int = 5, - scope: str | None = None, - status: str | None = None, - threshold: float | None = None, - search_filter: dict | None = None, - user_name: str | None = None, - filter: dict | None = None, - knowledgebase_ids: list[str] | None = None, - **kwargs, - ) -> list[dict]: - """ - Retrieve node IDs based on vector similarity using PostgreSQL vector operations. - """ - # Build WHERE clause dynamically like nebular.py - logger.info( - f"[search_by_embedding] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" - ) - where_clauses = [] - if scope: - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" - ) - if status: - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype" - ) - else: - where_clauses.append( - "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" - ) - where_clauses.append("embedding is not null") - # Add user_name filter like nebular.py - - """ - # user_name = self._get_config_value("user_name") - # if not self.config.use_multi_db and user_name: - # if kwargs.get("cube_name"): - # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{kwargs['cube_name']}\"'::agtype") - # else: - # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype") - """ - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( - user_name=user_name, - knowledgebase_ids=knowledgebase_ids, - default_user_name=self.config.user_name, - ) - - # Add OR condition if we have any user_name conditions - if user_name_conditions: - if len(user_name_conditions) == 1: - where_clauses.append(user_name_conditions[0]) - else: - where_clauses.append(f"({' OR '.join(user_name_conditions)})") - - # Add search_filter conditions like nebular.py - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" - ) - else: - where_clauses.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" - ) - - # Build filter conditions using common method - filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[search_by_embedding] filter_conditions: {filter_conditions}") - where_clauses.extend(filter_conditions) - - where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - - # Keep original simple query structure but add dynamic WHERE clause - query = f""" - WITH t AS ( - SELECT id, - properties, - timeline, - ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - (1 - (embedding <=> %s::vector(1024))) AS scope - FROM "{self.db_name}_graph"."Memory" - {where_clause} - ORDER BY scope DESC - LIMIT {top_k} - ) - SELECT * - FROM t - WHERE scope > 0.1; - """ - # Convert vector to string format for PostgreSQL vector type - # PostgreSQL vector type expects a string format like '[1,2,3]' - vector_str = convert_to_vector(vector) - # Use string format directly in query instead of parameterized query - # Replace %s with the vector string, but need to quote it properly - # PostgreSQL vector type needs the string to be quoted - query = query.replace("%s::vector(1024)", f"'{vector_str}'::vector(1024)") - params = [] - - # Split query by lines and wrap long lines to prevent terminal truncation - query_lines = query.strip().split("\n") - for line in query_lines: - # Wrap lines longer than 200 characters to prevent terminal truncation - if len(line) > 200: - wrapped_lines = textwrap.wrap( - line, width=200, break_long_words=False, break_on_hyphens=False - ) - for _wrapped_line in wrapped_lines: - pass - else: - pass - - logger.info(f"[search_by_embedding] query: {query}, params: {params}") - - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - try: - # If params is empty, execute query directly without parameters - if params: - cursor.execute(query, params) - else: - cursor.execute(query) - except Exception as e: - logger.error(f"[search_by_embedding] Error executing query: {e}") - logger.error(f"[search_by_embedding] Query length: {len(query)}") - logger.error( - f"[search_by_embedding] Params type: {type(params)}, length: {len(params)}" - ) - logger.error(f"[search_by_embedding] Query contains %s: {'%s' in query}") - raise - results = cursor.fetchall() - output = [] - for row in results: - """ - polarId = row[0] # id - properties = row[1] # properties - # embedding = row[3] # embedding - """ - if len(row) < 5: - logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") - continue - oldid = row[3] # old_id - score = row[4] # scope - id_val = str(oldid) - score_val = float(score) - score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score - if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) - return output[:top_k] - finally: - self._return_connection(conn) - - @timed - def get_by_metadata( - self, - filters: list[dict[str, Any]], - user_name: str | None = None, - filter: dict | None = None, - knowledgebase_ids: list | None = None, - user_name_flag: bool = True, - ) -> list[str]: - """ - Retrieve node IDs that match given metadata filters. - Supports exact match. - - Args: - filters: List of filter dicts like: - [ - {"field": "key", "op": "in", "value": ["A", "B"]}, - {"field": "confidence", "op": ">=", "value": 80}, - {"field": "tags", "op": "contains", "value": "AI"}, - ... - ] - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[str]: Node IDs whose metadata match the filter conditions. (AND logic). - """ - logger.info(f"[get_by_metadata] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") - - user_name = user_name if user_name else self._get_config_value("user_name") - - # Build WHERE conditions for cypher query - where_conditions = [] - - for f in filters: - field = f["field"] - op = f.get("op", "=") - value = f["value"] - - # Format value - if isinstance(value, str): - # Escape single quotes using backslash when inside $$ dollar-quoted strings - # In $$ delimiters, Cypher string literals can use \' to escape single quotes - escaped_str = value.replace("'", "\\'") - escaped_value = f"'{escaped_str}'" - elif isinstance(value, list): - # Handle list values - use double quotes for Cypher arrays - list_items = [] - for v in value: - if isinstance(v, str): - # Escape double quotes in string values for Cypher - escaped_str = v.replace('"', '\\"') - list_items.append(f'"{escaped_str}"') - else: - list_items.append(str(v)) - escaped_value = f"[{', '.join(list_items)}]" - else: - escaped_value = f"'{value}'" if isinstance(value, str) else str(value) - # Build WHERE conditions - if op == "=": - where_conditions.append(f"n.{field} = {escaped_value}") - elif op == "in": - where_conditions.append(f"n.{field} IN {escaped_value}") - """ - # where_conditions.append(f"{escaped_value} IN n.{field}") - """ - elif op == "contains": - where_conditions.append(f"{escaped_value} IN n.{field}") - """ - # where_conditions.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0") - """ - elif op == "starts_with": - where_conditions.append(f"n.{field} STARTS WITH {escaped_value}") - elif op == "ends_with": - where_conditions.append(f"n.{field} ENDS WITH {escaped_value}") - elif op == "like": - where_conditions.append(f"n.{field} CONTAINS {escaped_value}") - elif op in [">", ">=", "<", "<="]: - where_conditions.append(f"n.{field} {op} {escaped_value}") - else: - raise ValueError(f"Unsupported operator: {op}") - - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( - user_name=user_name, - knowledgebase_ids=knowledgebase_ids, - default_user_name=self._get_config_value("user_name"), - ) - logger.info(f"[get_by_metadata] user_name_conditions: {user_name_conditions}") - - # Add user_name WHERE clause - if user_name_conditions: - if len(user_name_conditions) == 1: - where_conditions.append(user_name_conditions[0]) - else: - where_conditions.append(f"({' OR '.join(user_name_conditions)})") - - # Build filter conditions using common method - filter_where_clause = self._build_filter_conditions_cypher(filter) - logger.info(f"[get_by_metadata] filter_where_clause: {filter_where_clause}") - - where_str = " AND ".join(where_conditions) + filter_where_clause - - # Use cypher query - cypher_query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - WHERE {where_str} - RETURN n.id AS id - $$) AS (id agtype) - """ - - ids = [] - conn = None - logger.info(f"[get_by_metadata] cypher_query: {cypher_query}") - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() - ids = [str(item[0]).strip('"') for item in results] - except Exception as e: - logger.error(f"Failed to get metadata: {e}, query is {cypher_query}") - finally: - self._return_connection(conn) - - return ids - - @timed - def get_grouped_counts1( - self, - group_fields: list[str], - where_clause: str = "", - params: dict[str, Any] | None = None, - user_name: str | None = None, - ) -> list[dict[str, Any]]: - """ - Count nodes grouped by any fields. - - Args: - group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"] - where_clause (str, optional): Extra WHERE condition. E.g., - "WHERE n.status = 'activated'" - params (dict, optional): Parameters for WHERE clause. - - Returns: - list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] - """ - user_name = user_name if user_name else self.config.user_name - if not group_fields: - raise ValueError("group_fields cannot be empty") - - final_params = params.copy() if params else {} - if not self.config.use_multi_db and (self.config.user_name or user_name): - user_clause = "n.user_name = $user_name" - final_params["user_name"] = user_name - if where_clause: - where_clause = where_clause.strip() - if where_clause.upper().startswith("WHERE"): - where_clause += f" AND {user_clause}" - else: - where_clause = f"WHERE {where_clause} AND {user_clause}" - else: - where_clause = f"WHERE {user_clause}" - # Force RETURN field AS field to guarantee key match - group_fields_cypher = ", ".join([f"n.{field} AS {field}" for field in group_fields]) - """ - # group_fields_cypher_polardb = "agtype, ".join([f"{field}" for field in group_fields]) - """ - group_fields_cypher_polardb = ", ".join([f"{field} agtype" for field in group_fields]) - query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - {where_clause} - RETURN {group_fields_cypher}, COUNT(n) AS count1 - $$ ) as ({group_fields_cypher_polardb}, count1 agtype); - """ - try: - with self.connection.cursor() as cursor: - # Handle parameterized query - if params and isinstance(params, list): - cursor.execute(query, final_params) - else: - cursor.execute(query) - results = cursor.fetchall() - - output = [] - for row in results: - group_values = {} - for i, field in enumerate(group_fields): - value = row[i] - if hasattr(value, "value"): - group_values[field] = value.value - else: - group_values[field] = str(value) - count_value = row[-1] # Last column is count - output.append({**group_values, "count": count_value}) - - return output - - except Exception as e: - logger.error(f"Failed to get grouped counts: {e}", exc_info=True) - return [] - - @timed - def get_grouped_counts( - self, - group_fields: list[str], - where_clause: str = "", - params: dict[str, Any] | None = None, - user_name: str | None = None, - ) -> list[dict[str, Any]]: - """ - Count nodes grouped by any fields. - - Args: - group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"] - where_clause (str, optional): Extra WHERE condition. E.g., - "WHERE n.status = 'activated'" - params (dict, optional): Parameters for WHERE clause. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] - """ - if not group_fields: - raise ValueError("group_fields cannot be empty") - - user_name = user_name if user_name else self._get_config_value("user_name") - - # Build user clause - user_clause = f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" - if where_clause: - where_clause = where_clause.strip() - if where_clause.upper().startswith("WHERE"): - where_clause += f" AND {user_clause}" - else: - where_clause = f"WHERE {where_clause} AND {user_clause}" - else: - where_clause = f"WHERE {user_clause}" - - # Inline parameters if provided - if params and isinstance(params, dict): - for key, value in params.items(): - # Handle different value types appropriately - if isinstance(value, str): - value = f"'{value}'" - where_clause = where_clause.replace(f"${key}", str(value)) - - # Handle user_name parameter in where_clause - if "user_name = %s" in where_clause: - where_clause = where_clause.replace( - "user_name = %s", - f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype", - ) - - # Build return fields and group by fields - return_fields = [] - group_by_fields = [] - - for field in group_fields: - alias = field.replace(".", "_") - return_fields.append( - f"ag_catalog.agtype_access_operator(properties, '\"{field}\"'::agtype)::text AS {alias}" - ) - group_by_fields.append( - f"ag_catalog.agtype_access_operator(properties, '\"{field}\"'::agtype)::text" - ) - - # Full SQL query construction - query = f""" - SELECT {", ".join(return_fields)}, COUNT(*) AS count - FROM "{self.db_name}_graph"."Memory" - {where_clause} - GROUP BY {", ".join(group_by_fields)} - """ - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Handle parameterized query - if params and isinstance(params, list): - cursor.execute(query, params) - else: - cursor.execute(query) - results = cursor.fetchall() - - output = [] - for row in results: - group_values = {} - for i, field in enumerate(group_fields): - value = row[i] - if hasattr(value, "value"): - group_values[field] = value.value - else: - group_values[field] = str(value) - count_value = row[-1] # Last column is count - output.append({**group_values, "count": int(count_value)}) - - return output - - except Exception as e: - logger.error(f"Failed to get grouped counts: {e}", exc_info=True) - return [] - finally: - self._return_connection(conn) - - def deduplicate_nodes(self) -> None: - """Deduplicate redundant or semantically similar nodes.""" - raise NotImplementedError - - def detect_conflicts(self) -> list[tuple[str, str]]: - """Detect conflicting nodes based on logical or semantic inconsistency.""" - raise NotImplementedError - - def merge_nodes(self, id1: str, id2: str) -> str: - """Merge two similar or duplicate nodes into one.""" - raise NotImplementedError - - @timed - def clear(self, user_name: str | None = None) -> None: - """ - Clear the entire graph if the target database exists. - - Args: - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self._get_config_value("user_name") - - try: - query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - WHERE n.user_name = '{user_name}' - DETACH DELETE n - $$) AS (result agtype) - """ - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - logger.info("Cleared all nodes from database.") - finally: - self._return_connection(conn) - - except Exception as e: - logger.error(f"[ERROR] Failed to clear database: {e}") - - @timed - def export_graph( - self, - include_embedding: bool = False, - user_name: str | None = None, - user_id: str | None = None, - page: int | None = None, - page_size: int | None = None, - filter: dict | None = None, - **kwargs, - ) -> dict[str, Any]: - """ - Export all graph nodes and edges in a structured form. - Args: - include_embedding (bool): Whether to include the large embedding field. - user_name (str, optional): User name for filtering in non-multi-db mode - user_id (str, optional): User ID for filtering - page (int, optional): Page number (starts from 1). If None, exports all data without pagination. - page_size (int, optional): Number of items per page. If None, exports all data without pagination. - filter (dict, optional): Filter dictionary for metadata filtering. Supports "and", "or" logic and operators: - - "=": equality - - "in": value in list - - "contains": array contains value - - "gt", "lt", "gte", "lte": comparison operators - - "like": fuzzy matching - Example: {"and": [{"created_at": {"gte": "2025-01-01"}}, {"tags": {"contains": "AI"}}]} - - Returns: - { - "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], - "edges": [ { "source": ..., "target": ..., "type": ... }, ... ], - "total_nodes": int, # Total number of nodes matching the filter criteria - "total_edges": int, # Total number of edges matching the filter criteria - } - """ - logger.info( - f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}" - ) - user_id = user_id if user_id else self._get_config_value("user_id") - - # Initialize total counts - total_nodes = 0 - total_edges = 0 - - # Determine if pagination is needed - use_pagination = page is not None and page_size is not None - - # Validate pagination parameters if pagination is enabled - if use_pagination: - if page < 1: - page = 1 - if page_size < 1: - page_size = 10 - offset = (page - 1) * page_size - else: - offset = None - - conn = None - try: - conn = self._get_connection() - # Build WHERE conditions - where_conditions = [] - if user_name: - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" - ) - if user_id: - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" - ) - - # Build filter conditions using common method - filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[export_graph] filter_conditions: {filter_conditions}") - if filter_conditions: - where_conditions.extend(filter_conditions) - - where_clause = "" - if where_conditions: - where_clause = f"WHERE {' AND '.join(where_conditions)}" - - # Get total count of nodes before pagination - count_node_query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - {where_clause} - """ - logger.info(f"[export_graph nodes count] Query: {count_node_query}") - with conn.cursor() as cursor: - cursor.execute(count_node_query) - total_nodes = cursor.fetchone()[0] - - # Export nodes - # Build pagination clause if needed - pagination_clause = "" - if use_pagination: - pagination_clause = f"LIMIT {page_size} OFFSET {offset}" - - if include_embedding: - node_query = f""" - SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" - {where_clause} - ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, - id DESC - {pagination_clause} - """ - else: - node_query = f""" - SELECT id, properties - FROM "{self.db_name}_graph"."Memory" - {where_clause} - ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, - id DESC - {pagination_clause} - """ - logger.info(f"[export_graph nodes] Query: {node_query}") - with conn.cursor() as cursor: - cursor.execute(node_query) - node_results = cursor.fetchall() - nodes = [] - - for row in node_results: - if include_embedding: - """row is (id, properties, embedding)""" - _, properties_json, embedding_json = row - else: - """row is (id, properties)""" - _, properties_json = row - embedding_json = None - - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except json.JSONDecodeError: - properties = {} - else: - properties = properties_json if properties_json else {} - - # Remove embedding field if include_embedding is False - if not include_embedding: - properties.pop("embedding", None) - elif include_embedding and embedding_json is not None: - properties["embedding"] = embedding_json - - nodes.append(self._parse_node(properties)) - - except Exception as e: - logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) - raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e - finally: - self._return_connection(conn) - - conn = None - try: - conn = self._get_connection() - # Build Cypher WHERE conditions for edges - cypher_where_conditions = [] - if user_name: - cypher_where_conditions.append(f"a.user_name = '{user_name}'") - cypher_where_conditions.append(f"b.user_name = '{user_name}'") - if user_id: - cypher_where_conditions.append(f"a.user_id = '{user_id}'") - cypher_where_conditions.append(f"b.user_id = '{user_id}'") - - # Build filter conditions for edges (apply to both source and target nodes) - filter_where_clause = self._build_filter_conditions_cypher(filter) - logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}") - if filter_where_clause: - # _build_filter_conditions_cypher returns a string that starts with " AND " if filter exists - # Remove the leading " AND " and replace n. with a. for source node and b. for target node - filter_clause = filter_where_clause.strip() - if filter_clause.startswith("AND "): - filter_clause = filter_clause[4:].strip() - # Replace n. with a. for source node and create a copy for target node - source_filter = filter_clause.replace("n.", "a.") - target_filter = filter_clause.replace("n.", "b.") - # Combine source and target filters with AND - combined_filter = f"({source_filter}) AND ({target_filter})" - cypher_where_conditions.append(combined_filter) - - cypher_where_clause = "" - if cypher_where_conditions: - cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}" - - # Get total count of edges before pagination - count_edge_query = f""" - SELECT COUNT(*) - FROM ( - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (a:Memory)-[r]->(b:Memory) - {cypher_where_clause} - RETURN a.id AS source, b.id AS target, type(r) as edge - $$) AS (source agtype, target agtype, edge agtype) - ) AS edges - """ - logger.info(f"[export_graph edges count] Query: {count_edge_query}") - with conn.cursor() as cursor: - cursor.execute(count_edge_query) - total_edges = cursor.fetchone()[0] - - # Export edges using cypher query - # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery - # Build pagination clause if needed - edge_pagination_clause = "" - if use_pagination: - edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" - - edge_query = f""" - SELECT source, target, edge FROM ( - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (a:Memory)-[r]->(b:Memory) - {cypher_where_clause} - RETURN a.id AS source, b.id AS target, type(r) as edge - ORDER BY COALESCE(a.created_at, '1970-01-01T00:00:00') DESC, - COALESCE(b.created_at, '1970-01-01T00:00:00') DESC, - a.id DESC, b.id DESC - $$) AS (source agtype, target agtype, edge agtype) - ) AS edges - {edge_pagination_clause} - """ - logger.info(f"[export_graph edges] Query: {edge_query}") - with conn.cursor() as cursor: - cursor.execute(edge_query) - edge_results = cursor.fetchall() - edges = [] - - for row in edge_results: - source_agtype, target_agtype, edge_agtype = row - - # Extract and clean source - source_raw = ( - source_agtype.value - if hasattr(source_agtype, "value") - else str(source_agtype) - ) - if ( - isinstance(source_raw, str) - and source_raw.startswith('"') - and source_raw.endswith('"') - ): - source = source_raw[1:-1] - else: - source = str(source_raw) - - # Extract and clean target - target_raw = ( - target_agtype.value - if hasattr(target_agtype, "value") - else str(target_agtype) - ) - if ( - isinstance(target_raw, str) - and target_raw.startswith('"') - and target_raw.endswith('"') - ): - target = target_raw[1:-1] - else: - target = str(target_raw) - - # Extract and clean edge type - type_raw = ( - edge_agtype.value if hasattr(edge_agtype, "value") else str(edge_agtype) - ) - if ( - isinstance(type_raw, str) - and type_raw.startswith('"') - and type_raw.endswith('"') - ): - edge_type = type_raw[1:-1] - else: - edge_type = str(type_raw) - - edges.append( - { - "source": source, - "target": target, - "type": edge_type, - } - ) - - except Exception as e: - logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True) - raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e - finally: - self._return_connection(conn) - - return { - "nodes": nodes, - "edges": edges, - "total_nodes": total_nodes, - "total_edges": total_edges, - } - - @timed - def count_nodes(self, scope: str, user_name: str | None = None) -> int: - user_name = user_name if user_name else self.config.user_name - - query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - WHERE n.memory_type = '{scope}' - AND n.user_name = '{user_name}' - RETURN count(n) - $$) AS (count agtype) - """ - conn = None - try: - conn = self._get_connection() - result = self.execute_query(query, conn) - return int(result.one_or_none()["count"].value) - finally: - self._return_connection(conn) - - @timed - def get_all_memory_items( - self, - scope: str, - include_embedding: bool = False, - user_name: str | None = None, - filter: dict | None = None, - knowledgebase_ids: list | None = None, - status: str | None = None, - ) -> list[dict]: - """ - Retrieve all memory items of a specific memory_type. - - Args: - scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results. - knowledgebase_ids (list, optional): List of knowledgebase IDs to filter by. - status (str, optional): Filter by status (e.g., 'activated', 'archived'). - If None, no status filter is applied. - - Returns: - list[dict]: Full list of memory items under this scope. - """ - logger.info( - f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status}" - ) - - user_name = user_name if user_name else self._get_config_value("user_name") - if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: - raise ValueError(f"Unsupported memory type scope: {scope}") - - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( - user_name=user_name, - knowledgebase_ids=knowledgebase_ids, - default_user_name=self._get_config_value("user_name"), - ) - - # Build user_name WHERE clause - if user_name_conditions: - if len(user_name_conditions) == 1: - user_name_where = user_name_conditions[0] - else: - user_name_where = f"({' OR '.join(user_name_conditions)})" - else: - user_name_where = "" - - # Build filter conditions using common method - filter_where_clause = self._build_filter_conditions_cypher(filter) - logger.info(f"[get_all_memory_items] filter_where_clause: {filter_where_clause}") - - # Use cypher query to retrieve memory items - if include_embedding: - # Build WHERE clause with user_name/knowledgebase_ids and filter - where_parts = [f"n.memory_type = '{scope}'"] - if status: - where_parts.append(f"n.status = '{status}'") - if user_name_where: - # user_name_where already contains parentheses if it's an OR condition - where_parts.append(user_name_where) - if filter_where_clause: - # filter_where_clause already contains " AND " prefix, so we just append it - where_clause = " AND ".join(where_parts) + filter_where_clause - else: - where_clause = " AND ".join(where_parts) - - cypher_query = f""" - WITH t as ( - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - WHERE {where_clause} - RETURN id(n) as id1,n - LIMIT 100 - $$) AS (id1 agtype,n agtype) - ) - SELECT - m.embedding, - t.n - FROM t, - {self.db_name}_graph."Memory" m - WHERE t.id1 = m.id; - """ - nodes = [] - node_ids = set() - conn = None - logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() - - for row in results: - """ - if isinstance(row, (list, tuple)) and len(row) >= 2: - """ - if isinstance(row, list | tuple) and len(row) >= 2: - embedding_val, node_val = row[0], row[1] - else: - embedding_val, node_val = None, row[0] - - node = self._build_node_from_agtype(node_val, embedding_val) - if node: - node_id = node["id"] - if node_id not in node_ids: - nodes.append(node) - node_ids.add(node_id) - - except Exception as e: - logger.error(f"Failed to get memories: {e}", exc_info=True) - finally: - self._return_connection(conn) - - return nodes - else: - # Build WHERE clause with user_name/knowledgebase_ids and filter - where_parts = [f"n.memory_type = '{scope}'"] - if status: - where_parts.append(f"n.status = '{status}'") - if user_name_where: - # user_name_where already contains parentheses if it's an OR condition - where_parts.append(user_name_where) - if filter_where_clause: - # filter_where_clause already contains " AND " prefix, so we just append it - where_clause = " AND ".join(where_parts) + filter_where_clause - else: - where_clause = " AND ".join(where_parts) - - cypher_query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - WHERE {where_clause} - RETURN properties(n) as props - LIMIT 100 - $$) AS (nprops agtype) - """ - - nodes = [] - conn = None - logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() - - for row in results: - """ - if isinstance(row[0], str): - memory_data = json.loads(row[0]) - else: - memory_data = row[0] # 如果已经是字典,直接使用 - nodes.append(self._parse_node(memory_data)) - """ - memory_data = json.loads(row[0]) if isinstance(row[0], str) else row[0] - nodes.append(self._parse_node(memory_data)) - - except Exception as e: - logger.error(f"Failed to get memories: {e}", exc_info=True) - finally: - self._return_connection(conn) - - return nodes - - def get_all_memory_items_old( - self, scope: str, include_embedding: bool = False, user_name: str | None = None - ) -> list[dict]: - """ - Retrieve all memory items of a specific memory_type. - - Args: - scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[dict]: Full list of memory items under this scope. - """ - user_name = user_name if user_name else self._get_config_value("user_name") - if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: - raise ValueError(f"Unsupported memory type scope: {scope}") - - # Use cypher query to retrieve memory items - if include_embedding: - cypher_query = f""" - WITH t as ( - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - WHERE n.memory_type = '{scope}' AND n.user_name = '{user_name}' - RETURN id(n) as id1,n - LIMIT 100 - $$) AS (id1 agtype,n agtype) - ) - SELECT - m.embedding, - t.n - FROM t, - {self.db_name}_graph."Memory" m - WHERE t.id1 = m.id; - """ - else: - cypher_query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - WHERE n.memory_type = '{scope}' AND n.user_name = '{user_name}' - RETURN properties(n) as props - LIMIT 100 - $$) AS (nprops agtype) - """ - - nodes = [] - try: - with self.connection.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() - - for row in results: - node_agtype = row[0] - - # Handle string-formatted data - if isinstance(node_agtype, str): - try: - # Remove ::vertex suffix - json_str = node_agtype.replace("::vertex", "") - node_data = json.loads(json_str) - - if isinstance(node_data, dict) and "properties" in node_data: - properties = node_data["properties"] - # Build node data - parsed_node_data = { - "id": properties.get("id", ""), - "memory": properties.get("memory", ""), - "metadata": properties, - } - - if include_embedding and "embedding" in properties: - parsed_node_data["embedding"] = properties["embedding"] - - nodes.append(self._parse_node(parsed_node_data)) - logger.debug( - f"[get_all_memory_items] Parsed node successfully: {properties.get('id', '')}" - ) - else: - logger.warning(f"Invalid node data format: {node_data}") - - except (json.JSONDecodeError, TypeError) as e: - logger.error(f"JSON parsing failed: {e}") - elif node_agtype and hasattr(node_agtype, "value"): - # Handle agtype object - node_props = node_agtype.value - if isinstance(node_props, dict): - # Parse node properties - node_data = { - "id": node_props.get("id", ""), - "memory": node_props.get("memory", ""), - "metadata": node_props, - } - - if include_embedding and "embedding" in node_props: - node_data["embedding"] = node_props["embedding"] - - nodes.append(self._parse_node(node_data)) - else: - logger.warning(f"Unknown data format: {type(node_agtype)}") - - except Exception as e: - logger.error(f"Failed to get memories: {e}", exc_info=True) - - return nodes - - @timed - def get_structure_optimization_candidates( - self, scope: str, include_embedding: bool = False, user_name: str | None = None - ) -> list[dict]: - """ - Find nodes that are likely candidates for structure optimization: - - Isolated nodes, nodes with empty background, or nodes with exactly one child. - - Plus: the child of any parent node that has exactly one child. - """ - user_name = user_name if user_name else self._get_config_value("user_name") - - # Build return fields based on include_embedding flag - if include_embedding: - return_fields = "id(n) as id1,n" - return_fields_agtype = " id1 agtype,n agtype" - else: - # Build field list without embedding - return_fields = ",".join( - [ - "n.id AS id", - "n.memory AS memory", - "n.user_name AS user_name", - "n.user_id AS user_id", - "n.session_id AS session_id", - "n.status AS status", - "n.key AS key", - "n.confidence AS confidence", - "n.tags AS tags", - "n.created_at AS created_at", - "n.updated_at AS updated_at", - "n.memory_type AS memory_type", - "n.sources AS sources", - "n.source AS source", - "n.node_type AS node_type", - "n.visibility AS visibility", - "n.usage AS usage", - "n.background AS background", - "n.graph_id as graph_id", - ] - ) - fields = [ - "id", - "memory", - "user_name", - "user_id", - "session_id", - "status", - "key", - "confidence", - "tags", - "created_at", - "updated_at", - "memory_type", - "sources", - "source", - "node_type", - "visibility", - "usage", - "background", - "graph_id", - ] - return_fields_agtype = ", ".join([f"{field} agtype" for field in fields]) - - # Use OPTIONAL MATCH to find isolated nodes (no parents or children) - cypher_query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - WHERE n.memory_type = '{scope}' - AND n.status = 'activated' - AND n.user_name = '{user_name}' - OPTIONAL MATCH (n)-[:PARENT]->(c:Memory) - OPTIONAL MATCH (p:Memory)-[:PARENT]->(n) - WITH n, c, p - WHERE c IS NULL AND p IS NULL - RETURN {return_fields} - $$) AS ({return_fields_agtype}) - """ - if include_embedding: - cypher_query = f""" - WITH t as ( - {cypher_query} - ) - SELECT - m.embedding, - t.n - FROM t, - {self.db_name}_graph."Memory" m - WHERE t.id1 = m.id - """ - logger.info(f"[get_structure_optimization_candidates] query: {cypher_query}") - - candidates = [] - node_ids = set() - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() - logger.info(f"Found {len(results)} structure optimization candidates") - for row in results: - if include_embedding: - # When include_embedding=True, return full node object - """ - if isinstance(row, (list, tuple)) and len(row) >= 2: - """ - if isinstance(row, list | tuple) and len(row) >= 2: - embedding_val, node_val = row[0], row[1] - else: - embedding_val, node_val = None, row[0] - - node = self._build_node_from_agtype(node_val, embedding_val) - if node: - node_id = node["id"] - if node_id not in node_ids: - candidates.append(node) - node_ids.add(node_id) - else: - # When include_embedding=False, return field dictionary - # Define field names matching the RETURN clause - field_names = [ - "id", - "memory", - "user_name", - "user_id", - "session_id", - "status", - "key", - "confidence", - "tags", - "created_at", - "updated_at", - "memory_type", - "sources", - "source", - "node_type", - "visibility", - "usage", - "background", - "graph_id", - ] - - # Convert row to dictionary - node_data = {} - for i, field_name in enumerate(field_names): - if i < len(row): - value = row[i] - # Handle special fields - if field_name in ["tags", "sources", "usage"] and isinstance( - value, str - ): - try: - # Try parsing JSON string - node_data[field_name] = json.loads(value) - except (json.JSONDecodeError, TypeError): - node_data[field_name] = value - else: - node_data[field_name] = value - - # Parse node using _parse_node_new - try: - node = self._parse_node_new(node_data) - node_id = node["id"] - - if node_id not in node_ids: - candidates.append(node) - node_ids.add(node_id) - logger.debug(f"Parsed node successfully: {node_id}") - except Exception as e: - logger.error(f"Failed to parse node: {e}") - - except Exception as e: - logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) - finally: - self._return_connection(conn) - - return candidates - - def drop_database(self) -> None: - """Permanently delete the entire graph this instance is using.""" - return - if self._get_config_value("use_multi_db", True): - with self.connection.cursor() as cursor: - cursor.execute(f"SELECT drop_graph('{self.db_name}_graph', true)") - logger.info(f"Graph '{self.db_name}_graph' has been dropped.") - else: - raise ValueError( - f"Refusing to drop graph '{self.db_name}_graph' in " - f"Shared Database Multi-Tenant mode" - ) - - def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: - """Parse node data from database format to standard format.""" - node = node_data.copy() - - # Convert datetime to string - for time_field in ("created_at", "updated_at"): - if time_field in node and hasattr(node[time_field], "isoformat"): - node[time_field] = node[time_field].isoformat() - - # Deserialize sources from JSON strings back to dict objects - if "sources" in node and node.get("sources"): - sources = node["sources"] - if isinstance(sources, list): - deserialized_sources = [] - for source_item in sources: - if isinstance(source_item, str): - # Try to parse JSON string - try: - parsed = json.loads(source_item) - deserialized_sources.append(parsed) - except (json.JSONDecodeError, TypeError): - # If parsing fails, keep as string or create a simple dict - deserialized_sources.append({"type": "doc", "content": source_item}) - elif isinstance(source_item, dict): - # Already a dict, keep as is - deserialized_sources.append(source_item) - else: - # Unknown type, create a simple dict - deserialized_sources.append({"type": "doc", "content": str(source_item)}) - node["sources"] = deserialized_sources - - return {"id": node.get("id"), "memory": node.get("memory", ""), "metadata": node} - - def _parse_node_new(self, node_data: dict[str, Any]) -> dict[str, Any]: - """Parse node data from database format to standard format.""" - node = node_data.copy() - - # Normalize string values that may arrive as quoted literals (e.g., '"abc"') - def _strip_wrapping_quotes(value: Any) -> Any: - """ - if isinstance(value, str) and len(value) >= 2: - if value[0] == value[-1] and value[0] in ("'", '"'): - return value[1:-1] - return value - """ - if ( - isinstance(value, str) - and len(value) >= 2 - and value[0] == value[-1] - and value[0] in ("'", '"') - ): - return value[1:-1] - return value - - for k, v in list(node.items()): - if isinstance(v, str): - node[k] = _strip_wrapping_quotes(v) - - # Convert datetime to string - for time_field in ("created_at", "updated_at"): - if time_field in node and hasattr(node[time_field], "isoformat"): - node[time_field] = node[time_field].isoformat() - - # Deserialize sources from JSON strings back to dict objects - if "sources" in node and node.get("sources"): - sources = node["sources"] - if isinstance(sources, list): - deserialized_sources = [] - for source_item in sources: - if isinstance(source_item, str): - # Try to parse JSON string - try: - parsed = json.loads(source_item) - deserialized_sources.append(parsed) - except (json.JSONDecodeError, TypeError): - # If parsing fails, keep as string or create a simple dict - deserialized_sources.append({"type": "doc", "content": source_item}) - elif isinstance(source_item, dict): - # Already a dict, keep as is - deserialized_sources.append(source_item) - else: - # Unknown type, create a simple dict - deserialized_sources.append({"type": "doc", "content": str(source_item)}) - node["sources"] = deserialized_sources - - # Do not remove user_name; keep all fields - - return {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node} - - def __del__(self): - """Close database connection when object is destroyed.""" - if hasattr(self, "connection") and self.connection: - self.connection.close() - - @timed - def add_node( - self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None - ) -> None: - """Add a memory node to the graph.""" - logger.info(f"[add_node] id: {id}, memory: {memory}, metadata: {metadata}") - - # user_name comes from metadata; fallback to config if missing - metadata["user_name"] = user_name if user_name else self.config.user_name - - metadata = _prepare_node_metadata(metadata) - - # Merge node and set metadata - created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) - updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) - - # Prepare properties - properties = { - "id": id, - "memory": memory, - "created_at": created_at, - "updated_at": updated_at, - "delete_time": "", - "delete_record_id": "", - **metadata, - } - - # Generate embedding if not provided - if "embedding" not in properties or not properties["embedding"]: - properties["embedding"] = generate_vector( - self._get_config_value("embedding_dimension", 1024) - ) - - # serialization - JSON-serialize sources and usage fields - for field_name in ["sources", "usage"]: - if properties.get(field_name): - if isinstance(properties[field_name], list): - for idx in range(len(properties[field_name])): - # Serialize only when element is not a string - if not isinstance(properties[field_name][idx], str): - properties[field_name][idx] = json.dumps(properties[field_name][idx]) - elif isinstance(properties[field_name], str): - # If already a string, leave as-is - pass - - # Extract embedding for separate column - embedding_vector = properties.pop("embedding", []) - if not isinstance(embedding_vector, list): - embedding_vector = [] - - # Select column name based on embedding dimension - embedding_column = "embedding" # default column - if len(embedding_vector) == 3072: - embedding_column = "embedding_3072" - elif len(embedding_vector) == 1024: - embedding_column = "embedding" - elif len(embedding_vector) == 768: - embedding_column = "embedding_768" - - conn = None - insert_query = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Delete existing record first (if any) - delete_query = f""" - DELETE FROM {self.db_name}_graph."Memory" - WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(delete_query, (id,)) - # - get_graph_id_query = f""" - SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(get_graph_id_query, (id,)) - graph_id = cursor.fetchone()[0] - properties["graph_id"] = str(graph_id) - - # Then insert new record - if embedding_vector: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s, - %s - ) - """ - cursor.execute( - insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) - ) - logger.info( - f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" - ) - else: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s - ) - """ - cursor.execute(insert_query, (id, json.dumps(properties))) - logger.info( - f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" - ) - except Exception as e: - logger.error(f"[add_node] Failed to add node: {e}", exc_info=True) - raise - finally: - if insert_query: - logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") - self._return_connection(conn) - - @timed - def add_nodes_batch( - self, - nodes: list[dict[str, Any]], - user_name: str | None = None, - ) -> None: - """ - Batch add multiple memory nodes to the graph. - - Args: - nodes: List of node dictionaries, each containing: - - id: str - Node ID - - memory: str - Memory content - - metadata: dict[str, Any] - Node metadata - user_name: Optional user name (will use config default if not provided) - """ - batch_start_time = time.time() - if not nodes: - logger.warning("[add_nodes_batch] Empty nodes list, skipping") - return - - logger.info(f"[add_nodes_batch] Processing only first node (total nodes: {len(nodes)})") - - # user_name comes from parameter; fallback to config if missing - effective_user_name = user_name if user_name else self.config.user_name - - # Prepare all nodes - prepared_nodes = [] - for node_data in nodes: - try: - id = node_data["id"] - memory = node_data["memory"] - metadata = node_data.get("metadata", {}) - - logger.debug(f"[add_nodes_batch] Processing node id: {id}") - - # Set user_name in metadata - metadata["user_name"] = effective_user_name - - metadata = _prepare_node_metadata(metadata) - - # Merge node and set metadata - created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) - updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) - - # Prepare properties - properties = { - "id": id, - "memory": memory, - "created_at": created_at, - "updated_at": updated_at, - "delete_time": "", - "delete_record_id": "", - **metadata, - } - - # Generate embedding if not provided - if "embedding" not in properties or not properties["embedding"]: - properties["embedding"] = generate_vector( - self._get_config_value("embedding_dimension", 1024) - ) - - # Serialization - JSON-serialize sources and usage fields - for field_name in ["sources", "usage"]: - if properties.get(field_name): - if isinstance(properties[field_name], list): - for idx in range(len(properties[field_name])): - # Serialize only when element is not a string - if not isinstance(properties[field_name][idx], str): - properties[field_name][idx] = json.dumps( - properties[field_name][idx] - ) - elif isinstance(properties[field_name], str): - # If already a string, leave as-is - pass - - # Extract embedding for separate column - embedding_vector = properties.pop("embedding", []) - if not isinstance(embedding_vector, list): - embedding_vector = [] - - # Select column name based on embedding dimension - embedding_column = "embedding" # default column - if len(embedding_vector) == 3072: - embedding_column = "embedding_3072" - elif len(embedding_vector) == 1024: - embedding_column = "embedding" - elif len(embedding_vector) == 768: - embedding_column = "embedding_768" - - prepared_nodes.append( - { - "id": id, - "memory": memory, - "properties": properties, - "embedding_vector": embedding_vector, - "embedding_column": embedding_column, - } - ) - except Exception as e: - logger.error( - f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}", - exc_info=True, - ) - # Continue with other nodes - continue - - if not prepared_nodes: - logger.warning("[add_nodes_batch] No valid nodes to insert after preparation") - return - - # Group nodes by embedding column to optimize batch inserts - nodes_by_embedding_column = {} - for node in prepared_nodes: - col = node["embedding_column"] - if col not in nodes_by_embedding_column: - nodes_by_embedding_column[col] = [] - nodes_by_embedding_column[col].append(node) - - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Process each group separately - for embedding_column, nodes_group in nodes_by_embedding_column.items(): - # Batch delete existing records using IN clause - ids_to_delete = [node["id"] for node in nodes_group] - if ids_to_delete: - delete_query = f""" - DELETE FROM {self.db_name}_graph."Memory" - WHERE id IN ( - SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, unnest(%s::text[])::cstring) - ) - """ - cursor.execute(delete_query, (ids_to_delete,)) - - # Batch get graph_ids for all nodes - get_graph_ids_query = f""" - SELECT - id_val, - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, id_val::text::cstring) as graph_id - FROM unnest(%s::text[]) as id_val - """ - cursor.execute(get_graph_ids_query, (ids_to_delete,)) - graph_id_map = {row[0]: row[1] for row in cursor.fetchall()} - - # Add graph_id to properties - for node in nodes_group: - graph_id = graph_id_map.get(node["id"]) - if graph_id: - node["properties"]["graph_id"] = str(graph_id) - - # Use PREPARE/EXECUTE for efficient batch insert - # Generate unique prepare statement name to avoid conflicts - prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}" - - try: - if embedding_column and any( - node["embedding_vector"] for node in nodes_group - ): - # PREPARE statement for insert with embedding - prepare_query = f""" - PREPARE {prepare_name} AS - INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), - $2::text::agtype, - $3::vector - ) - """ - logger.info( - f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}" - ) - logger.info( - f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}" - ) - - cursor.execute(prepare_query) - - # Execute prepared statement for each node - for node in nodes_group: - properties_json = json.dumps(node["properties"]) - embedding_json = ( - json.dumps(node["embedding_vector"]) - if node["embedding_vector"] - else None - ) - - cursor.execute( - f"EXECUTE {prepare_name}(%s, %s, %s)", - (node["id"], properties_json, embedding_json), - ) - else: - # PREPARE statement for insert without embedding - prepare_query = f""" - PREPARE {prepare_name} AS - INSERT INTO {self.db_name}_graph."Memory"(id, properties) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), - $2::text::agtype - ) - """ - logger.info( - f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}" - ) - logger.info( - f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}" - ) - cursor.execute(prepare_query) - - # Execute prepared statement for each node - for node in nodes_group: - properties_json = json.dumps(node["properties"]) - - cursor.execute( - f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json) - ) - finally: - # DEALLOCATE prepared statement (always execute, even on error) - try: - cursor.execute(f"DEALLOCATE {prepare_name}") - logger.info( - f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}" - ) - except Exception as dealloc_error: - logger.warning( - f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" - ) - - logger.info( - f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" - ) - elapsed_time = time.time() - batch_start_time - logger.info( - f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s" - ) - - except Exception as e: - logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True) - raise - finally: - self._return_connection(conn) - - def _build_node_from_agtype(self, node_agtype, embedding=None): - """ - Parse the cypher-returned column `n` (agtype or JSON string) - into a standard node and merge embedding into properties. - """ - try: - # String case: '{"id":...,"label":[...],"properties":{...}}::vertex' - if isinstance(node_agtype, str): - json_str = node_agtype.replace("::vertex", "") - obj = json.loads(json_str) - if not (isinstance(obj, dict) and "properties" in obj): - return None - props = obj["properties"] - # agtype case: has `value` attribute - elif node_agtype and hasattr(node_agtype, "value"): - val = node_agtype.value - if not (isinstance(val, dict) and "properties" in val): - return None - props = val["properties"] - else: - return None - - if embedding is not None: - if isinstance(embedding, str): - try: - embedding = json.loads(embedding) - except (json.JSONDecodeError, TypeError): - logger.warning("Failed to parse embedding for node") - props["embedding"] = embedding - - # Return standard format directly - return {"id": props.get("id", ""), "memory": props.get("memory", ""), "metadata": props} - except Exception: - return None - - @timed - def get_neighbors_by_tag( - self, - tags: list[str], - exclude_ids: list[str], - top_k: int = 5, - min_overlap: int = 1, - include_embedding: bool = False, - user_name: str | None = None, - ) -> list[dict[str, Any]]: - """ - Find top-K neighbor nodes with maximum tag overlap. - - Args: - tags: The list of tags to match. - exclude_ids: Node IDs to exclude (e.g., local cluster). - top_k: Max number of neighbors to return. - min_overlap: Minimum number of overlapping tags required. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - List of dicts with node details and overlap count. - """ - if not tags: - return [] - - user_name = user_name if user_name else self._get_config_value("user_name") - - # Build query conditions - more relaxed filters - where_clauses = [] - params = [] - - # Exclude specified IDs - use id in properties - if exclude_ids: - exclude_conditions = [] - for exclude_id in exclude_ids: - exclude_conditions.append( - "ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) != %s::agtype" - ) - params.append(self.format_param_value(exclude_id)) - where_clauses.append(f"({' AND '.join(exclude_conditions)})") - - # Status filter - keep only 'activated' - where_clauses.append( - "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" - ) - - # Type filter - exclude 'reasoning' type - where_clauses.append( - "ag_catalog.agtype_access_operator(properties, '\"node_type\"'::agtype) != '\"reasoning\"'::agtype" - ) - - # User filter - where_clauses.append( - "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - ) - params.append(self.format_param_value(user_name)) - - # Testing showed no data; annotate. - where_clauses.append( - "ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) != '\"WorkingMemory\"'::agtype" - ) - - where_clause = " AND ".join(where_clauses) - - # Fetch all candidate nodes - query = f""" - SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - - logger.debug(f"[get_neighbors_by_tag] query: {query}, params: {params}") - - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - - nodes_with_overlap = [] - for row in results: - node_id, properties_json, embedding_json = row - properties = properties_json if properties_json else {} - - # Parse embedding - if include_embedding and embedding_json is not None: - try: - embedding = ( - json.loads(embedding_json) - if isinstance(embedding_json, str) - else embedding_json - ) - properties["embedding"] = embedding - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {node_id}") - - # Compute tag overlap - node_tags = properties.get("tags", []) - if isinstance(node_tags, str): - try: - node_tags = json.loads(node_tags) - except (json.JSONDecodeError, TypeError): - node_tags = [] - - overlap_tags = [tag for tag in tags if tag in node_tags] - overlap_count = len(overlap_tags) - - if overlap_count >= min_overlap: - node_data = self._parse_node( - { - "id": properties.get("id", node_id), - "memory": properties.get("memory", ""), - "metadata": properties, - } - ) - nodes_with_overlap.append((node_data, overlap_count)) - - # Sort by overlap count and return top_k items - nodes_with_overlap.sort(key=lambda x: x[1], reverse=True) - return [node for node, _ in nodes_with_overlap[:top_k]] - - except Exception as e: - logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) - return [] - finally: - self._return_connection(conn) - - def get_neighbors_by_tag_ccl( - self, - tags: list[str], - exclude_ids: list[str], - top_k: int = 5, - min_overlap: int = 1, - include_embedding: bool = False, - user_name: str | None = None, - ) -> list[dict[str, Any]]: - """ - Find top-K neighbor nodes with maximum tag overlap. - - Args: - tags: The list of tags to match. - exclude_ids: Node IDs to exclude (e.g., local cluster). - top_k: Max number of neighbors to return. - min_overlap: Minimum number of overlapping tags required. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - List of dicts with node details and overlap count. - """ - if not tags: - return [] - - user_name = user_name if user_name else self._get_config_value("user_name") - - # Build query conditions; keep consistent with nebular.py - where_clauses = [ - 'n.status = "activated"', - 'NOT (n.node_type = "reasoning")', - 'NOT (n.memory_type = "WorkingMemory")', - ] - where_clauses = [ - 'n.status = "activated"', - 'NOT (n.memory_type = "WorkingMemory")', - ] - - if exclude_ids: - exclude_ids_str = "[" + ", ".join(f'"{id}"' for id in exclude_ids) + "]" - where_clauses.append(f"NOT (n.id IN {exclude_ids_str})") - - where_clauses.append(f'n.user_name = "{user_name}"') - - where_clause = " AND ".join(where_clauses) - tag_list_literal = "[" + ", ".join(f'"{t}"' for t in tags) + "]" - - return_fields = [ - "n.id AS id", - "n.memory AS memory", - "n.user_name AS user_name", - "n.user_id AS user_id", - "n.session_id AS session_id", - "n.status AS status", - "n.key AS key", - "n.confidence AS confidence", - "n.tags AS tags", - "n.created_at AS created_at", - "n.updated_at AS updated_at", - "n.memory_type AS memory_type", - "n.sources AS sources", - "n.source AS source", - "n.node_type AS node_type", - "n.visibility AS visibility", - "n.background AS background", - ] - - if include_embedding: - return_fields.append("n.embedding AS embedding") - - return_fields_str = ", ".join(return_fields) - result_fields = [] - for field in return_fields: - # Extract field name 'id' from 'n.id AS id' - field_name = field.split(" AS ")[-1] - result_fields.append(f"{field_name} agtype") - - # Add overlap_count - result_fields.append("overlap_count agtype") - result_fields_str = ", ".join(result_fields) - # Use Cypher query; keep consistent with nebular.py - query = f""" - SELECT * FROM ( - SELECT * FROM cypher('{self.db_name}_graph', $$ - WITH {tag_list_literal} AS tag_list - MATCH (n:Memory) - WHERE {where_clause} - RETURN {return_fields_str}, - size([tag IN n.tags WHERE tag IN tag_list]) AS overlap_count - $$) AS ({result_fields_str}) - ) AS subquery - ORDER BY (overlap_count::integer) DESC - LIMIT {top_k} - """ - logger.debug(f"get_neighbors_by_tag: {query}") - try: - with self.connection.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() - - neighbors = [] - for row in results: - # Parse results - props = {} - overlap_count = None - - # Manually parse each field - field_names = [ - "id", - "memory", - "user_name", - "user_id", - "session_id", - "status", - "key", - "confidence", - "tags", - "created_at", - "updated_at", - "memory_type", - "sources", - "source", - "node_type", - "visibility", - "background", - ] - - if include_embedding: - field_names.append("embedding") - field_names.append("overlap_count") - - for i, field in enumerate(field_names): - if field == "overlap_count": - overlap_count = row[i].value if hasattr(row[i], "value") else row[i] - else: - props[field] = row[i].value if hasattr(row[i], "value") else row[i] - overlap_int = int(overlap_count) - if overlap_count is not None and overlap_int >= min_overlap: - parsed = self._parse_node(props) - parsed["overlap_count"] = overlap_int - neighbors.append(parsed) - - # Sort by overlap count - neighbors.sort(key=lambda x: x["overlap_count"], reverse=True) - neighbors = neighbors[:top_k] - - # Remove overlap_count field - result = [] - for neighbor in neighbors: - neighbor.pop("overlap_count", None) - result.append(neighbor) - - return result - - except Exception as e: - logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) - return [] - - @timed - def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: - """ - Import the entire graph from a serialized dictionary. - - Args: - data: A dictionary containing all nodes and edges to be loaded. - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self._get_config_value("user_name") - - # Import nodes - for node in data.get("nodes", []): - try: - id, memory, metadata = _compose_node(node) - metadata["user_name"] = user_name - metadata = _prepare_node_metadata(metadata) - metadata.update({"id": id, "memory": memory}) - - # Use add_node to insert node - self.add_node(id, memory, metadata) - - except Exception as e: - logger.error(f"Fail to load node: {node}, error: {e}") - - # Import edges - for edge in data.get("edges", []): - try: - source_id, target_id = edge["source"], edge["target"] - edge_type = edge["type"] - - # Use add_edge to insert edge - self.add_edge(source_id, target_id, edge_type, user_name) - - except Exception as e: - logger.error(f"Fail to load edge: {edge}, error: {e}") - - @timed - def get_edges( - self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None - ) -> list[dict[str, str]]: - """ - Get edges connected to a node, with optional type and direction filter. - - Args: - id: Node ID to retrieve edges for. - type: Relationship type to match, or 'ANY' to match all. - direction: 'OUTGOING', 'INCOMING', or 'ANY'. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - List of edges: - [ - {"from": "source_id", "to": "target_id", "type": "RELATE"}, - ... - ] - """ - user_name = user_name if user_name else self._get_config_value("user_name") - - if direction == "OUTGOING": - pattern = "(a:Memory)-[r]->(b:Memory)" - where_clause = f"a.id = '{id}'" - elif direction == "INCOMING": - pattern = "(a:Memory)<-[r]-(b:Memory)" - where_clause = f"a.id = '{id}'" - elif direction == "ANY": - pattern = "(a:Memory)-[r]-(b:Memory)" - where_clause = f"a.id = '{id}' OR b.id = '{id}'" - else: - raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") - - # Add type filter - if type != "ANY": - where_clause += f" AND type(r) = '{type}'" - - # Add user filter - where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" - - query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH {pattern} - WHERE {where_clause} - RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type - $$) AS (from_id agtype, to_id agtype, edge_type agtype) - """ - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() - - edges = [] - for row in results: - # Extract and clean from_id - from_id_raw = row[0].value if hasattr(row[0], "value") else row[0] - if ( - isinstance(from_id_raw, str) - and from_id_raw.startswith('"') - and from_id_raw.endswith('"') - ): - from_id = from_id_raw[1:-1] - else: - from_id = str(from_id_raw) - - # Extract and clean to_id - to_id_raw = row[1].value if hasattr(row[1], "value") else row[1] - if ( - isinstance(to_id_raw, str) - and to_id_raw.startswith('"') - and to_id_raw.endswith('"') - ): - to_id = to_id_raw[1:-1] - else: - to_id = str(to_id_raw) - - # Extract and clean edge_type - edge_type_raw = row[2].value if hasattr(row[2], "value") else row[2] - if ( - isinstance(edge_type_raw, str) - and edge_type_raw.startswith('"') - and edge_type_raw.endswith('"') - ): - edge_type = edge_type_raw[1:-1] - else: - edge_type = str(edge_type_raw) - - edges.append({"from": from_id, "to": to_id, "type": edge_type}) - return edges - - except Exception as e: - logger.error(f"Failed to get edges: {e}", exc_info=True) - return [] - finally: - self._return_connection(conn) - - def _convert_graph_edges(self, core_node: dict) -> dict: - import copy - - data = copy.deepcopy(core_node) - id_map = {} - core_node = data.get("core_node", {}) - if not core_node: - return { - "core_node": None, - "neighbors": data.get("neighbors", []), - "edges": data.get("edges", []), - } - core_meta = core_node.get("metadata", {}) - if "graph_id" in core_meta and "id" in core_node: - id_map[core_meta["graph_id"]] = core_node["id"] - for neighbor in data.get("neighbors", []): - n_meta = neighbor.get("metadata", {}) - if "graph_id" in n_meta and "id" in neighbor: - id_map[n_meta["graph_id"]] = neighbor["id"] - for edge in data.get("edges", []): - src = edge.get("source") - tgt = edge.get("target") - if src in id_map: - edge["source"] = id_map[src] - if tgt in id_map: - edge["target"] = id_map[tgt] - return data - - def format_param_value(self, value: str | None) -> str: - """Format parameter value to handle both quoted and unquoted formats""" - # Handle None value - if value is None: - logger.warning("format_param_value: value is None") - return "null" - - # Remove outer quotes if they exist - if value.startswith('"') and value.endswith('"'): - # Already has double quotes, return as is - return value - else: - # Add double quotes - return f'"{value}"' - - def _build_user_name_and_kb_ids_conditions_cypher( - self, - user_name: str | None, - knowledgebase_ids: list | None, - default_user_name: str | None = None, - ) -> list[str]: - """ - Build user_name and knowledgebase_ids conditions for Cypher queries. - - Args: - user_name: User name for filtering - knowledgebase_ids: List of knowledgebase IDs - default_user_name: Default user name from config if user_name is None - - Returns: - List of condition strings (will be joined with OR) - """ - user_name_conditions = [] - effective_user_name = user_name if user_name else default_user_name - - if effective_user_name: - escaped_user_name = effective_user_name.replace("'", "''") - user_name_conditions.append(f"n.user_name = '{escaped_user_name}'") - - # Add knowledgebase_ids conditions (checking user_name field in the data) - if knowledgebase_ids and isinstance(knowledgebase_ids, list) and len(knowledgebase_ids) > 0: - for kb_id in knowledgebase_ids: - if isinstance(kb_id, str): - escaped_kb_id = kb_id.replace("'", "''") - user_name_conditions.append(f"n.user_name = '{escaped_kb_id}'") - - return user_name_conditions - - def _build_user_name_and_kb_ids_conditions_sql( - self, - user_name: str | None, - knowledgebase_ids: list | None, - default_user_name: str | None = None, - ) -> list[str]: - """ - Build user_name and knowledgebase_ids conditions for SQL queries. - - Args: - user_name: User name for filtering - knowledgebase_ids: List of knowledgebase IDs - default_user_name: Default user name from config if user_name is None - - Returns: - List of condition strings (will be joined with OR) - """ - user_name_conditions = [] - effective_user_name = user_name if user_name else default_user_name - - if effective_user_name: - user_name_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{effective_user_name}\"'::agtype" - ) - - # Add knowledgebase_ids conditions (checking user_name field in the data) - if knowledgebase_ids and isinstance(knowledgebase_ids, list) and len(knowledgebase_ids) > 0: - for kb_id in knowledgebase_ids: - if isinstance(kb_id, str): - user_name_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{kb_id}\"'::agtype" - ) - - return user_name_conditions - - def _build_filter_conditions_cypher( - self, - filter: dict | None, - ) -> str: - """ - Build filter conditions for Cypher queries. - - Args: - filter: Filter dictionary with "or" or "and" logic - - Returns: - Filter WHERE clause string (empty string if no filter) - """ - filter_where_clause = "" - filter = self.parse_filter(filter) - if filter: - - def escape_cypher_string(value: str) -> str: - """ - Escape single quotes in Cypher string literals. - - In Cypher, single quotes in string literals are escaped by doubling them: ' -> '' - However, when inside PostgreSQL's $$ dollar-quoted string, we need to be careful. - - The issue: In $$ delimiters, Cypher still needs to parse string literals correctly. - The solution: Use backslash escape \' instead of doubling '' when inside $$. - """ - # Use backslash escape for single quotes inside $$ dollar-quoted strings - # This works because $$ protects the backslash from PostgreSQL interpretation - return value.replace("'", "\\'") - - def build_cypher_filter_condition(condition_dict: dict) -> str: - """Build a Cypher WHERE condition for a single filter item.""" - condition_parts = [] - for key, value in condition_dict.items(): - # Check if value is a dict with comparison operators (gt, lt, gte, lte, =, contains, in, like) - if isinstance(value, dict): - # Handle comparison operators: gt, lt, gte, lte, =, contains, in, like - # Supports multiple operators for the same field, e.g.: - # will generate: n.created_at >= '2025-09-19' AND n.created_at <= '2025-12-31' - for op, op_value in value.items(): - if op in ("gt", "lt", "gte", "lte"): - # Map operator to Cypher operator - cypher_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="} - cypher_op = cypher_op_map[op] - - # Check if key is a datetime field - is_datetime = key in ("created_at", "updated_at") or key.endswith( - "_at" - ) - - # Check if key starts with "info." prefix (for nested fields like info.A, info.B) - if key.startswith("info."): - # Nested field access: n.info.field_name - info_field = key[5:] # Remove "info." prefix - is_info_datetime = info_field in ( - "created_at", - "updated_at", - ) or info_field.endswith("_at") - if isinstance(op_value, str): - escaped_value = escape_cypher_string(op_value) - if is_info_datetime: - condition_parts.append( - f"n.info.{info_field}::timestamp {cypher_op} '{escaped_value}'::timestamp" - ) - else: - condition_parts.append( - f"n.info.{info_field} {cypher_op} '{escaped_value}'" - ) - else: - condition_parts.append( - f"n.info.{info_field} {cypher_op} {op_value}" - ) - else: - # Direct property access (e.g., "created_at" is directly in n, not in n.info) - if isinstance(op_value, str): - escaped_value = escape_cypher_string(op_value) - if is_datetime: - condition_parts.append( - f"n.{key}::timestamp {cypher_op} '{escaped_value}'::timestamp" - ) - else: - condition_parts.append( - f"n.{key} {cypher_op} '{escaped_value}'" - ) - else: - condition_parts.append(f"n.{key} {cypher_op} {op_value}") - elif op == "=": - # Handle equality operator - # For array fields, = means exact match of the entire array (e.g., tags = ['test:zdy'] or tags = ['mode:fast', 'test:zdy']) - # For scalar fields, = means equality - # Check if key starts with "info." prefix - if key.startswith("info."): - info_field = key[5:] # Remove "info." prefix - if isinstance(op_value, str): - escaped_value = escape_cypher_string(op_value) - # For array fields, check if array exactly equals [value] - # For scalar fields, use = - if info_field in ("tags", "sources"): - condition_parts.append( - f"n.info.{info_field} = ['{escaped_value}']" - ) - else: - condition_parts.append( - f"n.info.{info_field} = '{escaped_value}'" - ) - elif isinstance(op_value, list): - # For array fields, format list as Cypher array - if info_field in ("tags", "sources"): - escaped_items = [ - f"'{escape_cypher_string(str(item))}'" - for item in op_value - ] - array_str = "[" + ", ".join(escaped_items) + "]" - condition_parts.append( - f"n.info.{info_field} = {array_str}" - ) - else: - condition_parts.append( - f"n.info.{info_field} = {op_value}" - ) - else: - if info_field in ("tags", "sources"): - condition_parts.append( - f"n.info.{info_field} = [{op_value}]" - ) - else: - condition_parts.append( - f"n.info.{info_field} = {op_value}" - ) - else: - # Direct property access - if isinstance(op_value, str): - escaped_value = escape_cypher_string(op_value) - # For array fields, check if array exactly equals [value] - # For scalar fields, use = - if key in ("tags", "sources"): - condition_parts.append(f"n.{key} = ['{escaped_value}']") - else: - condition_parts.append(f"n.{key} = '{escaped_value}'") - elif isinstance(op_value, list): - # For array fields, format list as Cypher array - if key in ("tags", "sources"): - escaped_items = [ - f"'{escape_cypher_string(str(item))}'" - for item in op_value - ] - array_str = "[" + ", ".join(escaped_items) + "]" - condition_parts.append(f"n.{key} = {array_str}") - else: - condition_parts.append(f"n.{key} = {op_value}") - else: - if key in ("tags", "sources"): - condition_parts.append(f"n.{key} = [{op_value}]") - else: - condition_parts.append(f"n.{key} = {op_value}") - elif op == "contains": - # Handle contains operator (for array fields) - # Check if key starts with "info." prefix - if key.startswith("info."): - info_field = key[5:] # Remove "info." prefix - if isinstance(op_value, str): - escaped_value = escape_cypher_string(op_value) - condition_parts.append( - f"'{escaped_value}' IN n.info.{info_field}" - ) - else: - condition_parts.append(f"{op_value} IN n.info.{info_field}") - else: - # Direct property access - if isinstance(op_value, str): - escaped_value = escape_cypher_string(op_value) - condition_parts.append(f"'{escaped_value}' IN n.{key}") - else: - condition_parts.append(f"{op_value} IN n.{key}") - elif op == "in": - # Handle in operator (for checking if field value is in a list) - # Supports array format: {"field": {"in": ["value1", "value2"]}} - # For array fields (like file_ids, tags, sources), uses CONTAINS logic - # For scalar fields, uses equality or IN clause - if not isinstance(op_value, list): - raise ValueError( - f"in operator only supports array format. " - f"Use {{'{key}': {{'in': ['{op_value}']}}}} instead of {{'{key}': {{'in': '{op_value}'}}}}" - ) - # Check if key is an array field - is_array_field = key in ("file_ids", "tags", "sources") - - # Check if key starts with "info." prefix - if key.startswith("info."): - info_field = key[5:] # Remove "info." prefix - # Check if info field is an array field - is_info_array = info_field in ("tags", "sources", "file_ids") - - if len(op_value) == 0: - # Empty list means no match - condition_parts.append("false") - elif len(op_value) == 1: - # Single value - item = op_value[0] - if is_info_array: - # For array fields, use CONTAINS (value IN array_field) - if isinstance(item, str): - escaped_value = escape_cypher_string(item) - condition_parts.append( - f"'{escaped_value}' IN n.info.{info_field}" - ) - else: - condition_parts.append( - f"{item} IN n.info.{info_field}" - ) - else: - # For scalar fields, use equality - if isinstance(item, str): - escaped_value = escape_cypher_string(item) - condition_parts.append( - f"n.info.{info_field} = '{escaped_value}'" - ) - else: - condition_parts.append( - f"n.info.{info_field} = {item}" - ) - else: - # Multiple values, use OR conditions - or_conditions = [] - for item in op_value: - if is_info_array: - # For array fields, use CONTAINS (value IN array_field) - if isinstance(item, str): - escaped_value = escape_cypher_string(item) - or_conditions.append( - f"'{escaped_value}' IN n.info.{info_field}" - ) - else: - or_conditions.append( - f"{item} IN n.info.{info_field}" - ) - else: - # For scalar fields, use equality - if isinstance(item, str): - escaped_value = escape_cypher_string(item) - or_conditions.append( - f"n.info.{info_field} = '{escaped_value}'" - ) - else: - or_conditions.append( - f"n.info.{info_field} = {item}" - ) - if or_conditions: - condition_parts.append( - f"({' OR '.join(or_conditions)})" - ) - else: - # Direct property access - if len(op_value) == 0: - # Empty list means no match - condition_parts.append("false") - elif len(op_value) == 1: - # Single value - item = op_value[0] - if is_array_field: - # For array fields, use CONTAINS (value IN array_field) - if isinstance(item, str): - escaped_value = escape_cypher_string(item) - condition_parts.append( - f"'{escaped_value}' IN n.{key}" - ) - else: - condition_parts.append(f"{item} IN n.{key}") - else: - # For scalar fields, use equality - if isinstance(item, str): - escaped_value = escape_cypher_string(item) - condition_parts.append( - f"n.{key} = '{escaped_value}'" - ) - else: - condition_parts.append(f"n.{key} = {item}") - else: - # Multiple values - if is_array_field: - # For array fields, use OR conditions with CONTAINS - or_conditions = [] - for item in op_value: - if isinstance(item, str): - escaped_value = escape_cypher_string(item) - or_conditions.append( - f"'{escaped_value}' IN n.{key}" - ) - else: - or_conditions.append(f"{item} IN n.{key}") - if or_conditions: - condition_parts.append( - f"({' OR '.join(or_conditions)})" - ) - else: - # For scalar fields, use IN clause - escaped_items = [ - f"'{escape_cypher_string(str(item))}'" - if isinstance(item, str) - else str(item) - for item in op_value - ] - array_str = "[" + ", ".join(escaped_items) + "]" - condition_parts.append(f"n.{key} IN {array_str}") - elif op == "like": - # Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%') - # Check if key starts with "info." prefix - if key.startswith("info."): - info_field = key[5:] # Remove "info." prefix - if isinstance(op_value, str): - escaped_value = escape_cypher_string(op_value) - condition_parts.append( - f"n.info.{info_field} CONTAINS '{escaped_value}'" - ) - else: - condition_parts.append( - f"n.info.{info_field} CONTAINS {op_value}" - ) - else: - # Direct property access - if isinstance(op_value, str): - escaped_value = escape_cypher_string(op_value) - condition_parts.append( - f"n.{key} CONTAINS '{escaped_value}'" - ) - else: - condition_parts.append(f"n.{key} CONTAINS {op_value}") - # Check if key starts with "info." prefix (for simple equality) - elif key.startswith("info."): - info_field = key[5:] - if isinstance(value, str): - escaped_value = escape_cypher_string(value) - condition_parts.append(f"n.info.{info_field} = '{escaped_value}'") - else: - condition_parts.append(f"n.info.{info_field} = {value}") - else: - # Direct property access (simple equality) - if isinstance(value, str): - escaped_value = escape_cypher_string(value) - condition_parts.append(f"n.{key} = '{escaped_value}'") - else: - condition_parts.append(f"n.{key} = {value}") - return " AND ".join(condition_parts) - - if isinstance(filter, dict): - if "or" in filter: - or_conditions = [] - for condition in filter["or"]: - if isinstance(condition, dict): - condition_str = build_cypher_filter_condition(condition) - if condition_str: - or_conditions.append(f"({condition_str})") - if or_conditions: - filter_where_clause = " AND " + f"({' OR '.join(or_conditions)})" - - elif "and" in filter: - and_conditions = [] - for condition in filter["and"]: - if isinstance(condition, dict): - condition_str = build_cypher_filter_condition(condition) - if condition_str: - and_conditions.append(f"({condition_str})") - if and_conditions: - filter_where_clause = " AND " + " AND ".join(and_conditions) - else: - # Handle simple dict without "and" or "or" (e.g., {"id": "xxx"}) - condition_str = build_cypher_filter_condition(filter) - if condition_str: - filter_where_clause = " AND " + condition_str - - return filter_where_clause - - def _build_filter_conditions_sql( - self, - filter: dict | None, - ) -> list[str]: - """ - Build filter conditions for SQL queries. - - Args: - filter: Filter dictionary with "or" or "and" logic - - Returns: - List of filter WHERE clause strings (empty list if no filter) - """ - filter_conditions = [] - filter = self.parse_filter(filter) - if filter: - # Helper function to escape string value for SQL - def escape_sql_string(value: str) -> str: - """Escape single quotes in SQL string.""" - return value.replace("'", "''") - - # Helper function to build a single filter condition - def build_filter_condition(condition_dict: dict) -> str: - """Build a WHERE condition for a single filter item.""" - condition_parts = [] - for key, value in condition_dict.items(): - # Check if value is a dict with comparison operators (gt, lt, gte, lte, =, contains) - if isinstance(value, dict): - # Handle comparison operators: gt, lt, gte, lte, =, contains - for op, op_value in value.items(): - if op in ("gt", "lt", "gte", "lte"): - # Map operator to SQL operator - sql_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="} - sql_op = sql_op_map[op] - - # Check if key is a datetime field - is_datetime = key in ("created_at", "updated_at") or key.endswith( - "_at" - ) - - # Check if key starts with "info." prefix (for nested fields like info.A, info.B) - if key.startswith("info."): - # Nested field access: properties->'info'->'field_name' - info_field = key[5:] # Remove "info." prefix - is_info_datetime = info_field in ( - "created_at", - "updated_at", - ) or info_field.endswith("_at") - if isinstance(op_value, str): - escaped_value = escape_sql_string(op_value) - if is_info_datetime: - condition_parts.append( - f"TRIM(BOTH '\"' FROM ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype)::text)::timestamp {sql_op} '{escaped_value}'::timestamp" - ) - else: - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} '\"{escaped_value}\"'::agtype" - ) - else: - # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype - value_json = json.dumps(op_value) - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} ag_catalog.agtype_in('{value_json}')" - ) - else: - # Direct property access (e.g., "created_at" is directly in properties, not in properties.info) - if isinstance(op_value, str): - escaped_value = escape_sql_string(op_value) - if is_datetime: - condition_parts.append( - f"TRIM(BOTH '\"' FROM ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype)::text)::timestamp {sql_op} '{escaped_value}'::timestamp" - ) - else: - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} '\"{escaped_value}\"'::agtype" - ) - else: - # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype - value_json = json.dumps(op_value) - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} ag_catalog.agtype_in('{value_json}')" - ) - elif op == "=": - # Handle equality operator - # For array fields, = means exact match of the entire array (e.g., tags = ['test:zdy'] or tags = ['mode:fast', 'test:zdy']) - # For scalar fields, = means equality - # Check if key starts with "info." prefix - if key.startswith("info."): - info_field = key[5:] # Remove "info." prefix - if isinstance(op_value, str): - escaped_value = escape_sql_string(op_value) - # For array fields, check if array exactly equals [value] - # For scalar fields, use = - if info_field in ("tags", "sources"): - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '[\"{escaped_value}\"]'::agtype" - ) - else: - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype" - ) - elif isinstance(op_value, list): - # For array fields, format list as JSON array string - if info_field in ("tags", "sources"): - escaped_items = [ - escape_sql_string(str(item)) for item in op_value - ] - json_array = json.dumps(escaped_items) - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '{json_array}'::agtype" - ) - else: - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {op_value}::agtype" - ) - else: - if info_field in ("tags", "sources"): - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '[{op_value}]'::agtype" - ) - else: - # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype - value_json = json.dumps(op_value) - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = ag_catalog.agtype_in('{value_json}')" - ) - else: - # Direct property access - if isinstance(op_value, str): - escaped_value = escape_sql_string(op_value) - # For array fields, check if array exactly equals [value] - # For scalar fields, use = - if key in ("tags", "sources"): - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '[\"{escaped_value}\"]'::agtype" - ) - else: - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype" - ) - elif isinstance(op_value, list): - # For array fields, format list as JSON array string - if key in ("tags", "sources"): - escaped_items = [ - escape_sql_string(str(item)) for item in op_value - ] - json_array = json.dumps(escaped_items) - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '{json_array}'::agtype" - ) - else: - # For non-string list values, convert to JSON string and then to agtype - value_json = json.dumps(op_value) - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = ag_catalog.agtype_in('{value_json}')" - ) - else: - if key in ("tags", "sources"): - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '[{op_value}]'::agtype" - ) - else: - # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype - value_json = json.dumps(op_value) - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = ag_catalog.agtype_in('{value_json}')" - ) - elif op == "contains": - # Handle contains operator - # For array fields: check if array contains the value using @> operator - # For string fields: check if string contains the value using @> operator - # Check if key starts with "info." prefix - if key.startswith("info."): - info_field = key[5:] # Remove "info." prefix - escaped_value = escape_sql_string(str(op_value)) - # For array fields, use @> with array format: '["value"]'::agtype - # For string fields, use @> with string format: '"value"'::agtype - # We'll use array format for contains to check if array contains the value - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '[\"{escaped_value}\"]'::agtype" - ) - else: - # Direct property access - escaped_value = escape_sql_string(str(op_value)) - # For array fields, use @> with array format - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '[\"{escaped_value}\"]'::agtype" - ) - elif op == "in": - # Handle in operator (for checking if field value is in a list) - # Supports array format: {"field": {"in": ["value1", "value2"]}} - # For array fields (like file_ids, tags, sources), uses @> operator (contains) - # For scalar fields, uses = operator (equality) - if not isinstance(op_value, list): - raise ValueError( - f"in operator only supports array format. " - f"Use {{'{key}': {{'in': ['{op_value}']}}}} instead of {{'{key}': {{'in': '{op_value}'}}}}" - ) - # Check if key is an array field - is_array_field = key in ("file_ids", "tags", "sources") - - # Check if key starts with "info." prefix - if key.startswith("info."): - info_field = key[5:] # Remove "info." prefix - # Check if info field is an array field - is_info_array = info_field in ("tags", "sources", "file_ids") - - if len(op_value) == 0: - # Empty list means no match - condition_parts.append("false") - elif len(op_value) == 1: - # Single value - item = op_value[0] - if is_info_array: - # For array fields, use @> operator (contains) - escaped_value = escape_sql_string(str(item)) - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '[\"{escaped_value}\"]'::agtype" - ) - else: - # For scalar fields, use equality - if isinstance(item, str): - escaped_value = escape_sql_string(item) - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype" - ) - else: - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {item}::agtype" - ) - else: - # Multiple values, use OR conditions - or_conditions = [] - for item in op_value: - if is_info_array: - # For array fields, use @> operator (contains) to check if array contains the value - escaped_value = escape_sql_string(str(item)) - or_conditions.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '[\"{escaped_value}\"]'::agtype" - ) - else: - # For scalar fields, use equality - if isinstance(item, str): - escaped_value = escape_sql_string(item) - or_conditions.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype" - ) - else: - or_conditions.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {item}::agtype" - ) - if or_conditions: - condition_parts.append( - f"({' OR '.join(or_conditions)})" - ) - else: - # Direct property access - if len(op_value) == 0: - # Empty list means no match - condition_parts.append("false") - elif len(op_value) == 1: - # Single value - item = op_value[0] - if is_array_field: - # For array fields, use @> operator (contains) - escaped_value = escape_sql_string(str(item)) - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '[\"{escaped_value}\"]'::agtype" - ) - else: - # For scalar fields, use equality - if isinstance(item, str): - escaped_value = escape_sql_string(item) - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype" - ) - else: - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {item}::agtype" - ) - else: - # Multiple values, use OR conditions - or_conditions = [] - for item in op_value: - if is_array_field: - # For array fields, use @> operator (contains) to check if array contains the value - escaped_value = escape_sql_string(str(item)) - or_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '[\"{escaped_value}\"]'::agtype" - ) - else: - # For scalar fields, use equality - if isinstance(item, str): - escaped_value = escape_sql_string(item) - or_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype" - ) - else: - or_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {item}::agtype" - ) - if or_conditions: - condition_parts.append( - f"({' OR '.join(or_conditions)})" - ) - elif op == "like": - # Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%') - # Check if key starts with "info." prefix - if key.startswith("info."): - info_field = key[5:] # Remove "info." prefix - if isinstance(op_value, str): - # Escape SQL special characters for LIKE: % and _ need to be escaped - escaped_value = ( - escape_sql_string(op_value) - .replace("%", "\\%") - .replace("_", "\\_") - ) - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype])::text LIKE '%{escaped_value}%'" - ) - else: - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype])::text LIKE '%{op_value}%'" - ) - else: - # Direct property access - if isinstance(op_value, str): - # Escape SQL special characters for LIKE: % and _ need to be escaped - escaped_value = ( - escape_sql_string(op_value) - .replace("%", "\\%") - .replace("_", "\\_") - ) - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype)::text LIKE '%{escaped_value}%'" - ) - else: - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype)::text LIKE '%{op_value}%'" - ) - # Check if key starts with "info." prefix (for simple equality) - elif key.startswith("info."): - # Extract the field name after "info." - info_field = key[5:] # Remove "info." prefix (5 characters) - if isinstance(value, str): - escaped_value = escape_sql_string(value) - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype" - ) - else: - # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype - value_json = json.dumps(value) - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = ag_catalog.agtype_in('{value_json}')" - ) - else: - # Direct property access (simple equality) - if isinstance(value, str): - escaped_value = escape_sql_string(value) - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype" - ) - else: - # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype - value_json = json.dumps(value) - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = ag_catalog.agtype_in('{value_json}')" - ) - return " AND ".join(condition_parts) - - # Process filter structure - if isinstance(filter, dict): - if "or" in filter: - # OR logic: at least one condition must match - or_conditions = [] - for condition in filter["or"]: - if isinstance(condition, dict): - condition_str = build_filter_condition(condition) - if condition_str: - or_conditions.append(f"({condition_str})") - if or_conditions: - filter_conditions.append(f"({' OR '.join(or_conditions)})") - - elif "and" in filter: - # AND logic: all conditions must match - for condition in filter["and"]: - if isinstance(condition, dict): - condition_str = build_filter_condition(condition) - if condition_str: - filter_conditions.append(f"({condition_str})") - else: - # Handle simple dict without "and" or "or" (e.g., {"id": "xxx"}) - condition_str = build_filter_condition(filter) - if condition_str: - filter_conditions.append(condition_str) - - return filter_conditions - - def parse_filter( - self, - filter_dict: dict | None = None, - ): - if filter_dict is None: - return None - full_fields = { - "id", - "key", - "tags", - "type", - "usage", - "memory", - "status", - "sources", - "user_id", - "graph_id", - "user_name", - "background", - "confidence", - "created_at", - "session_id", - "updated_at", - "memory_type", - "node_type", - "info", - "source", - "file_ids", - } - - def process_condition(condition): - if not isinstance(condition, dict): - return condition - - new_condition = {} - - for key, value in condition.items(): - if key.lower() in ["or", "and"]: - if isinstance(value, list): - processed_items = [] - for item in value: - if isinstance(item, dict): - processed_item = {} - for item_key, item_value in item.items(): - if item_key not in full_fields and not item_key.startswith( - "info." - ): - new_item_key = f"info.{item_key}" - else: - new_item_key = item_key - processed_item[new_item_key] = item_value - processed_items.append(processed_item) - else: - processed_items.append(item) - new_condition[key] = processed_items - else: - new_condition[key] = value - else: - if key not in full_fields and not key.startswith("info."): - new_key = f"info.{key}" - else: - new_key = key - - new_condition[new_key] = value - - return new_condition - - return process_condition(filter_dict) - - @timed - def delete_node_by_prams( - self, - writable_cube_ids: list[str] | None = None, - memory_ids: list[str] | None = None, - file_ids: list[str] | None = None, - filter: dict | None = None, - ) -> int: - """ - Delete nodes by memory_ids, file_ids, or filter. - - Args: - writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes. - If not provided, no user_name filter will be applied. - memory_ids (list[str], optional): List of memory node IDs to delete. - file_ids (list[str], optional): List of file node IDs to delete. - filter (dict, optional): Filter dictionary for metadata filtering. - Filter conditions are directly used in DELETE WHERE clause without pre-querying. - - Returns: - int: Number of nodes deleted. - """ - batch_start_time = time.time() - logger.info( - f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" - ) - - # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) - # Only add user_name filter if writable_cube_ids is provided - user_name_conditions = [] - if writable_cube_ids and len(writable_cube_ids) > 0: - for cube_id in writable_cube_ids: - # Use agtype_access_operator with VARIADIC ARRAY format for consistency - user_name_conditions.append( - f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype" - ) - - # Build filter conditions using common method (no query, direct use in WHERE clause) - filter_conditions = [] - if filter: - filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[delete_node_by_prams] filter_conditions: {filter_conditions}") - - # If no conditions to delete, return 0 - if not memory_ids and not file_ids and not filter_conditions: - logger.warning( - "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" - ) - return 0 - - conn = None - total_deleted_count = 0 - try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Build WHERE conditions list - where_conditions = [] - - # Add memory_ids conditions - if memory_ids: - logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids") - id_conditions = [] - for node_id in memory_ids: - id_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" - ) - where_conditions.append(f"({' OR '.join(id_conditions)})") - - # Add file_ids conditions - if file_ids: - logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") - file_id_conditions = [] - for file_id in file_ids: - file_id_conditions.append( - f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" - ) - where_conditions.append(f"({' OR '.join(file_id_conditions)})") - - # Add filter conditions - if filter_conditions: - logger.info("[delete_node_by_prams] Processing filter conditions") - where_conditions.extend(filter_conditions) - - # Add user_name filter if provided - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_conditions.append(f"({user_name_where})") - - # Build final WHERE clause - if not where_conditions: - logger.warning("[delete_node_by_prams] No WHERE conditions to delete") - return 0 - - where_clause = " AND ".join(where_conditions) - - # Delete directly without counting - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") - - cursor.execute(delete_query) - deleted_count = cursor.rowcount - total_deleted_count = deleted_count - - logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes") - - elapsed_time = time.time() - batch_start_time - logger.info( - f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes" - ) - except Exception as e: - logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) - raise - finally: - self._return_connection(conn) - - logger.info(f"[delete_node_by_prams] Successfully deleted {total_deleted_count} nodes") - return total_deleted_count - - @timed - def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]: - """Get user names by memory ids. - - Args: - memory_ids: List of memory node IDs to query. - - Returns: - dict[str, str | None]: Dictionary mapping memory_id to user_name. - - Key: memory_id - - Value: user_name if exists, None if memory_id does not exist - Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None} - """ - logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}") - if not memory_ids: - return {} - - # Validate and normalize memory_ids - # Ensure all items are strings - normalized_memory_ids = [] - for mid in memory_ids: - if not isinstance(mid, str): - mid = str(mid) - # Remove any whitespace - mid = mid.strip() - if mid: - normalized_memory_ids.append(mid) - - if not normalized_memory_ids: - return {} - - # Escape special characters for JSON string format in agtype - def escape_memory_id(mid: str) -> str: - """Escape special characters in memory_id for JSON string format.""" - # Escape backslashes first, then double quotes - mid_str = mid.replace("\\", "\\\\") - mid_str = mid_str.replace('"', '\\"') - return mid_str - - # Build OR conditions for each memory_id - id_conditions = [] - for mid in normalized_memory_ids: - # Escape special characters - escaped_mid = escape_memory_id(mid) - id_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{escaped_mid}\"'::agtype" - ) - - where_clause = f"({' OR '.join(id_conditions)})" - - # Query to get memory_id and user_name pairs - query = f""" - SELECT - ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text AS memory_id, - ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text AS user_name - FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - - logger.info(f"[get_user_names_by_memory_ids] query: {query}") - conn = None - result_dict = {} - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() - - # Build result dictionary from query results - for row in results: - memory_id_raw = row[0] - user_name_raw = row[1] - - # Remove quotes if present - if isinstance(memory_id_raw, str): - memory_id = memory_id_raw.strip('"').strip("'") - else: - memory_id = str(memory_id_raw).strip('"').strip("'") - - if isinstance(user_name_raw, str): - user_name = user_name_raw.strip('"').strip("'") - else: - user_name = ( - str(user_name_raw).strip('"').strip("'") if user_name_raw else None - ) - - result_dict[memory_id] = user_name if user_name else None - - # Set None for memory_ids that were not found - for mid in normalized_memory_ids: - if mid not in result_dict: - result_dict[mid] = None - - logger.info( - f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " - f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" - ) - - return result_dict - except Exception as e: - logger.error( - f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True - ) - raise - finally: - self._return_connection(conn) - - def exist_user_name(self, user_name: str) -> dict[str, bool]: - """Check if user name exists in the graph. - - Args: - user_name: User name to check. - - Returns: - dict[str, bool]: Dictionary with user_name as key and bool as value indicating existence. - """ - logger.info(f"[exist_user_name] Querying user_name {user_name}") - if not user_name: - return {user_name: False} - - # Escape special characters for JSON string format in agtype - def escape_user_name(un: str) -> str: - """Escape special characters in user_name for JSON string format.""" - # Escape backslashes first, then double quotes - un_str = un.replace("\\", "\\\\") - un_str = un_str.replace('"', '\\"') - return un_str - - # Escape special characters - escaped_un = escape_user_name(user_name) - - # Query to check if user_name exists - query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - WHERE ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{escaped_un}\"'::agtype - """ - logger.info(f"[exist_user_name] query: {query}") - result_dict = {} - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - count = cursor.fetchone()[0] - result = count > 0 - result_dict[user_name] = result - return result_dict - except Exception as e: - logger.error( - f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True - ) - raise - finally: - self._return_connection(conn) diff --git a/src/memos/graph_dbs/polardb/__init__.py b/src/memos/graph_dbs/polardb/__init__.py new file mode 100644 index 000000000..98cab53eb --- /dev/null +++ b/src/memos/graph_dbs/polardb/__init__.py @@ -0,0 +1,29 @@ +"""PolarDB graph database package using Apache AGE extension.""" + +from memos.graph_dbs.polardb.connection import ConnectionMixin +from memos.graph_dbs.polardb.edges import EdgeMixin +from memos.graph_dbs.polardb.filters import FilterMixin +from memos.graph_dbs.polardb.maintenance import MaintenanceMixin +from memos.graph_dbs.polardb.nodes import NodeMixin +from memos.graph_dbs.polardb.queries import QueryMixin +from memos.graph_dbs.polardb.schema import SchemaMixin +from memos.graph_dbs.polardb.search import SearchMixin +from memos.graph_dbs.polardb.traversal import TraversalMixin +from memos.graph_dbs.base import BaseGraphDB + + +class PolarDBGraphDB( + ConnectionMixin, + SchemaMixin, + NodeMixin, + EdgeMixin, + TraversalMixin, + SearchMixin, + FilterMixin, + QueryMixin, + MaintenanceMixin, + BaseGraphDB, +): + """PolarDB-based graph database using Apache AGE extension.""" + + pass diff --git a/src/memos/graph_dbs/polardb/connection.py b/src/memos/graph_dbs/polardb/connection.py new file mode 100644 index 000000000..42e5f082a --- /dev/null +++ b/src/memos/graph_dbs/polardb/connection.py @@ -0,0 +1,333 @@ +import time + +from contextlib import suppress + +from memos.configs.graph_db import PolarDBGraphDBConfig +from memos.dependency import require_python_package +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class ConnectionMixin: + """Mixin class providing PolarDB connection pool management.""" + + @require_python_package( + import_name="psycopg2", + install_command="pip install psycopg2-binary", + install_link="https://pypi.org/project/psycopg2-binary/", + ) + def __init__(self, config: PolarDBGraphDBConfig): + """PolarDB-based implementation using Apache AGE. + + Tenant Modes: + - use_multi_db = True: + Dedicated Database Mode (Multi-Database Multi-Tenant). + Each tenant or logical scope uses a separate PolarDB database. + `db_name` is the specific tenant database. + `user_name` can be None (optional). + + - use_multi_db = False: + Shared Database Multi-Tenant Mode. + All tenants share a single PolarDB database. + `db_name` is the shared database. + `user_name` is required to isolate each tenant's data at the node level. + All node queries will enforce `user_name` in WHERE conditions and store it in metadata, + but it will be removed automatically before returning to external consumers. + """ + import psycopg2 + import psycopg2.pool + + self.config = config + + # Handle both dict and object config + if isinstance(config, dict): + self.db_name = config.get("db_name") + self.user_name = config.get("user_name") + host = config.get("host") + port = config.get("port") + user = config.get("user") + password = config.get("password") + maxconn = config.get("maxconn", 100) # De + else: + self.db_name = config.db_name + self.user_name = config.user_name + host = config.host + port = config.port + user = config.user + password = config.password + maxconn = config.maxconn if hasattr(config, "maxconn") else 100 + """ + # Create connection + self.connection = psycopg2.connect( + host=host, port=port, user=user, password=password, dbname=self.db_name,minconn=10, maxconn=2000 + ) + """ + logger.info(f" db_name: {self.db_name} current maxconn is:'{maxconn}'") + + # Create connection pool + self.connection_pool = psycopg2.pool.ThreadedConnectionPool( + minconn=5, + maxconn=maxconn, + host=host, + port=port, + user=user, + password=password, + dbname=self.db_name, + connect_timeout=60, # Connection timeout in seconds + keepalives_idle=40, # Seconds of inactivity before sending keepalive (should be < server idle timeout) + keepalives_interval=15, # Seconds between keepalive retries + keepalives_count=5, # Number of keepalive retries before considering connection dead + ) + + # Keep a reference to the pool for cleanup + self._pool_closed = False + + """ + # Handle auto_create + # auto_create = config.get("auto_create", False) if isinstance(config, dict) else config.auto_create + # if auto_create: + # self._ensure_database_exists() + + # Create graph and tables + # self.create_graph() + # self.create_edge() + # self._create_graph() + + # Handle embedding_dimension + # embedding_dim = config.get("embedding_dimension", 1024) if isinstance(config,dict) else config.embedding_dimension + # self.create_index(dimensions=embedding_dim) + """ + + def _get_config_value(self, key: str, default=None): + """Safely get config value from either dict or object.""" + if isinstance(self.config, dict): + return self.config.get(key, default) + else: + return getattr(self.config, key, default) + + def _get_connection(self): + """ + Get a connection from the pool. + + This function: + 1. Gets a connection from ThreadedConnectionPool + 2. Checks if connection is closed or unhealthy + 3. Returns healthy connection or retries (max 3 times) + 4. Handles connection pool exhaustion gracefully + + Returns: + psycopg2 connection object + + Raises: + RuntimeError: If connection pool is closed or exhausted after retries + """ + logger.info(f" db_name: {self.db_name} pool maxconn is:'{self.connection_pool.maxconn}'") + if self._pool_closed: + raise RuntimeError("Connection pool has been closed") + + max_retries = 500 + import psycopg2.pool + + for attempt in range(max_retries): + conn = None + try: + # Try to get connection from pool + # This may raise PoolError if pool is exhausted + conn = self.connection_pool.getconn() + + # Check if connection is closed + if conn.closed != 0: + # Connection is closed, return it to pool with close flag and try again + logger.warning( + f"[_get_connection] Got closed connection, attempt {attempt + 1}/{max_retries}" + ) + try: + self.connection_pool.putconn(conn, close=True) + except Exception as e: + logger.warning( + f"[_get_connection] Failed to return closed connection to pool: {e}" + ) + with suppress(Exception): + conn.close() + + conn = None + if attempt < max_retries - 1: + # Exponential backoff: 0.1s, 0.2s, 0.4s + """time.sleep(0.1 * (2**attempt))""" + time.sleep(0.003) + continue + else: + raise RuntimeError("Pool returned a closed connection after all retries") + + # Set autocommit for PolarDB compatibility + conn.autocommit = True + + # Test connection health with SELECT 1 + try: + cursor = conn.cursor() + cursor.execute("SELECT 1") + cursor.fetchone() + cursor.close() + except Exception as health_check_error: + # Connection is not usable, return it to pool with close flag and try again + logger.warning( + f"[_get_connection] Connection health check failed (attempt {attempt + 1}/{max_retries}): {health_check_error}" + ) + try: + self.connection_pool.putconn(conn, close=True) + except Exception as putconn_error: + logger.warning( + f"[_get_connection] Failed to return unhealthy connection to pool: {putconn_error}" + ) + with suppress(Exception): + conn.close() + + conn = None + if attempt < max_retries - 1: + # Exponential backoff: 0.1s, 0.2s, 0.4s + """time.sleep(0.1 * (2**attempt))""" + time.sleep(0.003) + continue + else: + raise RuntimeError( + f"Failed to get a healthy connection from pool after {max_retries} attempts: {health_check_error}" + ) from health_check_error + + # Connection is healthy, return it + return conn + + except psycopg2.pool.PoolError as pool_error: + # Pool exhausted or other pool-related error + # Don't retry immediately for pool exhaustion - it's unlikely to resolve quickly + error_msg = str(pool_error).lower() + if "exhausted" in error_msg or "pool" in error_msg: + # Log pool status for debugging + try: + # Try to get pool stats if available + pool_info = f"Pool config: minconn={self.connection_pool.minconn}, maxconn={self.connection_pool.maxconn}" + logger.error( + f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}" + ) + except Exception: + logger.error( + f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries})" + ) + + # For pool exhaustion, wait longer before retry (connections may be returned) + if attempt < max_retries - 1: + # Longer backoff for pool exhaustion: 0.5s, 1.0s, 2.0s + wait_time = 0.5 * (2**attempt) + logger.info(f"[_get_connection] Waiting {wait_time}s before retry...") + """time.sleep(wait_time)""" + time.sleep(0.003) + continue + else: + raise RuntimeError( + f"Connection pool exhausted after {max_retries} attempts. " + f"This usually means connections are not being returned to the pool. " + f"Check for connection leaks in your code." + ) from pool_error + else: + # Other pool errors - retry with normal backoff + if attempt < max_retries - 1: + """time.sleep(0.1 * (2**attempt))""" + time.sleep(0.003) + continue + else: + raise RuntimeError( + f"Failed to get connection from pool: {pool_error}" + ) from pool_error + + except Exception as e: + # Other exceptions (not pool-related) + # Only try to return connection if we actually got one + # If getconn() failed (e.g., pool exhausted), conn will be None + if conn is not None: + try: + # Return connection to pool if it's valid + self.connection_pool.putconn(conn, close=True) + except Exception as putconn_error: + logger.warning( + f"[_get_connection] Failed to return connection after error: {putconn_error}" + ) + with suppress(Exception): + conn.close() + + if attempt >= max_retries - 1: + raise RuntimeError(f"Failed to get a valid connection from pool: {e}") from e + else: + # Exponential backoff: 0.1s, 0.2s, 0.4s + """time.sleep(0.1 * (2**attempt))""" + time.sleep(0.003) + continue + + # Should never reach here, but just in case + raise RuntimeError("Failed to get connection after all retries") + + def _return_connection(self, connection): + """ + Return a connection to the pool. + + This function safely returns a connection to the pool, handling: + - Closed connections (close them instead of returning) + - Pool closed state (close connection directly) + - None connections (no-op) + - putconn() failures (close connection as fallback) + + Args: + connection: psycopg2 connection object or None + """ + if self._pool_closed: + # Pool is closed, just close the connection if it exists + if connection: + try: + connection.close() + logger.debug("[_return_connection] Closed connection (pool is closed)") + except Exception as e: + logger.warning( + f"[_return_connection] Failed to close connection after pool closed: {e}" + ) + return + + if not connection: + # No connection to return - this is normal if _get_connection() failed + return + + try: + # Check if connection is closed + if hasattr(connection, "closed") and connection.closed != 0: + # Connection is closed, just close it explicitly and don't return to pool + logger.debug( + "[_return_connection] Connection is closed, closing it instead of returning to pool" + ) + try: + connection.close() + except Exception as e: + logger.warning(f"[_return_connection] Failed to close closed connection: {e}") + return + + # Connection is valid, return to pool + self.connection_pool.putconn(connection) + logger.debug("[_return_connection] Successfully returned connection to pool") + except Exception as e: + # If putconn fails, try to close the connection + # This prevents connection leaks if putconn() fails + logger.error( + f"[_return_connection] Failed to return connection to pool: {e}", exc_info=True + ) + try: + connection.close() + logger.debug( + "[_return_connection] Closed connection as fallback after putconn failure" + ) + except Exception as close_error: + logger.warning( + f"[_return_connection] Failed to close connection after putconn error: {close_error}" + ) + + def __del__(self): + """Close database connection when object is destroyed.""" + if hasattr(self, "connection") and self.connection: + self.connection.close() diff --git a/src/memos/graph_dbs/polardb/edges.py b/src/memos/graph_dbs/polardb/edges.py new file mode 100644 index 000000000..62170c480 --- /dev/null +++ b/src/memos/graph_dbs/polardb/edges.py @@ -0,0 +1,266 @@ +import json +import time + +from memos.log import get_logger +from memos.utils import timed + +logger = get_logger(__name__) + + +class EdgeMixin: + """Mixin for edge (relationship) operations.""" + + @timed + def create_edge(self): + """Create all valid edge types if they do not exist""" + + valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} + + for label_name in valid_rel_types: + conn = None + logger.info(f"Creating elabel: {label_name}") + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") + logger.info(f"Successfully created elabel: {label_name}") + except Exception as e: + if "already exists" in str(e): + logger.info(f"Label '{label_name}' already exists, skipping.") + else: + logger.warning(f"Failed to create label {label_name}: {e}") + logger.error(f"Failed to create elabel '{label_name}': {e}", exc_info=True) + finally: + self._return_connection(conn) + + @timed + def add_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + logger.info( + f"polardb [add_edge] source_id: {source_id}, target_id: {target_id}, type: {type},user_name:{user_name}" + ) + + start_time = time.time() + if not source_id or not target_id: + logger.warning(f"Edge '{source_id}' and '{target_id}' are both None") + raise ValueError("[add_edge] source_id and target_id must be provided") + + source_exists = self.get_node(source_id) is not None + target_exists = self.get_node(target_id) is not None + + if not source_exists or not target_exists: + logger.warning( + "[add_edge] Source %s or target %s does not exist.", source_exists, target_exists + ) + raise ValueError("[add_edge] source_id and target_id must be provided") + + properties = {} + if user_name is not None: + properties["user_name"] = user_name + query = f""" + INSERT INTO {self.db_name}_graph."Edges"(source_id, target_id, edge_type, properties) + SELECT + '{source_id}', + '{target_id}', + '{type}', + jsonb_build_object('user_name', '{user_name}') + WHERE NOT EXISTS ( + SELECT 1 FROM {self.db_name}_graph."Edges" + WHERE source_id = '{source_id}' + AND target_id = '{target_id}' + AND edge_type = '{type}' + ); + """ + logger.info(f"polardb [add_edge] query: {query}, properties: {json.dumps(properties)}") + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) + logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") + + elapsed_time = time.time() - start_time + logger.info(f" polardb [add_edge] insert completed time in {elapsed_time:.2f}s") + except Exception as e: + logger.error(f"Failed to insert edge: {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + @timed + def delete_edge(self, source_id: str, target_id: str, type: str) -> None: + """ + Delete a specific edge between two nodes. + Args: + source_id: ID of the source node. + target_id: ID of the target node. + type: Relationship type to remove. + """ + query = f""" + DELETE FROM "{self.db_name}_graph"."Edges" + WHERE source_id = %s AND target_id = %s AND edge_type = %s + """ + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type)) + logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") + finally: + self._return_connection(conn) + + @timed + def edge_exists( + self, + source_id: str, + target_id: str, + type: str = "ANY", + direction: str = "OUTGOING", + user_name: str | None = None, + ) -> bool: + """ + Check if an edge exists between two nodes. + Args: + source_id: ID of the source node. + target_id: ID of the target node. + type: Relationship type. Use "ANY" to match any relationship type. + direction: Direction of the edge. + Use "OUTGOING" (default), "INCOMING", or "ANY". + user_name (str, optional): User name for filtering in non-multi-db mode + Returns: + True if the edge exists, otherwise False. + """ + + # Prepare the relationship pattern + user_name = user_name if user_name else self.config.user_name + + # Prepare the match pattern with direction + if direction == "OUTGOING": + pattern = "(a:Memory)-[r]->(b:Memory)" + elif direction == "INCOMING": + pattern = "(a:Memory)<-[r]-(b:Memory)" + elif direction == "ANY": + pattern = "(a:Memory)-[r]-(b:Memory)" + else: + raise ValueError( + f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." + ) + query = f"SELECT * FROM cypher('{self.db_name}_graph', $$" + query += f"\nMATCH {pattern}" + query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + query += f"\nAND a.id = '{source_id}' AND b.id = '{target_id}'" + if type != "ANY": + query += f"\n AND type(r) = '{type}'" + + query += "\nRETURN r" + query += "\n$$) AS (r agtype)" + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + result = cursor.fetchone() + return result is not None and result[0] is not None + finally: + self._return_connection(conn) + + @timed + def get_edges( + self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None + ) -> list[dict[str, str]]: + """ + Get edges connected to a node, with optional type and direction filter. + + Args: + id: Node ID to retrieve edges for. + type: Relationship type to match, or 'ANY' to match all. + direction: 'OUTGOING', 'INCOMING', or 'ANY'. + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + List of edges: + [ + {"from": "source_id", "to": "target_id", "type": "RELATE"}, + ... + ] + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + if direction == "OUTGOING": + pattern = "(a:Memory)-[r]->(b:Memory)" + where_clause = f"a.id = '{id}'" + elif direction == "INCOMING": + pattern = "(a:Memory)<-[r]-(b:Memory)" + where_clause = f"a.id = '{id}'" + elif direction == "ANY": + pattern = "(a:Memory)-[r]-(b:Memory)" + where_clause = f"a.id = '{id}' OR b.id = '{id}'" + else: + raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") + + # Add type filter + if type != "ANY": + where_clause += f" AND type(r) = '{type}'" + + # Add user filter + where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH {pattern} + WHERE {where_clause} + RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + $$) AS (from_id agtype, to_id agtype, edge_type agtype) + """ + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() + + edges = [] + for row in results: + # Extract and clean from_id + from_id_raw = row[0].value if hasattr(row[0], "value") else row[0] + if ( + isinstance(from_id_raw, str) + and from_id_raw.startswith('"') + and from_id_raw.endswith('"') + ): + from_id = from_id_raw[1:-1] + else: + from_id = str(from_id_raw) + + # Extract and clean to_id + to_id_raw = row[1].value if hasattr(row[1], "value") else row[1] + if ( + isinstance(to_id_raw, str) + and to_id_raw.startswith('"') + and to_id_raw.endswith('"') + ): + to_id = to_id_raw[1:-1] + else: + to_id = str(to_id_raw) + + # Extract and clean edge_type + edge_type_raw = row[2].value if hasattr(row[2], "value") else row[2] + if ( + isinstance(edge_type_raw, str) + and edge_type_raw.startswith('"') + and edge_type_raw.endswith('"') + ): + edge_type = edge_type_raw[1:-1] + else: + edge_type = str(edge_type_raw) + + edges.append({"from": from_id, "to": to_id, "type": edge_type}) + return edges + + except Exception as e: + logger.error(f"Failed to get edges: {e}", exc_info=True) + return [] + finally: + self._return_connection(conn) diff --git a/src/memos/graph_dbs/polardb/filters.py b/src/memos/graph_dbs/polardb/filters.py new file mode 100644 index 000000000..b119d3fbb --- /dev/null +++ b/src/memos/graph_dbs/polardb/filters.py @@ -0,0 +1,581 @@ +import json +from typing import Any, Literal + +from memos.log import get_logger + +logger = get_logger(__name__) + + +class FilterMixin: + """Mixin for filter condition building (WHERE clause builders).""" + + def _build_user_name_and_kb_ids_conditions( + self, + user_name: str | None, + knowledgebase_ids: list | None, + default_user_name: str | None = None, + mode: Literal["cypher", "sql"] = "sql", + ) -> list[str]: + """ + Build user_name and knowledgebase_ids conditions. + + Args: + user_name: User name for filtering + knowledgebase_ids: List of knowledgebase IDs + default_user_name: Default user name from config if user_name is None + mode: 'cypher' for Cypher property access, 'sql' for AgType SQL access + + Returns: + List of condition strings (will be joined with OR) + """ + user_name_conditions = [] + effective_user_name = user_name if user_name else default_user_name + + def _fmt(value: str) -> str: + if mode == "cypher": + escaped = value.replace("'", "''") + return f"n.user_name = '{escaped}'" + return ( + f"ag_catalog.agtype_access_operator(properties::text::agtype, " + f"'\"user_name\"'::agtype) = '\"{value}\"'::agtype" + ) + + if effective_user_name: + user_name_conditions.append(_fmt(effective_user_name)) + + if knowledgebase_ids and isinstance(knowledgebase_ids, list) and len(knowledgebase_ids) > 0: + for kb_id in knowledgebase_ids: + if isinstance(kb_id, str): + user_name_conditions.append(_fmt(kb_id)) + + return user_name_conditions + + def _build_user_name_and_kb_ids_conditions_cypher(self, user_name, knowledgebase_ids, default_user_name=None): + return self._build_user_name_and_kb_ids_conditions(user_name, knowledgebase_ids, default_user_name, mode="cypher") + + def _build_user_name_and_kb_ids_conditions_sql(self, user_name, knowledgebase_ids, default_user_name=None): + return self._build_user_name_and_kb_ids_conditions(user_name, knowledgebase_ids, default_user_name, mode="sql") + + def _build_filter_conditions( + self, + filter: dict | None, + mode: Literal["cypher", "sql"] = "sql", + ) -> str | list[str]: + """ + Build filter conditions for Cypher or SQL queries. + + Args: + filter: Filter dictionary with "or" or "and" logic + mode: "cypher" for Cypher queries, "sql" for SQL queries + + Returns: + For mode="cypher": Filter WHERE clause string with " AND " prefix (empty string if no filter) + For mode="sql": List of filter WHERE clause strings (empty list if no filter) + """ + is_cypher = mode == "cypher" + filter = self.parse_filter(filter) + + if not filter: + return "" if is_cypher else [] + + # --- Dialect helpers --- + + def escape_string(value: str) -> str: + if is_cypher: + # Backslash escape for single quotes inside $$ dollar-quoted strings + return value.replace("'", "\\'") + else: + return value.replace("'", "''") + + def prop_direct(key: str) -> str: + """Property access expression for a direct (top-level) key.""" + if is_cypher: + return f"n.{key}" + else: + return f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"{key}\"'::agtype)" + + def prop_nested(info_field: str) -> str: + """Property access expression for a nested info.field key.""" + if is_cypher: + return f"n.info.{info_field}" + else: + return f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties::text::agtype, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype])" + + def prop_ref(key: str) -> str: + """Return the appropriate property access expression for a key (direct or nested).""" + if key.startswith("info."): + return prop_nested(key[5:]) + return prop_direct(key) + + def fmt_str_val(escaped_value: str) -> str: + """Format an escaped string value as a literal.""" + if is_cypher: + return f"'{escaped_value}'" + else: + return f"'\"{escaped_value}\"'::agtype" + + def fmt_non_str_val(value: Any) -> str: + """Format a non-string value as a literal.""" + if is_cypher: + return str(value) + else: + value_json = json.dumps(value) + return f"ag_catalog.agtype_in('{value_json}')" + + def fmt_array_eq_single_str(escaped_value: str) -> str: + """Format an array-equality check for a single string value: field = ['val'].""" + if is_cypher: + return f"['{escaped_value}']" + else: + return f"'[\"{escaped_value}\"]'::agtype" + + def fmt_array_eq_list(items: list, escape_fn) -> str: + """Format an array-equality check for a list of values.""" + if is_cypher: + escaped_items = [f"'{escape_fn(str(item))}'" for item in items] + return "[" + ", ".join(escaped_items) + "]" + else: + escaped_items = [escape_fn(str(item)) for item in items] + json_array = json.dumps(escaped_items) + return f"'{json_array}'::agtype" + + def fmt_array_eq_non_str(value: Any) -> str: + """Format an array-equality check for a single non-string value: field = [val].""" + if is_cypher: + return f"[{value}]" + else: + return f"'[{value}]'::agtype" + + def fmt_contains_str(escaped_value: str, prop_expr: str) -> str: + """Format a 'contains' check: array field contains a string value.""" + if is_cypher: + return f"'{escaped_value}' IN {prop_expr}" + else: + return f"{prop_expr} @> '[\"{escaped_value}\"]'::agtype" + + def fmt_contains_non_str(value: Any, prop_expr: str) -> str: + """Format a 'contains' check: array field contains a non-string value.""" + if is_cypher: + return f"{value} IN {prop_expr}" + else: + escaped_value = str(value).replace("'", "''") + return f"{prop_expr} @> '[\"{escaped_value}\"]'::agtype" + + def fmt_like(escaped_value: str, prop_expr: str) -> str: + """Format a 'like' (fuzzy match) check.""" + if is_cypher: + return f"{prop_expr} CONTAINS '{escaped_value}'" + else: + return f"{prop_expr}::text LIKE '%{escaped_value}%'" + + def fmt_datetime_cmp(prop_expr: str, cmp_op: str, escaped_value: str) -> str: + """Format a datetime comparison.""" + if is_cypher: + return f"{prop_expr}::timestamp {cmp_op} '{escaped_value}'::timestamp" + else: + return f"TRIM(BOTH '\"' FROM {prop_expr}::text)::timestamp {cmp_op} '{escaped_value}'::timestamp" + + def fmt_in_scalar_eq_str(escaped_value: str, prop_expr: str) -> str: + """Format scalar equality for 'in' operator with a string item.""" + return f"{prop_expr} = {fmt_str_val(escaped_value)}" + + def fmt_in_scalar_eq_non_str(item: Any, prop_expr: str) -> str: + """Format scalar equality for 'in' operator with a non-string item.""" + if is_cypher: + return f"{prop_expr} = {item}" + else: + return f"{prop_expr} = {item}::agtype" + + def fmt_in_array_contains_str(escaped_value: str, prop_expr: str) -> str: + """Format array-contains for 'in' operator with a string item.""" + if is_cypher: + return f"'{escaped_value}' IN {prop_expr}" + else: + return f"{prop_expr} @> '[\"{escaped_value}\"]'::agtype" + + def fmt_in_array_contains_non_str(item: Any, prop_expr: str) -> str: + """Format array-contains for 'in' operator with a non-string item.""" + if is_cypher: + return f"{item} IN {prop_expr}" + else: + escaped_value = str(item).replace("'", "''") + return f"{prop_expr} @> '[\"{escaped_value}\"]'::agtype" + + def escape_like_value(value: str) -> str: + """Escape a value for use in like/CONTAINS. SQL needs extra LIKE-char escaping.""" + escaped = escape_string(value) + if not is_cypher: + escaped = escaped.replace("%", "\\%").replace("_", "\\_") + return escaped + + def fmt_scalar_in_clause(items: list, prop_expr: str) -> str: + """Format a scalar IN clause for multiple values (cypher only has this path).""" + if is_cypher: + escaped_items = [ + f"'{escape_string(str(item))}'" if isinstance(item, str) else str(item) + for item in items + ] + array_str = "[" + ", ".join(escaped_items) + "]" + return f"{prop_expr} IN {array_str}" + else: + # SQL mode: use OR equality conditions + or_parts = [] + for item in items: + if isinstance(item, str): + escaped_value = escape_string(item) + or_parts.append(f"{prop_expr} = {fmt_str_val(escaped_value)}") + else: + or_parts.append(f"{prop_expr} = {item}::agtype") + return f"({' OR '.join(or_parts)})" + + # --- Main condition builder --- + + def build_filter_condition(condition_dict: dict) -> str: + """Build a WHERE condition for a single filter item.""" + condition_parts = [] + for key, value in condition_dict.items(): + is_info = key.startswith("info.") + info_field = key[5:] if is_info else None + prop_expr = prop_ref(key) + + # Check if value is a dict with comparison operators + if isinstance(value, dict): + for op, op_value in value.items(): + if op in ("gt", "lt", "gte", "lte"): + cmp_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="} + cmp_op = cmp_op_map[op] + + # Determine if this is a datetime field + field_name = info_field if is_info else key + is_dt = field_name in ("created_at", "updated_at") or field_name.endswith("_at") + + if isinstance(op_value, str): + escaped_value = escape_string(op_value) + if is_dt: + condition_parts.append( + fmt_datetime_cmp(prop_expr, cmp_op, escaped_value) + ) + else: + condition_parts.append( + f"{prop_expr} {cmp_op} {fmt_str_val(escaped_value)}" + ) + else: + condition_parts.append( + f"{prop_expr} {cmp_op} {fmt_non_str_val(op_value)}" + ) + + elif op == "=": + # Equality operator + field_name = info_field if is_info else key + is_array_field = field_name in ("tags", "sources") + + if isinstance(op_value, str): + escaped_value = escape_string(op_value) + if is_array_field: + condition_parts.append( + f"{prop_expr} = {fmt_array_eq_single_str(escaped_value)}" + ) + else: + condition_parts.append( + f"{prop_expr} = {fmt_str_val(escaped_value)}" + ) + elif isinstance(op_value, list): + if is_array_field: + condition_parts.append( + f"{prop_expr} = {fmt_array_eq_list(op_value, escape_string)}" + ) + else: + if is_cypher: + condition_parts.append( + f"{prop_expr} = {op_value}" + ) + elif is_info: + # Info nested field: use ::agtype cast + condition_parts.append( + f"{prop_expr} = {op_value}::agtype" + ) + else: + # Direct field: convert to JSON string and then to agtype + value_json = json.dumps(op_value) + condition_parts.append( + f"{prop_expr} = ag_catalog.agtype_in('{value_json}')" + ) + else: + if is_array_field: + condition_parts.append( + f"{prop_expr} = {fmt_array_eq_non_str(op_value)}" + ) + else: + condition_parts.append( + f"{prop_expr} = {fmt_non_str_val(op_value)}" + ) + + elif op == "contains": + if isinstance(op_value, str): + escaped_value = escape_string(str(op_value)) + condition_parts.append( + fmt_contains_str(escaped_value, prop_expr) + ) + else: + condition_parts.append( + fmt_contains_non_str(op_value, prop_expr) + ) + + elif op == "in": + if not isinstance(op_value, list): + raise ValueError( + f"in operator only supports array format. " + f"Use {{'{key}': {{'in': ['{op_value}']}}}} instead of {{'{key}': {{'in': '{op_value}'}}}}" + ) + + field_name = info_field if is_info else key + is_arr = field_name in ("file_ids", "tags", "sources") + + if len(op_value) == 0: + condition_parts.append("false") + elif len(op_value) == 1: + item = op_value[0] + if is_arr: + if isinstance(item, str): + escaped_value = escape_string(str(item)) + condition_parts.append( + fmt_in_array_contains_str(escaped_value, prop_expr) + ) + else: + condition_parts.append( + fmt_in_array_contains_non_str(item, prop_expr) + ) + else: + if isinstance(item, str): + escaped_value = escape_string(item) + condition_parts.append( + fmt_in_scalar_eq_str(escaped_value, prop_expr) + ) + else: + condition_parts.append( + fmt_in_scalar_eq_non_str(item, prop_expr) + ) + else: + if is_arr: + # For array fields, use OR conditions with contains + or_conditions = [] + for item in op_value: + if isinstance(item, str): + escaped_value = escape_string(str(item)) + or_conditions.append( + fmt_in_array_contains_str(escaped_value, prop_expr) + ) + else: + or_conditions.append( + fmt_in_array_contains_non_str(item, prop_expr) + ) + if or_conditions: + condition_parts.append( + f"({' OR '.join(or_conditions)})" + ) + else: + # For scalar fields + if is_cypher: + # Cypher uses IN clause with array literal + condition_parts.append( + fmt_scalar_in_clause(op_value, prop_expr) + ) + else: + # SQL uses OR equality conditions + or_conditions = [] + for item in op_value: + if isinstance(item, str): + escaped_value = escape_string(item) + or_conditions.append( + fmt_in_scalar_eq_str(escaped_value, prop_expr) + ) + else: + or_conditions.append( + fmt_in_scalar_eq_non_str(item, prop_expr) + ) + if or_conditions: + condition_parts.append( + f"({' OR '.join(or_conditions)})" + ) + + elif op == "like": + if isinstance(op_value, str): + escaped_value = escape_like_value(op_value) + condition_parts.append( + fmt_like(escaped_value, prop_expr) + ) + else: + if is_cypher: + condition_parts.append( + f"{prop_expr} CONTAINS {op_value}" + ) + else: + condition_parts.append( + f"{prop_expr}::text LIKE '%{op_value}%'" + ) + + # Simple equality (value is not a dict) + elif is_info: + if isinstance(value, str): + escaped_value = escape_string(value) + condition_parts.append(f"{prop_expr} = {fmt_str_val(escaped_value)}") + else: + condition_parts.append(f"{prop_expr} = {fmt_non_str_val(value)}") + else: + if isinstance(value, str): + escaped_value = escape_string(value) + condition_parts.append(f"{prop_expr} = {fmt_str_val(escaped_value)}") + else: + condition_parts.append(f"{prop_expr} = {fmt_non_str_val(value)}") + return " AND ".join(condition_parts) + + # --- Assemble final result based on filter structure and mode --- + + if is_cypher: + filter_where_clause = "" + if isinstance(filter, dict): + if "or" in filter: + or_conditions = [] + for condition in filter["or"]: + if isinstance(condition, dict): + condition_str = build_filter_condition(condition) + if condition_str: + or_conditions.append(f"({condition_str})") + if or_conditions: + filter_where_clause = " AND " + f"({' OR '.join(or_conditions)})" + elif "and" in filter: + and_conditions = [] + for condition in filter["and"]: + if isinstance(condition, dict): + condition_str = build_filter_condition(condition) + if condition_str: + and_conditions.append(f"({condition_str})") + if and_conditions: + filter_where_clause = " AND " + " AND ".join(and_conditions) + else: + condition_str = build_filter_condition(filter) + if condition_str: + filter_where_clause = " AND " + condition_str + return filter_where_clause + else: + filter_conditions: list[str] = [] + if isinstance(filter, dict): + if "or" in filter: + or_conditions = [] + for condition in filter["or"]: + if isinstance(condition, dict): + condition_str = build_filter_condition(condition) + if condition_str: + or_conditions.append(f"({condition_str})") + if or_conditions: + filter_conditions.append(f"({' OR '.join(or_conditions)})") + elif "and" in filter: + for condition in filter["and"]: + if isinstance(condition, dict): + condition_str = build_filter_condition(condition) + if condition_str: + filter_conditions.append(f"({condition_str})") + else: + condition_str = build_filter_condition(filter) + if condition_str: + filter_conditions.append(condition_str) + return filter_conditions + + def _build_filter_conditions_cypher( + self, + filter: dict | None, + ) -> str: + """ + Build filter conditions for Cypher queries. + + Args: + filter: Filter dictionary with "or" or "and" logic + + Returns: + Filter WHERE clause string (empty string if no filter) + """ + return self._build_filter_conditions(filter, mode="cypher") + + def _build_filter_conditions_sql( + self, + filter: dict | None, + ) -> list[str]: + """ + Build filter conditions for SQL queries. + + Args: + filter: Filter dictionary with "or" or "and" logic + + Returns: + List of filter WHERE clause strings (empty list if no filter) + """ + return self._build_filter_conditions(filter, mode="sql") + + def parse_filter( + self, + filter_dict: dict | None = None, + ): + if filter_dict is None: + return None + full_fields = { + "id", + "key", + "tags", + "type", + "usage", + "memory", + "status", + "sources", + "user_id", + "graph_id", + "user_name", + "background", + "confidence", + "created_at", + "session_id", + "updated_at", + "memory_type", + "node_type", + "info", + "source", + "file_ids", + } + + def process_condition(condition): + if not isinstance(condition, dict): + return condition + + new_condition = {} + + for key, value in condition.items(): + if key.lower() in ["or", "and"]: + if isinstance(value, list): + processed_items = [] + for item in value: + if isinstance(item, dict): + processed_item = {} + for item_key, item_value in item.items(): + if item_key not in full_fields and not item_key.startswith( + "info." + ): + new_item_key = f"info.{item_key}" + else: + new_item_key = item_key + processed_item[new_item_key] = item_value + processed_items.append(processed_item) + else: + processed_items.append(item) + new_condition[key] = processed_items + else: + new_condition[key] = value + else: + if key not in full_fields and not key.startswith("info."): + new_key = f"info.{key}" + else: + new_key = key + + new_condition[new_key] = value + + return new_condition + + return process_condition(filter_dict) diff --git a/src/memos/graph_dbs/polardb/helpers.py b/src/memos/graph_dbs/polardb/helpers.py new file mode 100644 index 000000000..c8dd2b844 --- /dev/null +++ b/src/memos/graph_dbs/polardb/helpers.py @@ -0,0 +1,13 @@ +"""Module-level utility functions for PolarDB graph database.""" + +import random + + +def generate_vector(dim=1024, low=-0.2, high=0.2): + """Generate a random vector for testing purposes.""" + return [round(random.uniform(low, high), 6) for _ in range(dim)] + + +def escape_sql_string(value: str) -> str: + """Escape single quotes in SQL string.""" + return value.replace("'", "''") diff --git a/src/memos/graph_dbs/polardb/maintenance.py b/src/memos/graph_dbs/polardb/maintenance.py new file mode 100644 index 000000000..13505c046 --- /dev/null +++ b/src/memos/graph_dbs/polardb/maintenance.py @@ -0,0 +1,768 @@ +import copy +import json +import time +from typing import Any + +from memos.graph_dbs.utils import compose_node as _compose_node, prepare_node_metadata as _prepare_node_metadata +from memos.log import get_logger +from memos.utils import timed + +logger = get_logger(__name__) + + +class MaintenanceMixin: + """Mixin for maintenance operations (import/export, clear, cleanup).""" + + @timed + def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: + """ + Import the entire graph from a serialized dictionary. + + Args: + data: A dictionary containing all nodes and edges to be loaded. + user_name (str, optional): User name for filtering in non-multi-db mode + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + # Import nodes + for node in data.get("nodes", []): + try: + id, memory, metadata = _compose_node(node) + metadata["user_name"] = user_name + metadata = _prepare_node_metadata(metadata) + metadata.update({"id": id, "memory": memory}) + + # Use add_node to insert node + self.add_node(id, memory, metadata) + + except Exception as e: + logger.error(f"Fail to load node: {node}, error: {e}") + + # Import edges + for edge in data.get("edges", []): + try: + source_id, target_id = edge["source"], edge["target"] + edge_type = edge["type"] + + # Use add_edge to insert edge + self.add_edge(source_id, target_id, edge_type, user_name) + + except Exception as e: + logger.error(f"Fail to load edge: {edge}, error: {e}") + + @timed + def export_graph( + self, + include_embedding: bool = False, + user_name: str | None = None, + user_id: str | None = None, + page: int | None = None, + page_size: int | None = None, + filter: dict | None = None, + **kwargs, + ) -> dict[str, Any]: + """ + Export all graph nodes and edges in a structured form. + Args: + include_embedding (bool): Whether to include the large embedding field. + user_name (str, optional): User name for filtering in non-multi-db mode + user_id (str, optional): User ID for filtering + page (int, optional): Page number (starts from 1). If None, exports all data without pagination. + page_size (int, optional): Number of items per page. If None, exports all data without pagination. + filter (dict, optional): Filter dictionary for metadata filtering. Supports "and", "or" logic and operators: + - "=": equality + - "in": value in list + - "contains": array contains value + - "gt", "lt", "gte", "lte": comparison operators + - "like": fuzzy matching + Example: {"and": [{"created_at": {"gte": "2025-01-01"}}, {"tags": {"contains": "AI"}}]} + + Returns: + { + "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], + "edges": [ { "source": ..., "target": ..., "type": ... }, ... ], + "total_nodes": int, # Total number of nodes matching the filter criteria + "total_edges": int, # Total number of edges matching the filter criteria + } + """ + logger.info( + f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}" + ) + user_id = user_id if user_id else self._get_config_value("user_id") + + # Initialize total counts + total_nodes = 0 + total_edges = 0 + + # Determine if pagination is needed + use_pagination = page is not None and page_size is not None + + # Validate pagination parameters if pagination is enabled + if use_pagination: + if page < 1: + page = 1 + if page_size < 1: + page_size = 10 + offset = (page - 1) * page_size + else: + offset = None + + conn = None + try: + conn = self._get_connection() + # Build WHERE conditions + where_conditions = [] + if user_name: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + ) + if user_id: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" + ) + + # Build filter conditions using common method + filter_conditions = self._build_filter_conditions_sql(filter) + logger.info(f"[export_graph] filter_conditions: {filter_conditions}") + if filter_conditions: + where_conditions.extend(filter_conditions) + + where_clause = "" + if where_conditions: + where_clause = f"WHERE {' AND '.join(where_conditions)}" + + # Get total count of nodes before pagination + count_node_query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + {where_clause} + """ + logger.info(f"[export_graph nodes count] Query: {count_node_query}") + with conn.cursor() as cursor: + cursor.execute(count_node_query) + total_nodes = cursor.fetchone()[0] + + # Export nodes + # Build pagination clause if needed + pagination_clause = "" + if use_pagination: + pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + + if include_embedding: + node_query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY ag_catalog.agtype_access_operator(properties::text::agtype, '"created_at"'::agtype) DESC NULLS LAST, + id DESC + {pagination_clause} + """ + else: + node_query = f""" + SELECT id, properties + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY ag_catalog.agtype_access_operator(properties::text::agtype, '"created_at"'::agtype) DESC NULLS LAST, + id DESC + {pagination_clause} + """ + logger.info(f"[export_graph nodes] Query: {node_query}") + with conn.cursor() as cursor: + cursor.execute(node_query) + node_results = cursor.fetchall() + nodes = [] + + for row in node_results: + if include_embedding: + """row is (id, properties, embedding)""" + _, properties_json, embedding_json = row + else: + """row is (id, properties)""" + _, properties_json = row + embedding_json = None + + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except json.JSONDecodeError: + properties = {} + else: + properties = properties_json if properties_json else {} + + # Remove embedding field if include_embedding is False + if not include_embedding: + properties.pop("embedding", None) + elif include_embedding and embedding_json is not None: + properties["embedding"] = embedding_json + + nodes.append(self._parse_node(properties)) + + except Exception as e: + logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) + raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e + finally: + self._return_connection(conn) + + conn = None + try: + conn = self._get_connection() + # Build Cypher WHERE conditions for edges + cypher_where_conditions = [] + if user_name: + cypher_where_conditions.append(f"a.user_name = '{user_name}'") + cypher_where_conditions.append(f"b.user_name = '{user_name}'") + if user_id: + cypher_where_conditions.append(f"a.user_id = '{user_id}'") + cypher_where_conditions.append(f"b.user_id = '{user_id}'") + + # Build filter conditions for edges (apply to both source and target nodes) + filter_where_clause = self._build_filter_conditions_cypher(filter) + logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}") + if filter_where_clause: + # _build_filter_conditions_cypher returns a string that starts with " AND " if filter exists + # Remove the leading " AND " and replace n. with a. for source node and b. for target node + filter_clause = filter_where_clause.strip() + if filter_clause.startswith("AND "): + filter_clause = filter_clause[4:].strip() + # Replace n. with a. for source node and create a copy for target node + source_filter = filter_clause.replace("n.", "a.") + target_filter = filter_clause.replace("n.", "b.") + # Combine source and target filters with AND + combined_filter = f"({source_filter}) AND ({target_filter})" + cypher_where_conditions.append(combined_filter) + + cypher_where_clause = "" + if cypher_where_conditions: + cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}" + + # Get total count of edges before pagination + count_edge_query = f""" + SELECT COUNT(*) + FROM ( + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (a:Memory)-[r]->(b:Memory) + {cypher_where_clause} + RETURN a.id AS source, b.id AS target, type(r) as edge + $$) AS (source agtype, target agtype, edge agtype) + ) AS edges + """ + logger.info(f"[export_graph edges count] Query: {count_edge_query}") + with conn.cursor() as cursor: + cursor.execute(count_edge_query) + total_edges = cursor.fetchone()[0] + + # Export edges using cypher query + # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery + # Build pagination clause if needed + edge_pagination_clause = "" + if use_pagination: + edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + + edge_query = f""" + SELECT source, target, edge FROM ( + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (a:Memory)-[r]->(b:Memory) + {cypher_where_clause} + RETURN a.id AS source, b.id AS target, type(r) as edge + ORDER BY COALESCE(a.created_at, '1970-01-01T00:00:00') DESC, + COALESCE(b.created_at, '1970-01-01T00:00:00') DESC, + a.id DESC, b.id DESC + $$) AS (source agtype, target agtype, edge agtype) + ) AS edges + {edge_pagination_clause} + """ + logger.info(f"[export_graph edges] Query: {edge_query}") + with conn.cursor() as cursor: + cursor.execute(edge_query) + edge_results = cursor.fetchall() + edges = [] + + for row in edge_results: + source_agtype, target_agtype, edge_agtype = row + + # Extract and clean source + source_raw = ( + source_agtype.value + if hasattr(source_agtype, "value") + else str(source_agtype) + ) + if ( + isinstance(source_raw, str) + and source_raw.startswith('"') + and source_raw.endswith('"') + ): + source = source_raw[1:-1] + else: + source = str(source_raw) + + # Extract and clean target + target_raw = ( + target_agtype.value + if hasattr(target_agtype, "value") + else str(target_agtype) + ) + if ( + isinstance(target_raw, str) + and target_raw.startswith('"') + and target_raw.endswith('"') + ): + target = target_raw[1:-1] + else: + target = str(target_raw) + + # Extract and clean edge type + type_raw = ( + edge_agtype.value if hasattr(edge_agtype, "value") else str(edge_agtype) + ) + if ( + isinstance(type_raw, str) + and type_raw.startswith('"') + and type_raw.endswith('"') + ): + edge_type = type_raw[1:-1] + else: + edge_type = str(type_raw) + + edges.append( + { + "source": source, + "target": target, + "type": edge_type, + } + ) + + except Exception as e: + logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True) + raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e + finally: + self._return_connection(conn) + + return { + "nodes": nodes, + "edges": edges, + "total_nodes": total_nodes, + "total_edges": total_edges, + } + + @timed + def clear(self, user_name: str | None = None) -> None: + """ + Clear the entire graph if the target database exists. + + Args: + user_name (str, optional): User name for filtering in non-multi-db mode + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + try: + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.user_name = '{user_name}' + DETACH DELETE n + $$) AS (result agtype) + """ + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + logger.info("Cleared all nodes from database.") + finally: + self._return_connection(conn) + + except Exception as e: + logger.error(f"[ERROR] Failed to clear database: {e}") + + def drop_database(self) -> None: + """Permanently delete the entire graph this instance is using.""" + return + if self._get_config_value("use_multi_db", True): + with self.connection.cursor() as cursor: + cursor.execute(f"SELECT drop_graph('{self.db_name}_graph', true)") + logger.info(f"Graph '{self.db_name}_graph' has been dropped.") + else: + raise ValueError( + f"Refusing to drop graph '{self.db_name}_graph' in " + f"Shared Database Multi-Tenant mode" + ) + + @timed + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str | None = None + ) -> None: + """ + Remove all WorkingMemory nodes except the latest `keep_latest` entries. + + Args: + memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). + keep_latest (int): Number of latest WorkingMemory entries to keep. + user_name (str, optional): User name for filtering in non-multi-db mode + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + # Use actual OFFSET logic, consistent with nebular.py + # First find IDs to delete, then delete them + select_query = f""" + SELECT id FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"memory_type"'::agtype) = %s::agtype + AND ag_catalog.agtype_access_operator(properties::text::agtype, '"user_name"'::agtype) = %s::agtype + ORDER BY ag_catalog.agtype_access_operator(properties::text::agtype, '"updated_at"'::agtype) DESC + OFFSET %s + """ + select_params = [ + self.format_param_value(memory_type), + self.format_param_value(user_name), + keep_latest, + ] + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Execute query to get IDs to delete + cursor.execute(select_query, select_params) + ids_to_delete = [row[0] for row in cursor.fetchall()] + + if not ids_to_delete: + logger.info(f"No {memory_type} memories to remove for user {user_name}") + return + + # Build delete query + placeholders = ",".join(["%s"] * len(ids_to_delete)) + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE id IN ({placeholders}) + """ + delete_params = ids_to_delete + + # Execute deletion + cursor.execute(delete_query, delete_params) + deleted_count = cursor.rowcount + logger.info( + f"Removed {deleted_count} oldest {memory_type} memories, " + f"keeping {keep_latest} latest for user {user_name}, " + f"removed ids: {ids_to_delete}" + ) + except Exception as e: + logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + def merge_nodes(self, id1: str, id2: str) -> str: + """Merge two similar or duplicate nodes into one.""" + raise NotImplementedError + + def deduplicate_nodes(self) -> None: + """Deduplicate redundant or semantically similar nodes.""" + raise NotImplementedError + + def detect_conflicts(self) -> list[tuple[str, str]]: + """Detect conflicting nodes based on logical or semantic inconsistency.""" + raise NotImplementedError + + def _convert_graph_edges(self, core_node: dict) -> dict: + import copy + + data = copy.deepcopy(core_node) + id_map = {} + core_node = data.get("core_node", {}) + if not core_node: + return { + "core_node": None, + "neighbors": data.get("neighbors", []), + "edges": data.get("edges", []), + } + core_meta = core_node.get("metadata", {}) + if "graph_id" in core_meta and "id" in core_node: + id_map[core_meta["graph_id"]] = core_node["id"] + for neighbor in data.get("neighbors", []): + n_meta = neighbor.get("metadata", {}) + if "graph_id" in n_meta and "id" in neighbor: + id_map[n_meta["graph_id"]] = neighbor["id"] + for edge in data.get("edges", []): + src = edge.get("source") + tgt = edge.get("target") + if src in id_map: + edge["source"] = id_map[src] + if tgt in id_map: + edge["target"] = id_map[tgt] + return data + + @timed + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]: + """Get user names by memory ids. + + Args: + memory_ids: List of memory node IDs to query. + + Returns: + dict[str, str | None]: Dictionary mapping memory_id to user_name. + - Key: memory_id + - Value: user_name if exists, None if memory_id does not exist + Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None} + """ + logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}") + if not memory_ids: + return {} + + # Validate and normalize memory_ids + # Ensure all items are strings + normalized_memory_ids = [] + for mid in memory_ids: + if not isinstance(mid, str): + mid = str(mid) + # Remove any whitespace + mid = mid.strip() + if mid: + normalized_memory_ids.append(mid) + + if not normalized_memory_ids: + return {} + + # Escape special characters for JSON string format in agtype + def escape_memory_id(mid: str) -> str: + """Escape special characters in memory_id for JSON string format.""" + # Escape backslashes first, then double quotes + mid_str = mid.replace("\\", "\\\\") + mid_str = mid_str.replace('"', '\\"') + return mid_str + + # Build OR conditions for each memory_id + id_conditions = [] + for mid in normalized_memory_ids: + # Escape special characters + escaped_mid = escape_memory_id(mid) + id_conditions.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"id\"'::agtype) = '\"{escaped_mid}\"'::agtype" + ) + + where_clause = f"({' OR '.join(id_conditions)})" + + # Query to get memory_id and user_name pairs + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties::text::agtype, '\"id\"'::agtype)::text AS memory_id, + ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype)::text AS user_name + FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + + logger.info(f"[get_user_names_by_memory_ids] query: {query}") + conn = None + result_dict = {} + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() + + # Build result dictionary from query results + for row in results: + memory_id_raw = row[0] + user_name_raw = row[1] + + # Remove quotes if present + if isinstance(memory_id_raw, str): + memory_id = memory_id_raw.strip('"').strip("'") + else: + memory_id = str(memory_id_raw).strip('"').strip("'") + + if isinstance(user_name_raw, str): + user_name = user_name_raw.strip('"').strip("'") + else: + user_name = ( + str(user_name_raw).strip('"').strip("'") if user_name_raw else None + ) + + result_dict[memory_id] = user_name if user_name else None + + # Set None for memory_ids that were not found + for mid in normalized_memory_ids: + if mid not in result_dict: + result_dict[mid] = None + + logger.info( + f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " + f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" + ) + + return result_dict + except Exception as e: + logger.error( + f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True + ) + raise + finally: + self._return_connection(conn) + + def exist_user_name(self, user_name: str) -> dict[str, bool]: + """Check if user name exists in the graph. + + Args: + user_name: User name to check. + + Returns: + dict[str, bool]: Dictionary with user_name as key and bool as value indicating existence. + """ + logger.info(f"[exist_user_name] Querying user_name {user_name}") + if not user_name: + return {user_name: False} + + # Escape special characters for JSON string format in agtype + def escape_user_name(un: str) -> str: + """Escape special characters in user_name for JSON string format.""" + # Escape backslashes first, then double quotes + un_str = un.replace("\\", "\\\\") + un_str = un_str.replace('"', '\\"') + return un_str + + # Escape special characters + escaped_un = escape_user_name(user_name) + + # Query to check if user_name exists + query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = '\"{escaped_un}\"'::agtype + """ + logger.info(f"[exist_user_name] query: {query}") + result_dict = {} + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + count = cursor.fetchone()[0] + result = count > 0 + result_dict[user_name] = result + return result_dict + except Exception as e: + logger.error( + f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True + ) + raise + finally: + self._return_connection(conn) + + @timed + def delete_node_by_prams( + self, + writable_cube_ids: list[str] | None = None, + memory_ids: list[str] | None = None, + file_ids: list[str] | None = None, + filter: dict | None = None, + ) -> int: + """ + Delete nodes by memory_ids, file_ids, or filter. + + Args: + writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes. + If not provided, no user_name filter will be applied. + memory_ids (list[str], optional): List of memory node IDs to delete. + file_ids (list[str], optional): List of file node IDs to delete. + filter (dict, optional): Filter dictionary for metadata filtering. + Filter conditions are directly used in DELETE WHERE clause without pre-querying. + + Returns: + int: Number of nodes deleted. + """ + batch_start_time = time.time() + logger.info( + f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" + ) + + # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) + # Only add user_name filter if writable_cube_ids is provided + user_name_conditions = [] + if writable_cube_ids and len(writable_cube_ids) > 0: + for cube_id in writable_cube_ids: + # Use agtype_access_operator with VARIADIC ARRAY format for consistency + user_name_conditions.append( + f"agtype_access_operator(VARIADIC ARRAY[properties::text::agtype, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype" + ) + + # Build filter conditions using common method (no query, direct use in WHERE clause) + filter_conditions = [] + if filter: + filter_conditions = self._build_filter_conditions_sql(filter) + logger.info(f"[delete_node_by_prams] filter_conditions: {filter_conditions}") + + # If no conditions to delete, return 0 + if not memory_ids and not file_ids and not filter_conditions: + logger.warning( + "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" + ) + return 0 + + conn = None + total_deleted_count = 0 + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Build WHERE conditions list + where_conditions = [] + + # Add memory_ids conditions + if memory_ids: + logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids") + id_conditions = [] + for node_id in memory_ids: + id_conditions.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" + ) + where_conditions.append(f"({' OR '.join(id_conditions)})") + + # Add file_ids conditions + if file_ids: + logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") + file_id_conditions = [] + for file_id in file_ids: + file_id_conditions.append( + f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties::text::agtype, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" + ) + where_conditions.append(f"({' OR '.join(file_id_conditions)})") + + # Add filter conditions + if filter_conditions: + logger.info("[delete_node_by_prams] Processing filter conditions") + where_conditions.extend(filter_conditions) + + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_conditions.append(f"({user_name_where})") + + # Build final WHERE clause + if not where_conditions: + logger.warning("[delete_node_by_prams] No WHERE conditions to delete") + return 0 + + where_clause = " AND ".join(where_conditions) + + # Delete directly without counting + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") + + cursor.execute(delete_query) + deleted_count = cursor.rowcount + total_deleted_count = deleted_count + + logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes") + + elapsed_time = time.time() - batch_start_time + logger.info( + f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes" + ) + except Exception as e: + logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + logger.info(f"[delete_node_by_prams] Successfully deleted {total_deleted_count} nodes") + return total_deleted_count diff --git a/src/memos/graph_dbs/polardb/nodes.py b/src/memos/graph_dbs/polardb/nodes.py new file mode 100644 index 000000000..9a24fbdfe --- /dev/null +++ b/src/memos/graph_dbs/polardb/nodes.py @@ -0,0 +1,714 @@ +import json +import time +from datetime import datetime +from typing import Any + +from memos.graph_dbs.polardb.helpers import generate_vector +from memos.graph_dbs.utils import prepare_node_metadata as _prepare_node_metadata +from memos.log import get_logger +from memos.utils import timed + +logger = get_logger(__name__) + + +class NodeMixin: + """Mixin for node (memory) CRUD operations.""" + + @timed + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + ) -> None: + """Add a memory node to the graph.""" + logger.info(f"[add_node] id: {id}, memory: {memory}, metadata: {metadata}") + + # user_name comes from metadata; fallback to config if missing + metadata["user_name"] = user_name if user_name else self.config.user_name + + metadata = _prepare_node_metadata(metadata) + + # Merge node and set metadata + created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) + updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) + + # Prepare properties + properties = { + "id": id, + "memory": memory, + "created_at": created_at, + "updated_at": updated_at, + "delete_time": "", + "delete_record_id": "", + **metadata, + } + + # Generate embedding if not provided + if "embedding" not in properties or not properties["embedding"]: + properties["embedding"] = generate_vector( + self._get_config_value("embedding_dimension", 1024) + ) + + # serialization - JSON-serialize sources and usage fields + for field_name in ["sources", "usage"]: + if properties.get(field_name): + if isinstance(properties[field_name], list): + for idx in range(len(properties[field_name])): + # Serialize only when element is not a string + if not isinstance(properties[field_name][idx], str): + properties[field_name][idx] = json.dumps(properties[field_name][idx]) + elif isinstance(properties[field_name], str): + # If already a string, leave as-is + pass + + # Extract embedding for separate column + embedding_vector = properties.pop("embedding", []) + if not isinstance(embedding_vector, list): + embedding_vector = [] + + # Select column name based on embedding dimension + embedding_column = "embedding" # default column + if len(embedding_vector) == 3072: + embedding_column = "embedding_3072" + elif len(embedding_vector) == 1024: + embedding_column = "embedding" + elif len(embedding_vector) == 768: + embedding_column = "embedding_768" + + conn = None + insert_query = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Delete existing record first (if any) + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id = %s + """ + cursor.execute(delete_query, (id,)) + properties["graph_id"] = str(id) + + # Then insert new record + if embedding_vector: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + %s, + %s, + %s + ) + """ + cursor.execute( + insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) + ) + logger.info( + f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + else: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + %s, + %s + ) + """ + cursor.execute(insert_query, (id, json.dumps(properties))) + logger.info( + f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + except Exception as e: + logger.error(f"[add_node] Failed to add node: {e}", exc_info=True) + raise + finally: + if insert_query: + logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") + self._return_connection(conn) + + @timed + def add_nodes_batch( + self, + nodes: list[dict[str, Any]], + user_name: str | None = None, + ) -> None: + """ + Batch add multiple memory nodes to the graph. + + Args: + nodes: List of node dictionaries, each containing: + - id: str - Node ID + - memory: str - Memory content + - metadata: dict[str, Any] - Node metadata + user_name: Optional user name (will use config default if not provided) + """ + batch_start_time = time.time() + if not nodes: + logger.warning("[add_nodes_batch] Empty nodes list, skipping") + return + + logger.info(f"[add_nodes_batch] Processing only first node (total nodes: {len(nodes)})") + + # user_name comes from parameter; fallback to config if missing + effective_user_name = user_name if user_name else self.config.user_name + + # Prepare all nodes + prepared_nodes = [] + for node_data in nodes: + try: + id = node_data["id"] + memory = node_data["memory"] + metadata = node_data.get("metadata", {}) + + logger.debug(f"[add_nodes_batch] Processing node id: {id}") + + # Set user_name in metadata + metadata["user_name"] = effective_user_name + + metadata = _prepare_node_metadata(metadata) + + # Merge node and set metadata + created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) + updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) + + # Prepare properties + properties = { + "id": id, + "memory": memory, + "created_at": created_at, + "updated_at": updated_at, + "delete_time": "", + "delete_record_id": "", + **metadata, + } + + # Generate embedding if not provided + if "embedding" not in properties or not properties["embedding"]: + properties["embedding"] = generate_vector( + self._get_config_value("embedding_dimension", 1024) + ) + + # Serialization - JSON-serialize sources and usage fields + for field_name in ["sources", "usage"]: + if properties.get(field_name): + if isinstance(properties[field_name], list): + for idx in range(len(properties[field_name])): + # Serialize only when element is not a string + if not isinstance(properties[field_name][idx], str): + properties[field_name][idx] = json.dumps( + properties[field_name][idx] + ) + elif isinstance(properties[field_name], str): + # If already a string, leave as-is + pass + + # Extract embedding for separate column + embedding_vector = properties.pop("embedding", []) + if not isinstance(embedding_vector, list): + embedding_vector = [] + + # Select column name based on embedding dimension + embedding_column = "embedding" # default column + if len(embedding_vector) == 3072: + embedding_column = "embedding_3072" + elif len(embedding_vector) == 1024: + embedding_column = "embedding" + elif len(embedding_vector) == 768: + embedding_column = "embedding_768" + + prepared_nodes.append( + { + "id": id, + "memory": memory, + "properties": properties, + "embedding_vector": embedding_vector, + "embedding_column": embedding_column, + } + ) + except Exception as e: + logger.error( + f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}", + exc_info=True, + ) + # Continue with other nodes + continue + + if not prepared_nodes: + logger.warning("[add_nodes_batch] No valid nodes to insert after preparation") + return + + # Group nodes by embedding column to optimize batch inserts + nodes_by_embedding_column = {} + for node in prepared_nodes: + col = node["embedding_column"] + if col not in nodes_by_embedding_column: + nodes_by_embedding_column[col] = [] + nodes_by_embedding_column[col].append(node) + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Process each group separately + for embedding_column, nodes_group in nodes_by_embedding_column.items(): + # Batch delete existing records using IN clause + ids_to_delete = [node["id"] for node in nodes_group] + if ids_to_delete: + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id = ANY(%s::text[]) + """ + cursor.execute(delete_query, (ids_to_delete,)) + + # Set graph_id in properties (using text ID directly) + for node in nodes_group: + node["properties"]["graph_id"] = str(node["id"]) + + # Use PREPARE/EXECUTE for efficient batch insert + # Generate unique prepare statement name to avoid conflicts + prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}" + + try: + if embedding_column and any( + node["embedding_vector"] for node in nodes_group + ): + # PREPARE statement for insert with embedding + prepare_query = f""" + PREPARE {prepare_name} AS + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + $1, + $2::jsonb, + $3::vector + ) + """ + logger.info( + f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}" + ) + logger.info( + f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}" + ) + + cursor.execute(prepare_query) + + # Execute prepared statement for each node + for node in nodes_group: + properties_json = json.dumps(node["properties"]) + embedding_json = ( + json.dumps(node["embedding_vector"]) + if node["embedding_vector"] + else None + ) + + cursor.execute( + f"EXECUTE {prepare_name}(%s, %s, %s)", + (node["id"], properties_json, embedding_json), + ) + else: + # PREPARE statement for insert without embedding + prepare_query = f""" + PREPARE {prepare_name} AS + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + $1, + $2::jsonb + ) + """ + logger.info( + f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}" + ) + logger.info( + f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}" + ) + cursor.execute(prepare_query) + + # Execute prepared statement for each node + for node in nodes_group: + properties_json = json.dumps(node["properties"]) + + cursor.execute( + f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json) + ) + finally: + # DEALLOCATE prepared statement (always execute, even on error) + try: + cursor.execute(f"DEALLOCATE {prepare_name}") + logger.info( + f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}" + ) + except Exception as dealloc_error: + logger.warning( + f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" + ) + + logger.info( + f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" + ) + elapsed_time = time.time() - batch_start_time + logger.info( + f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s" + ) + + except Exception as e: + logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + @timed + def get_node( + self, id: str, include_embedding: bool = False, user_name: str | None = None + ) -> dict[str, Any] | None: + """ + Retrieve a Memory node by its unique ID. + + Args: + id (str): Node ID (Memory.id) + include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + dict: Node properties as key-value pairs, or None if not found. + """ + logger.info( + f"polardb [get_node] id: {id}, include_embedding: {include_embedding}, user_name: {user_name}" + ) + start_time = time.time() + select_fields = "id, properties, embedding" if include_embedding else "id, properties" + + query = f""" + SELECT {select_fields} + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) = %s::agtype + """ + params = [self.format_param_value(id)] + + # Only add user filter when user_name is provided + if user_name is not None: + query += "\nAND ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + params.append(self.format_param_value(user_name)) + + logger.info(f"polardb [get_node] query: {query},params: {params}") + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + + if result: + if include_embedding: + _, properties_json, embedding_json = result + else: + _, properties_json = result + embedding_json = None + + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {id}") + properties = {} + else: + properties = properties_json if properties_json else {} + + # Parse embedding from JSONB if it exists and include_embedding is True + if include_embedding and embedding_json is not None: + try: + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {id}") + + elapsed_time = time.time() - start_time + logger.info( + f" polardb [get_node] get_node completed time in {elapsed_time:.2f}s" + ) + return self._parse_node( + { + "id": id, + "memory": properties.get("memory", ""), + **properties, + } + ) + return None + + except Exception as e: + logger.error(f"[get_node] Failed to retrieve node '{id}': {e}", exc_info=True) + return None + finally: + self._return_connection(conn) + + @timed + def get_nodes( + self, ids: list[str], user_name: str | None = None, **kwargs + ) -> list[dict[str, Any]]: + """ + Retrieve the metadata and memory of a list of nodes. + Args: + ids: List of Node identifier. + Returns: + list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'. + + Notes: + - Assumes all provided IDs are valid and exist. + - Returns empty list if input is empty. + """ + logger.info(f"get_nodes ids:{ids},user_name:{user_name}") + if not ids: + return [] + + # Build WHERE clause using IN operator with agtype array + # Use ANY operator with array for better performance + placeholders = ",".join(["%s"] * len(ids)) + params = [self.format_param_value(id_val) for id_val in ids] + + query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '\"id\"'::agtype) = ANY(ARRAY[{placeholders}]::agtype[]) + """ + + # Only add user_name filter if provided + if user_name is not None: + query += " AND ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + params.append(self.format_param_value(user_name)) + + logger.info(f"get_nodes query:{query},params:{params}") + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + + nodes = [] + for row in results: + node_id, properties_json, embedding_json = row + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {node_id}") + properties = {} + else: + properties = properties_json if properties_json else {} + + # Parse embedding from JSONB if it exists + if embedding_json is not None and kwargs.get("include_embedding"): + try: + # remove embedding + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + nodes.append( + self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } + ) + ) + return nodes + finally: + self._return_connection(conn) + + @timed + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: + """ + Update node fields in PolarDB, auto-converting `created_at` and `updated_at` to datetime type if present. + """ + if not fields: + return + + user_name = user_name if user_name else self.config.user_name + + # Get the current node + current_node = self.get_node(id, user_name=user_name) + if not current_node: + return + + # Update properties but keep original id and memory fields + properties = current_node["metadata"].copy() + original_id = properties.get("id", id) # Preserve original ID + original_memory = current_node.get("memory", "") # Preserve original memory + + # If fields include memory, use it; otherwise keep original memory + if "memory" in fields: + original_memory = fields.pop("memory") + + properties.update(fields) + properties["id"] = original_id # Ensure ID is not overwritten + properties["memory"] = original_memory # Ensure memory is not overwritten + + # Handle embedding field + embedding_vector = None + if "embedding" in fields: + embedding_vector = fields.pop("embedding") + if not isinstance(embedding_vector, list): + embedding_vector = None + + # Build update query + if embedding_vector is not None: + query = f""" + UPDATE "{self.db_name}_graph"."Memory" + SET properties = %s, embedding = %s + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) = %s::agtype + """ + params = [ + json.dumps(properties), + json.dumps(embedding_vector), + self.format_param_value(id), + ] + else: + query = f""" + UPDATE "{self.db_name}_graph"."Memory" + SET properties = %s + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) = %s::agtype + """ + params = [json.dumps(properties), self.format_param_value(id)] + + # Only add user filter when user_name is provided + if user_name is not None: + query += "\nAND ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + params.append(self.format_param_value(user_name)) + + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + except Exception as e: + logger.error(f"[update_node] Failed to update node '{id}': {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + @timed + def delete_node(self, id: str, user_name: str | None = None) -> None: + """ + Delete a node from the graph. + Args: + id: Node identifier to delete. + user_name (str, optional): User name for filtering in non-multi-db mode + """ + query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) = %s::agtype + """ + params = [self.format_param_value(id)] + + # Only add user filter when user_name is provided + if user_name is not None: + query += "\nAND ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + params.append(self.format_param_value(user_name)) + + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + except Exception as e: + logger.error(f"[delete_node] Failed to delete node '{id}': {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: + """Parse node data from database format to standard format.""" + node = node_data.copy() + + # Strip wrapping quotes from agtype string values (idempotent) + for k, v in list(node.items()): + if ( + isinstance(v, str) + and len(v) >= 2 + and v[0] == v[-1] + and v[0] in ("'", '"') + ): + node[k] = v[1:-1] + + # Convert datetime to string + for time_field in ("created_at", "updated_at"): + if time_field in node and hasattr(node[time_field], "isoformat"): + node[time_field] = node[time_field].isoformat() + + # Deserialize sources from JSON strings back to dict objects + if "sources" in node and node.get("sources"): + sources = node["sources"] + if isinstance(sources, list): + deserialized_sources = [] + for source_item in sources: + if isinstance(source_item, str): + try: + parsed = json.loads(source_item) + deserialized_sources.append(parsed) + except (json.JSONDecodeError, TypeError): + deserialized_sources.append({"type": "doc", "content": source_item}) + elif isinstance(source_item, dict): + deserialized_sources.append(source_item) + else: + deserialized_sources.append({"type": "doc", "content": str(source_item)}) + node["sources"] = deserialized_sources + + return {"id": node.pop("id", None), "memory": node.pop("memory", ""), "metadata": node} + + def _build_node_from_agtype(self, node_agtype, embedding=None): + """ + Parse the cypher-returned column `n` (agtype or JSON string) + into a standard node and merge embedding into properties. + """ + try: + # String case: '{"id":...,"label":[...],"properties":{...}}::vertex' + if isinstance(node_agtype, str): + json_str = node_agtype.replace("::vertex", "") + obj = json.loads(json_str) + if not (isinstance(obj, dict) and "properties" in obj): + return None + props = obj["properties"] + # agtype case: has `value` attribute + elif node_agtype and hasattr(node_agtype, "value"): + val = node_agtype.value + if not (isinstance(val, dict) and "properties" in val): + return None + props = val["properties"] + else: + return None + + if embedding is not None: + if isinstance(embedding, str): + try: + embedding = json.loads(embedding) + except (json.JSONDecodeError, TypeError): + logger.warning("Failed to parse embedding for node") + props["embedding"] = embedding + + return self._parse_node(props) + except Exception: + return None + + def format_param_value(self, value: str | None) -> str: + """Format parameter value to handle both quoted and unquoted formats""" + # Handle None value + if value is None: + logger.warning("format_param_value: value is None") + return "null" + + # Remove outer quotes if they exist + if value.startswith('"') and value.endswith('"'): + # Already has double quotes, return as is + return value + else: + # Add double quotes + return f'"{value}"' diff --git a/src/memos/graph_dbs/polardb/queries.py b/src/memos/graph_dbs/polardb/queries.py new file mode 100644 index 000000000..6404774f9 --- /dev/null +++ b/src/memos/graph_dbs/polardb/queries.py @@ -0,0 +1,657 @@ +import json +from typing import Any + +from memos.log import get_logger +from memos.utils import timed + +logger = get_logger(__name__) + + +class QueryMixin: + """Mixin for query operations (metadata, counts, grouped queries).""" + + @timed + def get_by_metadata( + self, + filters: list[dict[str, Any]], + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list | None = None, + user_name_flag: bool = True, + ) -> list[str]: + """ + Retrieve node IDs that match given metadata filters. + Supports exact match. + + Args: + filters: List of filter dicts like: + [ + {"field": "key", "op": "in", "value": ["A", "B"]}, + {"field": "confidence", "op": ">=", "value": 80}, + {"field": "tags", "op": "contains", "value": "AI"}, + ... + ] + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + list[str]: Node IDs whose metadata match the filter conditions. (AND logic). + """ + logger.info(f"[get_by_metadata] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") + + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build WHERE conditions for cypher query + where_conditions = [] + + for f in filters: + field = f["field"] + op = f.get("op", "=") + value = f["value"] + + # Format value + if isinstance(value, str): + # Escape single quotes using backslash when inside $$ dollar-quoted strings + # In $$ delimiters, Cypher string literals can use \' to escape single quotes + escaped_str = value.replace("'", "\\'") + escaped_value = f"'{escaped_str}'" + elif isinstance(value, list): + # Handle list values - use double quotes for Cypher arrays + list_items = [] + for v in value: + if isinstance(v, str): + # Escape double quotes in string values for Cypher + escaped_str = v.replace('"', '\\"') + list_items.append(f'"{escaped_str}"') + else: + list_items.append(str(v)) + escaped_value = f"[{', '.join(list_items)}]" + else: + escaped_value = f"'{value}'" if isinstance(value, str) else str(value) + # Build WHERE conditions + if op == "=": + where_conditions.append(f"n.{field} = {escaped_value}") + elif op == "in": + where_conditions.append(f"n.{field} IN {escaped_value}") + """ + # where_conditions.append(f"{escaped_value} IN n.{field}") + """ + elif op == "contains": + where_conditions.append(f"{escaped_value} IN n.{field}") + """ + # where_conditions.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0") + """ + elif op == "starts_with": + where_conditions.append(f"n.{field} STARTS WITH {escaped_value}") + elif op == "ends_with": + where_conditions.append(f"n.{field} ENDS WITH {escaped_value}") + elif op == "like": + where_conditions.append(f"n.{field} CONTAINS {escaped_value}") + elif op in [">", ">=", "<", "<="]: + where_conditions.append(f"n.{field} {op} {escaped_value}") + else: + raise ValueError(f"Unsupported operator: {op}") + + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self._get_config_value("user_name"), + ) + logger.info(f"[get_by_metadata] user_name_conditions: {user_name_conditions}") + + # Add user_name WHERE clause + if user_name_conditions: + if len(user_name_conditions) == 1: + where_conditions.append(user_name_conditions[0]) + else: + where_conditions.append(f"({' OR '.join(user_name_conditions)})") + + # Build filter conditions using common method + filter_where_clause = self._build_filter_conditions_cypher(filter) + logger.info(f"[get_by_metadata] filter_where_clause: {filter_where_clause}") + + where_str = " AND ".join(where_conditions) + filter_where_clause + + # Use cypher query + cypher_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE {where_str} + RETURN n.id AS id + $$) AS (id agtype) + """ + + ids = [] + conn = None + logger.info(f"[get_by_metadata] cypher_query: {cypher_query}") + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + ids = [str(item[0]).strip('"') for item in results] + except Exception as e: + logger.error(f"Failed to get metadata: {e}, query is {cypher_query}") + finally: + self._return_connection(conn) + + return ids + + @timed + def get_grouped_counts( + self, + group_fields: list[str], + where_clause: str = "", + params: dict[str, Any] | None = None, + user_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Count nodes grouped by any fields. + + Args: + group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"] + where_clause (str, optional): Extra WHERE condition. E.g., + "WHERE n.status = 'activated'" + params (dict, optional): Parameters for WHERE clause. + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] + """ + if not group_fields: + raise ValueError("group_fields cannot be empty") + + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build user clause + user_clause = f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + if where_clause: + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause += f" AND {user_clause}" + else: + where_clause = f"WHERE {where_clause} AND {user_clause}" + else: + where_clause = f"WHERE {user_clause}" + + # Inline parameters if provided + if params and isinstance(params, dict): + for key, value in params.items(): + # Handle different value types appropriately + if isinstance(value, str): + value = f"'{value}'" + where_clause = where_clause.replace(f"${key}", str(value)) + + # Handle user_name parameter in where_clause + if "user_name = %s" in where_clause: + where_clause = where_clause.replace( + "user_name = %s", + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype", + ) + + # Build return fields and group by fields + return_fields = [] + group_by_fields = [] + + for field in group_fields: + alias = field.replace(".", "_") + return_fields.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"{field}\"'::agtype)::text AS {alias}" + ) + group_by_fields.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"{field}\"'::agtype)::text" + ) + + # Full SQL query construction + query = f""" + SELECT {", ".join(return_fields)}, COUNT(*) AS count + FROM "{self.db_name}_graph"."Memory" + {where_clause} + GROUP BY {", ".join(group_by_fields)} + """ + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Handle parameterized query + if params and isinstance(params, list): + cursor.execute(query, params) + else: + cursor.execute(query) + results = cursor.fetchall() + + output = [] + for row in results: + group_values = {} + for i, field in enumerate(group_fields): + value = row[i] + if hasattr(value, "value"): + group_values[field] = value.value + else: + group_values[field] = str(value) + count_value = row[-1] # Last column is count + output.append({**group_values, "count": int(count_value)}) + + return output + + except Exception as e: + logger.error(f"Failed to get grouped counts: {e}", exc_info=True) + return [] + finally: + self._return_connection(conn) + + def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: + """Get count of memory nodes by type.""" + user_name = user_name if user_name else self._get_config_value("user_name") + query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"memory_type"'::agtype) = %s::agtype + """ + query += "\nAND ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + params = [self.format_param_value(memory_type), self.format_param_value(user_name)] + + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return result[0] if result else 0 + except Exception as e: + logger.error(f"[get_memory_count] Failed: {e}") + return -1 + finally: + self._return_connection(conn) + + @timed + def node_not_exist(self, scope: str, user_name: str | None = None) -> int: + """Check if a node with given scope exists.""" + user_name = user_name if user_name else self._get_config_value("user_name") + query = f""" + SELECT id + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties::text::agtype, '"memory_type"'::agtype) = %s::agtype + """ + query += "\nAND ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + query += "\nLIMIT 1" + params = [self.format_param_value(scope), self.format_param_value(user_name)] + + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return 1 if result else 0 + except Exception as e: + logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + + @timed + def count_nodes(self, scope: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name + + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.memory_type = '{scope}' + AND n.user_name = '{user_name}' + RETURN count(n) + $$) AS (count agtype) + """ + conn = None + try: + conn = self._get_connection() + cursor = conn.cursor() + cursor.execute(query) + row = cursor.fetchone() + cursor.close() + conn.commit() + return int(row[0]) if row else 0 + except Exception: + if conn: + conn.rollback() + raise + finally: + self._return_connection(conn) + + @timed + def get_all_memory_items( + self, + scope: str, + include_embedding: bool = False, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list | None = None, + status: str | None = None, + ) -> list[dict]: + """ + Retrieve all memory items of a specific memory_type. + + Args: + scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. + include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode + filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results. + knowledgebase_ids (list, optional): List of knowledgebase IDs to filter by. + status (str, optional): Filter by status (e.g., 'activated', 'archived'). + If None, no status filter is applied. + + Returns: + list[dict]: Full list of memory items under this scope. + """ + logger.info( + f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status}" + ) + + user_name = user_name if user_name else self._get_config_value("user_name") + if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: + raise ValueError(f"Unsupported memory type scope: {scope}") + + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self._get_config_value("user_name"), + ) + + # Build user_name WHERE clause + if user_name_conditions: + if len(user_name_conditions) == 1: + user_name_where = user_name_conditions[0] + else: + user_name_where = f"({' OR '.join(user_name_conditions)})" + else: + user_name_where = "" + + # Build filter conditions using common method + filter_where_clause = self._build_filter_conditions_cypher(filter) + logger.info(f"[get_all_memory_items] filter_where_clause: {filter_where_clause}") + + # Use cypher query to retrieve memory items + if include_embedding: + # Build WHERE clause with user_name/knowledgebase_ids and filter + where_parts = [f"n.memory_type = '{scope}'"] + if status: + where_parts.append(f"n.status = '{status}'") + if user_name_where: + # user_name_where already contains parentheses if it's an OR condition + where_parts.append(user_name_where) + if filter_where_clause: + # filter_where_clause already contains " AND " prefix, so we just append it + where_clause = " AND ".join(where_parts) + filter_where_clause + else: + where_clause = " AND ".join(where_parts) + + cypher_query = f""" + WITH t as ( + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE {where_clause} + RETURN id(n) as id1,n + LIMIT 100 + $$) AS (id1 agtype,n agtype) + ) + SELECT + m.embedding, + t.n + FROM t, + {self.db_name}_graph."Memory" m + WHERE t.id1 = m.id; + """ + nodes = [] + node_ids = set() + conn = None + logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + + for row in results: + if isinstance(row, list | tuple) and len(row) >= 2: + embedding_val, node_val = row[0], row[1] + else: + embedding_val, node_val = None, row[0] + + node = self._build_node_from_agtype(node_val, embedding_val) + if node: + node_id = node["id"] + if node_id not in node_ids: + nodes.append(node) + node_ids.add(node_id) + + except Exception as e: + logger.error(f"Failed to get memories: {e}", exc_info=True) + finally: + self._return_connection(conn) + + return nodes + else: + # Build WHERE clause with user_name/knowledgebase_ids and filter + where_parts = [f"n.memory_type = '{scope}'"] + if status: + where_parts.append(f"n.status = '{status}'") + if user_name_where: + # user_name_where already contains parentheses if it's an OR condition + where_parts.append(user_name_where) + if filter_where_clause: + # filter_where_clause already contains " AND " prefix, so we just append it + where_clause = " AND ".join(where_parts) + filter_where_clause + else: + where_clause = " AND ".join(where_parts) + + cypher_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE {where_clause} + RETURN properties(n) as props + LIMIT 100 + $$) AS (nprops agtype) + """ + + nodes = [] + conn = None + logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + + for row in results: + memory_data = json.loads(row[0]) if isinstance(row[0], str) else row[0] + nodes.append(self._parse_node(memory_data)) + + except Exception as e: + logger.error(f"Failed to get memories: {e}", exc_info=True) + finally: + self._return_connection(conn) + + return nodes + + @timed + def get_structure_optimization_candidates( + self, scope: str, include_embedding: bool = False, user_name: str | None = None + ) -> list[dict]: + """ + Find nodes that are likely candidates for structure optimization: + - Isolated nodes, nodes with empty background, or nodes with exactly one child. + - Plus: the child of any parent node that has exactly one child. + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build return fields based on include_embedding flag + if include_embedding: + return_fields = "id(n) as id1,n" + return_fields_agtype = " id1 agtype,n agtype" + else: + # Build field list without embedding + return_fields = ",".join( + [ + "n.id AS id", + "n.memory AS memory", + "n.user_name AS user_name", + "n.user_id AS user_id", + "n.session_id AS session_id", + "n.status AS status", + "n.key AS key", + "n.confidence AS confidence", + "n.tags AS tags", + "n.created_at AS created_at", + "n.updated_at AS updated_at", + "n.memory_type AS memory_type", + "n.sources AS sources", + "n.source AS source", + "n.node_type AS node_type", + "n.visibility AS visibility", + "n.usage AS usage", + "n.background AS background", + "n.graph_id as graph_id", + ] + ) + fields = [ + "id", + "memory", + "user_name", + "user_id", + "session_id", + "status", + "key", + "confidence", + "tags", + "created_at", + "updated_at", + "memory_type", + "sources", + "source", + "node_type", + "visibility", + "usage", + "background", + "graph_id", + ] + return_fields_agtype = ", ".join([f"{field} agtype" for field in fields]) + + # Use OPTIONAL MATCH to find isolated nodes (no parents or children) + cypher_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.memory_type = '{scope}' + AND n.status = 'activated' + AND n.user_name = '{user_name}' + OPTIONAL MATCH (n)-[:PARENT]->(c:Memory) + OPTIONAL MATCH (p:Memory)-[:PARENT]->(n) + WITH n, c, p + WHERE c IS NULL AND p IS NULL + RETURN {return_fields} + $$) AS ({return_fields_agtype}) + """ + if include_embedding: + cypher_query = f""" + WITH t as ( + {cypher_query} + ) + SELECT + m.embedding, + t.n + FROM t, + {self.db_name}_graph."Memory" m + WHERE t.id1 = m.id + """ + logger.info(f"[get_structure_optimization_candidates] query: {cypher_query}") + + candidates = [] + node_ids = set() + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + logger.info(f"Found {len(results)} structure optimization candidates") + for row in results: + if include_embedding: + # When include_embedding=True, return full node object + """ + if isinstance(row, (list, tuple)) and len(row) >= 2: + """ + if isinstance(row, list | tuple) and len(row) >= 2: + embedding_val, node_val = row[0], row[1] + else: + embedding_val, node_val = None, row[0] + + node = self._build_node_from_agtype(node_val, embedding_val) + if node: + node_id = node["id"] + if node_id not in node_ids: + candidates.append(node) + node_ids.add(node_id) + else: + # When include_embedding=False, return field dictionary + # Define field names matching the RETURN clause + field_names = [ + "id", + "memory", + "user_name", + "user_id", + "session_id", + "status", + "key", + "confidence", + "tags", + "created_at", + "updated_at", + "memory_type", + "sources", + "source", + "node_type", + "visibility", + "usage", + "background", + "graph_id", + ] + + # Convert row to dictionary + node_data = {} + for i, field_name in enumerate(field_names): + if i < len(row): + value = row[i] + # Handle special fields + if field_name in ["tags", "sources", "usage"] and isinstance( + value, str + ): + try: + # Try parsing JSON string + node_data[field_name] = json.loads(value) + except (json.JSONDecodeError, TypeError): + node_data[field_name] = value + else: + node_data[field_name] = value + + # Parse node using _parse_node_new + try: + node = self._parse_node(node_data) + node_id = node["id"] + + if node_id not in node_ids: + candidates.append(node) + node_ids.add(node_id) + logger.debug(f"Parsed node successfully: {node_id}") + except Exception as e: + logger.error(f"Failed to parse node: {e}") + + except Exception as e: + logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) + finally: + self._return_connection(conn) + + return candidates diff --git a/src/memos/graph_dbs/polardb/schema.py b/src/memos/graph_dbs/polardb/schema.py new file mode 100644 index 000000000..c1339046a --- /dev/null +++ b/src/memos/graph_dbs/polardb/schema.py @@ -0,0 +1,171 @@ +from memos.log import get_logger +from memos.utils import timed + +logger = get_logger(__name__) + + +class SchemaMixin: + """Mixin for schema and extension management.""" + + def _ensure_database_exists(self): + """Create database if it doesn't exist.""" + try: + # For PostgreSQL/PolarDB, we need to connect to a default database first + # This is a simplified implementation - in production you might want to handle this differently + logger.info(f"Using database '{self.db_name}'") + except Exception as e: + logger.error(f"Failed to access database '{self.db_name}': {e}") + raise + + @timed + def _create_graph(self): + """Create PostgreSQL schema and table for graph storage.""" + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Create schema if it doesn't exist + cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') + logger.info(f"Schema '{self.db_name}_graph' ensured.") + + # Create Memory table if it doesn't exist + cursor.execute(f""" + CREATE TABLE IF NOT EXISTS "{self.db_name}_graph"."Memory" ( + id TEXT PRIMARY KEY, + properties JSONB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """) + logger.info(f"Memory table created in schema '{self.db_name}_graph'.") + + # Add embedding column if it doesn't exist (using JSONB for compatibility) + try: + cursor.execute(f""" + ALTER TABLE "{self.db_name}_graph"."Memory" + ADD COLUMN IF NOT EXISTS embedding JSONB; + """) + logger.info("Embedding column added to Memory table.") + except Exception as e: + logger.warning(f"Failed to add embedding column: {e}") + + # Create indexes + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_properties + ON "{self.db_name}_graph"."Memory" USING GIN (properties); + """) + + # Create vector index for embedding field + try: + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_embedding + ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 100); + """) + logger.info("Vector index created for Memory table.") + except Exception as e: + logger.warning(f"Vector index creation failed (might not be supported): {e}") + + logger.info("Indexes created for Memory table.") + + except Exception as e: + logger.error(f"Failed to create graph schema: {e}") + raise e + finally: + self._return_connection(conn) + + def create_index( + self, + label: str = "Memory", + vector_property: str = "embedding", + dimensions: int = 1024, + index_name: str = "memory_vector_index", + ) -> None: + """ + Create indexes for embedding and other fields. + Note: This creates PostgreSQL indexes on the underlying tables. + """ + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Create indexes on the underlying PostgreSQL tables + # Apache AGE stores data in regular PostgreSQL tables + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_properties + ON "{self.db_name}_graph"."Memory" USING GIN (properties); + """) + + # Try to create vector index, but don't fail if it doesn't work + try: + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_embedding + ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops); + """) + except Exception as ve: + logger.warning(f"Vector index creation failed (might not be supported): {ve}") + + logger.debug("Indexes created successfully.") + except Exception as e: + logger.warning(f"Failed to create indexes: {e}") + finally: + self._return_connection(conn) + + @timed + def create_extension(self): + extensions = [("polar_age", "Graph engine"), ("vector", "Vector engine")] + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Ensure in the correct database context + cursor.execute("SELECT current_database();") + current_db = cursor.fetchone()[0] + logger.info(f"Current database context: {current_db}") + + for ext_name, ext_desc in extensions: + try: + cursor.execute(f"create extension if not exists {ext_name};") + logger.info(f"Extension '{ext_name}' ({ext_desc}) ensured.") + except Exception as e: + if "already exists" in str(e): + logger.info(f"Extension '{ext_name}' ({ext_desc}) already exists.") + else: + logger.warning( + f"Failed to create extension '{ext_name}' ({ext_desc}): {e}" + ) + logger.error( + f"Failed to create extension '{ext_name}': {e}", exc_info=True + ) + except Exception as e: + logger.warning(f"Failed to access database context: {e}") + logger.error(f"Failed to access database context: {e}", exc_info=True) + finally: + self._return_connection(conn) + + @timed + def create_graph(self): + # Get a connection from the pool + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(f""" + SELECT COUNT(*) FROM ag_catalog.ag_graph + WHERE name = '{self.db_name}_graph'; + """) + graph_exists = cursor.fetchone()[0] > 0 + + if graph_exists: + logger.info(f"Graph '{self.db_name}_graph' already exists.") + else: + cursor.execute(f"select create_graph('{self.db_name}_graph');") + logger.info(f"Graph database '{self.db_name}_graph' created.") + except Exception as e: + logger.warning(f"Failed to create graph '{self.db_name}_graph': {e}") + logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) + finally: + self._return_connection(conn) diff --git a/src/memos/graph_dbs/polardb/search.py b/src/memos/graph_dbs/polardb/search.py new file mode 100644 index 000000000..d8ef084fd --- /dev/null +++ b/src/memos/graph_dbs/polardb/search.py @@ -0,0 +1,360 @@ +import time + +from memos.graph_dbs.utils import convert_to_vector +from memos.log import get_logger +from memos.utils import timed + +logger = get_logger(__name__) + + +class SearchMixin: + """Mixin for search operations (keyword, fulltext, embedding).""" + + def _build_search_where_clauses_sql( + self, + scope: str | None = None, + status: str | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + ) -> list[str]: + """Build common WHERE clauses for SQL-based search methods.""" + where_clauses = [] + + if scope: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" + ) + if status: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"status\"'::agtype) = '\"{status}\"'::agtype" + ) + else: + where_clauses.append( + "ag_catalog.agtype_access_operator(properties::text::agtype, '\"status\"'::agtype) = '\"activated\"'::agtype" + ) + + # Build user_name filter with knowledgebase_ids support (OR relationship) + user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + ) + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" + ) + else: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties::text::agtype, '\"{key}\"'::agtype) = {value}::agtype" + ) + + # Build filter conditions + filter_conditions = self._build_filter_conditions_sql(filter) + where_clauses.extend(filter_conditions) + + return where_clauses + + @timed + def search_by_keywords_like( + self, + query_word: str, + scope: str | None = None, + status: str | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + where_clauses = self._build_search_where_clauses_sql( + scope=scope, status=status, search_filter=search_filter, + user_name=user_name, filter=filter, knowledgebase_ids=knowledgebase_ids, + ) + + # Method-specific: LIKE pattern match + where_clauses.append("""(properties -> '"memory"')::text LIKE %s""") + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) AS old_id, + agtype_object_field_text(properties, 'memory') as memory_text + FROM "{self.db_name}_graph"."Memory" + {where_clause} + """ + + params = (query_word,) + logger.info( + f"[search_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}" + ) + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] + id_val = str(oldid).strip('"') + output.append({"id": id_val}) + logger.info( + f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) + return output + finally: + self._return_connection(conn) + + @timed + def search_by_keywords_tfidf( + self, + query_words: list[str], + scope: str | None = None, + status: str | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + tsvector_field: str = "properties_tsvector_zh", + tsquery_config: str = "jiebaqry", + **kwargs, + ) -> list[dict]: + where_clauses = self._build_search_where_clauses_sql( + scope=scope, status=status, search_filter=search_filter, + user_name=user_name, filter=filter, knowledgebase_ids=knowledgebase_ids, + ) + + # Method-specific: TF-IDF fulltext search condition + tsquery_string = " | ".join(query_words) + where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") + + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) AS old_id, + agtype_object_field_text(properties, 'memory') as memory_text + FROM "{self.db_name}_graph"."Memory" + {where_clause} + """ + + params = (tsquery_string,) + logger.info( + f"[search_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}" + ) + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] + id_val = str(oldid).strip('"') + output.append({"id": id_val}) + + logger.info( + f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) + return output + finally: + self._return_connection(conn) + + @timed + def search_by_fulltext( + self, + query_words: list[str], + top_k: int = 10, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + tsvector_field: str = "properties_tsvector_zh", + tsquery_config: str = "jiebacfg", + **kwargs, + ) -> list[dict]: + """ + Full-text search functionality using PostgreSQL's full-text search capabilities. + + Args: + query_text: query text + top_k: maximum number of results to return + scope: memory type filter (memory_type) + status: status filter, defaults to "activated" + threshold: similarity threshold filter + search_filter: additional property filter conditions + user_name: username filter + knowledgebase_ids: knowledgebase ids filter + filter: filter conditions with 'and' or 'or' logic for search results. + tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 + tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) + **kwargs: other parameters (e.g. cube_name) + + Returns: + list[dict]: result list containing id and score + """ + logger.info( + f"[search_by_fulltext] query_words: {query_words}, top_k: {top_k}, scope: {scope}, filter: {filter}" + ) + start_time = time.time() + where_clauses = self._build_search_where_clauses_sql( + scope=scope, status=status, search_filter=search_filter, + user_name=user_name, filter=filter, knowledgebase_ids=knowledgebase_ids, + ) + + # Method-specific: fulltext search condition + tsquery_string = " | ".join(query_words) + + where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") + + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + logger.info(f"[search_by_fulltext] where_clause: {where_clause}") + + # Build fulltext search query + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) AS old_id, + agtype_object_field_text(properties, 'memory') as memory_text, + ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY rank DESC + LIMIT {top_k}; + """ + + params = [tsquery_string, tsquery_string] + logger.info(f"[search_by_fulltext] query: {query}, params: {params}") + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] # old_id + rank = row[2] # rank score + + id_val = str(oldid).strip('"') + score_val = float(rank) + + # Apply threshold filter if specified + if threshold is None or score_val >= threshold: + output.append({"id": id_val, "score": score_val}) + elapsed_time = time.time() - start_time + logger.info( + f" polardb [search_by_fulltext] query completed time in {elapsed_time:.2f}s" + ) + return output[:top_k] + finally: + self._return_connection(conn) + + @timed + def search_by_embedding( + self, + vector: list[float], + top_k: int = 5, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + """ + Retrieve node IDs based on vector similarity using PostgreSQL vector operations. + """ + logger.info( + f"[search_by_embedding] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" + ) + where_clauses = self._build_search_where_clauses_sql( + scope=scope, status=status, search_filter=search_filter, + user_name=user_name, filter=filter, knowledgebase_ids=knowledgebase_ids, + ) + # Method-specific: require embedding column + where_clauses.append("embedding is not null") + + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + # Keep original simple query structure but add dynamic WHERE clause + query = f""" + WITH t AS ( + SELECT id, + properties, + timeline, + ag_catalog.agtype_access_operator(properties::text::agtype, '"id"'::agtype) AS old_id, + (1 - (embedding <=> %s::vector(1024))) AS scope + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY scope DESC + LIMIT {top_k} + ) + SELECT * + FROM t + WHERE scope > 0.1; + """ + # Convert vector to string format for PostgreSQL vector type + # PostgreSQL vector type expects a string format like '[1,2,3]' + vector_str = convert_to_vector(vector) + # Use string format directly in query instead of parameterized query + # Replace %s with the vector string, but need to quote it properly + # PostgreSQL vector type needs the string to be quoted + query = query.replace("%s::vector(1024)", f"'{vector_str}'::vector(1024)") + params = [] + + logger.info(f"[search_by_embedding] query: {query}, params: {params}") + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + try: + # If params is empty, execute query directly without parameters + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + except Exception as e: + logger.error(f"[search_by_embedding] Error executing query: {e}") + raise + results = cursor.fetchall() + output = [] + for row in results: + if len(row) < 5: + logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") + continue + oldid = row[3] # old_id + score = row[4] # scope + id_val = str(oldid).strip('"') + score_val = float(score) + score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score + if threshold is None or score_val >= threshold: + output.append({"id": id_val, "score": score_val}) + return output[:top_k] + except Exception as e: + logger.error(f"[search_by_embedding] Error: {type(e).__name__}: {e}", exc_info=True) + return [] + finally: + self._return_connection(conn) + + # Backward-compatible aliases for renamed methods (typo -> correct) + seach_by_keywords_like = search_by_keywords_like + seach_by_keywords_tfidf = search_by_keywords_tfidf diff --git a/src/memos/graph_dbs/polardb/traversal.py b/src/memos/graph_dbs/polardb/traversal.py new file mode 100644 index 000000000..d9eb1612a --- /dev/null +++ b/src/memos/graph_dbs/polardb/traversal.py @@ -0,0 +1,431 @@ +import json +from typing import Any, Literal + +from memos.log import get_logger +from memos.utils import timed + +logger = get_logger(__name__) + + +class TraversalMixin: + """Mixin for graph traversal operations.""" + + def get_neighbors( + self, id: str, type: str, direction: Literal["in", "out", "both"] = "out" + ) -> list[str]: + """Get connected node IDs in a specific direction and relationship type.""" + raise NotImplementedError + + @timed + def get_children_with_embeddings( + self, id: str, user_name: str | None = None + ) -> list[dict[str, Any]]: + """Get children nodes with their embeddings.""" + user_name = user_name if user_name else self._get_config_value("user_name") + where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" + + query = f""" + WITH t as ( + SELECT * + FROM cypher('{self.db_name}_graph', $$ + MATCH (p:Memory)-[r:PARENT]->(c:Memory) + WHERE p.id = '{id}' {where_user} + RETURN id(c) as cid, c.id AS id, c.memory AS memory + $$) as (cid agtype, id agtype, memory agtype) + ) + SELECT t.id, m.embedding, t.memory FROM t, + "{self.db_name}_graph"."Memory" m + WHERE t.cid::graphid = m.id; + """ + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() + + children = [] + for row in results: + # Handle child_id - remove possible quotes + child_id_raw = row[0].value if hasattr(row[0], "value") else str(row[0]) + if isinstance(child_id_raw, str): + # If string starts and ends with quotes, remove quotes + if child_id_raw.startswith('"') and child_id_raw.endswith('"'): + child_id = child_id_raw[1:-1] + else: + child_id = child_id_raw + else: + child_id = str(child_id_raw) + + # Handle embedding - get from database embedding column + embedding_raw = row[1] + embedding = [] + if embedding_raw is not None: + try: + if isinstance(embedding_raw, str): + # If it is a JSON string, parse it + embedding = json.loads(embedding_raw) + elif isinstance(embedding_raw, list): + # If already a list, use directly + embedding = embedding_raw + else: + # Try converting to list + embedding = list(embedding_raw) + except (json.JSONDecodeError, TypeError, ValueError) as e: + logger.warning( + f"Failed to parse embedding for child node {child_id}: {e}" + ) + embedding = [] + + # Handle memory - remove possible quotes + memory_raw = row[2].value if hasattr(row[2], "value") else str(row[2]) + if isinstance(memory_raw, str): + # If string starts and ends with quotes, remove quotes + if memory_raw.startswith('"') and memory_raw.endswith('"'): + memory = memory_raw[1:-1] + else: + memory = memory_raw + else: + memory = str(memory_raw) + + children.append({"id": child_id, "embedding": embedding, "memory": memory}) + + return children + + except Exception as e: + logger.error(f"[get_children_with_embeddings] Failed: {e}", exc_info=True) + return [] + finally: + self._return_connection(conn) + + def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: + """Get the path of nodes from source to target within a limited depth.""" + raise NotImplementedError + + @timed + def get_subgraph( + self, + center_id: str, + depth: int = 2, + center_status: str = "activated", + user_name: str | None = None, + ) -> dict[str, Any]: + """ + Retrieve a local subgraph centered at a given node. + Args: + center_id: The ID of the center node. + depth: The hop distance for neighbors. + center_status: Required status for center node. + user_name (str, optional): User name for filtering in non-multi-db mode + Returns: + { + "core_node": {...}, + "neighbors": [...], + "edges": [...] + } + """ + logger.info(f"[get_subgraph] center_id: {center_id}") + if not 1 <= depth <= 5: + raise ValueError("depth must be 1-5") + + user_name = user_name if user_name else self._get_config_value("user_name") + + if center_id.startswith('"') and center_id.endswith('"'): + center_id = center_id[1:-1] + # Use UNION ALL for better performance: separate queries for depth 1 and depth 2 + if depth == 1: + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH(center: Memory)-[r]->(neighbor:Memory) + WHERE + center.id = '{center_id}' + AND center.status = '{center_status}' + AND center.user_name = '{user_name}' + RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r) + $$ ) as (centers agtype, neighbors agtype, rels agtype); + """ + else: + # For depth >= 2, use UNION ALL to combine depth 1 and depth 2 queries + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH(center: Memory)-[r]->(neighbor:Memory) + WHERE + center.id = '{center_id}' + AND center.status = '{center_status}' + AND center.user_name = '{user_name}' + RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r) + UNION ALL + MATCH(center: Memory)-[r]->(n:Memory)-[r1]->(neighbor:Memory) + WHERE + center.id = '{center_id}' + AND center.status = '{center_status}' + AND center.user_name = '{user_name}' + RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r1) + $$ ) as (centers agtype, neighbors agtype, rels agtype); + """ + conn = None + logger.info(f"[get_subgraph] Query: {query}") + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() + + if not results: + return {"core_node": None, "neighbors": [], "edges": []} + + # Merge results from all UNION ALL rows + all_centers_list = [] + all_neighbors_list = [] + all_edges_list = [] + + for result in results: + if not result or not result[0]: + continue + + centers_data = result[0] if result[0] else "[]" + neighbors_data = result[1] if result[1] else "[]" + edges_data = result[2] if result[2] else "[]" + + # Parse JSON data + try: + # Clean ::vertex and ::edge suffixes in data + if isinstance(centers_data, str): + centers_data = centers_data.replace("::vertex", "") + if isinstance(neighbors_data, str): + neighbors_data = neighbors_data.replace("::vertex", "") + if isinstance(edges_data, str): + edges_data = edges_data.replace("::edge", "") + + centers_list = ( + json.loads(centers_data) + if isinstance(centers_data, str) + else centers_data + ) + neighbors_list = ( + json.loads(neighbors_data) + if isinstance(neighbors_data, str) + else neighbors_data + ) + edges_list = ( + json.loads(edges_data) if isinstance(edges_data, str) else edges_data + ) + + # Collect data from this row + if isinstance(centers_list, list): + all_centers_list.extend(centers_list) + if isinstance(neighbors_list, list): + all_neighbors_list.extend(neighbors_list) + if isinstance(edges_list, list): + all_edges_list.extend(edges_list) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON data: {e}") + continue + + # Deduplicate centers by ID + centers_dict = {} + for center_data in all_centers_list: + if isinstance(center_data, dict) and "properties" in center_data: + center_id_key = center_data["properties"].get("id") + if center_id_key and center_id_key not in centers_dict: + centers_dict[center_id_key] = center_data + + # Parse center node (use first center) + core_node = None + if centers_dict: + center_data = next(iter(centers_dict.values())) + if isinstance(center_data, dict) and "properties" in center_data: + core_node = self._parse_node(center_data["properties"]) + + # Deduplicate neighbors by ID + neighbors_dict = {} + for neighbor_data in all_neighbors_list: + if isinstance(neighbor_data, dict) and "properties" in neighbor_data: + neighbor_id = neighbor_data["properties"].get("id") + if neighbor_id and neighbor_id not in neighbors_dict: + neighbors_dict[neighbor_id] = neighbor_data + + # Parse neighbor nodes + neighbors = [] + for neighbor_data in neighbors_dict.values(): + if isinstance(neighbor_data, dict) and "properties" in neighbor_data: + neighbor_parsed = self._parse_node(neighbor_data["properties"]) + neighbors.append(neighbor_parsed) + + # Deduplicate edges by (source, target, type) + edges_dict = {} + for edge_group in all_edges_list: + if isinstance(edge_group, list): + for edge_data in edge_group: + if isinstance(edge_data, dict): + edge_key = ( + edge_data.get("start_id", ""), + edge_data.get("end_id", ""), + edge_data.get("label", ""), + ) + if edge_key not in edges_dict: + edges_dict[edge_key] = { + "type": edge_data.get("label", ""), + "source": edge_data.get("start_id", ""), + "target": edge_data.get("end_id", ""), + } + elif isinstance(edge_group, dict): + # Handle single edge (not in a list) + edge_key = ( + edge_group.get("start_id", ""), + edge_group.get("end_id", ""), + edge_group.get("label", ""), + ) + if edge_key not in edges_dict: + edges_dict[edge_key] = { + "type": edge_group.get("label", ""), + "source": edge_group.get("start_id", ""), + "target": edge_group.get("end_id", ""), + } + + edges = list(edges_dict.values()) + + return self._convert_graph_edges( + {"core_node": core_node, "neighbors": neighbors, "edges": edges} + ) + + except Exception as e: + logger.error(f"Failed to get subgraph: {e}", exc_info=True) + return {"core_node": None, "neighbors": [], "edges": []} + finally: + self._return_connection(conn) + + def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: + """Get the ordered context chain starting from a node.""" + raise NotImplementedError + + @timed + def get_neighbors_by_tag( + self, + tags: list[str], + exclude_ids: list[str], + top_k: int = 5, + min_overlap: int = 1, + include_embedding: bool = False, + user_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Find top-K neighbor nodes with maximum tag overlap. + + Args: + tags: The list of tags to match. + exclude_ids: Node IDs to exclude (e.g., local cluster). + top_k: Max number of neighbors to return. + min_overlap: Minimum number of overlapping tags required. + include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + List of dicts with node details and overlap count. + """ + if not tags: + return [] + + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build query conditions - more relaxed filters + where_clauses = [] + params = [] + + # Exclude specified IDs - use id in properties + if exclude_ids: + exclude_conditions = [] + for exclude_id in exclude_ids: + exclude_conditions.append( + "ag_catalog.agtype_access_operator(properties::text::agtype, '\"id\"'::agtype) != %s::agtype" + ) + params.append(self.format_param_value(exclude_id)) + where_clauses.append(f"({' AND '.join(exclude_conditions)})") + + # Status filter - keep only 'activated' + where_clauses.append( + "ag_catalog.agtype_access_operator(properties::text::agtype, '\"status\"'::agtype) = '\"activated\"'::agtype" + ) + + # Type filter - exclude 'reasoning' type + where_clauses.append( + "ag_catalog.agtype_access_operator(properties::text::agtype, '\"node_type\"'::agtype) != '\"reasoning\"'::agtype" + ) + + # User filter + where_clauses.append( + "ag_catalog.agtype_access_operator(properties::text::agtype, '\"user_name\"'::agtype) = %s::agtype" + ) + params.append(self.format_param_value(user_name)) + + # Testing showed no data; annotate. + where_clauses.append( + "ag_catalog.agtype_access_operator(properties::text::agtype, '\"memory_type\"'::agtype) != '\"WorkingMemory\"'::agtype" + ) + + where_clause = " AND ".join(where_clauses) + + # Fetch all candidate nodes + query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + + logger.debug(f"[get_neighbors_by_tag] query: {query}, params: {params}") + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + + nodes_with_overlap = [] + for row in results: + node_id, properties_json, embedding_json = row + properties = properties_json if properties_json else {} + + # Parse embedding + if include_embedding and embedding_json is not None: + try: + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + + # Compute tag overlap + node_tags = properties.get("tags", []) + if isinstance(node_tags, str): + try: + node_tags = json.loads(node_tags) + except (json.JSONDecodeError, TypeError): + node_tags = [] + + overlap_tags = [tag for tag in tags if tag in node_tags] + overlap_count = len(overlap_tags) + + if overlap_count >= min_overlap: + node_data = self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } + ) + nodes_with_overlap.append((node_data, overlap_count)) + + # Sort by overlap count and return top_k items + nodes_with_overlap.sort(key=lambda x: x[1], reverse=True) + return [node for node, _ in nodes_with_overlap[:top_k]] + + except Exception as e: + logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) + return [] + finally: + self._return_connection(conn) diff --git a/src/memos/graph_dbs/postgres.py b/src/memos/graph_dbs/postgres.py new file mode 100644 index 000000000..09d1d0844 --- /dev/null +++ b/src/memos/graph_dbs/postgres.py @@ -0,0 +1,885 @@ +""" +PostgreSQL + pgvector backend for MemOS. + +Simple implementation using standard PostgreSQL with pgvector extension. +No Apache AGE or other graph extensions required. + +Tables: +- {schema}.memories: Memory nodes with JSONB properties and vector embeddings +- {schema}.edges: Relationships between memory nodes +""" + +import json +import time + +from contextlib import suppress +from datetime import datetime +from typing import Any, Literal + +from memos.configs.graph_db import PostgresGraphDBConfig +from memos.dependency import require_python_package +from memos.graph_dbs.base import BaseGraphDB +from memos.log import get_logger + + +logger = get_logger(__name__) + + +def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: + """Ensure metadata has proper datetime fields and normalized types.""" + now = datetime.utcnow().isoformat() + metadata.setdefault("created_at", now) + metadata.setdefault("updated_at", now) + + # Normalize embedding type + embedding = metadata.get("embedding") + if embedding and isinstance(embedding, list): + metadata["embedding"] = [float(x) for x in embedding] + + return metadata + + +class PostgresGraphDB(BaseGraphDB): + """PostgreSQL + pgvector implementation of a graph memory store.""" + + @require_python_package( + import_name="psycopg2", + install_command="pip install psycopg2-binary", + install_link="https://pypi.org/project/psycopg2-binary/", + ) + def __init__(self, config: PostgresGraphDBConfig): + """Initialize PostgreSQL connection pool.""" + import psycopg2 + import psycopg2.pool + + self.config = config + self.schema = config.schema_name + self.user_name = config.user_name + self._pool_closed = False + + logger.info(f"Connecting to PostgreSQL: {config.host}:{config.port}/{config.db_name}") + + # Create connection pool + self.pool = psycopg2.pool.ThreadedConnectionPool( + minconn=2, + maxconn=config.maxconn, + host=config.host, + port=config.port, + user=config.user, + password=config.password, + dbname=config.db_name, + connect_timeout=30, + keepalives_idle=30, + keepalives_interval=10, + keepalives_count=5, + ) + + # Initialize schema and tables + self._init_schema() + + def _get_conn(self): + """Get connection from pool with health check.""" + if self._pool_closed: + raise RuntimeError("Connection pool is closed") + + for attempt in range(3): + conn = None + try: + conn = self.pool.getconn() + if conn.closed != 0: + self.pool.putconn(conn, close=True) + continue + conn.autocommit = True + # Health check + with conn.cursor() as cur: + cur.execute("SELECT 1") + return conn + except Exception as e: + if conn: + with suppress(Exception): + self.pool.putconn(conn, close=True) + if attempt == 2: + raise RuntimeError(f"Failed to get connection: {e}") from e + time.sleep(0.1) + raise RuntimeError("Failed to get healthy connection") + + def _put_conn(self, conn): + """Return connection to pool.""" + if conn and not self._pool_closed: + try: + self.pool.putconn(conn) + except Exception: + with suppress(Exception): + conn.close() + + def _init_schema(self): + """Create schema and tables if they don't exist.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Create schema + cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.schema}") + + # Enable pgvector + cur.execute("CREATE EXTENSION IF NOT EXISTS vector") + + # Create memories table + dim = self.config.embedding_dimension + cur.execute(f""" + CREATE TABLE IF NOT EXISTS {self.schema}.memories ( + id TEXT PRIMARY KEY, + memory TEXT NOT NULL DEFAULT '', + properties JSONB NOT NULL DEFAULT '{{}}', + embedding vector({dim}), + user_name TEXT, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() + ) + """) + + # Create edges table + cur.execute(f""" + CREATE TABLE IF NOT EXISTS {self.schema}.edges ( + id SERIAL PRIMARY KEY, + source_id TEXT NOT NULL, + target_id TEXT NOT NULL, + edge_type TEXT NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW(), + UNIQUE(source_id, target_id, edge_type) + ) + """) + + # Create indexes + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memories_user + ON {self.schema}.memories(user_name) + """) + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memories_props + ON {self.schema}.memories USING GIN(properties) + """) + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memories_embedding + ON {self.schema}.memories USING ivfflat(embedding vector_cosine_ops) + WITH (lists = 100) + """) + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_edges_source + ON {self.schema}.edges(source_id) + """) + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_edges_target + ON {self.schema}.edges(target_id) + """) + + logger.info(f"Schema {self.schema} initialized successfully") + except Exception as e: + logger.error(f"Failed to init schema: {e}") + raise + finally: + self._put_conn(conn) + + # ========================================================================= + # Node Management + # ========================================================================= + + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str | None = None + ) -> None: + """ + Remove all memories of a given type except the latest `keep_latest` entries. + + Args: + memory_type: Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). + keep_latest: Number of latest entries to keep. + user_name: User to filter by. + """ + user_name = user_name or self.user_name + keep_latest = int(keep_latest) + + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Find IDs to delete (older than the keep_latest entries) + cur.execute(f""" + WITH ranked AS ( + SELECT id, ROW_NUMBER() OVER (ORDER BY updated_at DESC) as rn + FROM {self.schema}.memories + WHERE user_name = %s + AND properties->>'memory_type' = %s + ) + SELECT id FROM ranked WHERE rn > %s + """, (user_name, memory_type, keep_latest)) + + ids_to_delete = [row[0] for row in cur.fetchall()] + + if ids_to_delete: + # Delete edges first + cur.execute(f""" + DELETE FROM {self.schema}.edges + WHERE source_id = ANY(%s) OR target_id = ANY(%s) + """, (ids_to_delete, ids_to_delete)) + + # Delete nodes + cur.execute(f""" + DELETE FROM {self.schema}.memories + WHERE id = ANY(%s) + """, (ids_to_delete,)) + + logger.info(f"Removed {len(ids_to_delete)} oldest {memory_type} memories for user {user_name}") + finally: + self._put_conn(conn) + + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + ) -> None: + """Add a memory node.""" + user_name = user_name or self.user_name + metadata = _prepare_node_metadata(metadata.copy()) + + # Extract embedding + embedding = metadata.pop("embedding", None) + created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) + updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) + + # Serialize sources if present + if metadata.get("sources"): + metadata["sources"] = [ + json.dumps(s) if not isinstance(s, str) else s + for s in metadata["sources"] + ] + + conn = self._get_conn() + try: + with conn.cursor() as cur: + if embedding: + cur.execute(f""" + INSERT INTO {self.schema}.memories + (id, memory, properties, embedding, user_name, created_at, updated_at) + VALUES (%s, %s, %s, %s::vector, %s, %s, %s) + ON CONFLICT (id) DO UPDATE SET + memory = EXCLUDED.memory, + properties = EXCLUDED.properties, + embedding = EXCLUDED.embedding, + updated_at = EXCLUDED.updated_at + """, (id, memory, json.dumps(metadata), embedding, user_name, created_at, updated_at)) + else: + cur.execute(f""" + INSERT INTO {self.schema}.memories + (id, memory, properties, user_name, created_at, updated_at) + VALUES (%s, %s, %s, %s, %s, %s) + ON CONFLICT (id) DO UPDATE SET + memory = EXCLUDED.memory, + properties = EXCLUDED.properties, + updated_at = EXCLUDED.updated_at + """, (id, memory, json.dumps(metadata), user_name, created_at, updated_at)) + finally: + self._put_conn(conn) + + def add_nodes_batch( + self, nodes: list[dict[str, Any]], user_name: str | None = None + ) -> None: + """Batch add memory nodes.""" + for node in nodes: + self.add_node( + id=node["id"], + memory=node["memory"], + metadata=node.get("metadata", {}), + user_name=user_name, + ) + + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: + """Update node fields.""" + user_name = user_name or self.user_name + if not fields: + return + + # Get current node + current = self.get_node(id, user_name=user_name) + if not current: + return + + # Merge properties + props = current.get("metadata", {}).copy() + embedding = fields.pop("embedding", None) + memory = fields.pop("memory", current.get("memory", "")) + props.update(fields) + props["updated_at"] = datetime.utcnow().isoformat() + + conn = self._get_conn() + try: + with conn.cursor() as cur: + if embedding: + cur.execute(f""" + UPDATE {self.schema}.memories + SET memory = %s, properties = %s, embedding = %s::vector, updated_at = NOW() + WHERE id = %s AND user_name = %s + """, (memory, json.dumps(props), embedding, id, user_name)) + else: + cur.execute(f""" + UPDATE {self.schema}.memories + SET memory = %s, properties = %s, updated_at = NOW() + WHERE id = %s AND user_name = %s + """, (memory, json.dumps(props), id, user_name)) + finally: + self._put_conn(conn) + + def delete_node(self, id: str, user_name: str | None = None) -> None: + """Delete a node and its edges.""" + user_name = user_name or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Delete edges + cur.execute(f""" + DELETE FROM {self.schema}.edges + WHERE source_id = %s OR target_id = %s + """, (id, id)) + # Delete node + cur.execute(f""" + DELETE FROM {self.schema}.memories + WHERE id = %s AND user_name = %s + """, (id, user_name)) + finally: + self._put_conn(conn) + + def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None: + """Get a single node by ID.""" + user_name = kwargs.get("user_name") or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + cur.execute(f""" + SELECT {cols} FROM {self.schema}.memories + WHERE id = %s AND user_name = %s + """, (id, user_name)) + row = cur.fetchone() + if not row: + return None + return self._parse_row(row, include_embedding) + finally: + self._put_conn(conn) + + def get_nodes( + self, ids: list, include_embedding: bool = False, **kwargs + ) -> list[dict[str, Any]]: + """Get multiple nodes by IDs.""" + if not ids: + return [] + user_name = kwargs.get("user_name") or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + cur.execute(f""" + SELECT {cols} FROM {self.schema}.memories + WHERE id = ANY(%s) AND user_name = %s + """, (ids, user_name)) + return [self._parse_row(row, include_embedding) for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def _parse_row(self, row, include_embedding: bool = False) -> dict[str, Any]: + """Parse database row to node dict.""" + props = row[2] if isinstance(row[2], dict) else json.loads(row[2] or "{}") + props["created_at"] = row[3].isoformat() if row[3] else None + props["updated_at"] = row[4].isoformat() if row[4] else None + result = { + "id": row[0], + "memory": row[1] or "", + "metadata": props, + } + if include_embedding and len(row) > 5: + result["metadata"]["embedding"] = row[5] + return result + + # ========================================================================= + # Edge Management + # ========================================================================= + + def add_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + """Create an edge between nodes.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + INSERT INTO {self.schema}.edges (source_id, target_id, edge_type) + VALUES (%s, %s, %s) + ON CONFLICT (source_id, target_id, edge_type) DO NOTHING + """, (source_id, target_id, type)) + finally: + self._put_conn(conn) + + def delete_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + """Delete an edge.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + DELETE FROM {self.schema}.edges + WHERE source_id = %s AND target_id = %s AND edge_type = %s + """, (source_id, target_id, type)) + finally: + self._put_conn(conn) + + def edge_exists(self, source_id: str, target_id: str, type: str) -> bool: + """Check if edge exists.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + SELECT 1 FROM {self.schema}.edges + WHERE source_id = %s AND target_id = %s AND edge_type = %s + LIMIT 1 + """, (source_id, target_id, type)) + return cur.fetchone() is not None + finally: + self._put_conn(conn) + + # ========================================================================= + # Graph Queries + # ========================================================================= + + def get_neighbors( + self, id: str, type: str, direction: Literal["in", "out", "both"] = "out" + ) -> list[str]: + """Get neighboring node IDs.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + if direction == "out": + cur.execute(f""" + SELECT target_id FROM {self.schema}.edges + WHERE source_id = %s AND edge_type = %s + """, (id, type)) + elif direction == "in": + cur.execute(f""" + SELECT source_id FROM {self.schema}.edges + WHERE target_id = %s AND edge_type = %s + """, (id, type)) + else: # both + cur.execute(f""" + SELECT target_id FROM {self.schema}.edges WHERE source_id = %s AND edge_type = %s + UNION + SELECT source_id FROM {self.schema}.edges WHERE target_id = %s AND edge_type = %s + """, (id, type, id, type)) + return [row[0] for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: + """Get path between nodes using recursive CTE.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + WITH RECURSIVE path AS ( + SELECT source_id, target_id, ARRAY[source_id] as nodes, 1 as depth + FROM {self.schema}.edges + WHERE source_id = %s + UNION ALL + SELECT e.source_id, e.target_id, p.nodes || e.source_id, p.depth + 1 + FROM {self.schema}.edges e + JOIN path p ON e.source_id = p.target_id + WHERE p.depth < %s AND NOT e.source_id = ANY(p.nodes) + ) + SELECT nodes || target_id as full_path + FROM path + WHERE target_id = %s + ORDER BY depth + LIMIT 1 + """, (source_id, max_depth, target_id)) + row = cur.fetchone() + return row[0] if row else [] + finally: + self._put_conn(conn) + + def get_subgraph(self, center_id: str, depth: int = 2) -> list[str]: + """Get subgraph around center node.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + WITH RECURSIVE subgraph AS ( + SELECT %s::text as node_id, 0 as level + UNION + SELECT CASE WHEN e.source_id = s.node_id THEN e.target_id ELSE e.source_id END, + s.level + 1 + FROM {self.schema}.edges e + JOIN subgraph s ON (e.source_id = s.node_id OR e.target_id = s.node_id) + WHERE s.level < %s + ) + SELECT DISTINCT node_id FROM subgraph + """, (center_id, depth)) + return [row[0] for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: + """Get ordered chain following relationship type.""" + return self.get_neighbors(id, type, "out") + + # ========================================================================= + # Search Operations + # ========================================================================= + + def search_by_embedding( + self, + vector: list[float], + top_k: int = 5, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + """Search nodes by vector similarity using pgvector.""" + user_name = user_name or self.user_name + + # Build WHERE clause + conditions = ["embedding IS NOT NULL"] + params = [] + + if user_name: + conditions.append("user_name = %s") + params.append(user_name) + + if scope: + conditions.append("properties->>'memory_type' = %s") + params.append(scope) + + if status: + conditions.append("properties->>'status' = %s") + params.append(status) + else: + conditions.append("(properties->>'status' = 'activated' OR properties->>'status' IS NULL)") + + if search_filter: + for k, v in search_filter.items(): + conditions.append(f"properties->>'{k}' = %s") + params.append(str(v)) + + where_clause = " AND ".join(conditions) + + # pgvector cosine distance: 1 - (a <=> b) gives similarity score + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + SELECT id, 1 - (embedding <=> %s::vector) as score + FROM {self.schema}.memories + WHERE {where_clause} + ORDER BY embedding <=> %s::vector + LIMIT %s + """, (vector, *params, vector, top_k)) + + results = [] + for row in cur.fetchall(): + score = float(row[1]) + if threshold is None or score >= threshold: + results.append({"id": row[0], "score": score}) + return results + finally: + self._put_conn(conn) + + def get_by_metadata( + self, + filters: list[dict[str, Any]], + status: str | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + user_name_flag: bool = True, + ) -> list[str]: + """Get node IDs matching metadata filters.""" + user_name = user_name or self.user_name + + conditions = [] + params = [] + + if user_name_flag and user_name: + conditions.append("user_name = %s") + params.append(user_name) + + if status: + conditions.append("properties->>'status' = %s") + params.append(status) + + for f in filters: + field = f["field"] + op = f.get("op", "=") + value = f["value"] + + if op == "=": + conditions.append(f"properties->>'{field}' = %s") + params.append(str(value)) + elif op == "in": + placeholders = ",".join(["%s"] * len(value)) + conditions.append(f"properties->>'{field}' IN ({placeholders})") + params.extend([str(v) for v in value]) + elif op in (">", ">=", "<", "<="): + conditions.append(f"(properties->>'{field}')::numeric {op} %s") + params.append(value) + elif op == "contains": + conditions.append(f"properties->'{field}' @> %s::jsonb") + params.append(json.dumps([value])) + + where_clause = " AND ".join(conditions) if conditions else "TRUE" + + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + SELECT id FROM {self.schema}.memories + WHERE {where_clause} + """, params) + return [row[0] for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def get_all_memory_items( + self, + scope: str, + include_embedding: bool = False, + status: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + """Get all memory items of a specific type.""" + user_name = kwargs.get("user_name") or self.user_name + + conditions = ["properties->>'memory_type' = %s", "user_name = %s"] + params = [scope, user_name] + + if status: + conditions.append("properties->>'status' = %s") + params.append(status) + + where_clause = " AND ".join(conditions) + + conn = self._get_conn() + try: + with conn.cursor() as cur: + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + cur.execute(f""" + SELECT {cols} FROM {self.schema}.memories + WHERE {where_clause} + """, params) + return [self._parse_row(row, include_embedding) for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def get_structure_optimization_candidates( + self, scope: str, include_embedding: bool = False + ) -> list[dict]: + """Find isolated nodes (no edges).""" + user_name = self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + cols = "m.id, m.memory, m.properties, m.created_at, m.updated_at" + cur.execute(f""" + SELECT {cols} + FROM {self.schema}.memories m + LEFT JOIN {self.schema}.edges e1 ON m.id = e1.source_id + LEFT JOIN {self.schema}.edges e2 ON m.id = e2.target_id + WHERE m.properties->>'memory_type' = %s + AND m.user_name = %s + AND m.properties->>'status' = 'activated' + AND e1.id IS NULL + AND e2.id IS NULL + """, (scope, user_name)) + return [self._parse_row(row, False) for row in cur.fetchall()] + finally: + self._put_conn(conn) + + # ========================================================================= + # Maintenance + # ========================================================================= + + def deduplicate_nodes(self) -> None: + """Not implemented - handled at application level.""" + + def get_grouped_counts( + self, + group_fields: list[str], + where_clause: str = "", + params: dict[str, Any] | None = None, + user_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Count nodes grouped by specified fields. + + Args: + group_fields: Fields to group by, e.g., ["memory_type", "status"] + where_clause: Extra WHERE condition + params: Parameters for WHERE clause + user_name: User to filter by + + Returns: + list[dict]: e.g., [{'memory_type': 'WorkingMemory', 'count': 10}, ...] + """ + user_name = user_name or self.user_name + if not group_fields: + raise ValueError("group_fields cannot be empty") + + # Build SELECT and GROUP BY clauses + # Fields come from JSONB properties column + select_fields = ", ".join([ + f"properties->>'{field}' AS {field}" for field in group_fields + ]) + group_by = ", ".join([f"properties->>'{field}'" for field in group_fields]) + + # Build WHERE clause + conditions = ["user_name = %s"] + query_params = [user_name] + + if where_clause: + # Parse simple where clause format + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause = where_clause[5:].strip() + if where_clause: + conditions.append(where_clause) + if params: + query_params.extend(params.values()) + + where_sql = " AND ".join(conditions) + + query = f""" + SELECT {select_fields}, COUNT(*) AS count + FROM {self.schema}.memories + WHERE {where_sql} + GROUP BY {group_by} + """ + + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(query, query_params) + results = [] + for row in cur.fetchall(): + result = {} + for i, field in enumerate(group_fields): + result[field] = row[i] + result["count"] = row[len(group_fields)] + results.append(result) + return results + finally: + self._put_conn(conn) + + def detect_conflicts(self) -> list[tuple[str, str]]: + """Not implemented.""" + return [] + + def merge_nodes(self, id1: str, id2: str) -> str: + """Not implemented.""" + raise NotImplementedError + + def clear(self, user_name: str | None = None) -> None: + """Clear all data for user.""" + user_name = user_name or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Get all node IDs for user + cur.execute(f""" + SELECT id FROM {self.schema}.memories WHERE user_name = %s + """, (user_name,)) + ids = [row[0] for row in cur.fetchall()] + + if ids: + # Delete edges + cur.execute(f""" + DELETE FROM {self.schema}.edges + WHERE source_id = ANY(%s) OR target_id = ANY(%s) + """, (ids, ids)) + + # Delete nodes + cur.execute(f""" + DELETE FROM {self.schema}.memories WHERE user_name = %s + """, (user_name,)) + logger.info(f"Cleared all data for user {user_name}") + finally: + self._put_conn(conn) + + def export_graph(self, include_embedding: bool = False, **kwargs) -> dict[str, Any]: + """Export all data.""" + user_name = kwargs.get("user_name") or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Get nodes + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + cur.execute(f""" + SELECT {cols} FROM {self.schema}.memories + WHERE user_name = %s + ORDER BY created_at DESC + """, (user_name,)) + nodes = [self._parse_row(row, include_embedding) for row in cur.fetchall()] + + # Get edges + node_ids = [n["id"] for n in nodes] + if node_ids: + cur.execute(f""" + SELECT source_id, target_id, edge_type + FROM {self.schema}.edges + WHERE source_id = ANY(%s) OR target_id = ANY(%s) + """, (node_ids, node_ids)) + edges = [ + {"source": row[0], "target": row[1], "type": row[2]} + for row in cur.fetchall() + ] + else: + edges = [] + + return { + "nodes": nodes, + "edges": edges, + "total_nodes": len(nodes), + "total_edges": len(edges), + } + finally: + self._put_conn(conn) + + def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: + """Import graph data.""" + user_name = user_name or self.user_name + + for node in data.get("nodes", []): + self.add_node( + id=node["id"], + memory=node.get("memory", ""), + metadata=node.get("metadata", {}), + user_name=user_name, + ) + + for edge in data.get("edges", []): + self.add_edge( + source_id=edge["source"], + target_id=edge["target"], + type=edge["type"], + ) + + def close(self): + """Close connection pool.""" + if not self._pool_closed: + self._pool_closed = True + self.pool.closeall() diff --git a/src/memos/graph_dbs/utils.py b/src/memos/graph_dbs/utils.py new file mode 100644 index 000000000..d4975075c --- /dev/null +++ b/src/memos/graph_dbs/utils.py @@ -0,0 +1,62 @@ +"""Shared utilities for graph database backends.""" + +from datetime import datetime +from typing import Any + +import numpy as np + + +def compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: + """Extract id, memory, and metadata from a node dict.""" + node_id = item["id"] + memory = item["memory"] + metadata = item.get("metadata", {}) + return node_id, memory, metadata + + +def prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: + """ + Ensure metadata has proper datetime fields and normalized types. + + - Fill `created_at` and `updated_at` if missing (in ISO 8601 format). + - Convert embedding to list of float if present. + """ + now = datetime.utcnow().isoformat() + + # Fill timestamps if missing + metadata.setdefault("created_at", now) + metadata.setdefault("updated_at", now) + + # Normalize embedding type + embedding = metadata.get("embedding") + if embedding and isinstance(embedding, list): + metadata["embedding"] = [float(x) for x in embedding] + + return metadata + + +def convert_to_vector(embedding_list): + """Convert an embedding list to PostgreSQL vector string format.""" + if not embedding_list: + return None + if isinstance(embedding_list, np.ndarray): + embedding_list = embedding_list.tolist() + return "[" + ",".join(str(float(x)) for x in embedding_list) + "]" + + +def detect_embedding_field(embedding_list): + """Detect the embedding field name based on vector dimension.""" + if not embedding_list: + return None + dim = len(embedding_list) + if dim == 1024: + return "embedding" + return None + + +def clean_properties(props): + """Remove vector fields from properties dict.""" + vector_keys = {"embedding", "embedding_1024", "embedding_3072", "embedding_768"} + if not isinstance(props, dict): + return {} + return {k: v for k, v in props.items() if k not in vector_keys} diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index e38318a64..45aa0a4da 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -756,7 +756,7 @@ def filter_fault_update(self, operations: list[dict]): for judge in all_judge: valid_update = None if judge["judgement"] == "UPDATE_APPROVED": - valid_update = id2op.get(judge["id"], None) + valid_update = id2op.get(judge["id"]) if valid_update: valid_updates.append(valid_update) @@ -924,7 +924,7 @@ def process_keyword_replace( ) must_part = f"{' & '.join(queries)}" if len(queries) > 1 else queries[0] - retrieved_ids = self.graph_store.seach_by_keywords_tfidf( + retrieved_ids = self.graph_store.search_by_keywords_tfidf( [must_part], user_name=user_name, filter=filter_dict ) if len(retrieved_ids) < 1: @@ -932,7 +932,7 @@ def process_keyword_replace( queries, top_k=100, user_name=user_name, filter=filter_dict ) else: - retrieved_ids = self.graph_store.seach_by_keywords_like( + retrieved_ids = self.graph_store.search_by_keywords_like( f"%{original_word}%", user_name=user_name, filter=filter_dict ) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 22cd0e9cb..0397411f0 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -629,6 +629,8 @@ def search_textual_memory(cube_id, cube): search_filter=search_filter, ) search_time_end = time.time() + print(f"🔍 [SEARCH_DEBUG] cube_id={cube_id}, found {len(memories)} memories", flush=True) + logger.warning(f"[SEARCH_DEBUG] cube_id={cube_id}, memories_count={len(memories)}, first_3_ids={[m.id for m in memories[:3]]}") logger.info( f"🧠 [Memory] Searched memories from {cube_id}:\n{self._str_memories(memories)}\n" ) diff --git a/src/memos/mem_os/utils/default_config.py b/src/memos/mem_os/utils/default_config.py index edb7875d4..8a063db9f 100644 --- a/src/memos/mem_os/utils/default_config.py +++ b/src/memos/mem_os/utils/default_config.py @@ -3,12 +3,15 @@ Provides simplified configuration generation for users. """ +import logging from typing import Literal from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig from memos.mem_cube.general import GeneralMemCube +logger = logging.getLogger(__name__) + def get_default_config( openai_api_key: str, @@ -116,20 +119,9 @@ def get_default_config( }, } - # Add activation memory if enabled - if config_dict.get("enable_activation_memory", False): - config_dict["act_mem"] = { - "backend": "kv_cache", - "config": { - "memory_filename": kwargs.get( - "activation_memory_filename", "activation_memory.pickle" - ), - "extractor_llm": { - "backend": "openai", - "config": openai_config, - }, - }, - } + # Note: act_mem configuration belongs in MemCube config (get_default_cube_config), + # not in MOSConfig which doesn't have an act_mem field. + # The enable_activation_memory flag above is sufficient for MOSConfig. return MOSConfig(**config_dict) @@ -180,38 +172,53 @@ def get_default_cube_config( # Configure text memory based on type if text_mem_type == "tree_text": - # Tree text memory requires Neo4j configuration - # NOTE: Neo4j Community Edition does NOT support multiple databases. - # It only has one default database named 'neo4j'. - # If you are using Community Edition: - # 1. Set 'use_multi_db' to False (default) - # 2. Set 'db_name' to 'neo4j' (default) - # 3. Set 'auto_create' to False to avoid 'CREATE DATABASE' permission errors. - db_name = f"memos{user_id.replace('-', '').replace('_', '')}" - if not kwargs.get("use_multi_db", False): - db_name = kwargs.get("neo4j_db_name", "neo4j") - - neo4j_config = { - "uri": kwargs.get("neo4j_uri", "bolt://localhost:7687"), - "user": kwargs.get("neo4j_user", "neo4j"), - "db_name": db_name, - "password": kwargs.get("neo4j_password", "12345678"), - "auto_create": kwargs.get("neo4j_auto_create", True), - "use_multi_db": kwargs.get("use_multi_db", False), - "embedding_dimension": kwargs.get("embedding_dimension", 3072), - } - if not kwargs.get("use_multi_db", False): - neo4j_config["user_name"] = f"memos{user_id.replace('-', '').replace('_', '')}" + graph_db_backend = kwargs.get("graph_db_backend", "neo4j").lower() + + if graph_db_backend in ("polardb", "postgres"): + # PolarDB (Postgres + Apache AGE) configuration + user_name = f"memos{user_id.replace('-', '').replace('_', '')}" + graph_db_config = { + "backend": "polardb", + "config": { + "host": kwargs.get("polar_db_host", "localhost"), + "port": int(kwargs.get("polar_db_port", 5432)), + "user": kwargs.get("polar_db_user", "postgres"), + "password": kwargs.get("polar_db_password", ""), + "db_name": kwargs.get("polar_db_name", "memos"), + "user_name": user_name, + "use_multi_db": kwargs.get("use_multi_db", False), + "auto_create": kwargs.get("neo4j_auto_create", True), + "embedding_dimension": int(kwargs.get("embedding_dimension", 1024)), + }, + } + else: + # Neo4j configuration (default) + db_name = f"memos{user_id.replace('-', '').replace('_', '')}" + if not kwargs.get("use_multi_db", False): + db_name = kwargs.get("neo4j_db_name", "neo4j") + + neo4j_config = { + "uri": kwargs.get("neo4j_uri", "bolt://localhost:7687"), + "user": kwargs.get("neo4j_user", "neo4j"), + "db_name": db_name, + "password": kwargs.get("neo4j_password", "12345678"), + "auto_create": kwargs.get("neo4j_auto_create", True), + "use_multi_db": kwargs.get("use_multi_db", False), + "embedding_dimension": int(kwargs.get("embedding_dimension", 3072)), + } + if not kwargs.get("use_multi_db", False): + neo4j_config["user_name"] = f"memos{user_id.replace('-', '').replace('_', '')}" + graph_db_config = { + "backend": "neo4j", + "config": neo4j_config, + } text_mem_config = { "backend": "tree_text", "config": { "extractor_llm": {"backend": "openai", "config": openai_config}, "dispatcher_llm": {"backend": "openai", "config": openai_config}, - "graph_db": { - "backend": "neo4j", - "config": neo4j_config, - }, + "graph_db": graph_db_config, "embedder": embedder_config, "reorganize": kwargs.get("enable_reorganize", False), }, @@ -231,27 +238,41 @@ def get_default_cube_config( "collection_name": kwargs.get("collection_name", f"{user_id}_collection"), "vector_dimension": kwargs.get("vector_dimension", 3072), "distance_metric": "cosine", + **({"host": kwargs["qdrant_host"]} if "qdrant_host" in kwargs else {}), + **({"port": kwargs["qdrant_port"]} if "qdrant_port" in kwargs else {}), }, }, "embedder": embedder_config, }, } - # Configure activation memory if enabled + # Configure activation memory if enabled. + # KV cache activation memory requires a local HuggingFace/vLLM model (it + # extracts internal attention KV tensors via build_kv_cache), so it cannot + # work with remote API backends like OpenAI or Gemini. + # Only create act_mem when activation_memory_backend is explicitly provided. act_mem_config = {} if kwargs.get("enable_activation_memory", False): - act_mem_config = { - "backend": "kv_cache", - "config": { - "memory_filename": kwargs.get( - "activation_memory_filename", "activation_memory.pickle" - ), - "extractor_llm": { - "backend": "openai", - "config": openai_config, + extractor_backend = kwargs.get("activation_memory_backend") + if extractor_backend in ("huggingface", "huggingface_singleton", "vllm"): + act_mem_config = { + "backend": "kv_cache", + "config": { + "memory_filename": kwargs.get( + "activation_memory_filename", "activation_memory.pickle" + ), + "extractor_llm": { + "backend": extractor_backend, + "config": kwargs.get("activation_memory_llm_config", {}), + }, }, - }, - } + } + else: + logger.info( + "Activation memory (kv_cache) requires a local model backend " + "(huggingface/vllm) via activation_memory_backend kwarg. " + "Skipping act_mem in MemCube config." + ) # Create MemCube configuration cube_config_dict = { diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index fbc704d0b..88fc500a2 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -110,7 +110,7 @@ def _handle_url(self, url_str: str, filename: str) -> tuple[str, str | None, boo return "", temp_file.name, False except Exception as e: logger.error(f"[FileContentParser] URL processing error: {e}") - return f"[File URL download failed: {url_str}]", None + return f"[File URL download failed: {url_str}]", None, False def _is_base64(self, data: str) -> bool: """Quick heuristic to check base64-like string.""" @@ -412,7 +412,6 @@ def parse_fast( # Extract file parameters (all are optional) file_data = file_info.get("file_data", "") file_id = file_info.get("file_id", "") - filename = file_info.get("filename", "") file_url_flag = False # Build content string based on available information content_parts = [] @@ -433,25 +432,12 @@ def parse_fast( # Check if it looks like a URL elif file_data.startswith(("http://", "https://", "file://")): file_url_flag = True - content_parts.append(f"[File URL: {file_data}]") else: # TODO: split into multiple memory items content_parts.append(file_data) else: content_parts.append(f"[File Data: {type(file_data).__name__}]") - # Priority 2: If file_id is provided, reference it - if file_id: - content_parts.append(f"[File ID: {file_id}]") - - # Priority 3: If filename is provided, include it - if filename: - content_parts.append(f"[Filename: {filename}]") - - # If no content can be extracted, create a placeholder - if not content_parts: - content_parts.append("[File: unknown]") - # Combine content parts content = " ".join(content_parts) diff --git a/src/memos/mem_reader/read_multi_modal/utils.py b/src/memos/mem_reader/read_multi_modal/utils.py index a6d910e54..40e725308 100644 --- a/src/memos/mem_reader/read_multi_modal/utils.py +++ b/src/memos/mem_reader/read_multi_modal/utils.py @@ -346,6 +346,8 @@ def detect_lang(text): r"\b(user|assistant|query|answer)\s*:", "", cleaned_text, flags=re.IGNORECASE ) cleaned_text = re.sub(r"\[[\d\-:\s]+\]", "", cleaned_text) + # remove URLs to prevent dilution of Chinese character ratio + cleaned_text = re.sub(r'https?://[^\s<>"{}|\\^`\[\]]+', "", cleaned_text) # extract chinese characters chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index 903088a4c..b103acf3a 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -61,9 +61,11 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: "neo4j": APIConfig.get_neo4j_config(user_id=user_id), "nebular": APIConfig.get_nebular_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), + "postgres": APIConfig.get_postgres_config(user_id=user_id), } - graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() + # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars + graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower() return GraphDBConfigFactory.model_validate( { "backend": graph_db_backend, diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 46770758d..63476c7cc 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -45,6 +45,43 @@ class SourceMessage(BaseModel): model_config = ConfigDict(extra="allow") +class ArchivedTextualMemory(BaseModel): + """ + This is a light-weighted class for storing archived versions of memories. + + When an existing memory item needs to be updated due to conflict/duplicate with new memory contents, + its previous contents will be preserved, in 2 places: + 1. ArchivedTextualMemory, which only contains minimal information, like memory content and create time, + stored in the 'history' field of the original node. + 2. A new memory node, storing full original information including sources and embedding, + and referenced by 'archived_memory_id'. + """ + + version: int = Field( + default=1, + description="The version of the archived memory content. Will be compared to the version of the active memory item(in Metadata)", + ) + is_fast: bool = Field( + default=False, + description="Whether this archived memory was created in fast mode, thus raw.", + ) + memory: str | None = Field( + default_factory=lambda: "", description="The content of the archived version of the memory." + ) + update_type: Literal["conflict", "duplicate", "extract", "unrelated"] = Field( + default="unrelated", + description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`).", + ) + archived_memory_id: str | None = Field( + default=None, + description="Link to a memory node with status='archived', storing full original information, including sources and embedding.", + ) + created_at: str | None = Field( + default_factory=lambda: datetime.now().isoformat(), + description="The time the memory was created.", + ) + + class TextualMemoryMetadata(BaseModel): """Metadata for a memory item. @@ -60,9 +97,29 @@ class TextualMemoryMetadata(BaseModel): default=None, description="The ID of the session during which the memory was created. Useful for tracking context in conversations.", ) - status: Literal["activated", "archived", "deleted"] | None = Field( + status: Literal["activated", "resolving", "archived", "deleted"] | None = Field( default="activated", - description="The status of the memory, e.g., 'activated', 'archived', 'deleted'.", + description="The status of the memory, e.g., 'activated', 'resolving'(updating with conflicting/duplicating new memories), 'archived', 'deleted'.", + ) + is_fast: bool | None = Field( + default=None, + description="Whether or not the memory was created in fast mode, carrying raw memory contents that haven't been edited by llms yet.", + ) + evolve_to: list[str] | None = Field( + default_factory=list, + description="Only valid if a node was once a (raw)fast node. Recording which new memory nodes it 'evolves' to after llm extraction.", + ) + version: int | None = Field( + default=None, + description="The version of the memory. Will be incremented when the memory is updated.", + ) + history: list[ArchivedTextualMemory] | None = Field( + default_factory=list, + description="Storing the archived versions of the memory. Only preserving core information of each version.", + ) + working_binding: str | None = Field( + default=None, + description="The working memory id binding of the (fast) memory.", ) type: str | None = Field(default=None) key: str | None = Field(default=None, description="Memory key or title.") diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 2df819f3a..a2ee15003 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -15,7 +15,6 @@ if TYPE_CHECKING: from memos.embedders.factory import OllamaEmbedder - from memos.graph_dbs.factory import Neo4jGraphDB from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM @@ -47,7 +46,7 @@ def __init__( self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = llm self.dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM = llm self.embedder: OllamaEmbedder = embedder - self.graph_store: Neo4jGraphDB = graph_db + self.graph_store: BaseGraphDB = graph_db self.search_strategy = config.search_strategy self.bm25_retriever = ( EnhancedBM25() diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index b556db5d7..f64058259 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -10,7 +10,8 @@ from memos.configs.memory import TreeTextMemoryConfig from memos.configs.reranker import RerankerConfigFactory from memos.embedders.factory import EmbedderFactory, OllamaEmbedder -from memos.graph_dbs.factory import GraphStoreFactory, Neo4jGraphDB +from memos.graph_dbs.base import BaseGraphDB +from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.base import BaseTextMemory @@ -47,7 +48,7 @@ def __init__(self, config: TreeTextMemoryConfig): config.dispatcher_llm ) self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder) - self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db) + self.graph_store: BaseGraphDB = GraphStoreFactory.from_config(config.graph_db) self.search_strategy = config.search_strategy self.bm25_retriever = ( @@ -166,6 +167,7 @@ def search( dedup: str | None = None, **kwargs, ) -> list[TextualMemoryItem]: + print(f"🌲 [TREE.SEARCH] query='{query}', mode={mode}, user_name={user_name}, kwargs={kwargs}", flush=True) """Search for memories based on a query. User query -> TaskGoalParser -> MemoryPathResolver -> GraphMemoryRetriever -> MemoryReranker -> MemoryReasoner -> Final output diff --git a/src/memos/memories/textual/tree_text_memory/organize/handler.py b/src/memos/memories/textual/tree_text_memory/organize/handler.py index 595cf099c..42f06c084 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/handler.py +++ b/src/memos/memories/textual/tree_text_memory/organize/handler.py @@ -6,7 +6,7 @@ from dateutil import parser from memos.embedders.base import BaseEmbedder -from memos.graph_dbs.neo4j import Neo4jGraphDB +from memos.graph_dbs.base import BaseGraphDB from memos.llms.base import BaseLLM from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata @@ -22,7 +22,7 @@ class NodeHandler: EMBEDDING_THRESHOLD: float = 0.8 # Threshold for embedding similarity to consider conflict - def __init__(self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: BaseEmbedder): + def __init__(self, graph_store: BaseGraphDB, llm: BaseLLM, embedder: BaseEmbedder): self.graph_store = graph_store self.llm = llm self.embedder = embedder diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py new file mode 100644 index 000000000..1afdc9281 --- /dev/null +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -0,0 +1,166 @@ +import logging + +from typing import Literal + +from memos.context.context import ContextThreadPoolExecutor +from memos.extras.nli_model.client import NLIClient +from memos.extras.nli_model.types import NLIResult +from memos.graph_dbs.base import BaseGraphDB +from memos.memories.textual.item import ArchivedTextualMemory, TextualMemoryItem + + +logger = logging.getLogger(__name__) + +CONFLICT_MEMORY_TITLE = "[possibly conflicting memories]" +DUPLICATE_MEMORY_TITLE = "[possibly duplicate memories]" + + +def _append_related_content( + new_item: TextualMemoryItem, duplicates: list[str], conflicts: list[str] +) -> None: + """ + Append duplicate and conflict memory contents to the new item's memory text, + truncated to avoid excessive length. + """ + max_per_item_len = 200 + max_section_len = 1000 + + def _format_section(title: str, items: list[str]) -> str: + if not items: + return "" + + section_content = "" + for mem in items: + # Truncate individual item + snippet = mem[:max_per_item_len] + "..." if len(mem) > max_per_item_len else mem + # Check total section length + if len(section_content) + len(snippet) + 5 > max_section_len: + section_content += "\n- ... (more items truncated)" + break + section_content += f"\n- {snippet}" + + return f"\n\n{title}:{section_content}" + + append_text = "" + append_text += _format_section(CONFLICT_MEMORY_TITLE, conflicts) + append_text += _format_section(DUPLICATE_MEMORY_TITLE, duplicates) + + if append_text: + new_item.memory += append_text + + +def _detach_related_content(new_item: TextualMemoryItem) -> None: + """ + Detach duplicate and conflict memory contents from the new item's memory text. + """ + markers = [f"\n\n{CONFLICT_MEMORY_TITLE}:", f"\n\n{DUPLICATE_MEMORY_TITLE}:"] + + cut_index = -1 + for marker in markers: + idx = new_item.memory.find(marker) + if idx != -1 and (cut_index == -1 or idx < cut_index): + cut_index = idx + + if cut_index != -1: + new_item.memory = new_item.memory[:cut_index] + + return + + +class MemoryHistoryManager: + def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: + """ + Initialize the MemoryHistoryManager. + + Args: + nli_client: NLIClient for conflict/duplicate detection. + graph_db: GraphDB instance for marking operations during history management. + """ + self.nli_client = nli_client + self.graph_db = graph_db + + def resolve_history_via_nli( + self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] + ) -> list[TextualMemoryItem]: + """ + Detect relationships (Duplicate/Conflict) between the new item and related items using NLI, + and attach them as history to the new fast item. + + Args: + new_item: The new memory item being added. + related_items: Existing memory items that might be related. + + Returns: + List of duplicate or conflicting memory items judged by the NLI service. + """ + if not related_items: + return [] + + # 1. Call NLI + nli_results = self.nli_client.compare_one_to_many( + new_item.memory, [r.memory for r in related_items] + ) + + # 2. Process results and attach to history + duplicate_memories = [] + conflict_memories = [] + + for r_item, nli_res in zip(related_items, nli_results, strict=False): + if nli_res == NLIResult.DUPLICATE: + update_type = "duplicate" + duplicate_memories.append(r_item.memory) + elif nli_res == NLIResult.CONTRADICTION: + update_type = "conflict" + conflict_memories.append(r_item.memory) + else: + update_type = "unrelated" + + # Safely get created_at, fallback to updated_at + created_at = getattr(r_item.metadata, "created_at", None) or r_item.metadata.updated_at + + archived = ArchivedTextualMemory( + version=r_item.metadata.version or 1, + is_fast=r_item.metadata.is_fast or False, + memory=r_item.memory, + update_type=update_type, + archived_memory_id=r_item.id, + created_at=created_at, + ) + new_item.metadata.history.append(archived) + logger.info( + f"[MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" + ) + + # 3. Concat duplicate/conflict memories to new_item.memory + # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss. + _append_related_content(new_item, duplicate_memories, conflict_memories) + + return duplicate_memories + conflict_memories + + def mark_memory_status( + self, + memory_items: list[TextualMemoryItem], + status: Literal["activated", "resolving", "archived", "deleted"], + ) -> None: + """ + Support status marking operations during history management. Common usages are: + 1. Mark conflict/duplicate old memories' status as "resolving", + to make them invisible to /search api, but still visible for PreUpdateRetriever. + 2. Mark resolved memories' status as "activated", to restore their visibility. + """ + # Execute the actual marking operation - in db. + with ContextThreadPoolExecutor() as executor: + futures = [] + for mem in memory_items: + futures.append( + executor.submit( + self.graph_db.update_node, + id=mem.id, + fields={"status": status}, + ) + ) + + # Wait for all tasks to complete and raise any exceptions + for future in futures: + future.result() + return diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 5e9c74f61..80d4bb6f9 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -7,7 +7,7 @@ from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import OllamaEmbedder -from memos.graph_dbs.neo4j import Neo4jGraphDB +from memos.graph_dbs.base import BaseGraphDB from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata @@ -54,7 +54,7 @@ def extract_working_binding_ids(mem_items: list[TextualMemoryItem]) -> set[str]: class MemoryManager: def __init__( self, - graph_store: Neo4jGraphDB, + graph_store: BaseGraphDB, embedder: OllamaEmbedder, llm: OpenAILLM | OllamaLLM | AzureLLM, memory_size: dict | None = None, diff --git a/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py b/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py index ad9dcb2b8..d19f26bd4 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +++ b/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py @@ -3,7 +3,7 @@ from memos.embedders.factory import OllamaEmbedder from memos.graph_dbs.item import GraphDBNode -from memos.graph_dbs.neo4j import Neo4jGraphDB +from memos.graph_dbs.base import BaseGraphDB from memos.llms.base import BaseLLM from memos.log import get_logger from memos.memories.textual.item import TreeNodeTextualMemoryMetadata @@ -18,7 +18,7 @@ class RelationAndReasoningDetector: - def __init__(self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: OllamaEmbedder): + def __init__(self, graph_store: BaseGraphDB, llm: BaseLLM, embedder: OllamaEmbedder): self.graph_store = graph_store self.llm = llm self.embedder = embedder diff --git a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py index ea06a7c60..656c6d5e4 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py +++ b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py @@ -13,7 +13,7 @@ from memos.dependency import require_python_package from memos.embedders.factory import OllamaEmbedder from memos.graph_dbs.item import GraphDBEdge, GraphDBNode -from memos.graph_dbs.neo4j import Neo4jGraphDB +from memos.graph_dbs.base import BaseGraphDB from memos.llms.base import BaseLLM from memos.log import get_logger from memos.memories.textual.item import SourceMessage, TreeNodeTextualMemoryMetadata @@ -78,7 +78,7 @@ def extract_first_to_last_brace(text: str): class GraphStructureReorganizer: def __init__( - self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: OllamaEmbedder, is_reorganize: bool + self, graph_store: BaseGraphDB, llm: BaseLLM, embedder: OllamaEmbedder, is_reorganize: bool ): self.queue = PriorityQueue() # Min-heap self.graph_store = graph_store diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py index e58ebcdd1..de89a909c 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py @@ -4,7 +4,7 @@ from typing import Any from memos.embedders.factory import OllamaEmbedder -from memos.graph_dbs.factory import Neo4jGraphDB +from memos.graph_dbs.base import BaseGraphDB from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata @@ -26,7 +26,7 @@ class AdvancedSearcher(Searcher): def __init__( self, dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM, - graph_store: Neo4jGraphDB, + graph_store: BaseGraphDB, embedder: OllamaEmbedder, reranker: BaseReranker, bm25_retriever: EnhancedBM25 | None = None, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py index a5fc7e049..cb77d2243 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py @@ -163,14 +163,14 @@ def keyword_search( results = [] - # 2. Try seach_by_keywords_tfidf (PolarDB specific) - if hasattr(self.graph_db, "seach_by_keywords_tfidf"): + # 2. Try search_by_keywords_tfidf (PolarDB specific) + if hasattr(self.graph_db, "search_by_keywords_tfidf"): try: - results = self.graph_db.seach_by_keywords_tfidf( + results = self.graph_db.search_by_keywords_tfidf( query_words=keywords, user_name=user_name, filter=search_filter ) except Exception as e: - logger.warning(f"[PreUpdateRetriever] seach_by_keywords_tfidf failed: {e}") + logger.warning(f"[PreUpdateRetriever] search_by_keywords_tfidf failed: {e}") # 3. Fallback to search_by_fulltext if not results and hasattr(self.graph_db, "search_by_fulltext"): diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index c9f2ec156..f19dc192b 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -2,7 +2,7 @@ from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import OllamaEmbedder -from memos.graph_dbs.neo4j import Neo4jGraphDB +from memos.graph_dbs.base import BaseGraphDB from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 @@ -19,7 +19,7 @@ class GraphMemoryRetriever: def __init__( self, - graph_store: Neo4jGraphDB, + graph_store: BaseGraphDB, embedder: OllamaEmbedder, bm25_retriever: EnhancedBM25 | None = None, include_embedding: bool = False, @@ -335,7 +335,9 @@ def _vector_recall( Perform vector-based similarity retrieval using query embedding. # TODO: tackle with post-filter and pre-filter(5.18+) better. """ + logger.warning(f"[_vector_recall_DEBUG] Called with {len(query_embedding) if query_embedding else 0} embeddings, memory_scope: {memory_scope}, top_k: {top_k}") if not query_embedding: + logger.warning(f"[_vector_recall_DEBUG] Empty query_embedding, returning empty list") return [] def search_single(vec, search_priority=None, search_filter=None): @@ -385,24 +387,52 @@ def search_path_b(): path_a_future = executor.submit(search_path_a) path_b_future = executor.submit(search_path_b) - all_hits.extend(path_a_future.result()) - all_hits.extend(path_b_future.result()) + path_a_results = path_a_future.result() + path_b_results = path_b_future.result() + logger.warning(f"[_vector_recall_DEBUG] Path A returned {len(path_a_results)} hits") + logger.warning(f"[_vector_recall_DEBUG] Path B returned {len(path_b_results)} hits") + all_hits.extend(path_a_results) + all_hits.extend(path_b_results) if not all_hits: + logger.warning(f"[_vector_recall_DEBUG] No hits found, returning empty list") return [] - # merge and deduplicate - unique_ids = {r["id"] for r in all_hits if r.get("id")} + # merge and deduplicate, keeping highest score per ID + id_to_score = {} + for r in all_hits: + rid = r.get("id") + if rid: + score = r.get("score", 0.0) + if rid not in id_to_score or score > id_to_score[rid]: + id_to_score[rid] = score + + # Sort IDs by score (descending) to preserve ranking + sorted_ids = sorted(id_to_score.keys(), key=lambda x: id_to_score[x], reverse=True) + node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), + sorted_ids, include_embedding=self.include_embedding, cube_name=cube_name, user_name=user_name, ) or [] ) - return [TextualMemoryItem.from_dict(n) for n in node_dicts] + + # Restore score-based order and inject scores into metadata + id_to_node = {n.get("id"): n for n in node_dicts} + ordered_nodes = [] + for rid in sorted_ids: + if rid in id_to_node: + node = id_to_node[rid] + # Inject similarity score as relativity + if "metadata" not in node: + node["metadata"] = {} + node["metadata"]["relativity"] = id_to_score.get(rid, 0.0) + ordered_nodes.append(node) + + return [TextualMemoryItem.from_dict(n) for n in ordered_nodes] def _bm25_recall( self, @@ -484,15 +514,38 @@ def _fulltext_recall( if not all_hits: return [] - # merge and deduplicate - unique_ids = {r["id"] for r in all_hits if r.get("id")} + # merge and deduplicate, keeping highest score per ID + id_to_score = {} + for r in all_hits: + rid = r.get("id") + if rid: + score = r.get("score", 0.0) + if rid not in id_to_score or score > id_to_score[rid]: + id_to_score[rid] = score + + # Sort IDs by score (descending) to preserve ranking + sorted_ids = sorted(id_to_score.keys(), key=lambda x: id_to_score[x], reverse=True) + node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), + sorted_ids, include_embedding=self.include_embedding, cube_name=cube_name, user_name=user_name, ) or [] ) - return [TextualMemoryItem.from_dict(n) for n in node_dicts] + + # Restore score-based order and inject scores into metadata + id_to_node = {n.get("id"): n for n in node_dicts} + ordered_nodes = [] + for rid in sorted_ids: + if rid in id_to_node: + node = id_to_node[rid] + # Inject similarity score as relativity + if "metadata" not in node: + node["metadata"] = {} + node["metadata"]["relativity"] = id_to_score.get(rid, 0.0) + ordered_nodes.append(node) + + return [TextualMemoryItem.from_dict(n) for n in ordered_nodes] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py b/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py index 861343e20..b8ab813dc 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py @@ -78,7 +78,11 @@ def rerank( embeddings = [item.metadata.embedding for item in items_with_embeddings] if not embeddings: - return [(item, 0.5) for item in graph_results[:top_k]] + # Use relativity from recall stage if available, otherwise default to 0.5 + return [ + (item, getattr(item.metadata, "relativity", None) or 0.5) + for item in graph_results[:top_k] + ] # Step 2: Compute cosine similarities similarity_scores = batch_cosine_similarity(query_embedding, embeddings) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index dcd4e1fba..7f4dcd43a 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -2,7 +2,7 @@ from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import OllamaEmbedder -from memos.graph_dbs.factory import Neo4jGraphDB +from memos.graph_dbs.base import BaseGraphDB from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem @@ -39,7 +39,7 @@ class Searcher: def __init__( self, dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM, - graph_store: Neo4jGraphDB, + graph_store: BaseGraphDB, embedder: OllamaEmbedder, reranker: BaseReranker, bm25_retriever: EnhancedBM25 | None = None, @@ -85,6 +85,7 @@ def retrieve( skill_mem_top_k: int = 3, **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: + print(f"🔍 [SEARCHER.RETRIEVE] query='{query}', mode={mode}, kwargs={kwargs}", flush=True) logger.info( f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" ) @@ -301,10 +302,18 @@ def _parse_task( **kwargs, ) + # DIAGNOSTIC: Log parsed goal + print(f"🔍 [PARSE_TASK] memories={parsed_goal.memories}, rephrased={parsed_goal.rephrased_query}", flush=True) + logger.warning(f"[_parse_task_DEBUG] Parsed goal: memories={parsed_goal.memories}, rephrased_query={parsed_goal.rephrased_query}") + logger.warning(f"[_parse_task_DEBUG] Parsed goal keys={parsed_goal.keys}, tags={parsed_goal.tags}, internet_search={parsed_goal.internet_search}") + query = parsed_goal.rephrased_query or query # if goal has extra memories, embed them too if parsed_goal.memories: query_embedding = self.embedder.embed(list({query, *parsed_goal.memories})) + logger.warning(f"[_parse_task_DEBUG] Generated {len(query_embedding)} embeddings from parsed_goal.memories") + else: + logger.warning(f"[_parse_task_DEBUG] parsed_goal.memories is EMPTY - query_embedding will be None!") return parsed_goal, query_embedding, context, query @@ -751,16 +760,32 @@ def _retrieve_simple( query_words = list(set(query_words))[: top_k * 3] query_words = [query, *query_words] logger.info(f"[SIMPLESEARCH] Query words: {query_words}") - query_embeddings = self.embedder.embed(query_words) - items = self.graph_retriever.retrieve_from_mixed( - top_k=top_k * 2, - memory_scope=None, - query_embedding=query_embeddings, - search_filter=search_filter, - user_name=user_name, - ) - logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") + # DIAGNOSTIC: Log embedder config + logger.warning(f"[SIMPLESEARCH_DEBUG] Embedder type: {type(self.embedder).__name__}") + logger.warning(f"[SIMPLESEARCH_DEBUG] Embedder config: {getattr(self.embedder, 'config', 'No config attr')}") + + try: + query_embeddings = self.embedder.embed(query_words) + logger.warning(f"[SIMPLESEARCH_DEBUG] Successfully generated {len(query_embeddings)} embeddings, dims: {len(query_embeddings[0]) if query_embeddings else 'N/A'}") + except Exception as e: + logger.error(f"[SIMPLESEARCH_DEBUG] EMBEDDER FAILED: {type(e).__name__}: {e}", exc_info=True) + return [] + + logger.warning(f"[SIMPLESEARCH_DEBUG] Calling retrieve_from_mixed with {len(query_embeddings)} embeddings") + try: + items = self.graph_retriever.retrieve_from_mixed( + top_k=top_k * 2, + memory_scope=None, + query_embedding=query_embeddings, + search_filter=search_filter, + user_name=user_name, + ) + logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") + logger.warning(f"[SIMPLESEARCH_DEBUG] Retrieved items: {[item.id for item in items] if items else 'NONE'}") + except Exception as e: + logger.error(f"[SIMPLESEARCH_DEBUG] retrieve_from_mixed FAILED: {type(e).__name__}: {e}", exc_info=True) + return [] documents = [getattr(item, "memory", "") for item in items] if not documents: return [] diff --git a/src/memos/types/openai_chat_completion_types/__init__.py b/src/memos/types/openai_chat_completion_types/__init__.py index 4a08a9f24..025e75360 100644 --- a/src/memos/types/openai_chat_completion_types/__init__.py +++ b/src/memos/types/openai_chat_completion_types/__init__.py @@ -1,4 +1,4 @@ -# ruff: noqa: F403, F401 +# ruff: noqa: F403 from .chat_completion_assistant_message_param import * from .chat_completion_content_part_image_param import * diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py index 3c5638788..f28796c2d 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py @@ -1,4 +1,4 @@ -# ruff: noqa: TC001, TC003 +# ruff: noqa: TC001 from __future__ import annotations diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py index ea2101229..13a9a89af 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py @@ -1,4 +1,4 @@ -# ruff: noqa: TC001, TC003 +# ruff: noqa: TC001 from __future__ import annotations diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py index 99c845d11..f76f2b862 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py @@ -1,4 +1,4 @@ -# ruff: noqa: TC001, TC003 +# ruff: noqa: TC001 from __future__ import annotations diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py index 8c004f340..b5bee9842 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py @@ -1,4 +1,4 @@ -# ruff: noqa: TC001, TC003 +# ruff: noqa: TC001 from __future__ import annotations diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py new file mode 100644 index 000000000..46cf3a1f6 --- /dev/null +++ b/tests/memories/textual/test_history_manager.py @@ -0,0 +1,137 @@ +import uuid + +from unittest.mock import MagicMock + +import pytest + +from memos.extras.nli_model.client import NLIClient +from memos.extras.nli_model.types import NLIResult +from memos.graph_dbs.base import BaseGraphDB +from memos.memories.textual.item import ( + TextualMemoryItem, + TextualMemoryMetadata, +) +from memos.memories.textual.tree_text_memory.organize.history_manager import ( + MemoryHistoryManager, + _append_related_content, + _detach_related_content, +) + + +@pytest.fixture +def mock_nli_client(): + client = MagicMock(spec=NLIClient) + return client + + +@pytest.fixture +def mock_graph_db(): + return MagicMock(spec=BaseGraphDB) + + +@pytest.fixture +def history_manager(mock_nli_client, mock_graph_db): + return MemoryHistoryManager(nli_client=mock_nli_client, graph_db=mock_graph_db) + + +def test_detach_related_content(): + original_memory = "This is the original memory content." + item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + + duplicates = ["Duplicate 1", "Duplicate 2"] + conflicts = ["Conflict 1", "Conflict 2"] + + # 1. Append content + _append_related_content(item, duplicates, conflicts) + + # Verify content was appended + assert item.memory != original_memory + assert "[possibly conflicting memories]" in item.memory + assert "[possibly duplicate memories]" in item.memory + assert "Duplicate 1" in item.memory + assert "Conflict 1" in item.memory + + # 2. Detach content + _detach_related_content(item) + + # 3. Verify content is restored + assert item.memory == original_memory + + +def test_detach_only_conflicts(): + original_memory = "Original memory." + item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + + duplicates = [] + conflicts = ["Conflict A"] + + _append_related_content(item, duplicates, conflicts) + assert "Conflict A" in item.memory + assert "Duplicate" not in item.memory + + _detach_related_content(item) + assert item.memory == original_memory + + +def test_detach_only_duplicates(): + original_memory = "Original memory." + item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + + duplicates = ["Duplicate A"] + conflicts = [] + + _append_related_content(item, duplicates, conflicts) + assert "Duplicate A" in item.memory + assert "Conflict" not in item.memory + + _detach_related_content(item) + assert item.memory == original_memory + + +def test_truncation(history_manager, mock_nli_client): + # Setup + new_item = TextualMemoryItem(memory="Test") + long_memory = "A" * 300 + related_item = TextualMemoryItem(memory=long_memory) + + mock_nli_client.compare_one_to_many.return_value = [NLIResult.DUPLICATE] + + # Action + history_manager.resolve_history_via_nli(new_item, [related_item]) + + # Assert + assert "possibly duplicate memories" in new_item.memory + assert "..." in new_item.memory # Should be truncated + assert len(new_item.memory) < 1000 # Ensure reasonable length + + +def test_empty_related_items(history_manager, mock_nli_client): + new_item = TextualMemoryItem(memory="Test") + history_manager.resolve_history_via_nli(new_item, []) + + mock_nli_client.compare_one_to_many.assert_not_called() + assert new_item.metadata.history is None or len(new_item.metadata.history) == 0 + + +def test_mark_memory_status(history_manager, mock_graph_db): + # Setup + id1 = uuid.uuid4().hex + id2 = uuid.uuid4().hex + id3 = uuid.uuid4().hex + items = [ + TextualMemoryItem(memory="M1", id=id1), + TextualMemoryItem(memory="M2", id=id2), + TextualMemoryItem(memory="M3", id=id3), + ] + status = "resolving" + + # Action + history_manager.mark_memory_status(items, status) + + # Assert + assert mock_graph_db.update_node.call_count == 3 + + # Verify we called it correctly + mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}) + mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}) + mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status})