Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions overlays/krolik/api/middleware/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import hashlib
import os
import time

from typing import Any

from fastapi import Depends, HTTPException, Request, Security
from fastapi.security import APIKeyHeader

import memos.log


logger = memos.log.get_logger(__name__)

# API key header configuration
Expand Down Expand Up @@ -149,10 +151,7 @@ def is_internal_request(request: Request) -> bool:

# Check internal header (for container-to-container)
internal_header = request.headers.get("X-Internal-Service")
if internal_header == os.getenv("INTERNAL_SERVICE_SECRET"):
return True

return False
return internal_header == os.getenv("INTERNAL_SERVICE_SECRET")


async def verify_api_key(
Expand Down Expand Up @@ -245,8 +244,9 @@ def require_scope(required_scope: str):
Usage:
@router.post("/admin/keys", dependencies=[Depends(require_scope("admin"))])
"""

async def scope_checker(
auth: dict[str, Any] = Depends(verify_api_key),
auth: dict[str, Any] = Depends(verify_api_key), # noqa: B008
) -> dict[str, Any]:
scopes = auth.get("scopes", [])

Expand Down
13 changes: 10 additions & 3 deletions overlays/krolik/api/middleware/rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@

import os
import time

from collections import defaultdict
from typing import Callable
from collections.abc import Callable
from typing import ClassVar

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response

import memos.log


logger = memos.log.get_logger(__name__)

# Configuration from environment
Expand Down Expand Up @@ -131,7 +134,11 @@ def _check_rate_limit_memory(key: str) -> tuple[bool, int, int]:
current_count = len(_memory_store[key])

if current_count >= RATE_LIMIT:
reset_time = int(min(_memory_store[key]) + RATE_WINDOW) if _memory_store[key] else int(now + RATE_WINDOW)
reset_time = (
int(min(_memory_store[key]) + RATE_WINDOW)
if _memory_store[key]
else int(now + RATE_WINDOW)
)
return False, 0, reset_time

# Add current request
Expand All @@ -156,7 +163,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
"""

# Paths exempt from rate limiting
EXEMPT_PATHS = {"/health", "/openapi.json", "/docs", "/redoc"}
EXEMPT_PATHS: ClassVar[set[str]] = {"/health", "/openapi.json", "/docs", "/redoc"}

async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip rate limiting for exempt paths
Expand Down
17 changes: 10 additions & 7 deletions overlays/krolik/api/routers/admin_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
"""

import os

from typing import Any

from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field

import memos.log

from memos.api.middleware.auth import require_scope, verify_api_key
from memos.api.utils.api_keys import (
create_api_key_in_db,
Expand All @@ -19,6 +21,7 @@
revoke_api_key,
)


logger = memos.log.get_logger(__name__)

router = APIRouter(prefix="/admin", tags=["Admin"])
Expand Down Expand Up @@ -75,7 +78,7 @@ def _get_db_connection():
)
def create_key(
request: CreateKeyRequest,
auth: dict = Depends(verify_api_key),
auth: dict = Depends(verify_api_key), # noqa: B008
):
"""
Create a new API key for a user.
Expand Down Expand Up @@ -111,7 +114,7 @@ def create_key(
conn.close()
except Exception as e:
logger.error(f"Failed to create API key: {e}")
raise HTTPException(status_code=500, detail="Failed to create API key")
raise HTTPException(status_code=500, detail="Failed to create API key") from e


@router.get(
Expand All @@ -122,7 +125,7 @@ def create_key(
)
def list_keys(
user_name: str | None = None,
auth: dict = Depends(verify_api_key),
auth: dict = Depends(verify_api_key), # noqa: B008
):
"""
List all API keys (admin) or keys for a specific user.
Expand All @@ -141,7 +144,7 @@ def list_keys(
conn.close()
except Exception as e:
logger.error(f"Failed to list API keys: {e}")
raise HTTPException(status_code=500, detail="Failed to list API keys")
raise HTTPException(status_code=500, detail="Failed to list API keys") from e


@router.delete(
Expand All @@ -152,7 +155,7 @@ def list_keys(
)
def revoke_key(
key_id: str,
auth: dict = Depends(verify_api_key),
auth: dict = Depends(verify_api_key), # noqa: B008
):
"""
Revoke an API key by ID.
Expand All @@ -174,7 +177,7 @@ def revoke_key(
raise
except Exception as e:
logger.error(f"Failed to revoke API key: {e}")
raise HTTPException(status_code=500, detail="Failed to revoke API key")
raise HTTPException(status_code=500, detail="Failed to revoke API key") from e


@router.post(
Expand All @@ -184,7 +187,7 @@ def revoke_key(
dependencies=[Depends(require_scope("admin"))],
)
def generate_new_master_key(
auth: dict = Depends(verify_api_key),
auth: dict = Depends(verify_api_key), # noqa: B008
):
"""
Generate a new master key.
Expand Down
12 changes: 8 additions & 4 deletions overlays/krolik/api/server_api_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Usage in Dockerfile:
# Copy overlays after base installation
COPY overlays/krolik/ /app/src/memos/

# Use this as entrypoint instead of server_api
CMD ["gunicorn", "memos.api.server_api_ext:app", ...]
"""
Expand All @@ -25,16 +25,18 @@
from starlette.requests import Request
from starlette.responses import Response

# Import base routers from MemOS
from memos.api.routers.server_router import router as server_router

# Import Krolik extensions
from memos.api.middleware.rate_limit import RateLimitMiddleware
from memos.api.routers.admin_router import router as admin_router

# Import base routers from MemOS
from memos.api.routers.server_router import router as server_router


# Try to import exception handlers (may vary between MemOS versions)
try:
from memos.api.exceptions import APIExceptionHandler

HAS_EXCEPTION_HANDLER = True
except ImportError:
HAS_EXCEPTION_HANDLER = False
Expand Down Expand Up @@ -98,6 +100,7 @@ async def dispatch(self, request: Request, call_next) -> Response:
# Exception handlers
if HAS_EXCEPTION_HANDLER:
from fastapi import HTTPException

app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler)
app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler)
app.exception_handler(HTTPException)(APIExceptionHandler.http_error_handler)
Expand All @@ -117,4 +120,5 @@ async def health_check():

if __name__ == "__main__":
import uvicorn

uvicorn.run("memos.api.server_api_ext:app", host="0.0.0.0", port=8000, workers=1)
8 changes: 6 additions & 2 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
"postgres": postgres_config,
}
# Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars
graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")).lower()
graph_db_backend = os.getenv(
"GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")
).lower()
if graph_db_backend in graph_db_backend_map:
# Create MemCube config

Expand Down Expand Up @@ -1052,7 +1054,9 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None":
else None
)
# Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars
graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")).lower()
graph_db_backend = os.getenv(
"GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")
).lower()
if graph_db_backend in graph_db_backend_map:
return GeneralMemCubeConfig.model_validate(
{
Expand Down
4 changes: 3 additions & 1 deletion src/memos/configs/graph_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ class PostgresGraphDBConfig(BaseConfig):
default=False,
description="If False: use single database with logical isolation by user_name",
)
embedding_dimension: int = Field(default=768, description="Dimension of vector embedding (768 for all-mpnet-base-v2)")
embedding_dimension: int = Field(
default=768, description="Dimension of vector embedding (768 for all-mpnet-base-v2)"
)
maxconn: int = Field(
default=20,
description="Maximum number of connections in the connection pool",
Expand Down
Loading