diff --git a/overlays/krolik/api/middleware/auth.py b/overlays/krolik/api/middleware/auth.py index 30349c9c4..15b217651 100644 --- a/overlays/krolik/api/middleware/auth.py +++ b/overlays/krolik/api/middleware/auth.py @@ -8,6 +8,7 @@ import hashlib import os import time + from typing import Any from fastapi import Depends, HTTPException, Request, Security @@ -15,6 +16,7 @@ import memos.log + logger = memos.log.get_logger(__name__) # API key header configuration @@ -149,10 +151,7 @@ def is_internal_request(request: Request) -> bool: # Check internal header (for container-to-container) internal_header = request.headers.get("X-Internal-Service") - if internal_header == os.getenv("INTERNAL_SERVICE_SECRET"): - return True - - return False + return internal_header == os.getenv("INTERNAL_SERVICE_SECRET") async def verify_api_key( @@ -245,8 +244,9 @@ def require_scope(required_scope: str): Usage: @router.post("/admin/keys", dependencies=[Depends(require_scope("admin"))]) """ + async def scope_checker( - auth: dict[str, Any] = Depends(verify_api_key), + auth: dict[str, Any] = Depends(verify_api_key), # noqa: B008 ) -> dict[str, Any]: scopes = auth.get("scopes", []) diff --git a/overlays/krolik/api/middleware/rate_limit.py b/overlays/krolik/api/middleware/rate_limit.py index 12ee84ef4..c547378ca 100644 --- a/overlays/krolik/api/middleware/rate_limit.py +++ b/overlays/krolik/api/middleware/rate_limit.py @@ -7,8 +7,10 @@ import os import time + from collections import defaultdict -from typing import Callable +from collections.abc import Callable +from typing import ClassVar from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request @@ -16,6 +18,7 @@ import memos.log + logger = memos.log.get_logger(__name__) # Configuration from environment @@ -131,7 +134,11 @@ def _check_rate_limit_memory(key: str) -> tuple[bool, int, int]: current_count = len(_memory_store[key]) if current_count >= RATE_LIMIT: - reset_time = int(min(_memory_store[key]) + RATE_WINDOW) if _memory_store[key] else int(now + RATE_WINDOW) + reset_time = ( + int(min(_memory_store[key]) + RATE_WINDOW) + if _memory_store[key] + else int(now + RATE_WINDOW) + ) return False, 0, reset_time # Add current request @@ -156,7 +163,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware): """ # Paths exempt from rate limiting - EXEMPT_PATHS = {"/health", "/openapi.json", "/docs", "/redoc"} + EXEMPT_PATHS: ClassVar[set[str]] = {"/health", "/openapi.json", "/docs", "/redoc"} async def dispatch(self, request: Request, call_next: Callable) -> Response: # Skip rate limiting for exempt paths diff --git a/overlays/krolik/api/routers/admin_router.py b/overlays/krolik/api/routers/admin_router.py index 939e5101f..238643ba9 100644 --- a/overlays/krolik/api/routers/admin_router.py +++ b/overlays/krolik/api/routers/admin_router.py @@ -5,12 +5,14 @@ """ import os + from typing import Any from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field import memos.log + from memos.api.middleware.auth import require_scope, verify_api_key from memos.api.utils.api_keys import ( create_api_key_in_db, @@ -19,6 +21,7 @@ revoke_api_key, ) + logger = memos.log.get_logger(__name__) router = APIRouter(prefix="/admin", tags=["Admin"]) @@ -75,7 +78,7 @@ def _get_db_connection(): ) def create_key( request: CreateKeyRequest, - auth: dict = Depends(verify_api_key), + auth: dict = Depends(verify_api_key), # noqa: B008 ): """ Create a new API key for a user. @@ -111,7 +114,7 @@ def create_key( conn.close() except Exception as e: logger.error(f"Failed to create API key: {e}") - raise HTTPException(status_code=500, detail="Failed to create API key") + raise HTTPException(status_code=500, detail="Failed to create API key") from e @router.get( @@ -122,7 +125,7 @@ def create_key( ) def list_keys( user_name: str | None = None, - auth: dict = Depends(verify_api_key), + auth: dict = Depends(verify_api_key), # noqa: B008 ): """ List all API keys (admin) or keys for a specific user. @@ -141,7 +144,7 @@ def list_keys( conn.close() except Exception as e: logger.error(f"Failed to list API keys: {e}") - raise HTTPException(status_code=500, detail="Failed to list API keys") + raise HTTPException(status_code=500, detail="Failed to list API keys") from e @router.delete( @@ -152,7 +155,7 @@ def list_keys( ) def revoke_key( key_id: str, - auth: dict = Depends(verify_api_key), + auth: dict = Depends(verify_api_key), # noqa: B008 ): """ Revoke an API key by ID. @@ -174,7 +177,7 @@ def revoke_key( raise except Exception as e: logger.error(f"Failed to revoke API key: {e}") - raise HTTPException(status_code=500, detail="Failed to revoke API key") + raise HTTPException(status_code=500, detail="Failed to revoke API key") from e @router.post( @@ -184,7 +187,7 @@ def revoke_key( dependencies=[Depends(require_scope("admin"))], ) def generate_new_master_key( - auth: dict = Depends(verify_api_key), + auth: dict = Depends(verify_api_key), # noqa: B008 ): """ Generate a new master key. diff --git a/overlays/krolik/api/server_api_ext.py b/overlays/krolik/api/server_api_ext.py index 85b9411af..8c457e362 100644 --- a/overlays/krolik/api/server_api_ext.py +++ b/overlays/krolik/api/server_api_ext.py @@ -10,7 +10,7 @@ Usage in Dockerfile: # Copy overlays after base installation COPY overlays/krolik/ /app/src/memos/ - + # Use this as entrypoint instead of server_api CMD ["gunicorn", "memos.api.server_api_ext:app", ...] """ @@ -25,16 +25,18 @@ from starlette.requests import Request from starlette.responses import Response -# Import base routers from MemOS -from memos.api.routers.server_router import router as server_router - # Import Krolik extensions from memos.api.middleware.rate_limit import RateLimitMiddleware from memos.api.routers.admin_router import router as admin_router +# Import base routers from MemOS +from memos.api.routers.server_router import router as server_router + + # Try to import exception handlers (may vary between MemOS versions) try: from memos.api.exceptions import APIExceptionHandler + HAS_EXCEPTION_HANDLER = True except ImportError: HAS_EXCEPTION_HANDLER = False @@ -98,6 +100,7 @@ async def dispatch(self, request: Request, call_next) -> Response: # Exception handlers if HAS_EXCEPTION_HANDLER: from fastapi import HTTPException + app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler) app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler) app.exception_handler(HTTPException)(APIExceptionHandler.http_error_handler) @@ -117,4 +120,5 @@ async def health_check(): if __name__ == "__main__": import uvicorn + uvicorn.run("memos.api.server_api_ext:app", host="0.0.0.0", port=8000, workers=1) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index bed1d6899..3c1ad959b 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -970,7 +970,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene "postgres": postgres_config, } # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars - graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")).lower() + graph_db_backend = os.getenv( + "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community") + ).lower() if graph_db_backend in graph_db_backend_map: # Create MemCube config @@ -1052,7 +1054,9 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": else None ) # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars - graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")).lower() + graph_db_backend = os.getenv( + "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community") + ).lower() if graph_db_backend in graph_db_backend_map: return GeneralMemCubeConfig.model_validate( { diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 5ce9faad1..9b1ce7f9d 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -247,7 +247,9 @@ class PostgresGraphDBConfig(BaseConfig): default=False, description="If False: use single database with logical isolation by user_name", ) - embedding_dimension: int = Field(default=768, description="Dimension of vector embedding (768 for all-mpnet-base-v2)") + embedding_dimension: int = Field( + default=768, description="Dimension of vector embedding (768 for all-mpnet-base-v2)" + ) maxconn: int = Field( default=20, description="Maximum number of connections in the connection pool", diff --git a/src/memos/graph_dbs/postgres.py b/src/memos/graph_dbs/postgres.py index f9065d718..1c1cae378 100644 --- a/src/memos/graph_dbs/postgres.py +++ b/src/memos/graph_dbs/postgres.py @@ -11,6 +11,7 @@ import json import time + from contextlib import suppress from datetime import datetime from typing import Any, Literal @@ -20,6 +21,7 @@ from memos.graph_dbs.base import BaseGraphDB from memos.log import get_logger + logger = get_logger(__name__) @@ -199,7 +201,8 @@ def remove_oldest_memory( try: with conn.cursor() as cur: # Find IDs to delete (older than the keep_latest entries) - cur.execute(f""" + cur.execute( + f""" WITH ranked AS ( SELECT id, ROW_NUMBER() OVER (ORDER BY updated_at DESC) as rn FROM {self.schema}.memories @@ -207,24 +210,34 @@ def remove_oldest_memory( AND properties->>'memory_type' = %s ) SELECT id FROM ranked WHERE rn > %s - """, (user_name, memory_type, keep_latest)) + """, + (user_name, memory_type, keep_latest), + ) ids_to_delete = [row[0] for row in cur.fetchall()] if ids_to_delete: # Delete edges first - cur.execute(f""" + cur.execute( + f""" DELETE FROM {self.schema}.edges WHERE source_id = ANY(%s) OR target_id = ANY(%s) - """, (ids_to_delete, ids_to_delete)) + """, + (ids_to_delete, ids_to_delete), + ) # Delete nodes - cur.execute(f""" + cur.execute( + f""" DELETE FROM {self.schema}.memories WHERE id = ANY(%s) - """, (ids_to_delete,)) + """, + (ids_to_delete,), + ) - logger.info(f"Removed {len(ids_to_delete)} oldest {memory_type} memories for user {user_name}") + logger.info( + f"Removed {len(ids_to_delete)} oldest {memory_type} memories for user {user_name}" + ) finally: self._put_conn(conn) @@ -243,15 +256,15 @@ def add_node( # Serialize sources if present if metadata.get("sources"): metadata["sources"] = [ - json.dumps(s) if not isinstance(s, str) else s - for s in metadata["sources"] + json.dumps(s) if not isinstance(s, str) else s for s in metadata["sources"] ] conn = self._get_conn() try: with conn.cursor() as cur: if embedding: - cur.execute(f""" + cur.execute( + f""" INSERT INTO {self.schema}.memories (id, memory, properties, embedding, user_name, created_at, updated_at) VALUES (%s, %s, %s, %s::vector, %s, %s, %s) @@ -260,9 +273,20 @@ def add_node( properties = EXCLUDED.properties, embedding = EXCLUDED.embedding, updated_at = EXCLUDED.updated_at - """, (id, memory, json.dumps(metadata), embedding, user_name, created_at, updated_at)) + """, + ( + id, + memory, + json.dumps(metadata), + embedding, + user_name, + created_at, + updated_at, + ), + ) else: - cur.execute(f""" + cur.execute( + f""" INSERT INTO {self.schema}.memories (id, memory, properties, user_name, created_at, updated_at) VALUES (%s, %s, %s, %s, %s, %s) @@ -270,13 +294,13 @@ def add_node( memory = EXCLUDED.memory, properties = EXCLUDED.properties, updated_at = EXCLUDED.updated_at - """, (id, memory, json.dumps(metadata), user_name, created_at, updated_at)) + """, + (id, memory, json.dumps(metadata), user_name, created_at, updated_at), + ) finally: self._put_conn(conn) - def add_nodes_batch( - self, nodes: list[dict[str, Any]], user_name: str | None = None - ) -> None: + def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = None) -> None: """Batch add memory nodes.""" for node in nodes: self.add_node( @@ -308,17 +332,23 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N try: with conn.cursor() as cur: if embedding: - cur.execute(f""" + cur.execute( + f""" UPDATE {self.schema}.memories SET memory = %s, properties = %s, embedding = %s::vector, updated_at = NOW() WHERE id = %s AND user_name = %s - """, (memory, json.dumps(props), embedding, id, user_name)) + """, + (memory, json.dumps(props), embedding, id, user_name), + ) else: - cur.execute(f""" + cur.execute( + f""" UPDATE {self.schema}.memories SET memory = %s, properties = %s, updated_at = NOW() WHERE id = %s AND user_name = %s - """, (memory, json.dumps(props), id, user_name)) + """, + (memory, json.dumps(props), id, user_name), + ) finally: self._put_conn(conn) @@ -329,15 +359,21 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: try: with conn.cursor() as cur: # Delete edges - cur.execute(f""" + cur.execute( + f""" DELETE FROM {self.schema}.edges WHERE source_id = %s OR target_id = %s - """, (id, id)) + """, + (id, id), + ) # Delete node - cur.execute(f""" + cur.execute( + f""" DELETE FROM {self.schema}.memories WHERE id = %s AND user_name = %s - """, (id, user_name)) + """, + (id, user_name), + ) finally: self._put_conn(conn) @@ -350,10 +386,13 @@ def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[s cols = "id, memory, properties, created_at, updated_at" if include_embedding: cols += ", embedding" - cur.execute(f""" + cur.execute( + f""" SELECT {cols} FROM {self.schema}.memories WHERE id = %s AND user_name = %s - """, (id, user_name)) + """, + (id, user_name), + ) row = cur.fetchone() if not row: return None @@ -374,10 +413,13 @@ def get_nodes( cols = "id, memory, properties, created_at, updated_at" if include_embedding: cols += ", embedding" - cur.execute(f""" + cur.execute( + f""" SELECT {cols} FROM {self.schema}.memories WHERE id = ANY(%s) AND user_name = %s - """, (ids, user_name)) + """, + (ids, user_name), + ) return [self._parse_row(row, include_embedding) for row in cur.fetchall()] finally: self._put_conn(conn) @@ -407,11 +449,14 @@ def add_edge( conn = self._get_conn() try: with conn.cursor() as cur: - cur.execute(f""" + cur.execute( + f""" INSERT INTO {self.schema}.edges (source_id, target_id, edge_type) VALUES (%s, %s, %s) ON CONFLICT (source_id, target_id, edge_type) DO NOTHING - """, (source_id, target_id, type)) + """, + (source_id, target_id, type), + ) finally: self._put_conn(conn) @@ -422,10 +467,13 @@ def delete_edge( conn = self._get_conn() try: with conn.cursor() as cur: - cur.execute(f""" + cur.execute( + f""" DELETE FROM {self.schema}.edges WHERE source_id = %s AND target_id = %s AND edge_type = %s - """, (source_id, target_id, type)) + """, + (source_id, target_id, type), + ) finally: self._put_conn(conn) @@ -434,11 +482,14 @@ def edge_exists(self, source_id: str, target_id: str, type: str) -> bool: conn = self._get_conn() try: with conn.cursor() as cur: - cur.execute(f""" + cur.execute( + f""" SELECT 1 FROM {self.schema}.edges WHERE source_id = %s AND target_id = %s AND edge_type = %s LIMIT 1 - """, (source_id, target_id, type)) + """, + (source_id, target_id, type), + ) return cur.fetchone() is not None finally: self._put_conn(conn) @@ -455,21 +506,30 @@ def get_neighbors( try: with conn.cursor() as cur: if direction == "out": - cur.execute(f""" + cur.execute( + f""" SELECT target_id FROM {self.schema}.edges WHERE source_id = %s AND edge_type = %s - """, (id, type)) + """, + (id, type), + ) elif direction == "in": - cur.execute(f""" + cur.execute( + f""" SELECT source_id FROM {self.schema}.edges WHERE target_id = %s AND edge_type = %s - """, (id, type)) + """, + (id, type), + ) else: # both - cur.execute(f""" + cur.execute( + f""" SELECT target_id FROM {self.schema}.edges WHERE source_id = %s AND edge_type = %s UNION SELECT source_id FROM {self.schema}.edges WHERE target_id = %s AND edge_type = %s - """, (id, type, id, type)) + """, + (id, type, id, type), + ) return [row[0] for row in cur.fetchall()] finally: self._put_conn(conn) @@ -479,7 +539,8 @@ def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[s conn = self._get_conn() try: with conn.cursor() as cur: - cur.execute(f""" + cur.execute( + f""" WITH RECURSIVE path AS ( SELECT source_id, target_id, ARRAY[source_id] as nodes, 1 as depth FROM {self.schema}.edges @@ -495,7 +556,9 @@ def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[s WHERE target_id = %s ORDER BY depth LIMIT 1 - """, (source_id, max_depth, target_id)) + """, + (source_id, max_depth, target_id), + ) row = cur.fetchone() return row[0] if row else [] finally: @@ -506,7 +569,8 @@ def get_subgraph(self, center_id: str, depth: int = 2) -> list[str]: conn = self._get_conn() try: with conn.cursor() as cur: - cur.execute(f""" + cur.execute( + f""" WITH RECURSIVE subgraph AS ( SELECT %s::text as node_id, 0 as level UNION @@ -517,7 +581,9 @@ def get_subgraph(self, center_id: str, depth: int = 2) -> list[str]: WHERE s.level < %s ) SELECT DISTINCT node_id FROM subgraph - """, (center_id, depth)) + """, + (center_id, depth), + ) return [row[0] for row in cur.fetchall()] finally: self._put_conn(conn) @@ -562,7 +628,9 @@ def search_by_embedding( conditions.append("properties->>'status' = %s") params.append(status) else: - conditions.append("(properties->>'status' = 'activated' OR properties->>'status' IS NULL)") + conditions.append( + "(properties->>'status' = 'activated' OR properties->>'status' IS NULL)" + ) if search_filter: for k, v in search_filter.items(): @@ -575,13 +643,16 @@ def search_by_embedding( conn = self._get_conn() try: with conn.cursor() as cur: - cur.execute(f""" + cur.execute( + f""" SELECT id, 1 - (embedding <=> %s::vector) as score FROM {self.schema}.memories WHERE {where_clause} ORDER BY embedding <=> %s::vector LIMIT %s - """, (vector, *params, vector, top_k)) + """, + (vector, *params, vector, top_k), + ) results = [] for row in cur.fetchall(): @@ -639,10 +710,13 @@ def get_by_metadata( conn = self._get_conn() try: with conn.cursor() as cur: - cur.execute(f""" + cur.execute( + f""" SELECT id FROM {self.schema}.memories WHERE {where_clause} - """, params) + """, + params, + ) return [row[0] for row in cur.fetchall()] finally: self._put_conn(conn) @@ -674,10 +748,13 @@ def get_all_memory_items( cols = "id, memory, properties, created_at, updated_at" if include_embedding: cols += ", embedding" - cur.execute(f""" + cur.execute( + f""" SELECT {cols} FROM {self.schema}.memories WHERE {where_clause} - """, params) + """, + params, + ) return [self._parse_row(row, include_embedding) for row in cur.fetchall()] finally: self._put_conn(conn) @@ -691,7 +768,8 @@ def get_structure_optimization_candidates( try: with conn.cursor() as cur: cols = "m.id, m.memory, m.properties, m.created_at, m.updated_at" - cur.execute(f""" + cur.execute( + f""" SELECT {cols} FROM {self.schema}.memories m LEFT JOIN {self.schema}.edges e1 ON m.id = e1.source_id @@ -701,7 +779,9 @@ def get_structure_optimization_candidates( AND m.properties->>'status' = 'activated' AND e1.id IS NULL AND e2.id IS NULL - """, (scope, user_name)) + """, + (scope, user_name), + ) return [self._parse_row(row, False) for row in cur.fetchall()] finally: self._put_conn(conn) @@ -712,7 +792,6 @@ def get_structure_optimization_candidates( def deduplicate_nodes(self) -> None: """Not implemented - handled at application level.""" - pass def get_grouped_counts( self, @@ -739,13 +818,11 @@ def get_grouped_counts( # Build SELECT and GROUP BY clauses # Fields come from JSONB properties column - select_fields = ", ".join([ - f"properties->>'{field}' AS {field}" for field in group_fields - ]) + select_fields = ", ".join([f"properties->>'{field}' AS {field}" for field in group_fields]) group_by = ", ".join([f"properties->>'{field}'" for field in group_fields]) # Build WHERE clause - conditions = [f"user_name = %s"] + conditions = ["user_name = %s"] query_params = [user_name] if where_clause: @@ -797,22 +874,31 @@ def clear(self, user_name: str | None = None) -> None: try: with conn.cursor() as cur: # Get all node IDs for user - cur.execute(f""" + cur.execute( + f""" SELECT id FROM {self.schema}.memories WHERE user_name = %s - """, (user_name,)) + """, + (user_name,), + ) ids = [row[0] for row in cur.fetchall()] if ids: # Delete edges - cur.execute(f""" + cur.execute( + f""" DELETE FROM {self.schema}.edges WHERE source_id = ANY(%s) OR target_id = ANY(%s) - """, (ids, ids)) + """, + (ids, ids), + ) # Delete nodes - cur.execute(f""" + cur.execute( + f""" DELETE FROM {self.schema}.memories WHERE user_name = %s - """, (user_name,)) + """, + (user_name,), + ) logger.info(f"Cleared all data for user {user_name}") finally: self._put_conn(conn) @@ -827,21 +913,27 @@ def export_graph(self, include_embedding: bool = False, **kwargs) -> dict[str, A cols = "id, memory, properties, created_at, updated_at" if include_embedding: cols += ", embedding" - cur.execute(f""" + cur.execute( + f""" SELECT {cols} FROM {self.schema}.memories WHERE user_name = %s ORDER BY created_at DESC - """, (user_name,)) + """, + (user_name,), + ) nodes = [self._parse_row(row, include_embedding) for row in cur.fetchall()] # Get edges node_ids = [n["id"] for n in nodes] if node_ids: - cur.execute(f""" + cur.execute( + f""" SELECT source_id, target_id, edge_type FROM {self.schema}.edges WHERE source_id = ANY(%s) OR target_id = ANY(%s) - """, (node_ids, node_ids)) + """, + (node_ids, node_ids), + ) edges = [ {"source": row[0], "target": row[1], "type": row[2]} for row in cur.fetchall()