diff --git a/README.md b/README.md
index fe9d04ea1..75e673dbd 100644
--- a/README.md
+++ b/README.md
@@ -269,6 +269,9 @@ Get Free API: [Try API](https://memos-dashboard.openmem.net/quickstart/?source=g
- **Awesome-AI-Memory**
This is a curated repository dedicated to resources on memory and memory systems for large language models. It systematically collects relevant research papers, frameworks, tools, and practical insights. The repository aims to organize and present the rapidly evolving research landscape of LLM memory, bridging multiple research directions including natural language processing, information retrieval, agentic systems, and cognitive science.
- **Get started** 👉 [IAAR-Shanghai/Awesome-AI-Memory](https://github.com/IAAR-Shanghai/Awesome-AI-Memory)
+- **MemOS Cloud OpenClaw Plugin**
+ Official OpenClaw lifecycle plugin for MemOS Cloud. It automatically recalls context from MemOS before the agent starts and saves the conversation back to MemOS after the agent finishes.
+- **Get started** 👉 [MemTensor/MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin)
diff --git a/docker/Dockerfile.krolik b/docker/Dockerfile.krolik
new file mode 100644
index 000000000..c475a6d30
--- /dev/null
+++ b/docker/Dockerfile.krolik
@@ -0,0 +1,65 @@
+# MemOS with Krolik Security Extensions
+#
+# This Dockerfile builds MemOS with authentication, rate limiting, and admin API.
+# It uses the overlay pattern to keep customizations separate from base code.
+
+FROM python:3.11-slim
+
+# Install system dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ gcc \
+ g++ \
+ build-essential \
+ libffi-dev \
+ python3-dev \
+ curl \
+ libpq-dev \
+ && rm -rf /var/lib/apt/lists/*
+
+# Create non-root user
+RUN groupadd -r memos && useradd -r -g memos -u 1000 memos
+
+WORKDIR /app
+
+# Use official Hugging Face
+ENV HF_ENDPOINT=https://huggingface.co
+
+# Copy base MemOS source
+COPY src/ ./src/
+COPY pyproject.toml ./
+
+# Install base dependencies
+RUN pip install --upgrade pip && \
+ pip install --no-cache-dir poetry && \
+ poetry config virtualenvs.create false && \
+ poetry install --no-dev --extras "tree-mem mem-scheduler"
+
+# Install additional dependencies for Krolik
+RUN pip install --no-cache-dir \
+ sentence-transformers \
+ torch \
+ transformers \
+ psycopg2-binary \
+ redis
+
+# Apply Krolik overlay (AFTER base install to allow easy updates)
+COPY overlays/krolik/ ./src/memos/
+
+# Create data directory
+RUN mkdir -p /data/memos && chown -R memos:memos /data/memos
+RUN chown -R memos:memos /app
+
+# Set Python path
+ENV PYTHONPATH=/app/src
+
+# Switch to non-root user
+USER memos
+
+EXPOSE 8000
+
+# Healthcheck
+HEALTHCHECK --interval=30s --timeout=10s --retries=3 --start-period=60s \
+ CMD curl -f http://localhost:8000/health || exit 1
+
+# Use extended entry point with security features
+CMD ["gunicorn", "memos.api.server_api_ext:app", "--preload", "-w", "2", "-k", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000", "--timeout", "120"]
diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py
index 6a9405456..6dbe202c2 100644
--- a/examples/mem_agent/deepsearch_example.py
+++ b/examples/mem_agent/deepsearch_example.py
@@ -179,7 +179,7 @@ def factory_initialization() -> tuple[DeepSearchMemAgent, dict[str, Any]]:
def main():
- agent_factory, components_factory = factory_initialization()
+ agent_factory, _components_factory = factory_initialization()
results = agent_factory.run(
"Caroline met up with friends, family, and mentors in early July 2023.",
user_id="locomo_exp_user_0_speaker_b_ct-1118",
diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py
index b7250a677..f3c1c7f87 100644
--- a/examples/mem_scheduler/memos_w_scheduler.py
+++ b/examples/mem_scheduler/memos_w_scheduler.py
@@ -84,7 +84,9 @@ def init_task():
return conversations, questions
-working_memories = []
+default_mem_update_handler = mem_scheduler.handlers.get(MEM_UPDATE_TASK_LABEL)
+if default_mem_update_handler is None:
+ logger.warning("Default MEM_UPDATE handler not found; custom handler will be a no-op.")
# Define custom query handler function
@@ -100,24 +102,11 @@ def custom_query_handler(messages: list[ScheduleMessageItem]):
# Define custom memory update handler function
def custom_mem_update_handler(messages: list[ScheduleMessageItem]):
- global working_memories
- search_args = {}
- top_k = 2
- for msg in messages:
- # Search for memories relevant to the current content in text memory (return top_k=2)
- results = mem_scheduler.retriever.search(
- query=msg.content,
- user_id=msg.user_id,
- mem_cube_id=msg.mem_cube_id,
- mem_cube=mem_scheduler.current_mem_cube,
- top_k=top_k,
- method=mem_scheduler.search_method,
- search_args=search_args,
- )
- working_memories.extend(results)
- working_memories = working_memories[-5:]
- for mem in results:
- print(f"\n[scheduler] Retrieved memory: {mem.memory}")
+ if default_mem_update_handler is None:
+ logger.error("Default MEM_UPDATE handler missing; cannot process messages.")
+ return
+ # Delegate to the built-in handler to keep behavior aligned with scheduler refactor.
+ default_mem_update_handler(messages)
async def run_with_scheduler():
diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py
index d942aad4e..51d87b1aa 100644
--- a/examples/mem_scheduler/try_schedule_modules.py
+++ b/examples/mem_scheduler/try_schedule_modules.py
@@ -167,7 +167,8 @@ def add_msgs(
label=MEM_UPDATE_TASK_LABEL,
content=query,
)
- # Run one session turn manually to get search candidates
- mem_scheduler._memory_update_consumer(
- messages=[message],
- )
+ # Run one session turn manually via registered handler (public surface)
+ handler = mem_scheduler.handlers.get(MEM_UPDATE_TASK_LABEL)
+ if handler is None:
+ raise RuntimeError("MEM_UPDATE handler not registered on mem_scheduler.")
+ handler([message])
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/memos/api/README_api.md b/src/memos/api/README_api.md
new file mode 100644
index 000000000..ba63ed996
--- /dev/null
+++ b/src/memos/api/README_api.md
@@ -0,0 +1,13 @@
+# MemOS API
+
+## Default entry and deployment
+
+- Use **`server_api.py`** as the API service entry for **public open-source usage**.
+- You can deploy via **`docker/Dockerfile`**.
+
+The above is the default, general way to run and deploy the API.
+
+## Extensions and reference implementations
+
+- **`server_api_ext.py`** and **`Dockerfile.krolik`** are one developer’s extended API and deployment setup, **for reference only**. They are not yet integrated with cloud services and are still in testing.
+- If you need extensions or custom behavior, you can refer to these and use or adapt them as you like.
diff --git a/src/memos/api/__init__.py b/src/memos/api/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/memos/api/client.py b/src/memos/api/client.py
index 91bc86829..818ce5e0d 100644
--- a/src/memos/api/client.py
+++ b/src/memos/api/client.py
@@ -31,10 +31,27 @@
class MemOSClient:
"""MemOS API client"""
- def __init__(self, api_key: str | None = None, base_url: str | None = None):
- self.base_url = (
- base_url or os.getenv("MEMOS_BASE_URL") or "https://memos.memtensor.cn/api/openmem/v1"
+ def __init__(
+ self,
+ api_key: str | None = None,
+ base_url: str | None = None,
+ is_global: str | bool = "false",
+ ):
+ # Priority:
+ # 1. base_url argument
+ # 2. MEMOS_BASE_URL environment variable (direct URL)
+ # 3. MEMOS_IS_GLOBAL environment variable (True/False toggle)
+ arg_is_global = str(is_global).lower() in ("true", "1", "yes")
+ memos_is_global = os.getenv("MEMOS_IS_GLOBAL", "false").lower() in ("true", "1", "yes")
+ final_is_global = arg_is_global or memos_is_global
+ default_url = (
+ "https://api.memt.ai/platform/api/openmem/v1"
+ if final_is_global
+ else "https://memos.memtensor.cn/api/openmem/v1"
)
+
+ self.base_url = base_url or os.getenv("MEMOS_BASE_URL") or default_url
+
api_key = api_key or os.getenv("MEMOS_API_KEY")
if not api_key:
@@ -56,7 +73,7 @@ def get_message(
message_limit_number: int = 6,
source: str | None = None,
) -> MemOSGetMessagesResponse | None:
- """Get messages"""
+ """Get message"""
# Validate required parameters
self._validate_required_params(user_id=user_id)
diff --git a/src/memos/api/config.py b/src/memos/api/config.py
index d27c391ab..70d9366e3 100644
--- a/src/memos/api/config.py
+++ b/src/memos/api/config.py
@@ -538,6 +538,10 @@ def get_internet_config() -> dict[str, Any]:
"chunker": {
"backend": "sentence",
"config": {
+ "save_rawfile": os.getenv(
+ "MEM_READER_SAVE_RAWFILENODE", "true"
+ ).lower()
+ == "true",
"tokenizer_or_token_counter": "gpt2",
"chunk_size": 512,
"chunk_overlap": 128,
@@ -676,6 +680,30 @@ def get_polardb_config(user_id: str | None = None) -> dict[str, Any]:
"embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)),
}
+ @staticmethod
+ def get_postgres_config(user_id: str | None = None) -> dict[str, Any]:
+ """Get PostgreSQL + pgvector configuration for MemOS graph storage.
+
+ Uses standard PostgreSQL with pgvector extension.
+ Schema: memos.memories, memos.edges
+ """
+ user_name = os.getenv("MEMOS_USER_NAME", "default")
+ if user_id:
+ user_name = f"memos_{user_id.replace('-', '')}"
+
+ return {
+ "host": os.getenv("POSTGRES_HOST", "postgres"),
+ "port": int(os.getenv("POSTGRES_PORT", "5432")),
+ "user": os.getenv("POSTGRES_USER", "n8n"),
+ "password": os.getenv("POSTGRES_PASSWORD", ""),
+ "db_name": os.getenv("POSTGRES_DB", "n8n"),
+ "schema_name": os.getenv("MEMOS_SCHEMA", "memos"),
+ "user_name": user_name,
+ "use_multi_db": False,
+ "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", "384")),
+ "maxconn": int(os.getenv("POSTGRES_MAX_CONN", "20")),
+ }
+
@staticmethod
def get_mysql_config() -> dict[str, Any]:
"""Get MySQL configuration."""
@@ -780,6 +808,8 @@ def get_product_default_config() -> dict[str, Any]:
"chunker": {
"backend": "sentence",
"config": {
+ "save_rawfile": os.getenv("MEM_READER_SAVE_RAWFILENODE", "true").lower()
+ == "true",
"tokenizer_or_token_counter": "gpt2",
"chunk_size": 512,
"chunk_overlap": 128,
@@ -797,7 +827,12 @@ def get_product_default_config() -> dict[str, Any]:
"oss_config": APIConfig.get_oss_config(),
"skills_dir_config": {
"skills_oss_dir": os.getenv("SKILLS_OSS_DIR", "skill_memory/"),
- "skills_local_dir": os.getenv("SKILLS_LOCAL_DIR", "/tmp/skill_memory/"),
+ "skills_local_tmp_dir": os.getenv(
+ "SKILLS_LOCAL_TMP_DIR", "/tmp/skill_memory/"
+ ),
+ "skills_local_dir": os.getenv(
+ "SKILLS_LOCAL_DIR", "/tmp/upload_skill_memory/"
+ ),
},
},
},
@@ -895,6 +930,8 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
"chunker": {
"backend": "sentence",
"config": {
+ "save_rawfile": os.getenv("MEM_READER_SAVE_RAWFILENODE", "true").lower()
+ == "true",
"tokenizer_or_token_counter": "gpt2",
"chunk_size": 512,
"chunk_overlap": 128,
@@ -937,13 +974,18 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
if os.getenv("ENABLE_INTERNET", "false").lower() == "true"
else None
)
+ postgres_config = APIConfig.get_postgres_config(user_id=user_id)
graph_db_backend_map = {
"neo4j-community": neo4j_community_config,
"neo4j": neo4j_config,
"nebular": nebular_config,
"polardb": polardb_config,
+ "postgres": postgres_config,
}
- graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower()
+ # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars
+ graph_db_backend = os.getenv(
+ "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")
+ ).lower()
if graph_db_backend in graph_db_backend_map:
# Create MemCube config
@@ -1011,18 +1053,23 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None":
neo4j_config = APIConfig.get_neo4j_config(user_id="default")
nebular_config = APIConfig.get_nebular_config(user_id="default")
polardb_config = APIConfig.get_polardb_config(user_id="default")
+ postgres_config = APIConfig.get_postgres_config(user_id="default")
graph_db_backend_map = {
"neo4j-community": neo4j_community_config,
"neo4j": neo4j_config,
"nebular": nebular_config,
"polardb": polardb_config,
+ "postgres": postgres_config,
}
internet_config = (
APIConfig.get_internet_config()
if os.getenv("ENABLE_INTERNET", "false").lower() == "true"
else None
)
- graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower()
+ # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars
+ graph_db_backend = os.getenv(
+ "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")
+ ).lower()
if graph_db_backend in graph_db_backend_map:
return GeneralMemCubeConfig.model_validate(
{
diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py
index 8292e027b..cd33a7aeb 100644
--- a/src/memos/api/handlers/chat_handler.py
+++ b/src/memos/api/handlers/chat_handler.py
@@ -7,6 +7,7 @@
import asyncio
import json
+import os
import re
import time
import traceback
@@ -23,6 +24,7 @@
APIADDRequest,
APIChatCompleteRequest,
APISearchRequest,
+ ChatBusinessRequest,
ChatPlaygroundRequest,
ChatRequest,
)
@@ -208,6 +210,8 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
query=chat_req.query,
full_response=response,
async_mode="async",
+ manager_user_id=chat_req.manager_user_id,
+ project_id=chat_req.project_id,
)
end = time.time()
self.logger.info(f"[Cloud Service] Chat Add Time: {end - start} seconds")
@@ -380,6 +384,8 @@ def generate_chat_response() -> Generator[str, None, None]:
query=chat_req.query,
full_response=full_response,
async_mode="async",
+ manager_user_id=chat_req.manager_user_id,
+ project_id=chat_req.project_id,
)
end = time.time()
self.logger.info(
@@ -561,6 +567,8 @@ def generate_chat_response() -> Generator[str, None, None]:
query=chat_req.query,
full_response=None,
async_mode="sync",
+ manager_user_id=chat_req.manager_user_id,
+ project_id=chat_req.project_id,
)
# Extract memories from search results (second search)
@@ -729,6 +737,8 @@ def generate_chat_response() -> Generator[str, None, None]:
query=chat_req.query,
full_response=full_response,
async_mode="sync",
+ manager_user_id=chat_req.manager_user_id,
+ project_id=chat_req.project_id,
)
except Exception as e:
@@ -759,6 +769,197 @@ def generate_chat_response() -> Generator[str, None, None]:
)
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
+ def handle_chat_stream_for_business_user(
+ self, chat_req: ChatBusinessRequest
+ ) -> StreamingResponse:
+ """Chat API for business user."""
+ self.logger.info(f"[ChatBusinessHandler] Chat Req is: {chat_req}")
+
+ # Validate business_key permission
+ business_chat_keys = os.environ.get("BUSINESS_CHAT_KEYS", "[]")
+ allowed_keys = json.loads(business_chat_keys)
+
+ if not allowed_keys or chat_req.business_key not in allowed_keys:
+ self.logger.warning(
+ f"[ChatBusinessHandler] Unauthorized access attempt with business_key: {chat_req.business_key}"
+ )
+ raise HTTPException(
+ status_code=403,
+ detail="Access denied: Invalid business_key. You do not have permission to use this service.",
+ )
+
+ try:
+
+ def generate_chat_response() -> Generator[str, None, None]:
+ """Generate chat stream response as SSE stream."""
+ try:
+ if chat_req.need_search:
+ # Resolve readable cube IDs (for search)
+ readable_cube_ids = chat_req.readable_cube_ids or (
+ [chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id]
+ )
+
+ search_req = APISearchRequest(
+ query=chat_req.query,
+ user_id=chat_req.user_id,
+ readable_cube_ids=readable_cube_ids,
+ mode=chat_req.mode,
+ internet_search=chat_req.internet_search,
+ top_k=chat_req.top_k,
+ chat_history=chat_req.history,
+ session_id=chat_req.session_id,
+ include_preference=chat_req.include_preference,
+ pref_top_k=chat_req.pref_top_k,
+ filter=chat_req.filter,
+ )
+
+ search_response = self.search_handler.handle_search_memories(search_req)
+
+ # Extract memories from search results
+ memories_list = []
+ if search_response.data and search_response.data.get("text_mem"):
+ text_mem_results = search_response.data["text_mem"]
+ if text_mem_results and text_mem_results[0].get("memories"):
+ memories_list = text_mem_results[0]["memories"]
+
+ # Drop internet memories forced
+ memories_list = [
+ mem
+ for mem in memories_list
+ if mem.get("metadata", {}).get("memory_type") != "OuterMemory"
+ ]
+
+ # Filter memories by threshold
+ filtered_memories = self._filter_memories_by_threshold(memories_list)
+
+ # Step 2: Build system prompt with memories
+ system_prompt = self._build_system_prompt(
+ query=chat_req.query,
+ memories=filtered_memories,
+ pref_string=search_response.data.get("pref_string", ""),
+ base_prompt=chat_req.system_prompt,
+ )
+
+ self.logger.info(
+ f"[ChatBusinessHandler] chat stream user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, "
+ f"current_system_prompt: {system_prompt}"
+ )
+ else:
+ system_prompt = self._build_system_prompt(
+ query=chat_req.query,
+ memories=None,
+ pref_string=None,
+ base_prompt=chat_req.system_prompt,
+ )
+
+ # Prepare messages
+ history_info = chat_req.history[-20:] if chat_req.history else []
+ current_messages = [
+ {"role": "system", "content": system_prompt},
+ *history_info,
+ {"role": "user", "content": chat_req.query},
+ ]
+
+ # Step 3: Generate streaming response from LLM
+ if (
+ chat_req.model_name_or_path
+ and chat_req.model_name_or_path not in self.chat_llms
+ ):
+ raise HTTPException(
+ status_code=400,
+ detail=f"Model {chat_req.model_name_or_path} not suport, choose from {list(self.chat_llms.keys())}",
+ )
+
+ model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys()))
+ self.logger.info(f"[ChatBusinessHandler] Chat Stream Model: {model}")
+
+ start = time.time()
+ response_stream = self.chat_llms[model].generate_stream(
+ current_messages, model_name_or_path=model
+ )
+
+ # Stream the response
+ buffer = ""
+ full_response = ""
+ in_think = False
+
+ for chunk in response_stream:
+ if chunk == "":
+ in_think = True
+ continue
+ if chunk == "":
+ in_think = False
+ continue
+
+ if in_think:
+ chunk_data = f"data: {json.dumps({'type': 'reasoning', 'data': chunk}, ensure_ascii=False)}\n\n"
+ yield chunk_data
+ continue
+
+ buffer += chunk
+ full_response += chunk
+
+ chunk_data = f"data: {json.dumps({'type': 'text', 'data': chunk}, ensure_ascii=False)}\n\n"
+ yield chunk_data
+
+ end = time.time()
+ self.logger.info(
+ f"[ChatBusinessHandler] Chat Stream Time: {end - start} seconds"
+ )
+
+ self.logger.info(
+ f"[ChatBusinessHandler] Chat Stream LLM Input: {json.dumps(current_messages, ensure_ascii=False)} Chat Stream LLM Response: {full_response}"
+ )
+
+ current_messages.append({"role": "assistant", "content": full_response})
+ if chat_req.add_message_on_answer:
+ # Resolve writable cube IDs (for add)
+ writable_cube_ids = chat_req.writable_cube_ids or (
+ [chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id]
+ )
+ start = time.time()
+ self._start_add_to_memory(
+ user_id=chat_req.user_id,
+ writable_cube_ids=writable_cube_ids,
+ session_id=chat_req.session_id or "default_session",
+ query=chat_req.query,
+ full_response=full_response,
+ async_mode="async",
+ manager_user_id=chat_req.manager_user_id,
+ project_id=chat_req.project_id,
+ )
+ end = time.time()
+ self.logger.info(
+ f"[ChatBusinessHandler] Chat Stream Add Time: {end - start} seconds"
+ )
+ except Exception as e:
+ self.logger.error(
+ f"[ChatBusinessHandler] Error in chat stream: {e}", exc_info=True
+ )
+ error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n"
+ yield error_data
+
+ return StreamingResponse(
+ generate_chat_response(),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "Content-Type": "text/event-stream",
+ "Access-Control-Allow-Origin": "*",
+ "Access-Control-Allow-Headers": "*",
+ "Access-Control-Allow-Methods": "*",
+ },
+ )
+
+ except ValueError as err:
+ raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
+ except Exception as err:
+ self.logger.error(
+ f"[ChatBusinessHandler] Failed to start chat stream: {traceback.format_exc()}"
+ )
+ raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
+
def _dedup_and_supplement_memories(
self, first_filtered_memories: list, second_filtered_memories: list
) -> list:
@@ -1118,6 +1319,8 @@ async def _add_conversation_to_memory(
writable_cube_ids: list[str],
session_id: str,
query: str,
+ manager_user_id: str | None = None,
+ project_id: str | None = None,
clean_response: str | None = None,
async_mode: Literal["async", "sync"] = "sync",
) -> None:
@@ -1142,6 +1345,8 @@ async def _add_conversation_to_memory(
session_id=session_id,
messages=messages,
async_mode=async_mode,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
)
self.add_handler.handle_add_memories(add_req)
@@ -1323,12 +1528,14 @@ def run_async_in_thread():
)
# Add exception handling for the background task
task.add_done_callback(
- lambda t: self.logger.error(
- f"Error in background post-chat processing for user {user_id}: {t.exception()}",
- exc_info=True,
+ lambda t: (
+ self.logger.error(
+ f"Error in background post-chat processing for user {user_id}: {t.exception()}",
+ exc_info=True,
+ )
+ if t.exception()
+ else None
)
- if t.exception()
- else None
)
except RuntimeError:
# No event loop, run in a new thread with context propagation
@@ -1347,7 +1554,13 @@ def _start_add_to_memory(
query: str,
full_response: str | None = None,
async_mode: Literal["async", "sync"] = "sync",
+ manager_user_id: str | None = None,
+ project_id: str | None = None,
) -> None:
+ self.logger.info(
+ f"Start add to memory for user {user_id}, writable_cube_ids: {writable_cube_ids}, session_id: {session_id}, query: {query}, full_response: {full_response}, async_mode: {async_mode}, manager_user_id: {manager_user_id}, project_id: {project_id}"
+ )
+
def run_async_in_thread():
try:
loop = asyncio.new_event_loop()
@@ -1364,6 +1577,8 @@ def run_async_in_thread():
query=query,
clean_response=clean_response,
async_mode=async_mode,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
)
)
finally:
@@ -1387,15 +1602,19 @@ def run_async_in_thread():
query=query,
clean_response=clean_response,
async_mode=async_mode,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
)
)
task.add_done_callback(
- lambda t: self.logger.error(
- f"Error in background add to memory for user {user_id}: {t.exception()}",
- exc_info=True,
+ lambda t: (
+ self.logger.error(
+ f"Error in background add to memory for user {user_id}: {t.exception()}",
+ exc_info=True,
+ )
+ if t.exception()
+ else None
)
- if t.exception()
- else None
)
except RuntimeError:
thread = ContextThread(
diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py
index 13dd92189..ba527d602 100644
--- a/src/memos/api/handlers/component_init.py
+++ b/src/memos/api/handlers/component_init.py
@@ -43,6 +43,7 @@
)
from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
+from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
@@ -190,6 +191,7 @@ def init_server() -> dict[str, Any]:
)
embedder = EmbedderFactory.from_config(embedder_config)
nli_client = NLIClient(base_url=nli_client_config["base_url"])
+ memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db)
# Pass graph_db to mem_reader for recall operations (deduplication, conflict detection)
mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db)
reranker = RerankerFactory.from_config(reranker_config)
@@ -393,4 +395,5 @@ def init_server() -> dict[str, Any]:
"redis_client": redis_client,
"deepsearch_agent": deepsearch_agent,
"nli_client": nli_client,
+ "memory_history_manager": memory_history_manager,
}
diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py
index ed673977a..2b3fbdd35 100644
--- a/src/memos/api/handlers/config_builders.py
+++ b/src/memos/api/handlers/config_builders.py
@@ -41,9 +41,11 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
"neo4j": APIConfig.get_neo4j_config(user_id=user_id),
"nebular": APIConfig.get_nebular_config(user_id=user_id),
"polardb": APIConfig.get_polardb_config(user_id=user_id),
+ "postgres": APIConfig.get_postgres_config(user_id=user_id),
}
- graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower()
+ # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars
+ graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower()
return GraphDBConfigFactory.model_validate(
{
"backend": graph_db_backend,
diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py
index cecc42c6c..06c4fd223 100644
--- a/src/memos/api/handlers/formatters_handler.py
+++ b/src/memos/api/handlers/formatters_handler.py
@@ -113,7 +113,7 @@ def post_process_textual_mem(
mem
for mem in text_formatted_mem
if mem["metadata"]["memory_type"]
- in ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"]
+ in ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory", "RawFileMemory"]
]
tool_mem = [
mem
@@ -157,12 +157,13 @@ def separate_knowledge_and_conversation_mem(memories: list[dict[str, Any]]):
for item in memories:
sources = item.get("metadata", {}).get("sources", [])
if (
- len(sources) > 0
+ item["metadata"]["memory_type"] != "RawFileMemory"
+ and len(sources) > 0
and "type" in sources[0]
and sources[0]["type"] == "file"
and "content" in sources[0]
and sources[0]["content"] != ""
- ): # TODO change to memory_type
+ ):
knowledge_mem.append(item)
else:
conversation_mem.append(item)
@@ -203,8 +204,7 @@ def rerank_knowledge_mem(
key=lambda item: item.get("metadata", {}).get("relativity", 0.0),
reverse=True,
)
-
- # TODO revoke sources replace memory value
+ # replace memory value with source.content for LongTermMemory, WorkingMemory or UserMemory
for item in reranked_knowledge_mem:
item["memory"] = item["metadata"]["sources"][0]["content"]
item["metadata"]["sources"] = []
diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py
index e8bc5b640..a3430d475 100644
--- a/src/memos/api/handlers/memory_handler.py
+++ b/src/memos/api/handlers/memory_handler.py
@@ -9,11 +9,11 @@
from memos.api.handlers.formatters_handler import (
format_memory_item,
post_process_pref_mem,
- post_process_textual_mem,
)
from memos.api.product_models import (
DeleteMemoryRequest,
DeleteMemoryResponse,
+ GetMemoryDashboardRequest,
GetMemoryRequest,
GetMemoryResponse,
MemoryResponse,
@@ -109,6 +109,7 @@ def handle_get_subgraph(
query: str,
top_k: int,
naive_mem_cube: Any,
+ search_type: Literal["embedding", "fulltext"],
) -> MemoryResponse:
"""
Main handler for getting memory subgraph based on query.
@@ -128,7 +129,7 @@ def handle_get_subgraph(
try:
# Get relevant subgraph from text memory
memories = naive_mem_cube.text_mem.get_relevant_subgraph(
- query, top_k=top_k, user_name=mem_cube_id
+ query, top_k=top_k, user_name=mem_cube_id, search_type=search_type
)
# Format and convert to tree structure
@@ -139,7 +140,7 @@ def handle_get_subgraph(
"UserMemory": 0.40,
}
tree_result, node_type_count = convert_graph_to_tree_forworkmem(
- memories_cleaned, target_node_count=150, type_ratios=custom_type_ratios
+ memories_cleaned, target_node_count=200, type_ratios=custom_type_ratios
)
# Ensure all node IDs are unique in the tree structure
tree_result = ensure_unique_tree_ids(tree_result)
@@ -249,22 +250,68 @@ def handle_get_memories(
get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube
) -> GetMemoryResponse:
results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": [], "skill_mem": []}
- memories = naive_mem_cube.text_mem.get_all(
+ text_memory_type = ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"]
+ text_memories_info = naive_mem_cube.text_mem.get_all(
user_name=get_mem_req.mem_cube_id,
user_id=get_mem_req.user_id,
page=get_mem_req.page,
page_size=get_mem_req.page_size,
filter=get_mem_req.filter,
- )["nodes"]
+ memory_type=text_memory_type,
+ )
+ text_memories, total_text_nodes = text_memories_info["nodes"], text_memories_info["total_nodes"]
+ results["text_mem"] = [
+ {
+ "cube_id": get_mem_req.mem_cube_id,
+ "memories": text_memories,
+ "total_nodes": total_text_nodes,
+ }
+ ]
- results = post_process_textual_mem(results, memories, get_mem_req.mem_cube_id)
+ if get_mem_req.include_tool_memory:
+ tool_memories_info = naive_mem_cube.text_mem.get_all(
+ user_name=get_mem_req.mem_cube_id,
+ user_id=get_mem_req.user_id,
+ page=get_mem_req.page,
+ page_size=get_mem_req.page_size,
+ filter=get_mem_req.filter,
+ memory_type=["ToolSchemaMemory", "ToolTrajectoryMemory"],
+ )
+ tool_memories, total_tool_nodes = (
+ tool_memories_info["nodes"],
+ tool_memories_info["total_nodes"],
+ )
- if not get_mem_req.include_tool_memory:
- results["tool_mem"] = []
- if not get_mem_req.include_skill_memory:
- results["skill_mem"] = []
+ results["tool_mem"] = [
+ {
+ "cube_id": get_mem_req.mem_cube_id,
+ "memories": tool_memories,
+ "total_nodes": total_tool_nodes,
+ }
+ ]
+ if get_mem_req.include_skill_memory:
+ skill_memories_info = naive_mem_cube.text_mem.get_all(
+ user_name=get_mem_req.mem_cube_id,
+ user_id=get_mem_req.user_id,
+ page=get_mem_req.page,
+ page_size=get_mem_req.page_size,
+ filter=get_mem_req.filter,
+ memory_type=["SkillMemory"],
+ )
+ skill_memories, total_skill_nodes = (
+ skill_memories_info["nodes"],
+ skill_memories_info["total_nodes"],
+ )
+ results["skill_mem"] = [
+ {
+ "cube_id": get_mem_req.mem_cube_id,
+ "memories": skill_memories,
+ "total_nodes": total_skill_nodes,
+ }
+ ]
preferences: list[TextualMemoryItem] = []
+ total_preference_nodes = 0
format_preferences = []
if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None:
@@ -293,7 +340,7 @@ def handle_get_memories(
filter_params.update(filter_copy)
- preferences, _ = naive_mem_cube.pref_mem.get_memory_by_filter(
+ preferences, total_preference_nodes = naive_mem_cube.pref_mem.get_memory_by_filter(
filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size
)
format_preferences = [format_memory_item(item, save_sources=False) for item in preferences]
@@ -301,6 +348,8 @@ def handle_get_memories(
results = post_process_pref_mem(
results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference
)
+ if total_preference_nodes > 0 and results.get("pref_mem", []):
+ results["pref_mem"][0]["total_nodes"] = total_preference_nodes
# Filter to only keep text_mem, pref_mem, tool_mem
filtered_results = {
@@ -352,3 +401,181 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube:
message="Memories deleted successfully",
data={"status": "success"},
)
+
+
+# =============================================================================
+# Other handler functions Endpoints (for internal use)
+# =============================================================================
+
+
+def handle_get_memories_dashboard(
+ get_mem_req: GetMemoryDashboardRequest, naive_mem_cube: NaiveMemCube
+) -> GetMemoryResponse:
+ results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": [], "skill_mem": []}
+ text_memory_type = ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"]
+ text_memories_info = naive_mem_cube.text_mem.get_all(
+ user_name=get_mem_req.mem_cube_id,
+ user_id=get_mem_req.user_id,
+ page=get_mem_req.page,
+ page_size=get_mem_req.page_size,
+ filter=get_mem_req.filter,
+ memory_type=text_memory_type,
+ )
+ text_memories, _ = text_memories_info["nodes"], text_memories_info["total_nodes"]
+
+ # Group text memories by cube_id from metadata.user_name
+ text_mem_by_cube: dict[str, list] = {}
+ for memory in text_memories:
+ cube_id = memory.get("metadata", {}).get("user_name", get_mem_req.mem_cube_id)
+ if cube_id not in text_mem_by_cube:
+ text_mem_by_cube[cube_id] = []
+ text_mem_by_cube[cube_id].append(memory)
+
+ # If no memories found, create a default entry with the requested cube_id
+ if not text_mem_by_cube and get_mem_req.mem_cube_id:
+ text_mem_by_cube[get_mem_req.mem_cube_id] = []
+
+ results["text_mem"] = [
+ {
+ "cube_id": cube_id,
+ "memories": memories,
+ "total_nodes": len(memories),
+ }
+ for cube_id, memories in text_mem_by_cube.items()
+ ]
+
+ if get_mem_req.include_tool_memory:
+ tool_memories_info = naive_mem_cube.text_mem.get_all(
+ user_name=get_mem_req.mem_cube_id,
+ user_id=get_mem_req.user_id,
+ page=get_mem_req.page,
+ page_size=get_mem_req.page_size,
+ filter=get_mem_req.filter,
+ memory_type=["ToolSchemaMemory", "ToolTrajectoryMemory"],
+ )
+ tool_memories, _ = (
+ tool_memories_info["nodes"],
+ tool_memories_info["total_nodes"],
+ )
+
+ # Group tool memories by cube_id from metadata.user_name
+ tool_mem_by_cube: dict[str, list] = {}
+ for memory in tool_memories:
+ cube_id = memory.get("metadata", {}).get("user_name", get_mem_req.mem_cube_id)
+ if cube_id not in tool_mem_by_cube:
+ tool_mem_by_cube[cube_id] = []
+ tool_mem_by_cube[cube_id].append(memory)
+
+ # If no memories found, create a default entry with the requested cube_id
+ if not tool_mem_by_cube and get_mem_req.mem_cube_id:
+ tool_mem_by_cube[get_mem_req.mem_cube_id] = []
+
+ results["tool_mem"] = [
+ {
+ "cube_id": cube_id,
+ "memories": memories,
+ "total_nodes": len(memories),
+ }
+ for cube_id, memories in tool_mem_by_cube.items()
+ ]
+
+ if get_mem_req.include_skill_memory:
+ skill_memories_info = naive_mem_cube.text_mem.get_all(
+ user_name=get_mem_req.mem_cube_id,
+ user_id=get_mem_req.user_id,
+ page=get_mem_req.page,
+ page_size=get_mem_req.page_size,
+ filter=get_mem_req.filter,
+ memory_type=["SkillMemory"],
+ )
+ skill_memories, _ = (
+ skill_memories_info["nodes"],
+ skill_memories_info["total_nodes"],
+ )
+
+ # Group skill memories by cube_id from metadata.user_name
+ skill_mem_by_cube: dict[str, list] = {}
+ for memory in skill_memories:
+ cube_id = memory.get("metadata", {}).get("user_name", get_mem_req.mem_cube_id)
+ if cube_id not in skill_mem_by_cube:
+ skill_mem_by_cube[cube_id] = []
+ skill_mem_by_cube[cube_id].append(memory)
+
+ # If no memories found, create a default entry with the requested cube_id
+ if not skill_mem_by_cube and get_mem_req.mem_cube_id:
+ skill_mem_by_cube[get_mem_req.mem_cube_id] = []
+
+ results["skill_mem"] = [
+ {
+ "cube_id": cube_id,
+ "memories": memories,
+ "total_nodes": len(memories),
+ }
+ for cube_id, memories in skill_mem_by_cube.items()
+ ]
+
+ preferences: list[TextualMemoryItem] = []
+ total_preference_nodes = 0
+
+ format_preferences = []
+ if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None:
+ filter_params: dict[str, Any] = {}
+ if get_mem_req.user_id is not None:
+ filter_params["user_id"] = get_mem_req.user_id
+ if get_mem_req.mem_cube_id is not None:
+ filter_params["mem_cube_id"] = get_mem_req.mem_cube_id
+ if get_mem_req.filter is not None:
+ # Check and remove user_id/mem_cube_id from filter if present
+ filter_copy = get_mem_req.filter.copy()
+ removed_fields = []
+
+ if "user_id" in filter_copy:
+ filter_copy.pop("user_id")
+ removed_fields.append("user_id")
+ if "mem_cube_id" in filter_copy:
+ filter_copy.pop("mem_cube_id")
+ removed_fields.append("mem_cube_id")
+
+ if removed_fields:
+ logger.warning(
+ f"Fields {removed_fields} found in filter will be ignored. "
+ f"Use request-level user_id/mem_cube_id parameters instead."
+ )
+
+ filter_params.update(filter_copy)
+
+ preferences, total_preference_nodes = naive_mem_cube.pref_mem.get_memory_by_filter(
+ filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size
+ )
+ format_preferences = [format_memory_item(item, save_sources=False) for item in preferences]
+
+ # Group preferences by cube_id from metadata.mem_cube_id
+ pref_mem_by_cube: dict[str, list] = {}
+ for pref in format_preferences:
+ cube_id = pref.get("metadata", {}).get("mem_cube_id", get_mem_req.mem_cube_id)
+ if cube_id not in pref_mem_by_cube:
+ pref_mem_by_cube[cube_id] = []
+ pref_mem_by_cube[cube_id].append(pref)
+
+ # If no preferences found, create a default entry with the requested cube_id
+ if not pref_mem_by_cube and get_mem_req.mem_cube_id:
+ pref_mem_by_cube[get_mem_req.mem_cube_id] = []
+
+ results["pref_mem"] = [
+ {
+ "cube_id": cube_id,
+ "memories": memories,
+ "total_nodes": len(memories),
+ }
+ for cube_id, memories in pref_mem_by_cube.items()
+ ]
+
+ # Filter to only keep text_mem, pref_mem, tool_mem, skill_mem
+ filtered_results = {
+ "text_mem": results.get("text_mem", []),
+ "pref_mem": results.get("pref_mem", []),
+ "tool_mem": results.get("tool_mem", []),
+ "skill_mem": results.get("skill_mem", []),
+ }
+
+ return GetMemoryResponse(message="Memories retrieved successfully", data=filtered_results)
diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py
index 6eda1e2aa..91980bdeb 100644
--- a/src/memos/api/handlers/search_handler.py
+++ b/src/memos/api/handlers/search_handler.py
@@ -11,6 +11,7 @@
from typing import Any
from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
+from memos.api.handlers.formatters_handler import rerank_knowledge_mem
from memos.api.product_models import APISearchRequest, SearchResponse
from memos.log import get_logger
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
@@ -60,39 +61,36 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
# Use deepcopy to avoid modifying the original request object
search_req_local = copy.deepcopy(search_req)
- original_top_k = search_req_local.top_k
# Expand top_k for deduplication (5x to ensure enough candidates)
if search_req_local.dedup in ("sim", "mmr"):
- search_req_local.top_k = original_top_k * 5
-
- # Create new searcher with include_embedding for MMR deduplication
- searcher_to_use = self.searcher
- if search_req_local.dedup == "mmr":
- text_mem = getattr(self.naive_mem_cube, "text_mem", None)
- if text_mem is not None:
- # Create new searcher instance with include_embedding=True
- searcher_to_use = text_mem.get_searcher(
- manual_close_internet=not getattr(self.searcher, "internet_retriever", None),
- moscube=False,
- process_llm=getattr(self.mem_reader, "llm", None),
- )
- # Override include_embedding for this searcher
- if hasattr(searcher_to_use, "graph_retriever"):
- searcher_to_use.graph_retriever.include_embedding = True
+ search_req_local.top_k = search_req_local.top_k * 5
# Search and deduplicate
- cube_view = self._build_cube_view(search_req_local, searcher_to_use)
+ cube_view = self._build_cube_view(search_req_local)
results = cube_view.search_memories(search_req_local)
+ if not search_req_local.relativity:
+ search_req_local.relativity = 0
+ self.logger.info(f"[SearchHandler] Relativity filter: {search_req_local.relativity}")
+ results = self._apply_relativity_threshold(results, search_req_local.relativity)
if search_req_local.dedup == "sim":
- results = self._dedup_text_memories(results, original_top_k)
+ results = self._dedup_text_memories(results, search_req.top_k)
self._strip_embeddings(results)
elif search_req_local.dedup == "mmr":
pref_top_k = getattr(search_req_local, "pref_top_k", 6)
- results = self._mmr_dedup_text_memories(results, original_top_k, pref_top_k)
+ results = self._mmr_dedup_text_memories(results, search_req.top_k, pref_top_k)
self._strip_embeddings(results)
+ text_mem = results["text_mem"]
+ results["text_mem"] = rerank_knowledge_mem(
+ self.reranker,
+ query=search_req.query,
+ text_mem=text_mem,
+ top_k=search_req_local.top_k,
+ file_mem_proportion=0.5,
+ )
+
self.logger.info(
f"[SearchHandler] Final search results: count={len(results)} results={results}"
)
@@ -102,6 +100,40 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
data=results,
)
+ @staticmethod
+ def _apply_relativity_threshold(results: dict[str, Any], relativity: float) -> dict[str, Any]:
+ if relativity <= 0:
+ return results
+
+ for key in ("text_mem", "pref_mem"):
+ buckets = results.get(key)
+ if not isinstance(buckets, list):
+ continue
+
+ for bucket in buckets:
+ memories = bucket.get("memories")
+ if not isinstance(memories, list):
+ continue
+
+ filtered: list[dict[str, Any]] = []
+ for mem in memories:
+ if not isinstance(mem, dict):
+ continue
+ meta = mem.get("metadata", {})
+ score = meta.get("relativity", 0.0) if isinstance(meta, dict) else 0.0
+ try:
+ score_val = float(score) if score is not None else 0.0
+ except (TypeError, ValueError):
+ score_val = 0.0
+ if score_val >= relativity:
+ filtered.append(mem)
+
+ bucket["memories"] = filtered
+ if "total_nodes" in bucket:
+ bucket["total_nodes"] = len(filtered)
+
+ return results
+
def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> dict[str, Any]:
buckets = results.get("text_mem", [])
if not buckets:
@@ -169,7 +201,7 @@ def _mmr_dedup_text_memories(
3. Re-sort by original relevance for better generation quality
"""
text_buckets = results.get("text_mem", [])
- pref_buckets = results.get("preference", [])
+ pref_buckets = results.get("pref_mem", [])
# Early return if no memories to deduplicate
if not text_buckets and not pref_buckets:
@@ -238,7 +270,7 @@ def _mmr_dedup_text_memories(
# Skip if highly similar (Dice + TF-IDF + 2-gram combined, with embedding filter)
if SearchHandler._is_text_highly_similar_optimized(
- idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.9
+ idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.92
):
continue
@@ -281,7 +313,7 @@ def _mmr_dedup_text_memories(
# Skip if highly similar (Dice + TF-IDF + 2-gram combined, with embedding filter)
if SearchHandler._is_text_highly_similar_optimized(
- idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.9
+ idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.92
):
continue # Skip highly similar text, don't participate in MMR competition
@@ -547,6 +579,168 @@ def _is_text_highly_similar_optimized(
return combined_score >= threshold
+ @staticmethod
+ def _dice_similarity(text1: str, text2: str) -> float:
+ """
+ Calculate Dice coefficient (character-level, fastest).
+
+ Dice = 2 * |A ∩ B| / (|A| + |B|)
+ Speed: O(n + m), ~0.05-0.1ms per comparison
+
+ Args:
+ text1: First text string
+ text2: Second text string
+
+ Returns:
+ Dice similarity score between 0.0 and 1.0
+ """
+ if not text1 or not text2:
+ return 0.0
+
+ chars1 = set(text1)
+ chars2 = set(text2)
+
+ intersection = len(chars1 & chars2)
+ return 2 * intersection / (len(chars1) + len(chars2))
+
+ @staticmethod
+ def _bigram_similarity(text1: str, text2: str) -> float:
+ """
+ Calculate character-level 2-gram Jaccard similarity.
+
+ Speed: O(n + m), ~0.1-0.2ms per comparison
+ Considers local order (more strict than Dice).
+
+ Args:
+ text1: First text string
+ text2: Second text string
+
+ Returns:
+ Jaccard similarity score between 0.0 and 1.0
+ """
+ if not text1 or not text2:
+ return 0.0
+
+ # Generate 2-grams
+ bigrams1 = {text1[i : i + 2] for i in range(len(text1) - 1)} if len(text1) >= 2 else {text1}
+ bigrams2 = {text2[i : i + 2] for i in range(len(text2) - 1)} if len(text2) >= 2 else {text2}
+
+ intersection = len(bigrams1 & bigrams2)
+ union = len(bigrams1 | bigrams2)
+
+ return intersection / union if union > 0 else 0.0
+
+ @staticmethod
+ def _tfidf_similarity(text1: str, text2: str) -> float:
+ """
+ Calculate TF-IDF cosine similarity (character-level, no sklearn).
+
+ Speed: O(n + m), ~0.3-0.5ms per comparison
+ Considers character frequency weighting.
+
+ Args:
+ text1: First text string
+ text2: Second text string
+
+ Returns:
+ Cosine similarity score between 0.0 and 1.0
+ """
+ if not text1 or not text2:
+ return 0.0
+
+ from collections import Counter
+
+ # Character frequency (TF)
+ tf1 = Counter(text1)
+ tf2 = Counter(text2)
+
+ # All unique characters (vocabulary)
+ vocab = set(tf1.keys()) | set(tf2.keys())
+
+ # Simple IDF: log(2 / df) where df is document frequency
+ # For two documents, IDF is log(2/1)=0.693 if char appears in one doc,
+ # or log(2/2)=0 if appears in both (we use log(2/1) for simplicity)
+ idf = {char: (1.0 if char in tf1 and char in tf2 else 1.5) for char in vocab}
+
+ # TF-IDF vectors
+ vec1 = {char: tf1.get(char, 0) * idf[char] for char in vocab}
+ vec2 = {char: tf2.get(char, 0) * idf[char] for char in vocab}
+
+ # Cosine similarity
+ dot_product = sum(vec1[char] * vec2[char] for char in vocab)
+ norm1 = math.sqrt(sum(v * v for v in vec1.values()))
+ norm2 = math.sqrt(sum(v * v for v in vec2.values()))
+
+ if norm1 == 0 or norm2 == 0:
+ return 0.0
+
+ return dot_product / (norm1 * norm2)
+
+ @staticmethod
+ def _is_text_highly_similar_optimized(
+ candidate_idx: int,
+ candidate_text: str,
+ selected_global: list[int],
+ similarity_matrix,
+ flat: list,
+ threshold: float = 0.92,
+ ) -> bool:
+ """
+ Multi-algorithm text similarity check with embedding pre-filtering.
+
+ Strategy:
+ 1. Only compare with the single highest embedding similarity item (not all 25)
+ 2. Only perform text comparison if embedding similarity > 0.60
+ 3. Use weighted combination of three algorithms:
+ - Dice (40%): Fastest, character-level set similarity
+ - TF-IDF (35%): Considers character frequency weighting
+ - 2-gram (25%): Considers local character order
+
+ Combined formula:
+ combined_score = 0.40 * dice + 0.35 * tfidf + 0.25 * bigram
+
+ This reduces comparisons from O(N) to O(1) per candidate, with embedding pre-filtering.
+ Expected speedup: 100-200x compared to LCS approach.
+
+ Args:
+ candidate_idx: Index of candidate memory in flat list
+ candidate_text: Text content of candidate memory
+ selected_global: List of already selected memory indices
+ similarity_matrix: Precomputed embedding similarity matrix
+ flat: Flat list of all memories
+ threshold: Combined similarity threshold (default 0.75)
+
+ Returns:
+ True if candidate is highly similar to any selected memory
+ """
+ if not selected_global:
+ return False
+
+ # Find the already-selected memory with highest embedding similarity
+ max_sim_idx = max(selected_global, key=lambda j: similarity_matrix[candidate_idx][j])
+ max_sim = similarity_matrix[candidate_idx][max_sim_idx]
+
+ # If highest embedding similarity < 0.60, skip text comparison entirely
+ if max_sim <= 0.9:
+ return False
+
+ # Get text of most similar memory
+ most_similar_mem = flat[max_sim_idx][2]
+ most_similar_text = most_similar_mem.get("memory", "").strip()
+
+ # Calculate three similarity scores
+ dice_sim = SearchHandler._dice_similarity(candidate_text, most_similar_text)
+ tfidf_sim = SearchHandler._tfidf_similarity(candidate_text, most_similar_text)
+ bigram_sim = SearchHandler._bigram_similarity(candidate_text, most_similar_text)
+
+ # Weighted combination: Dice (40%) + TF-IDF (35%) + 2-gram (25%)
+ # Dice has highest weight (fastest and most reliable)
+ # TF-IDF considers frequency (handles repeated characters well)
+ # 2-gram considers order (catches local pattern similarity)
+ combined_score = 0.40 * dice_sim + 0.35 * tfidf_sim + 0.25 * bigram_sim
+
+ return combined_score >= threshold
+
def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]:
"""
Normalize target cube ids from search_req.
diff --git a/src/memos/api/mcp_serve.py b/src/memos/api/mcp_serve.py
index ce2e41390..8f8e70311 100644
--- a/src/memos/api/mcp_serve.py
+++ b/src/memos/api/mcp_serve.py
@@ -122,15 +122,6 @@ def load_default_config(user_id="default_user"):
return config, cube
-class MOSMCPStdioServer:
- def __init__(self):
- self.mcp = FastMCP("MOS Memory System")
- config, cube = load_default_config()
- self.mos_core = MOS(config=config)
- self.mos_core.register_mem_cube(cube)
- self._setup_tools()
-
-
class MOSMCPServer:
"""MCP Server that accepts an existing MOS instance."""
@@ -584,7 +575,6 @@ def _run_mcp(self, transport: str = "stdio", **kwargs):
raise ValueError(f"Unsupported transport: {transport}")
-MOSMCPStdioServer.run = _run_mcp
MOSMCPServer.run = _run_mcp
@@ -610,5 +600,5 @@ def _run_mcp(self, transport: str = "stdio", **kwargs):
args = parser.parse_args()
# Create and run MCP server
- server = MOSMCPStdioServer()
+ server = MOSMCPServer()
server.run(transport=args.transport, host=args.host, port=args.port)
diff --git a/src/memos/api/middleware/__init__.py b/src/memos/api/middleware/__init__.py
new file mode 100644
index 000000000..64cbc5c60
--- /dev/null
+++ b/src/memos/api/middleware/__init__.py
@@ -0,0 +1,13 @@
+"""Krolik middleware extensions for MemOS."""
+
+from .auth import verify_api_key, require_scope, require_admin, require_read, require_write
+from .rate_limit import RateLimitMiddleware
+
+__all__ = [
+ "verify_api_key",
+ "require_scope",
+ "require_admin",
+ "require_read",
+ "require_write",
+ "RateLimitMiddleware",
+]
diff --git a/src/memos/api/middleware/auth.py b/src/memos/api/middleware/auth.py
new file mode 100644
index 000000000..15b217651
--- /dev/null
+++ b/src/memos/api/middleware/auth.py
@@ -0,0 +1,268 @@
+"""
+API Key Authentication Middleware for MemOS.
+
+Validates API keys and extracts user context for downstream handlers.
+Keys are validated against SHA-256 hashes stored in PostgreSQL.
+"""
+
+import hashlib
+import os
+import time
+
+from typing import Any
+
+from fastapi import Depends, HTTPException, Request, Security
+from fastapi.security import APIKeyHeader
+
+import memos.log
+
+
+logger = memos.log.get_logger(__name__)
+
+# API key header configuration
+API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=False)
+
+# Environment configuration
+AUTH_ENABLED = os.getenv("AUTH_ENABLED", "false").lower() == "true"
+MASTER_KEY_HASH = os.getenv("MASTER_KEY_HASH") # SHA-256 hash of master key
+INTERNAL_SERVICE_IPS = {"127.0.0.1", "::1", "memos-mcp", "moltbot", "clawdbot"}
+
+# Connection pool for auth queries (lazy init)
+_auth_pool = None
+
+
+def _get_auth_pool():
+ """Get or create auth database connection pool."""
+ global _auth_pool
+ if _auth_pool is not None:
+ return _auth_pool
+
+ try:
+ import psycopg2.pool
+
+ _auth_pool = psycopg2.pool.ThreadedConnectionPool(
+ minconn=1,
+ maxconn=5,
+ host=os.getenv("POSTGRES_HOST", "postgres"),
+ port=int(os.getenv("POSTGRES_PORT", "5432")),
+ user=os.getenv("POSTGRES_USER", "memos"),
+ password=os.getenv("POSTGRES_PASSWORD", ""),
+ dbname=os.getenv("POSTGRES_DB", "memos"),
+ connect_timeout=10,
+ )
+ logger.info("Auth database pool initialized")
+ return _auth_pool
+ except Exception as e:
+ logger.error(f"Failed to initialize auth pool: {e}")
+ return None
+
+
+def hash_api_key(key: str) -> str:
+ """Hash an API key using SHA-256."""
+ return hashlib.sha256(key.encode()).hexdigest()
+
+
+def validate_key_format(key: str) -> bool:
+ """Validate API key format: krlk_<64-hex>."""
+ if not key or not key.startswith("krlk_"):
+ return False
+ hex_part = key[5:] # Remove 'krlk_' prefix
+ if len(hex_part) != 64:
+ return False
+ try:
+ int(hex_part, 16)
+ return True
+ except ValueError:
+ return False
+
+
+def get_key_prefix(key: str) -> str:
+ """Extract prefix for key identification (first 12 chars)."""
+ return key[:12] if len(key) >= 12 else key
+
+
+async def lookup_api_key(key_hash: str) -> dict[str, Any] | None:
+ """
+ Look up API key in database.
+
+ Returns dict with user_name, scopes, etc. or None if not found.
+ """
+ pool = _get_auth_pool()
+ if not pool:
+ logger.warning("Auth pool not available, cannot validate key")
+ return None
+
+ conn = None
+ try:
+ conn = pool.getconn()
+ with conn.cursor() as cur:
+ cur.execute(
+ """
+ SELECT id, user_name, scopes, expires_at, is_active
+ FROM api_keys
+ WHERE key_hash = %s
+ """,
+ (key_hash,),
+ )
+ row = cur.fetchone()
+
+ if not row:
+ return None
+
+ key_id, user_name, scopes, expires_at, is_active = row
+
+ # Check if key is active
+ if not is_active:
+ logger.warning(f"Inactive API key used: {key_hash[:16]}...")
+ return None
+
+ # Check expiration
+ if expires_at and expires_at < time.time():
+ logger.warning(f"Expired API key used: {key_hash[:16]}...")
+ return None
+
+ # Update last_used_at
+ cur.execute(
+ "UPDATE api_keys SET last_used_at = NOW() WHERE id = %s",
+ (key_id,),
+ )
+ conn.commit()
+
+ return {
+ "id": str(key_id),
+ "user_name": user_name,
+ "scopes": scopes or ["read"],
+ }
+ except Exception as e:
+ logger.error(f"Database error during key lookup: {e}")
+ return None
+ finally:
+ if conn and pool:
+ pool.putconn(conn)
+
+
+def is_internal_request(request: Request) -> bool:
+ """Check if request is from internal service."""
+ client_host = request.client.host if request.client else None
+
+ # Check internal IPs
+ if client_host in INTERNAL_SERVICE_IPS:
+ return True
+
+ # Check internal header (for container-to-container)
+ internal_header = request.headers.get("X-Internal-Service")
+ return internal_header == os.getenv("INTERNAL_SERVICE_SECRET")
+
+
+async def verify_api_key(
+ request: Request,
+ api_key: str | None = Security(API_KEY_HEADER),
+) -> dict[str, Any]:
+ """
+ Verify API key and return user context.
+
+ This is the main dependency for protected endpoints.
+
+ Returns:
+ dict with user_name, scopes, and is_master_key flag
+
+ Raises:
+ HTTPException 401 if authentication fails
+ """
+ # Skip auth if disabled
+ if not AUTH_ENABLED:
+ return {
+ "user_name": request.headers.get("X-User-Name", "default"),
+ "scopes": ["all"],
+ "is_master_key": False,
+ "auth_bypassed": True,
+ }
+
+ # Allow internal services
+ if is_internal_request(request):
+ logger.debug(f"Internal request from {request.client.host}")
+ return {
+ "user_name": "internal",
+ "scopes": ["all"],
+ "is_master_key": False,
+ "is_internal": True,
+ }
+
+ # Require API key
+ if not api_key:
+ raise HTTPException(
+ status_code=401,
+ detail="Missing API key",
+ headers={"WWW-Authenticate": "ApiKey"},
+ )
+
+ # Handle "Bearer" or "Token" prefix
+ if api_key.lower().startswith("bearer "):
+ api_key = api_key[7:]
+ elif api_key.lower().startswith("token "):
+ api_key = api_key[6:]
+
+ # Check against master key first (has different format: mk_*)
+ key_hash = hash_api_key(api_key)
+ if MASTER_KEY_HASH and key_hash == MASTER_KEY_HASH:
+ logger.info("Master key authentication")
+ return {
+ "user_name": "admin",
+ "scopes": ["all"],
+ "is_master_key": True,
+ }
+
+ # Validate format for regular API keys (krlk_*)
+ if not validate_key_format(api_key):
+ raise HTTPException(
+ status_code=401,
+ detail="Invalid API key format",
+ )
+
+ # Look up in database
+ key_data = await lookup_api_key(key_hash)
+ if not key_data:
+ logger.warning(f"Invalid API key attempt: {get_key_prefix(api_key)}...")
+ raise HTTPException(
+ status_code=401,
+ detail="Invalid or expired API key",
+ )
+
+ logger.debug(f"Authenticated user: {key_data['user_name']}")
+ return {
+ "user_name": key_data["user_name"],
+ "scopes": key_data["scopes"],
+ "is_master_key": False,
+ "api_key_id": key_data["id"],
+ }
+
+
+def require_scope(required_scope: str):
+ """
+ Dependency factory to require a specific scope.
+
+ Usage:
+ @router.post("/admin/keys", dependencies=[Depends(require_scope("admin"))])
+ """
+
+ async def scope_checker(
+ auth: dict[str, Any] = Depends(verify_api_key), # noqa: B008
+ ) -> dict[str, Any]:
+ scopes = auth.get("scopes", [])
+
+ # "all" scope grants everything
+ if "all" in scopes or required_scope in scopes:
+ return auth
+
+ raise HTTPException(
+ status_code=403,
+ detail=f"Insufficient permissions. Required scope: {required_scope}",
+ )
+
+ return scope_checker
+
+
+# Convenience dependencies
+require_read = require_scope("read")
+require_write = require_scope("write")
+require_admin = require_scope("admin")
diff --git a/src/memos/api/middleware/rate_limit.py b/src/memos/api/middleware/rate_limit.py
new file mode 100644
index 000000000..c547378ca
--- /dev/null
+++ b/src/memos/api/middleware/rate_limit.py
@@ -0,0 +1,207 @@
+"""
+Redis-based Rate Limiting Middleware.
+
+Implements sliding window rate limiting with Redis.
+Falls back to in-memory limiting if Redis is unavailable.
+"""
+
+import os
+import time
+
+from collections import defaultdict
+from collections.abc import Callable
+from typing import ClassVar
+
+from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.requests import Request
+from starlette.responses import JSONResponse, Response
+
+import memos.log
+
+
+logger = memos.log.get_logger(__name__)
+
+# Configuration from environment
+RATE_LIMIT = int(os.getenv("RATE_LIMIT", "100")) # Requests per window
+RATE_WINDOW = int(os.getenv("RATE_WINDOW_SEC", "60")) # Window in seconds
+REDIS_URL = os.getenv("REDIS_URL", "redis://redis:6379")
+
+# Redis client (lazy initialization)
+_redis_client = None
+
+# In-memory fallback (per process)
+_memory_store: dict[str, list[float]] = defaultdict(list)
+
+
+def _get_redis():
+ """Get or create Redis client."""
+ global _redis_client
+ if _redis_client is not None:
+ return _redis_client
+
+ try:
+ import redis
+
+ _redis_client = redis.from_url(REDIS_URL, decode_responses=True)
+ _redis_client.ping() # Test connection
+ logger.info("Rate limiter connected to Redis")
+ return _redis_client
+ except Exception as e:
+ logger.warning(f"Redis not available for rate limiting: {e}")
+ return None
+
+
+def _get_client_key(request: Request) -> str:
+ """
+ Generate a unique key for rate limiting.
+
+ Uses API key if available, otherwise falls back to IP.
+ """
+ # Try to get API key from header
+ auth_header = request.headers.get("Authorization", "")
+ if auth_header.startswith("krlk_"):
+ # Use first 20 chars of key as identifier
+ return f"ratelimit:key:{auth_header[:20]}"
+
+ # Fall back to IP address
+ client_ip = request.client.host if request.client else "unknown"
+
+ # Check for forwarded IP (behind proxy)
+ forwarded = request.headers.get("X-Forwarded-For")
+ if forwarded:
+ client_ip = forwarded.split(",")[0].strip()
+
+ return f"ratelimit:ip:{client_ip}"
+
+
+def _check_rate_limit_redis(key: str) -> tuple[bool, int, int]:
+ """
+ Check rate limit using Redis sliding window.
+
+ Returns:
+ (allowed, remaining, reset_time)
+ """
+ redis_client = _get_redis()
+ if not redis_client:
+ return _check_rate_limit_memory(key)
+
+ try:
+ now = time.time()
+ window_start = now - RATE_WINDOW
+
+ pipe = redis_client.pipeline()
+
+ # Remove old entries
+ pipe.zremrangebyscore(key, 0, window_start)
+
+ # Count current entries
+ pipe.zcard(key)
+
+ # Add current request
+ pipe.zadd(key, {str(now): now})
+
+ # Set expiry
+ pipe.expire(key, RATE_WINDOW + 1)
+
+ results = pipe.execute()
+ current_count = results[1]
+
+ remaining = max(0, RATE_LIMIT - current_count - 1)
+ reset_time = int(now + RATE_WINDOW)
+
+ if current_count >= RATE_LIMIT:
+ return False, 0, reset_time
+
+ return True, remaining, reset_time
+
+ except Exception as e:
+ logger.warning(f"Redis rate limit error: {e}")
+ return _check_rate_limit_memory(key)
+
+
+def _check_rate_limit_memory(key: str) -> tuple[bool, int, int]:
+ """
+ Fallback in-memory rate limiting.
+
+ Note: This is per-process and not distributed!
+ """
+ now = time.time()
+ window_start = now - RATE_WINDOW
+
+ # Clean old entries
+ _memory_store[key] = [t for t in _memory_store[key] if t > window_start]
+
+ current_count = len(_memory_store[key])
+
+ if current_count >= RATE_LIMIT:
+ reset_time = (
+ int(min(_memory_store[key]) + RATE_WINDOW)
+ if _memory_store[key]
+ else int(now + RATE_WINDOW)
+ )
+ return False, 0, reset_time
+
+ # Add current request
+ _memory_store[key].append(now)
+
+ remaining = RATE_LIMIT - current_count - 1
+ reset_time = int(now + RATE_WINDOW)
+
+ return True, remaining, reset_time
+
+
+class RateLimitMiddleware(BaseHTTPMiddleware):
+ """
+ Rate limiting middleware using sliding window algorithm.
+
+ Adds headers:
+ - X-RateLimit-Limit: Maximum requests per window
+ - X-RateLimit-Remaining: Remaining requests
+ - X-RateLimit-Reset: Unix timestamp when the window resets
+
+ Returns 429 Too Many Requests when limit is exceeded.
+ """
+
+ # Paths exempt from rate limiting
+ EXEMPT_PATHS: ClassVar[set[str]] = {"/health", "/openapi.json", "/docs", "/redoc"}
+
+ async def dispatch(self, request: Request, call_next: Callable) -> Response:
+ # Skip rate limiting for exempt paths
+ if request.url.path in self.EXEMPT_PATHS:
+ return await call_next(request)
+
+ # Skip OPTIONS requests (CORS preflight)
+ if request.method == "OPTIONS":
+ return await call_next(request)
+
+ # Get rate limit key
+ key = _get_client_key(request)
+
+ # Check rate limit
+ allowed, remaining, reset_time = _check_rate_limit_redis(key)
+
+ if not allowed:
+ logger.warning(f"Rate limit exceeded for {key}")
+ return JSONResponse(
+ status_code=429,
+ content={
+ "detail": "Too many requests. Please slow down.",
+ "retry_after": reset_time - int(time.time()),
+ },
+ headers={
+ "X-RateLimit-Limit": str(RATE_LIMIT),
+ "X-RateLimit-Remaining": "0",
+ "X-RateLimit-Reset": str(reset_time),
+ "Retry-After": str(reset_time - int(time.time())),
+ },
+ )
+
+ # Process request
+ response = await call_next(request)
+
+ # Add rate limit headers
+ response.headers["X-RateLimit-Limit"] = str(RATE_LIMIT)
+ response.headers["X-RateLimit-Remaining"] = str(remaining)
+ response.headers["X-RateLimit-Reset"] = str(reset_time)
+
+ return response
diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py
index d8fa784a3..c41526e33 100644
--- a/src/memos/api/product_models.py
+++ b/src/memos/api/product_models.py
@@ -46,6 +46,7 @@ class GetMemoryPlaygroundRequest(BaseRequest):
)
mem_cube_ids: list[str] | None = Field(None, description="Cube IDs")
search_query: str | None = Field(None, description="Search query")
+ search_type: Literal["embedding", "fulltext"] = Field("fulltext", description="Search type")
# Start API Models
@@ -95,6 +96,8 @@ class ChatRequest(BaseRequest):
temperature: float | None = Field(None, description="Temperature for sampling")
top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter")
add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat")
+ manager_user_id: str | None = Field(None, description="Manager User ID")
+ project_id: str | None = Field(None, description="Project ID")
# ==== Filter conditions ====
filter: dict[str, Any] | None = Field(
@@ -167,6 +170,13 @@ class ChatPlaygroundRequest(ChatRequest):
)
+class ChatBusinessRequest(ChatRequest):
+ """Request model for chat operations for business user."""
+
+ business_key: str = Field(..., description="Business User Key")
+ need_search: bool = Field(False, description="Whether to need search before chat")
+
+
class ChatCompleteRequest(BaseRequest):
"""Request model for chat operations. will (Deprecated), instead use APIChatCompleteRequest."""
@@ -319,6 +329,16 @@ class APISearchRequest(BaseRequest):
description="Number of textual memories to retrieve (top-K). Default: 10.",
)
+ relativity: float = Field(
+ 0.0,
+ ge=0,
+ description=(
+ "Relevance threshold for recalled memories. "
+ "Only memories with metadata.relativity >= relativity will be returned. "
+ "Use 0 to disable threshold filtering. Default: 0.3."
+ ),
+ )
+
dedup: Literal["no", "sim", "mmr"] | None = Field(
"mmr",
description=(
@@ -405,7 +425,7 @@ class APISearchRequest(BaseRequest):
# Internal field for search memory type
search_memory_type: str = Field(
"All",
- description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory, SkillMemory",
+ description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory, RawFileMemory, AllSummaryMemory, SkillMemory",
)
# ==== Context ====
@@ -443,6 +463,13 @@ class APISearchRequest(BaseRequest):
description="Source of the search query [plugin will router diff search]",
)
+ neighbor_discovery: bool = Field(
+ False,
+ description="Whether to enable neighbor discovery. "
+ "If enabled, the system will automatically recall neighbor chunks "
+ "relevant to the query. Default: False.",
+ )
+
@model_validator(mode="after")
def _convert_deprecated_fields(self) -> "APISearchRequest":
"""
@@ -489,6 +516,8 @@ class APIADDRequest(BaseRequest):
description="Session ID. If not provided, a default session will be used.",
)
task_id: str | None = Field(None, description="Task ID for monitering async tasks")
+ manager_user_id: str | None = Field(None, description="Manager User ID")
+ project_id: str | None = Field(None, description="Project ID")
# ==== Multi-cube writing ====
writable_cube_ids: list[str] | None = Field(
@@ -744,6 +773,8 @@ class APIChatCompleteRequest(BaseRequest):
temperature: float | None = Field(None, description="Temperature for sampling")
top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter")
add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat")
+ manager_user_id: str | None = Field(None, description="Manager User ID")
+ project_id: str | None = Field(None, description="Project ID")
# ==== Filter conditions ====
filter: dict[str, Any] | None = Field(
@@ -796,6 +827,12 @@ class GetMemoryRequest(BaseRequest):
)
+class GetMemoryDashboardRequest(GetMemoryRequest):
+ """Request model for getting memories for dashboard."""
+
+ mem_cube_id: str | None = Field(None, description="Cube ID")
+
+
class DeleteMemoryRequest(BaseRequest):
"""Request model for deleting memories."""
@@ -1217,3 +1254,26 @@ class ExistMemCubeIdRequest(BaseRequest):
class ExistMemCubeIdResponse(BaseResponse[dict[str, bool]]):
"""Response model for checking if mem cube id exists."""
+
+
+class DeleteMemoryByRecordIdRequest(BaseRequest):
+ """Request model for deleting memory by record id."""
+
+ mem_cube_id: str = Field(..., description="Mem cube ID")
+ record_id: str = Field(..., description="Record ID")
+ hard_delete: bool = Field(False, description="Hard delete")
+
+
+class DeleteMemoryByRecordIdResponse(BaseResponse[dict]):
+ """Response model for deleting memory by record id."""
+
+
+class RecoverMemoryByRecordIdRequest(BaseRequest):
+ """Request model for recovering memory by record id."""
+
+ mem_cube_id: str = Field(..., description="Mem cube ID")
+ delete_record_id: str = Field(..., description="Delete record ID")
+
+
+class RecoverMemoryByRecordIdResponse(BaseResponse[dict]):
+ """Response model for recovering memory by record id."""
diff --git a/src/memos/api/routers/admin_router.py b/src/memos/api/routers/admin_router.py
new file mode 100644
index 000000000..238643ba9
--- /dev/null
+++ b/src/memos/api/routers/admin_router.py
@@ -0,0 +1,228 @@
+"""
+Admin Router for API Key Management.
+
+Protected by master key or admin scope.
+"""
+
+import os
+
+from typing import Any
+
+from fastapi import APIRouter, Depends, HTTPException
+from pydantic import BaseModel, Field
+
+import memos.log
+
+from memos.api.middleware.auth import require_scope, verify_api_key
+from memos.api.utils.api_keys import (
+ create_api_key_in_db,
+ generate_master_key,
+ list_api_keys,
+ revoke_api_key,
+)
+
+
+logger = memos.log.get_logger(__name__)
+
+router = APIRouter(prefix="/admin", tags=["Admin"])
+
+
+# Request/Response models
+class CreateKeyRequest(BaseModel):
+ user_name: str = Field(..., min_length=1, max_length=255)
+ scopes: list[str] = Field(default=["read"])
+ description: str | None = Field(default=None, max_length=500)
+ expires_in_days: int | None = Field(default=None, ge=1, le=365)
+
+
+class CreateKeyResponse(BaseModel):
+ message: str
+ key: str # Only returned once!
+ key_prefix: str
+ user_name: str
+ scopes: list[str]
+
+
+class KeyListResponse(BaseModel):
+ message: str
+ keys: list[dict[str, Any]]
+
+
+class RevokeKeyRequest(BaseModel):
+ key_id: str
+
+
+class SimpleResponse(BaseModel):
+ message: str
+ success: bool = True
+
+
+def _get_db_connection():
+ """Get database connection for admin operations."""
+ import psycopg2
+
+ return psycopg2.connect(
+ host=os.getenv("POSTGRES_HOST", "postgres"),
+ port=int(os.getenv("POSTGRES_PORT", "5432")),
+ user=os.getenv("POSTGRES_USER", "memos"),
+ password=os.getenv("POSTGRES_PASSWORD", ""),
+ dbname=os.getenv("POSTGRES_DB", "memos"),
+ )
+
+
+@router.post(
+ "/keys",
+ response_model=CreateKeyResponse,
+ summary="Create a new API key",
+ dependencies=[Depends(require_scope("admin"))],
+)
+def create_key(
+ request: CreateKeyRequest,
+ auth: dict = Depends(verify_api_key), # noqa: B008
+):
+ """
+ Create a new API key for a user.
+
+ Requires admin scope or master key.
+
+ **WARNING**: The API key is only returned once. Store it securely!
+ """
+ try:
+ conn = _get_db_connection()
+ try:
+ api_key = create_api_key_in_db(
+ conn=conn,
+ user_name=request.user_name,
+ scopes=request.scopes,
+ description=request.description,
+ expires_in_days=request.expires_in_days,
+ created_by=auth.get("user_name", "unknown"),
+ )
+
+ logger.info(
+ f"API key created for user '{request.user_name}' by '{auth.get('user_name')}'"
+ )
+
+ return CreateKeyResponse(
+ message="API key created successfully. Store this key securely - it won't be shown again!",
+ key=api_key.key,
+ key_prefix=api_key.key_prefix,
+ user_name=request.user_name,
+ scopes=request.scopes,
+ )
+ finally:
+ conn.close()
+ except Exception as e:
+ logger.error(f"Failed to create API key: {e}")
+ raise HTTPException(status_code=500, detail="Failed to create API key") from e
+
+
+@router.get(
+ "/keys",
+ response_model=KeyListResponse,
+ summary="List API keys",
+ dependencies=[Depends(require_scope("admin"))],
+)
+def list_keys(
+ user_name: str | None = None,
+ auth: dict = Depends(verify_api_key), # noqa: B008
+):
+ """
+ List all API keys (admin) or keys for a specific user.
+
+ Note: Actual key values are never returned, only prefixes.
+ """
+ try:
+ conn = _get_db_connection()
+ try:
+ keys = list_api_keys(conn, user_name=user_name)
+ return KeyListResponse(
+ message=f"Found {len(keys)} key(s)",
+ keys=keys,
+ )
+ finally:
+ conn.close()
+ except Exception as e:
+ logger.error(f"Failed to list API keys: {e}")
+ raise HTTPException(status_code=500, detail="Failed to list API keys") from e
+
+
+@router.delete(
+ "/keys/{key_id}",
+ response_model=SimpleResponse,
+ summary="Revoke an API key",
+ dependencies=[Depends(require_scope("admin"))],
+)
+def revoke_key(
+ key_id: str,
+ auth: dict = Depends(verify_api_key), # noqa: B008
+):
+ """
+ Revoke an API key by ID.
+
+ The key will be deactivated but not deleted (for audit purposes).
+ """
+ try:
+ conn = _get_db_connection()
+ try:
+ success = revoke_api_key(conn, key_id)
+ if success:
+ logger.info(f"API key {key_id} revoked by '{auth.get('user_name')}'")
+ return SimpleResponse(message="API key revoked successfully")
+ else:
+ raise HTTPException(status_code=404, detail="API key not found or already revoked")
+ finally:
+ conn.close()
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Failed to revoke API key: {e}")
+ raise HTTPException(status_code=500, detail="Failed to revoke API key") from e
+
+
+@router.post(
+ "/generate-master-key",
+ response_model=dict,
+ summary="Generate a new master key",
+ dependencies=[Depends(require_scope("admin"))],
+)
+def generate_new_master_key(
+ auth: dict = Depends(verify_api_key), # noqa: B008
+):
+ """
+ Generate a new master key.
+
+ **WARNING**: Store the key securely! Add MASTER_KEY_HASH to your .env file.
+ """
+ if not auth.get("is_master_key"):
+ raise HTTPException(
+ status_code=403,
+ detail="Only master key can generate new master keys",
+ )
+
+ key, key_hash = generate_master_key()
+
+ logger.warning("New master key generated - update MASTER_KEY_HASH in .env")
+
+ return {
+ "message": "Master key generated. Add MASTER_KEY_HASH to your .env file!",
+ "key": key,
+ "key_hash": key_hash,
+ "env_line": f"MASTER_KEY_HASH={key_hash}",
+ }
+
+
+@router.get(
+ "/health",
+ summary="Admin health check",
+)
+def admin_health():
+ """Health check for admin endpoints."""
+ auth_enabled = os.getenv("AUTH_ENABLED", "false").lower() == "true"
+ master_key_configured = bool(os.getenv("MASTER_KEY_HASH"))
+
+ return {
+ "status": "ok",
+ "auth_enabled": auth_enabled,
+ "master_key_configured": master_key_configured,
+ }
diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py
index 736c328ac..83079239f 100644
--- a/src/memos/api/routers/server_router.py
+++ b/src/memos/api/routers/server_router.py
@@ -29,18 +29,24 @@
APIChatCompleteRequest,
APIFeedbackRequest,
APISearchRequest,
+ ChatBusinessRequest,
ChatPlaygroundRequest,
ChatRequest,
+ DeleteMemoryByRecordIdRequest,
+ DeleteMemoryByRecordIdResponse,
DeleteMemoryRequest,
DeleteMemoryResponse,
ExistMemCubeIdRequest,
ExistMemCubeIdResponse,
+ GetMemoryDashboardRequest,
GetMemoryPlaygroundRequest,
GetMemoryRequest,
GetMemoryResponse,
GetUserNamesByMemoryIdsRequest,
GetUserNamesByMemoryIdsResponse,
MemoryResponse,
+ RecoverMemoryByRecordIdRequest,
+ RecoverMemoryByRecordIdResponse,
SearchResponse,
StatusResponse,
SuggestionRequest,
@@ -290,8 +296,9 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest):
memory_req.mem_cube_ids[0] if memory_req.mem_cube_ids else memory_req.user_id
),
query=memory_req.search_query,
- top_k=20,
+ top_k=200,
naive_mem_cube=naive_mem_cube,
+ search_type=memory_req.search_type,
)
else:
return handlers.memory_handler.handle_get_all_memories(
@@ -394,9 +401,69 @@ def get_user_names_by_memory_ids(request: GetUserNamesByMemoryIdsRequest):
response_model=ExistMemCubeIdResponse,
)
def exist_mem_cube_id(request: ExistMemCubeIdRequest):
- """Check if mem cube id exists."""
+ """(inner) Check if mem cube id exists."""
return ExistMemCubeIdResponse(
code=200,
message="Successfully",
data=graph_db.exist_user_name(user_name=request.mem_cube_id),
)
+
+
+@router.post("/chat/stream/business_user", summary="Chat with MemOS for business user")
+def chat_stream_business_user(chat_req: ChatBusinessRequest):
+ """(inner) Chat with MemOS for a specific business user. Returns SSE stream."""
+ if chat_handler is None:
+ raise HTTPException(
+ status_code=503, detail="Chat service is not available. Chat handler not initialized."
+ )
+
+ return chat_handler.handle_chat_stream_for_business_user(chat_req)
+
+
+@router.post(
+ "/delete_memory_by_record_id",
+ summary="Delete memory by record id",
+ response_model=DeleteMemoryByRecordIdResponse,
+)
+def delete_memory_by_record_id(memory_req: DeleteMemoryByRecordIdRequest):
+ """(inner) Delete memory nodes by mem_cube_id (user_name) and delete_record_id. Record id is inner field, just for delete and recover memory, not for user to set."""
+ graph_db.delete_node_by_mem_cube_id(
+ mem_cube_id=memory_req.mem_cube_id,
+ delete_record_id=memory_req.record_id,
+ hard_delete=memory_req.hard_delete,
+ )
+
+ return DeleteMemoryByRecordIdResponse(
+ code=200,
+ message="Called Successfully",
+ data={"status": "success"},
+ )
+
+
+@router.post(
+ "/recover_memory_by_record_id",
+ summary="Recover memory by record id",
+ response_model=RecoverMemoryByRecordIdResponse,
+)
+def recover_memory_by_record_id(memory_req: RecoverMemoryByRecordIdRequest):
+ """(inner) Recover memory nodes by mem_cube_id (user_name) and delete_record_id. Record id is inner field, just for delete and recover memory, not for user to set."""
+ graph_db.recover_memory_by_mem_cube_id(
+ mem_cube_id=memory_req.mem_cube_id,
+ delete_record_id=memory_req.delete_record_id,
+ )
+
+ return RecoverMemoryByRecordIdResponse(
+ code=200,
+ message="Called Successfully",
+ data={"status": "success"},
+ )
+
+
+@router.post(
+ "/get_memory_dashboard", summary="Get memories for dashboard", response_model=GetMemoryResponse
+)
+def get_memories_dashboard(memory_req: GetMemoryDashboardRequest):
+ return handlers.memory_handler.handle_get_memories_dashboard(
+ get_mem_req=memory_req,
+ naive_mem_cube=naive_mem_cube,
+ )
diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py
index ac9ed8d88..529a709a4 100644
--- a/src/memos/api/server_api.py
+++ b/src/memos/api/server_api.py
@@ -1,13 +1,18 @@
import logging
+import os
+from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
+from starlette.staticfiles import StaticFiles
from memos.api.exceptions import APIExceptionHandler
from memos.api.middleware.request_context import RequestContextMiddleware
from memos.api.routers.server_router import router as server_router
+load_dotenv()
+
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
@@ -18,6 +23,8 @@
version="1.0.1",
)
+app.mount("/download", StaticFiles(directory=os.getenv("FILE_LOCAL_PATH")), name="static_mapping")
+
app.add_middleware(RequestContextMiddleware, source="server_api")
# Include routers
app.include_router(server_router)
diff --git a/src/memos/api/server_api_ext.py b/src/memos/api/server_api_ext.py
new file mode 100644
index 000000000..8c457e362
--- /dev/null
+++ b/src/memos/api/server_api_ext.py
@@ -0,0 +1,124 @@
+"""
+Extended Server API for Krolik deployment.
+
+This module extends the base MemOS server_api with:
+- API Key Authentication (PostgreSQL-backed)
+- Redis Rate Limiting
+- Admin API for key management
+- Security Headers
+
+Usage in Dockerfile:
+ # Copy overlays after base installation
+ COPY overlays/krolik/ /app/src/memos/
+
+ # Use this as entrypoint instead of server_api
+ CMD ["gunicorn", "memos.api.server_api_ext:app", ...]
+"""
+
+import logging
+import os
+
+from fastapi import FastAPI
+from fastapi.exceptions import RequestValidationError
+from fastapi.middleware.cors import CORSMiddleware
+from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.requests import Request
+from starlette.responses import Response
+
+# Import Krolik extensions
+from memos.api.middleware.rate_limit import RateLimitMiddleware
+from memos.api.routers.admin_router import router as admin_router
+
+# Import base routers from MemOS
+from memos.api.routers.server_router import router as server_router
+
+
+# Try to import exception handlers (may vary between MemOS versions)
+try:
+ from memos.api.exceptions import APIExceptionHandler
+
+ HAS_EXCEPTION_HANDLER = True
+except ImportError:
+ HAS_EXCEPTION_HANDLER = False
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
+logger = logging.getLogger(__name__)
+
+
+class SecurityHeadersMiddleware(BaseHTTPMiddleware):
+ """Add security headers to all responses."""
+
+ async def dispatch(self, request: Request, call_next) -> Response:
+ response = await call_next(request)
+ response.headers["X-Content-Type-Options"] = "nosniff"
+ response.headers["X-Frame-Options"] = "DENY"
+ response.headers["X-XSS-Protection"] = "1; mode=block"
+ response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
+ response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
+ return response
+
+
+# Create FastAPI app
+app = FastAPI(
+ title="MemOS Server REST APIs (Krolik Extended)",
+ description="MemOS API with authentication, rate limiting, and admin endpoints.",
+ version="2.0.3-krolik",
+)
+
+# CORS configuration
+CORS_ORIGINS = os.getenv("CORS_ORIGINS", "").split(",")
+CORS_ORIGINS = [origin.strip() for origin in CORS_ORIGINS if origin.strip()]
+
+if not CORS_ORIGINS:
+ CORS_ORIGINS = [
+ "https://krolik.hully.one",
+ "https://memos.hully.one",
+ "http://localhost:3000",
+ ]
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=CORS_ORIGINS,
+ allow_credentials=True,
+ allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
+ allow_headers=["Authorization", "Content-Type", "X-API-Key", "X-User-Name"],
+)
+
+# Security headers
+app.add_middleware(SecurityHeadersMiddleware)
+
+# Rate limiting (before auth to protect against brute force)
+RATE_LIMIT_ENABLED = os.getenv("RATE_LIMIT_ENABLED", "true").lower() == "true"
+if RATE_LIMIT_ENABLED:
+ app.add_middleware(RateLimitMiddleware)
+ logger.info("Rate limiting enabled")
+
+# Include routers
+app.include_router(server_router)
+app.include_router(admin_router)
+
+# Exception handlers
+if HAS_EXCEPTION_HANDLER:
+ from fastapi import HTTPException
+
+ app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler)
+ app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler)
+ app.exception_handler(HTTPException)(APIExceptionHandler.http_error_handler)
+ app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler)
+
+
+@app.get("/health")
+async def health_check():
+ """Health check endpoint."""
+ return {
+ "status": "healthy",
+ "version": "2.0.3-krolik",
+ "auth_enabled": os.getenv("AUTH_ENABLED", "false").lower() == "true",
+ "rate_limit_enabled": RATE_LIMIT_ENABLED,
+ }
+
+
+if __name__ == "__main__":
+ import uvicorn
+
+ uvicorn.run("memos.api.server_api_ext:app", host="0.0.0.0", port=8000, workers=1)
diff --git a/src/memos/api/utils/__init__.py b/src/memos/api/utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/memos/api/utils/api_keys.py b/src/memos/api/utils/api_keys.py
new file mode 100644
index 000000000..559ddd355
--- /dev/null
+++ b/src/memos/api/utils/api_keys.py
@@ -0,0 +1,197 @@
+"""
+API Key Management Utilities.
+
+Provides functions for generating, validating, and managing API keys.
+"""
+
+import hashlib
+import os
+import secrets
+from dataclasses import dataclass
+from datetime import datetime, timedelta
+
+
+@dataclass
+class APIKey:
+ """Represents a generated API key."""
+
+ key: str # Full key (only available at creation time)
+ key_hash: str # SHA-256 hash (stored in database)
+ key_prefix: str # First 12 chars for identification
+
+
+def generate_api_key() -> APIKey:
+ """
+ Generate a new API key.
+
+ Format: krlk_<64-hex-chars>
+
+ Returns:
+ APIKey with key, hash, and prefix
+ """
+ # Generate 32 random bytes = 64 hex chars
+ random_bytes = secrets.token_bytes(32)
+ hex_part = random_bytes.hex()
+
+ key = f"krlk_{hex_part}"
+ key_hash = hashlib.sha256(key.encode()).hexdigest()
+ key_prefix = key[:12]
+
+ return APIKey(key=key, key_hash=key_hash, key_prefix=key_prefix)
+
+
+def hash_key(key: str) -> str:
+ """Hash an API key using SHA-256."""
+ return hashlib.sha256(key.encode()).hexdigest()
+
+
+def validate_key_format(key: str) -> bool:
+ """
+ Validate API key format.
+
+ Valid format: krlk_<64-hex-chars>
+ """
+ if not key or not isinstance(key, str):
+ return False
+
+ if not key.startswith("krlk_"):
+ return False
+
+ hex_part = key[5:]
+ if len(hex_part) != 64:
+ return False
+
+ try:
+ int(hex_part, 16)
+ return True
+ except ValueError:
+ return False
+
+
+def generate_master_key() -> tuple[str, str]:
+ """
+ Generate a master key for admin operations.
+
+ Returns:
+ Tuple of (key, hash)
+ """
+ random_bytes = secrets.token_bytes(32)
+ key = f"mk_{random_bytes.hex()}"
+ key_hash = hashlib.sha256(key.encode()).hexdigest()
+ return key, key_hash
+
+
+def create_api_key_in_db(
+ conn,
+ user_name: str,
+ scopes: list[str] | None = None,
+ description: str | None = None,
+ expires_in_days: int | None = None,
+ created_by: str | None = None,
+) -> APIKey:
+ """
+ Create a new API key and store in database.
+
+ Args:
+ conn: Database connection
+ user_name: Owner of the key
+ scopes: List of scopes (default: ["read"])
+ description: Human-readable description
+ expires_in_days: Days until expiration (None = never)
+ created_by: Who created this key
+
+ Returns:
+ APIKey with the generated key (only time it's available!)
+ """
+ api_key = generate_api_key()
+
+ expires_at = None
+ if expires_in_days:
+ expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
+
+ with conn.cursor() as cur:
+ cur.execute(
+ """
+ INSERT INTO api_keys (key_hash, key_prefix, user_name, scopes, description, expires_at, created_by)
+ VALUES (%s, %s, %s, %s, %s, %s, %s)
+ RETURNING id
+ """,
+ (
+ api_key.key_hash,
+ api_key.key_prefix,
+ user_name,
+ scopes or ["read"],
+ description,
+ expires_at,
+ created_by,
+ ),
+ )
+ conn.commit()
+
+ return api_key
+
+
+def revoke_api_key(conn, key_id: str) -> bool:
+ """
+ Revoke an API key by ID.
+
+ Returns:
+ True if key was revoked, False if not found
+ """
+ with conn.cursor() as cur:
+ cur.execute(
+ "UPDATE api_keys SET is_active = false WHERE id = %s AND is_active = true",
+ (key_id,),
+ )
+ conn.commit()
+ return cur.rowcount > 0
+
+
+def list_api_keys(conn, user_name: str | None = None) -> list[dict]:
+ """
+ List API keys (without exposing the actual keys).
+
+ Args:
+ conn: Database connection
+ user_name: Filter by user (None = all users)
+
+ Returns:
+ List of key metadata dicts
+ """
+ with conn.cursor() as cur:
+ if user_name:
+ cur.execute(
+ """
+ SELECT id, key_prefix, user_name, scopes, description,
+ created_at, last_used_at, expires_at, is_active
+ FROM api_keys
+ WHERE user_name = %s
+ ORDER BY created_at DESC
+ """,
+ (user_name,),
+ )
+ else:
+ cur.execute(
+ """
+ SELECT id, key_prefix, user_name, scopes, description,
+ created_at, last_used_at, expires_at, is_active
+ FROM api_keys
+ ORDER BY created_at DESC
+ """
+ )
+
+ rows = cur.fetchall()
+ return [
+ {
+ "id": str(row[0]),
+ "key_prefix": row[1],
+ "user_name": row[2],
+ "scopes": row[3],
+ "description": row[4],
+ "created_at": row[5].isoformat() if row[5] else None,
+ "last_used_at": row[6].isoformat() if row[6] else None,
+ "expires_at": row[7].isoformat() if row[7] else None,
+ "is_active": row[8],
+ }
+ for row in rows
+ ]
diff --git a/src/memos/configs/chunker.py b/src/memos/configs/chunker.py
index c2af012f0..f9a738415 100644
--- a/src/memos/configs/chunker.py
+++ b/src/memos/configs/chunker.py
@@ -14,6 +14,7 @@ class BaseChunkerConfig(BaseConfig):
chunk_size: int = Field(default=512, description="Maximum tokens per chunk")
chunk_overlap: int = Field(default=128, description="Overlap between chunks")
min_sentences_per_chunk: int = Field(default=1, description="Minimum sentences in each chunk")
+ save_rawfile: bool = Field(default=True, description="Whether to save rawfile") # TODO
class SentenceChunkerConfig(BaseChunkerConfig):
diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py
index 3b4bace0e..9b1ce7f9d 100644
--- a/src/memos/configs/graph_db.py
+++ b/src/memos/configs/graph_db.py
@@ -211,6 +211,60 @@ def validate_config(self):
return self
+class PostgresGraphDBConfig(BaseConfig):
+ """
+ PostgreSQL + pgvector configuration for MemOS.
+
+ Uses standard PostgreSQL with pgvector extension for vector search.
+ Does NOT require Apache AGE or other graph extensions.
+
+ Schema:
+ - memos_memories: Main table for memory nodes (id, memory, properties JSONB, embedding vector)
+ - memos_edges: Edge table for relationships (source_id, target_id, type)
+
+ Example:
+ ---
+ host = "postgres"
+ port = 5432
+ user = "n8n"
+ password = "secret"
+ db_name = "n8n"
+ schema_name = "memos"
+ user_name = "default"
+ """
+
+ host: str = Field(..., description="Database host")
+ port: int = Field(default=5432, description="Database port")
+ user: str = Field(..., description="Database user")
+ password: str = Field(..., description="Database password")
+ db_name: str = Field(..., description="Database name")
+ schema_name: str = Field(default="memos", description="Schema name for MemOS tables")
+ user_name: str | None = Field(
+ default=None,
+ description="Logical user/tenant ID for data isolation",
+ )
+ use_multi_db: bool = Field(
+ default=False,
+ description="If False: use single database with logical isolation by user_name",
+ )
+ embedding_dimension: int = Field(
+ default=768, description="Dimension of vector embedding (768 for all-mpnet-base-v2)"
+ )
+ maxconn: int = Field(
+ default=20,
+ description="Maximum number of connections in the connection pool",
+ )
+
+ @model_validator(mode="after")
+ def validate_config(self):
+ """Validate config."""
+ if not self.db_name:
+ raise ValueError("`db_name` must be provided")
+ if not self.use_multi_db and not self.user_name:
+ raise ValueError("In single-database mode, `user_name` must be provided")
+ return self
+
+
class GraphDBConfigFactory(BaseModel):
backend: str = Field(..., description="Backend for graph database")
config: dict[str, Any] = Field(..., description="Configuration for the graph database backend")
@@ -220,6 +274,7 @@ class GraphDBConfigFactory(BaseModel):
"neo4j-community": Neo4jCommunityGraphDBConfig,
"nebular": NebulaGraphDBConfig,
"polardb": PolarDBGraphDBConfig,
+ "postgres": PostgresGraphDBConfig,
}
@field_validator("backend")
diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py
index ec9cbcda0..c207e3190 100644
--- a/src/memos/graph_dbs/factory.py
+++ b/src/memos/graph_dbs/factory.py
@@ -6,6 +6,7 @@
from memos.graph_dbs.neo4j import Neo4jGraphDB
from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB
from memos.graph_dbs.polardb import PolarDBGraphDB
+from memos.graph_dbs.postgres import PostgresGraphDB
class GraphStoreFactory(BaseGraphDB):
@@ -16,6 +17,7 @@ class GraphStoreFactory(BaseGraphDB):
"neo4j-community": Neo4jCommunityGraphDB,
"nebular": NebulaGraphDB,
"polardb": PolarDBGraphDB,
+ "postgres": PostgresGraphDB,
}
@classmethod
diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py
index 70d40f13c..23ce2408b 100644
--- a/src/memos/graph_dbs/neo4j.py
+++ b/src/memos/graph_dbs/neo4j.py
@@ -502,7 +502,7 @@ def edge_exists(
return result.single() is not None
# Graph Query & Reasoning
- def get_node(self, id: str, **kwargs) -> dict[str, Any] | None:
+ def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None:
"""
Retrieve the metadata and memory of a node.
Args:
@@ -510,18 +510,28 @@ def get_node(self, id: str, **kwargs) -> dict[str, Any] | None:
Returns:
Dictionary of node fields, or None if not found.
"""
- user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
+ logger.info(f"[get_node] id: {id}")
+ user_name = kwargs.get("user_name")
where_user = ""
params = {"id": id}
- if not self.config.use_multi_db and (self.config.user_name or user_name):
+ if user_name is not None:
where_user = " AND n.user_name = $user_name"
params["user_name"] = user_name
query = f"MATCH (n:Memory) WHERE n.id = $id {where_user} RETURN n"
+ logger.info(f"[get_node] query: {query}")
with self.driver.session(database=self.db_name) as session:
record = session.run(query, params).single()
- return self._parse_node(dict(record["n"])) if record else None
+ if not record:
+ return None
+
+ node_dict = dict(record["n"])
+ if include_embedding is False:
+ for key in ("embedding", "embedding_1024", "embedding_3072", "embedding_768"):
+ node_dict.pop(key, None)
+
+ return self._parse_node(node_dict)
def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]:
"""
@@ -1153,6 +1163,10 @@ def export_graph(
self,
page: int | None = None,
page_size: int | None = None,
+ memory_type: list[str] | None = None,
+ status: list[str] | None = None,
+ filter: dict | None = None,
+ include_embedding: bool = False,
**kwargs,
) -> dict[str, Any]:
"""
@@ -1161,6 +1175,13 @@ def export_graph(
Args:
page (int, optional): Page number (starts from 1). If None, exports all data without pagination.
page_size (int, optional): Number of items per page. If None, exports all data without pagination.
+ memory_type (list[str], optional): List of memory_type values to filter by. If provided, only nodes/edges
+ with memory_type in this list will be exported.
+ status (list[str], optional): If not provided, only nodes/edges with status != 'deleted' are exported.
+ If provided (non-empty list), only nodes/edges with status in this list are exported.
+ filter (dict, optional): Filter conditions with 'and' or 'or' logic. Same as get_all_memory_items.
+ Example: {"and": [{"id": "xxx"}, {"A": "yyy"}]} or {"or": [{"id": "xxx"}, {"A": "yyy"}]}
+ include_embedding (bool): Whether to include embedding fields in node metadata. Default False (same as get_node).
**kwargs: Additional keyword arguments, including:
- user_name (str, optional): User name for filtering in non-multi-db mode
@@ -1172,6 +1193,9 @@ def export_graph(
"total_edges": int, # Total number of edges matching the filter criteria
}
"""
+ logger.info(
+ f" export_graph include_embedding: {include_embedding}, kwargs: {kwargs}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}"
+ )
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
# Initialize total counts
@@ -1190,16 +1214,64 @@ def export_graph(
skip = (page - 1) * page_size
with self.driver.session(database=self.db_name) as session:
- # Build base queries
- node_base_query = "MATCH (n:Memory)"
- edge_base_query = "MATCH (a:Memory)-[r]->(b:Memory)"
- params = {}
+ # Build WHERE conditions for nodes
+ node_where_clauses = []
+ params: dict[str, Any] = {}
if not self.config.use_multi_db and (self.config.user_name or user_name):
- node_base_query += " WHERE n.user_name = $user_name"
- edge_base_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name"
+ node_where_clauses.append("n.user_name = $user_name")
params["user_name"] = user_name
+ if memory_type and isinstance(memory_type, list) and len(memory_type) > 0:
+ node_where_clauses.append("n.memory_type IN $memory_type")
+ params["memory_type"] = memory_type
+
+ if status is None:
+ node_where_clauses.append("n.status <> 'deleted'")
+ elif isinstance(status, list) and len(status) > 0:
+ node_where_clauses.append("n.status IN $status")
+ params["status"] = status
+
+ # Build filter conditions using common method (same as get_all_memory_items)
+ filter_conditions, filter_params = self._build_filter_conditions_cypher(
+ filter=filter,
+ param_counter_start=0,
+ node_alias="n",
+ )
+ logger.info(f"export_graph filter_conditions: {filter_conditions}")
+ node_where_clauses.extend(filter_conditions)
+ if filter_params:
+ params.update(filter_params)
+
+ node_base_query = "MATCH (n:Memory)"
+ if node_where_clauses:
+ node_base_query += " WHERE " + " AND ".join(node_where_clauses)
+ logger.info(f"export_graph node_base_query: {node_base_query}")
+
+ # Build WHERE conditions for edges (a and b must match same filters)
+ edge_where_clauses = []
+ if not self.config.use_multi_db and (self.config.user_name or user_name):
+ edge_where_clauses.append("a.user_name = $user_name AND b.user_name = $user_name")
+ if memory_type and isinstance(memory_type, list) and len(memory_type) > 0:
+ edge_where_clauses.append(
+ "a.memory_type IN $memory_type AND b.memory_type IN $memory_type"
+ )
+ if status is None:
+ edge_where_clauses.append("a.status <> 'deleted' AND b.status <> 'deleted'")
+ elif isinstance(status, list) and len(status) > 0:
+ edge_where_clauses.append("a.status IN $status AND b.status IN $status")
+ # Apply same filter to both endpoints of the edge
+ if filter_conditions:
+ filter_a = [c.replace("n.", "a.") for c in filter_conditions]
+ filter_b = [c.replace("n.", "b.") for c in filter_conditions]
+ edge_where_clauses.append(
+ f"({' AND '.join(filter_a)}) AND ({' AND '.join(filter_b)})"
+ )
+
+ edge_base_query = "MATCH (a:Memory)-[r]->(b:Memory)"
+ if edge_where_clauses:
+ edge_base_query += " WHERE " + " AND ".join(edge_where_clauses)
+
# Get total count of nodes before pagination
count_node_query = node_base_query + " RETURN COUNT(n) AS count"
count_node_result = session.run(count_node_query, params)
@@ -1211,7 +1283,13 @@ def export_graph(
node_query += f" SKIP {skip} LIMIT {page_size}"
node_result = session.run(node_query, params)
- nodes = [self._parse_node(dict(record["n"])) for record in node_result]
+ nodes = []
+ for record in node_result:
+ node_dict = dict(record["n"])
+ if not include_embedding:
+ for key in ("embedding", "embedding_1024", "embedding_3072", "embedding_768"):
+ node_dict.pop(key, None)
+ nodes.append(self._parse_node(node_dict))
# Get total count of edges before pagination
count_edge_query = edge_base_query + " RETURN COUNT(r) AS count"
@@ -1225,7 +1303,7 @@ def export_graph(
)
if use_pagination:
edge_query += f" SKIP {skip} LIMIT {page_size}"
-
+ logger.info(f"export_graph edge_query: {edge_query},params:{params}")
edge_result = session.run(edge_query, params)
edges = [
{"source": record["source"], "target": record["target"], "type": record["type"]}
@@ -1313,10 +1391,6 @@ def get_all_memory_items(
logger.info(
f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}"
)
- print(
- f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}"
- )
-
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}:
raise ValueError(f"Unsupported memory type scope: {scope}")
@@ -1367,11 +1441,17 @@ def get_all_memory_items(
RETURN n
"""
logger.info(f"[get_all_memory_items] query: {query},params: {params}")
- print(f"[get_all_memory_items] query: {query},params: {params}")
with self.driver.session(database=self.db_name) as session:
results = session.run(query, params)
- return [self._parse_node(dict(record["n"])) for record in results]
+ nodes = []
+ for record in results:
+ node_dict = dict(record["n"])
+ if not include_embedding:
+ for key in ("embedding", "embedding_1024", "embedding_3072", "embedding_768"):
+ node_dict.pop(key, None)
+ nodes.append(self._parse_node(node_dict))
+ return nodes
def get_structure_optimization_candidates(self, scope: str, **kwargs) -> list[dict]:
"""
@@ -1940,3 +2020,136 @@ def exist_user_name(self, user_name: str) -> dict[str, bool]:
f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True
)
raise
+
+ def delete_node_by_mem_cube_id(
+ self,
+ mem_cube_id: str | None = None,
+ delete_record_id: str | None = None,
+ hard_delete: bool = False,
+ ) -> int:
+ logger.info(
+ f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, "
+ f"delete_record_id:{delete_record_id}, hard_delete:{hard_delete}"
+ )
+
+ if not mem_cube_id:
+ logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided")
+ return 0
+
+ if not delete_record_id:
+ logger.warning(
+ "[delete_node_by_mem_cube_id] delete_record_id is required but not provided"
+ )
+ return 0
+
+ try:
+ with self.driver.session(database=self.db_name) as session:
+ if hard_delete:
+ query = """
+ MATCH (n:Memory)
+ WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id
+ DETACH DELETE n
+ """
+ logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {query}")
+
+ result = session.run(
+ query, mem_cube_id=mem_cube_id, delete_record_id=delete_record_id
+ )
+ summary = result.consume()
+ deleted_count = summary.counters.nodes_deleted if summary.counters else 0
+
+ logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes")
+ return deleted_count
+ else:
+ current_time = datetime.utcnow().isoformat()
+
+ query = """
+ MATCH (n:Memory)
+ WHERE n.user_name = $mem_cube_id
+ AND (n.delete_time IS NULL OR n.delete_time = "")
+ AND (n.delete_record_id IS NULL OR n.delete_record_id = "")
+ SET n.status = $status,
+ n.delete_record_id = $delete_record_id,
+ n.delete_time = $delete_time
+ RETURN count(n) AS updated_count
+ """
+ logger.info(f"[delete_node_by_mem_cube_id] Soft delete query: {query}")
+
+ result = session.run(
+ query,
+ mem_cube_id=mem_cube_id,
+ status="deleted",
+ delete_record_id=delete_record_id,
+ delete_time=current_time,
+ )
+ record = result.single()
+ updated_count = record["updated_count"] if record else 0
+
+ logger.info(
+ f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes"
+ )
+ return updated_count
+
+ except Exception as e:
+ logger.error(
+ f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True
+ )
+ raise
+
+ def recover_memory_by_mem_cube_id(
+ self,
+ mem_cube_id: str | None = None,
+ delete_record_id: str | None = None,
+ ) -> int:
+ logger.info(
+ f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}"
+ )
+ # Validate required parameters
+ if not mem_cube_id:
+ logger.warning("recover_memory_by_mem_cube_id mem_cube_id is required but not provided")
+ return 0
+
+ if not delete_record_id:
+ logger.warning(
+ "recover_memory_by_mem_cube_id delete_record_id is required but not provided"
+ )
+ return 0
+
+ logger.info(
+ f"recover_memory_by_mem_cube_id mem_cube_id={mem_cube_id}, "
+ f"delete_record_id={delete_record_id}"
+ )
+
+ try:
+ with self.driver.session(database=self.db_name) as session:
+ query = """
+ MATCH (n:Memory)
+ WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id
+ SET n.status = $status,
+ n.delete_record_id = $delete_record_id_empty,
+ n.delete_time = $delete_time_empty
+ RETURN count(n) AS updated_count
+ """
+ logger.info(f"[recover_memory_by_mem_cube_id] Update query: {query}")
+
+ result = session.run(
+ query,
+ mem_cube_id=mem_cube_id,
+ delete_record_id=delete_record_id,
+ status="activated",
+ delete_record_id_empty="",
+ delete_time_empty="",
+ )
+ record = result.single()
+ updated_count = record["updated_count"] if record else 0
+
+ logger.info(
+ f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes"
+ )
+ return updated_count
+
+ except Exception as e:
+ logger.error(
+ f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True
+ )
+ raise
diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py
index f2182f6cd..e34313fa2 100644
--- a/src/memos/graph_dbs/neo4j_community.py
+++ b/src/memos/graph_dbs/neo4j_community.py
@@ -620,7 +620,9 @@ def get_all_memory_items(
with self.driver.session(database=self.db_name) as session:
results = session.run(query, params)
- return [self._parse_node(dict(record["n"])) for record in results]
+ nodes_data = [dict(record["n"]) for record in results]
+ # Use batch parsing to fetch all embeddings at once
+ return self._parse_nodes(nodes_data)
def get_by_metadata(
self,
@@ -1056,3 +1058,261 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
logger.warning(f"Failed to fetch vector for node {new_node['id']}: {e}")
new_node["metadata"]["embedding"] = None
return new_node
+
+ def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ """Parse multiple Neo4j nodes and batch fetch embeddings from vector DB."""
+ if not nodes_data:
+ return []
+
+ # First, parse all nodes without embeddings
+ parsed_nodes = []
+ node_ids = []
+ for node_data in nodes_data:
+ node = node_data.copy()
+
+ # Convert Neo4j datetime to string
+ for time_field in ("created_at", "updated_at"):
+ if time_field in node and hasattr(node[time_field], "isoformat"):
+ node[time_field] = node[time_field].isoformat()
+ node.pop("user_name", None)
+ # serialization
+ if node.get("sources"):
+ for idx in range(len(node["sources"])):
+ if not (
+ isinstance(node["sources"][idx], str)
+ and node["sources"][idx][0] == "{"
+ and node["sources"][idx][0] == "}"
+ ):
+ break
+ node["sources"][idx] = json.loads(node["sources"][idx])
+
+ node_id = node.pop("id")
+ node_ids.append(node_id)
+ parsed_nodes.append({"id": node_id, "memory": node.pop("memory", ""), "metadata": node})
+
+ # Batch fetch all embeddings at once
+ vec_items_map = {}
+ if node_ids:
+ try:
+ vec_items = self.vec_db.get_by_ids(node_ids)
+ vec_items_map = {v.id: v.vector for v in vec_items if v and v.vector}
+ except Exception as e:
+ logger.warning(f"Failed to batch fetch vectors for {len(node_ids)} nodes: {e}")
+
+ # Merge embeddings into parsed nodes
+ for parsed_node in parsed_nodes:
+ node_id = parsed_node["id"]
+ parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id, None)
+
+ return parsed_nodes
+
+ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]:
+ """Get user names by memory ids.
+
+ Args:
+ memory_ids: List of memory node IDs to query.
+
+ Returns:
+ dict[str, str | None]: Dictionary mapping memory_id to user_name.
+ - Key: memory_id
+ - Value: user_name if exists, None if memory_id does not exist
+ Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None}
+ """
+ if not memory_ids:
+ return {}
+
+ logger.info(
+ f"[ neo4j_community get_user_names_by_memory_ids] Querying memory_ids {memory_ids}"
+ )
+
+ try:
+ with self.driver.session(database=self.db_name) as session:
+ # Query to get memory_id and user_name pairs
+ query = """
+ MATCH (n:Memory)
+ WHERE n.id IN $memory_ids
+ RETURN n.id AS memory_id, n.user_name AS user_name
+ """
+ logger.info(f"[get_user_names_by_memory_ids] query: {query}")
+
+ result = session.run(query, memory_ids=memory_ids)
+ result_dict = {}
+
+ # Build result dictionary from query results
+ for record in result:
+ memory_id = record["memory_id"]
+ user_name = record["user_name"]
+ result_dict[memory_id] = user_name if user_name else None
+
+ # Set None for memory_ids that were not found
+ for mid in memory_ids:
+ if mid not in result_dict:
+ result_dict[mid] = None
+
+ logger.info(
+ f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, "
+ f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names"
+ )
+
+ return result_dict
+ except Exception as e:
+ logger.error(
+ f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
+ )
+ raise
+
+ def delete_node_by_mem_cube_id(
+ self,
+ mem_cube_id: str | None = None,
+ delete_record_id: str | None = None,
+ hard_delete: bool = False,
+ ) -> int:
+ logger.info(
+ f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, "
+ f"delete_record_id:{delete_record_id}, hard_delete:{hard_delete}"
+ )
+
+ if not mem_cube_id:
+ logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided")
+ return 0
+
+ if not delete_record_id:
+ logger.warning(
+ "[delete_node_by_mem_cube_id] delete_record_id is required but not provided"
+ )
+ return 0
+
+ try:
+ with self.driver.session(database=self.db_name) as session:
+ if hard_delete:
+ query_get_ids = """
+ MATCH (n:Memory)
+ WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id
+ RETURN n.id AS id
+ """
+ result = session.run(
+ query_get_ids, mem_cube_id=mem_cube_id, delete_record_id=delete_record_id
+ )
+ node_ids = [record["id"] for record in result]
+
+ # Delete from Neo4j
+ query = """
+ MATCH (n:Memory)
+ WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id
+ DETACH DELETE n
+ """
+ logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {query}")
+
+ result = session.run(
+ query, mem_cube_id=mem_cube_id, delete_record_id=delete_record_id
+ )
+ summary = result.consume()
+ deleted_count = summary.counters.nodes_deleted if summary.counters else 0
+
+ # Delete from vector DB
+ if node_ids and self.vec_db:
+ try:
+ self.vec_db.delete(node_ids)
+ logger.info(
+ f"[delete_node_by_mem_cube_id] Deleted {len(node_ids)} vectors from VecDB"
+ )
+ except Exception as e:
+ logger.warning(
+ f"[delete_node_by_mem_cube_id] Failed to delete vectors from VecDB: {e}"
+ )
+
+ logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes")
+ return deleted_count
+ else:
+ current_time = datetime.utcnow().isoformat()
+
+ query = """
+ MATCH (n:Memory)
+ WHERE n.user_name = $mem_cube_id
+ AND (n.delete_time IS NULL OR n.delete_time = "")
+ AND (n.delete_record_id IS NULL OR n.delete_record_id = "")
+ SET n.status = $status,
+ n.delete_record_id = $delete_record_id,
+ n.delete_time = $delete_time
+ RETURN count(n) AS updated_count
+ """
+ logger.info(f"[delete_node_by_mem_cube_id] Soft delete query: {query}")
+
+ result = session.run(
+ query,
+ mem_cube_id=mem_cube_id,
+ status="deleted",
+ delete_record_id=delete_record_id,
+ delete_time=current_time,
+ )
+ record = result.single()
+ updated_count = record["updated_count"] if record else 0
+
+ logger.info(
+ f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes"
+ )
+ return updated_count
+
+ except Exception as e:
+ logger.error(
+ f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True
+ )
+ raise
+
+ def recover_memory_by_mem_cube_id(
+ self,
+ mem_cube_id: str | None = None,
+ delete_record_id: str | None = None,
+ ) -> int:
+ logger.info(
+ f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}"
+ )
+ # Validate required parameters
+ if not mem_cube_id:
+ logger.warning("recover_memory_by_mem_cube_id mem_cube_id is required but not provided")
+ return 0
+
+ if not delete_record_id:
+ logger.warning(
+ "recover_memory_by_mem_cube_id delete_record_id is required but not provided"
+ )
+ return 0
+
+ logger.info(
+ f"recover_memory_by_mem_cube_id mem_cube_id={mem_cube_id}, "
+ f"delete_record_id={delete_record_id}"
+ )
+
+ try:
+ with self.driver.session(database=self.db_name) as session:
+ query = """
+ MATCH (n:Memory)
+ WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id
+ SET n.status = $status,
+ n.delete_record_id = $delete_record_id_empty,
+ n.delete_time = $delete_time_empty
+ RETURN count(n) AS updated_count
+ """
+ logger.info(f"[recover_memory_by_mem_cube_id] Update query: {query}")
+
+ result = session.run(
+ query,
+ mem_cube_id=mem_cube_id,
+ delete_record_id=delete_record_id,
+ status="activated",
+ delete_record_id_empty="",
+ delete_time_empty="",
+ )
+ record = result.single()
+ updated_count = record["updated_count"] if record else 0
+
+ logger.info(
+ f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes"
+ )
+ return updated_count
+
+ except Exception as e:
+ logger.error(
+ f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True
+ )
+ raise
diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py
index b9c8ca84b..f0a23e39b 100644
--- a/src/memos/graph_dbs/polardb.py
+++ b/src/memos/graph_dbs/polardb.py
@@ -1691,7 +1691,7 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
raise NotImplementedError
@timed
- def seach_by_keywords_like(
+ def search_by_keywords_like(
self,
query_word: str,
scope: str | None = None,
@@ -1761,7 +1761,7 @@ def seach_by_keywords_like(
params = (query_word,)
logger.info(
- f"[seach_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}"
+ f"[search_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}"
)
conn = None
try:
@@ -1773,16 +1773,18 @@ def seach_by_keywords_like(
for row in results:
oldid = row[0]
id_val = str(oldid)
+ if id_val.startswith('"') and id_val.endswith('"'):
+ id_val = id_val[1:-1]
output.append({"id": id_val})
logger.info(
- f"[seach_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}"
+ f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}"
)
return output
finally:
self._return_connection(conn)
@timed
- def seach_by_keywords_tfidf(
+ def search_by_keywords_tfidf(
self,
query_words: list[str],
scope: str | None = None,
@@ -1858,7 +1860,7 @@ def seach_by_keywords_tfidf(
params = (tsquery_string,)
logger.info(
- f"[seach_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}"
+ f"[search_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}"
)
conn = None
try:
@@ -1870,10 +1872,12 @@ def seach_by_keywords_tfidf(
for row in results:
oldid = row[0]
id_val = str(oldid)
+ if id_val.startswith('"') and id_val.endswith('"'):
+ id_val = id_val[1:-1]
output.append({"id": id_val})
logger.info(
- f"[seach_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}"
+ f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}"
)
return output
finally:
@@ -2003,6 +2007,8 @@ def search_by_fulltext(
rank = row[2] # rank score
id_val = str(oldid)
+ if id_val.startswith('"') and id_val.endswith('"'):
+ id_val = id_val[1:-1]
score_val = float(rank)
# Apply threshold filter if specified
@@ -2167,6 +2173,8 @@ def search_by_embedding(
oldid = row[3] # old_id
score = row[4] # scope
id_val = str(oldid)
+ if id_val.startswith('"') and id_val.endswith('"'):
+ id_val = id_val[1:-1]
score_val = float(score)
score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
if threshold is None or score_val >= threshold:
@@ -2534,6 +2542,8 @@ def export_graph(
page: int | None = None,
page_size: int | None = None,
filter: dict | None = None,
+ memory_type: list[str] | None = None,
+ status: list[str] | None = None,
**kwargs,
) -> dict[str, Any]:
"""
@@ -2551,6 +2561,11 @@ def export_graph(
- "gt", "lt", "gte", "lte": comparison operators
- "like": fuzzy matching
Example: {"and": [{"created_at": {"gte": "2025-01-01"}}, {"tags": {"contains": "AI"}}]}
+ memory_type (list[str], optional): List of memory_type values to filter by. If provided, only nodes/edges with
+ memory_type in this list will be exported. Example: ["LongTermMemory", "WorkingMemory"]
+ status (list[str], optional): List of status values to filter by. If not provided, only nodes/edges with
+ status != 'deleted' are exported. If provided, only nodes/edges with status in this list are exported.
+ Example: ["activated"] or ["activated", "archived"]
Returns:
{
@@ -2561,7 +2576,7 @@ def export_graph(
}
"""
logger.info(
- f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}"
+ f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}"
)
user_id = user_id if user_id else self._get_config_value("user_id")
@@ -2596,6 +2611,36 @@ def export_graph(
f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype"
)
+ # Add memory_type filter condition
+ if memory_type and isinstance(memory_type, list) and len(memory_type) > 0:
+ # Escape memory_type values and build IN clause
+ memory_type_values = []
+ for mt in memory_type:
+ # Escape single quotes in memory_type value
+ escaped_memory_type = str(mt).replace("'", "''")
+ memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype")
+ memory_type_in_clause = ", ".join(memory_type_values)
+ where_conditions.append(
+ f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})"
+ )
+
+ # Add status filter condition: if not passed, exclude deleted; otherwise filter by IN list
+ if status is None:
+ # Default behavior: exclude deleted entries
+ where_conditions.append(
+ "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) <> '\"deleted\"'::agtype"
+ )
+ elif isinstance(status, list) and len(status) > 0:
+ # status IN (list)
+ status_values = []
+ for st in status:
+ escaped_status = str(st).replace("'", "''")
+ status_values.append(f"'\"{escaped_status}\"'::agtype")
+ status_in_clause = ", ".join(status_values)
+ where_conditions.append(
+ f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) IN ({status_in_clause})"
+ )
+
# Build filter conditions using common method
filter_conditions = self._build_filter_conditions_sql(filter)
logger.info(f"[export_graph] filter_conditions: {filter_conditions}")
@@ -2691,6 +2736,25 @@ def export_graph(
cypher_where_conditions.append(f"a.user_id = '{user_id}'")
cypher_where_conditions.append(f"b.user_id = '{user_id}'")
+ # Add memory_type filter condition for edges (apply to both source and target nodes)
+ if memory_type and isinstance(memory_type, list) and len(memory_type) > 0:
+ # Escape single quotes in memory_type values for Cypher
+ escaped_memory_types = [mt.replace("'", "\\'") for mt in memory_type]
+ memory_type_list_str = ", ".join([f"'{mt}'" for mt in escaped_memory_types])
+ # Cypher IN syntax: a.memory_type IN ['LongTermMemory', 'WorkingMemory']
+ cypher_where_conditions.append(f"a.memory_type IN [{memory_type_list_str}]")
+ cypher_where_conditions.append(f"b.memory_type IN [{memory_type_list_str}]")
+
+ # Add status filter for edges: if not passed, exclude deleted; otherwise filter by IN list
+ if status is None:
+ # Default behavior: exclude deleted entries
+ cypher_where_conditions.append("a.status <> 'deleted' AND b.status <> 'deleted'")
+ elif isinstance(status, list) and len(status) > 0:
+ escaped_statuses = [st.replace("'", "\\'") for st in status]
+ status_list_str = ", ".join([f"'{st}'" for st in escaped_statuses])
+ cypher_where_conditions.append(f"a.status IN [{status_list_str}]")
+ cypher_where_conditions.append(f"b.status IN [{status_list_str}]")
+
# Build filter conditions for edges (apply to both source and target nodes)
filter_where_clause = self._build_filter_conditions_cypher(filter)
logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}")
@@ -4310,7 +4374,7 @@ def _build_user_name_and_kb_ids_conditions_sql(
user_name_conditions = []
effective_user_name = user_name if user_name else default_user_name
- if effective_user_name:
+ if user_name:
user_name_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{effective_user_name}\"'::agtype"
)
@@ -5124,6 +5188,9 @@ def parse_filter(
"info",
"source",
"file_ids",
+ "project_id",
+ "manager_user_id",
+ "delete_time",
}
def process_condition(condition):
@@ -5441,3 +5508,170 @@ def escape_user_name(un: str) -> str:
raise
finally:
self._return_connection(conn)
+
+ @timed
+ def delete_node_by_mem_cube_id(
+ self,
+ mem_cube_id: str | None = None,
+ delete_record_id: str | None = None,
+ hard_delete: bool = False,
+ ) -> int:
+ logger.info(
+ f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, "
+ f"delete_record_id:{delete_record_id}, hard_delete:{hard_delete}"
+ )
+
+ if not mem_cube_id:
+ logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided")
+ return 0
+
+ if not delete_record_id:
+ logger.warning(
+ "[delete_node_by_mem_cube_id] delete_record_id is required but not provided"
+ )
+ return 0
+
+ conn = None
+ try:
+ conn = self._get_connection()
+ with conn.cursor() as cursor:
+ user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
+
+ user_name_param = self.format_param_value(mem_cube_id)
+
+ if hard_delete:
+ delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype"
+ where_clause = f"{user_name_condition} AND {delete_record_id_condition}"
+
+ where_params = [user_name_param, self.format_param_value(delete_record_id)]
+
+ delete_query = f"""
+ DELETE FROM "{self.db_name}_graph"."Memory"
+ WHERE {where_clause}
+ """
+ logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {delete_query}")
+
+ cursor.execute(delete_query, where_params)
+ deleted_count = cursor.rowcount
+
+ logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes")
+ return deleted_count
+ else:
+ delete_time_empty_condition = (
+ "(ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) IS NULL "
+ "OR ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) = '\"\"'::agtype)"
+ )
+ delete_record_id_empty_condition = (
+ "(ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) IS NULL "
+ "OR ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = '\"\"'::agtype)"
+ )
+ where_clause = f"{user_name_condition} AND {delete_time_empty_condition} AND {delete_record_id_empty_condition}"
+
+ current_time = datetime.utcnow().isoformat()
+ update_query = f"""
+ UPDATE "{self.db_name}_graph"."Memory"
+ SET properties = (
+ properties::jsonb || %s::jsonb
+ )::text::agtype,
+ deletetime = %s
+ WHERE {where_clause}
+ """
+ update_properties = {
+ "status": "deleted",
+ "delete_time": current_time,
+ "delete_record_id": delete_record_id,
+ }
+ logger.info(
+ f"delete_node_by_mem_cube_id Soft delete update_query:{update_query},update_properties:{update_properties},deletetime:{current_time}"
+ )
+ update_params = [json.dumps(update_properties), current_time, user_name_param]
+ cursor.execute(update_query, update_params)
+ updated_count = cursor.rowcount
+
+ logger.info(
+ f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes"
+ )
+ return updated_count
+
+ except Exception as e:
+ logger.error(
+ f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True
+ )
+ raise
+ finally:
+ self._return_connection(conn)
+
+ @timed
+ def recover_memory_by_mem_cube_id(
+ self,
+ mem_cube_id: str | None = None,
+ delete_record_id: str | None = None,
+ ) -> int:
+ logger.info(
+ f"recover_memory_by_mem_cube_id mem_cube_id:{mem_cube_id},delete_record_id:{delete_record_id}"
+ )
+ # Validate required parameters
+ if not mem_cube_id:
+ logger.warning("recover_memory_by_mem_cube_id mem_cube_id is required but not provided")
+ return 0
+
+ if not delete_record_id:
+ logger.warning(
+ "recover_memory_by_mem_cube_id delete_record_id is required but not provided"
+ )
+ return 0
+
+ logger.info(
+ f"recover_memory_by_mem_cube_id mem_cube_id={mem_cube_id}, "
+ f"delete_record_id={delete_record_id}"
+ )
+
+ conn = None
+ try:
+ conn = self._get_connection()
+ with conn.cursor() as cursor:
+ user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
+ delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype"
+ where_clause = f"{user_name_condition} AND {delete_record_id_condition}"
+
+ where_params = [
+ self.format_param_value(mem_cube_id),
+ self.format_param_value(delete_record_id),
+ ]
+
+ update_properties = {
+ "status": "activated",
+ "delete_record_id": "",
+ "delete_time": "",
+ }
+
+ update_query = f"""
+ UPDATE "{self.db_name}_graph"."Memory"
+ SET properties = (
+ properties::jsonb || %s::jsonb
+ )::text::agtype,
+ deletetime = NULL
+ WHERE {where_clause}
+ """
+
+ logger.info(f"[recover_memory_by_mem_cube_id] Update query: {update_query}")
+ logger.info(
+ f"[recover_memory_by_mem_cube_id] update_properties: {update_properties}"
+ )
+
+ update_params = [json.dumps(update_properties), *where_params]
+ cursor.execute(update_query, update_params)
+ updated_count = cursor.rowcount
+
+ logger.info(
+ f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes"
+ )
+ return updated_count
+
+ except Exception as e:
+ logger.error(
+ f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True
+ )
+ raise
+ finally:
+ self._return_connection(conn)
diff --git a/src/memos/graph_dbs/postgres.py b/src/memos/graph_dbs/postgres.py
new file mode 100644
index 000000000..1c1cae378
--- /dev/null
+++ b/src/memos/graph_dbs/postgres.py
@@ -0,0 +1,976 @@
+"""
+PostgreSQL + pgvector backend for MemOS.
+
+Simple implementation using standard PostgreSQL with pgvector extension.
+No Apache AGE or other graph extensions required.
+
+Tables:
+- {schema}.memories: Memory nodes with JSONB properties and vector embeddings
+- {schema}.edges: Relationships between memory nodes
+"""
+
+import json
+import time
+
+from contextlib import suppress
+from datetime import datetime
+from typing import Any, Literal
+
+from memos.configs.graph_db import PostgresGraphDBConfig
+from memos.dependency import require_python_package
+from memos.graph_dbs.base import BaseGraphDB
+from memos.log import get_logger
+
+
+logger = get_logger(__name__)
+
+
+def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
+ """Ensure metadata has proper datetime fields and normalized types."""
+ now = datetime.utcnow().isoformat()
+ metadata.setdefault("created_at", now)
+ metadata.setdefault("updated_at", now)
+
+ # Normalize embedding type
+ embedding = metadata.get("embedding")
+ if embedding and isinstance(embedding, list):
+ metadata["embedding"] = [float(x) for x in embedding]
+
+ return metadata
+
+
+class PostgresGraphDB(BaseGraphDB):
+ """PostgreSQL + pgvector implementation of a graph memory store."""
+
+ @require_python_package(
+ import_name="psycopg2",
+ install_command="pip install psycopg2-binary",
+ install_link="https://pypi.org/project/psycopg2-binary/",
+ )
+ def __init__(self, config: PostgresGraphDBConfig):
+ """Initialize PostgreSQL connection pool."""
+ import psycopg2
+ import psycopg2.pool
+
+ self.config = config
+ self.schema = config.schema_name
+ self.user_name = config.user_name
+ self._pool_closed = False
+
+ logger.info(f"Connecting to PostgreSQL: {config.host}:{config.port}/{config.db_name}")
+
+ # Create connection pool
+ self.pool = psycopg2.pool.ThreadedConnectionPool(
+ minconn=2,
+ maxconn=config.maxconn,
+ host=config.host,
+ port=config.port,
+ user=config.user,
+ password=config.password,
+ dbname=config.db_name,
+ connect_timeout=30,
+ keepalives_idle=30,
+ keepalives_interval=10,
+ keepalives_count=5,
+ )
+
+ # Initialize schema and tables
+ self._init_schema()
+
+ def _get_conn(self):
+ """Get connection from pool with health check."""
+ if self._pool_closed:
+ raise RuntimeError("Connection pool is closed")
+
+ for attempt in range(3):
+ conn = None
+ try:
+ conn = self.pool.getconn()
+ if conn.closed != 0:
+ self.pool.putconn(conn, close=True)
+ continue
+ conn.autocommit = True
+ # Health check
+ with conn.cursor() as cur:
+ cur.execute("SELECT 1")
+ return conn
+ except Exception as e:
+ if conn:
+ with suppress(Exception):
+ self.pool.putconn(conn, close=True)
+ if attempt == 2:
+ raise RuntimeError(f"Failed to get connection: {e}") from e
+ time.sleep(0.1)
+ raise RuntimeError("Failed to get healthy connection")
+
+ def _put_conn(self, conn):
+ """Return connection to pool."""
+ if conn and not self._pool_closed:
+ try:
+ self.pool.putconn(conn)
+ except Exception:
+ with suppress(Exception):
+ conn.close()
+
+ def _init_schema(self):
+ """Create schema and tables if they don't exist."""
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ # Create schema
+ cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.schema}")
+
+ # Enable pgvector
+ cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
+
+ # Create memories table
+ dim = self.config.embedding_dimension
+ cur.execute(f"""
+ CREATE TABLE IF NOT EXISTS {self.schema}.memories (
+ id TEXT PRIMARY KEY,
+ memory TEXT NOT NULL DEFAULT '',
+ properties JSONB NOT NULL DEFAULT '{{}}',
+ embedding vector({dim}),
+ user_name TEXT,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW()
+ )
+ """)
+
+ # Create edges table
+ cur.execute(f"""
+ CREATE TABLE IF NOT EXISTS {self.schema}.edges (
+ id SERIAL PRIMARY KEY,
+ source_id TEXT NOT NULL,
+ target_id TEXT NOT NULL,
+ edge_type TEXT NOT NULL,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ UNIQUE(source_id, target_id, edge_type)
+ )
+ """)
+
+ # Create indexes
+ cur.execute(f"""
+ CREATE INDEX IF NOT EXISTS idx_memories_user
+ ON {self.schema}.memories(user_name)
+ """)
+ cur.execute(f"""
+ CREATE INDEX IF NOT EXISTS idx_memories_props
+ ON {self.schema}.memories USING GIN(properties)
+ """)
+ cur.execute(f"""
+ CREATE INDEX IF NOT EXISTS idx_memories_embedding
+ ON {self.schema}.memories USING ivfflat(embedding vector_cosine_ops)
+ WITH (lists = 100)
+ """)
+ cur.execute(f"""
+ CREATE INDEX IF NOT EXISTS idx_edges_source
+ ON {self.schema}.edges(source_id)
+ """)
+ cur.execute(f"""
+ CREATE INDEX IF NOT EXISTS idx_edges_target
+ ON {self.schema}.edges(target_id)
+ """)
+
+ logger.info(f"Schema {self.schema} initialized successfully")
+ except Exception as e:
+ logger.error(f"Failed to init schema: {e}")
+ raise
+ finally:
+ self._put_conn(conn)
+
+ # =========================================================================
+ # Node Management
+ # =========================================================================
+
+ def remove_oldest_memory(
+ self, memory_type: str, keep_latest: int, user_name: str | None = None
+ ) -> None:
+ """
+ Remove all memories of a given type except the latest `keep_latest` entries.
+
+ Args:
+ memory_type: Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
+ keep_latest: Number of latest entries to keep.
+ user_name: User to filter by.
+ """
+ user_name = user_name or self.user_name
+ keep_latest = int(keep_latest)
+
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ # Find IDs to delete (older than the keep_latest entries)
+ cur.execute(
+ f"""
+ WITH ranked AS (
+ SELECT id, ROW_NUMBER() OVER (ORDER BY updated_at DESC) as rn
+ FROM {self.schema}.memories
+ WHERE user_name = %s
+ AND properties->>'memory_type' = %s
+ )
+ SELECT id FROM ranked WHERE rn > %s
+ """,
+ (user_name, memory_type, keep_latest),
+ )
+
+ ids_to_delete = [row[0] for row in cur.fetchall()]
+
+ if ids_to_delete:
+ # Delete edges first
+ cur.execute(
+ f"""
+ DELETE FROM {self.schema}.edges
+ WHERE source_id = ANY(%s) OR target_id = ANY(%s)
+ """,
+ (ids_to_delete, ids_to_delete),
+ )
+
+ # Delete nodes
+ cur.execute(
+ f"""
+ DELETE FROM {self.schema}.memories
+ WHERE id = ANY(%s)
+ """,
+ (ids_to_delete,),
+ )
+
+ logger.info(
+ f"Removed {len(ids_to_delete)} oldest {memory_type} memories for user {user_name}"
+ )
+ finally:
+ self._put_conn(conn)
+
+ def add_node(
+ self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None
+ ) -> None:
+ """Add a memory node."""
+ user_name = user_name or self.user_name
+ metadata = _prepare_node_metadata(metadata.copy())
+
+ # Extract embedding
+ embedding = metadata.pop("embedding", None)
+ created_at = metadata.pop("created_at", datetime.utcnow().isoformat())
+ updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat())
+
+ # Serialize sources if present
+ if metadata.get("sources"):
+ metadata["sources"] = [
+ json.dumps(s) if not isinstance(s, str) else s for s in metadata["sources"]
+ ]
+
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ if embedding:
+ cur.execute(
+ f"""
+ INSERT INTO {self.schema}.memories
+ (id, memory, properties, embedding, user_name, created_at, updated_at)
+ VALUES (%s, %s, %s, %s::vector, %s, %s, %s)
+ ON CONFLICT (id) DO UPDATE SET
+ memory = EXCLUDED.memory,
+ properties = EXCLUDED.properties,
+ embedding = EXCLUDED.embedding,
+ updated_at = EXCLUDED.updated_at
+ """,
+ (
+ id,
+ memory,
+ json.dumps(metadata),
+ embedding,
+ user_name,
+ created_at,
+ updated_at,
+ ),
+ )
+ else:
+ cur.execute(
+ f"""
+ INSERT INTO {self.schema}.memories
+ (id, memory, properties, user_name, created_at, updated_at)
+ VALUES (%s, %s, %s, %s, %s, %s)
+ ON CONFLICT (id) DO UPDATE SET
+ memory = EXCLUDED.memory,
+ properties = EXCLUDED.properties,
+ updated_at = EXCLUDED.updated_at
+ """,
+ (id, memory, json.dumps(metadata), user_name, created_at, updated_at),
+ )
+ finally:
+ self._put_conn(conn)
+
+ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = None) -> None:
+ """Batch add memory nodes."""
+ for node in nodes:
+ self.add_node(
+ id=node["id"],
+ memory=node["memory"],
+ metadata=node.get("metadata", {}),
+ user_name=user_name,
+ )
+
+ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None:
+ """Update node fields."""
+ user_name = user_name or self.user_name
+ if not fields:
+ return
+
+ # Get current node
+ current = self.get_node(id, user_name=user_name)
+ if not current:
+ return
+
+ # Merge properties
+ props = current.get("metadata", {}).copy()
+ embedding = fields.pop("embedding", None)
+ memory = fields.pop("memory", current.get("memory", ""))
+ props.update(fields)
+ props["updated_at"] = datetime.utcnow().isoformat()
+
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ if embedding:
+ cur.execute(
+ f"""
+ UPDATE {self.schema}.memories
+ SET memory = %s, properties = %s, embedding = %s::vector, updated_at = NOW()
+ WHERE id = %s AND user_name = %s
+ """,
+ (memory, json.dumps(props), embedding, id, user_name),
+ )
+ else:
+ cur.execute(
+ f"""
+ UPDATE {self.schema}.memories
+ SET memory = %s, properties = %s, updated_at = NOW()
+ WHERE id = %s AND user_name = %s
+ """,
+ (memory, json.dumps(props), id, user_name),
+ )
+ finally:
+ self._put_conn(conn)
+
+ def delete_node(self, id: str, user_name: str | None = None) -> None:
+ """Delete a node and its edges."""
+ user_name = user_name or self.user_name
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ # Delete edges
+ cur.execute(
+ f"""
+ DELETE FROM {self.schema}.edges
+ WHERE source_id = %s OR target_id = %s
+ """,
+ (id, id),
+ )
+ # Delete node
+ cur.execute(
+ f"""
+ DELETE FROM {self.schema}.memories
+ WHERE id = %s AND user_name = %s
+ """,
+ (id, user_name),
+ )
+ finally:
+ self._put_conn(conn)
+
+ def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None:
+ """Get a single node by ID."""
+ user_name = kwargs.get("user_name") or self.user_name
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ cols = "id, memory, properties, created_at, updated_at"
+ if include_embedding:
+ cols += ", embedding"
+ cur.execute(
+ f"""
+ SELECT {cols} FROM {self.schema}.memories
+ WHERE id = %s AND user_name = %s
+ """,
+ (id, user_name),
+ )
+ row = cur.fetchone()
+ if not row:
+ return None
+ return self._parse_row(row, include_embedding)
+ finally:
+ self._put_conn(conn)
+
+ def get_nodes(
+ self, ids: list, include_embedding: bool = False, **kwargs
+ ) -> list[dict[str, Any]]:
+ """Get multiple nodes by IDs."""
+ if not ids:
+ return []
+ user_name = kwargs.get("user_name") or self.user_name
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ cols = "id, memory, properties, created_at, updated_at"
+ if include_embedding:
+ cols += ", embedding"
+ cur.execute(
+ f"""
+ SELECT {cols} FROM {self.schema}.memories
+ WHERE id = ANY(%s) AND user_name = %s
+ """,
+ (ids, user_name),
+ )
+ return [self._parse_row(row, include_embedding) for row in cur.fetchall()]
+ finally:
+ self._put_conn(conn)
+
+ def _parse_row(self, row, include_embedding: bool = False) -> dict[str, Any]:
+ """Parse database row to node dict."""
+ props = row[2] if isinstance(row[2], dict) else json.loads(row[2] or "{}")
+ props["created_at"] = row[3].isoformat() if row[3] else None
+ props["updated_at"] = row[4].isoformat() if row[4] else None
+ result = {
+ "id": row[0],
+ "memory": row[1] or "",
+ "metadata": props,
+ }
+ if include_embedding and len(row) > 5:
+ result["metadata"]["embedding"] = row[5]
+ return result
+
+ # =========================================================================
+ # Edge Management
+ # =========================================================================
+
+ def add_edge(
+ self, source_id: str, target_id: str, type: str, user_name: str | None = None
+ ) -> None:
+ """Create an edge between nodes."""
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ cur.execute(
+ f"""
+ INSERT INTO {self.schema}.edges (source_id, target_id, edge_type)
+ VALUES (%s, %s, %s)
+ ON CONFLICT (source_id, target_id, edge_type) DO NOTHING
+ """,
+ (source_id, target_id, type),
+ )
+ finally:
+ self._put_conn(conn)
+
+ def delete_edge(
+ self, source_id: str, target_id: str, type: str, user_name: str | None = None
+ ) -> None:
+ """Delete an edge."""
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ cur.execute(
+ f"""
+ DELETE FROM {self.schema}.edges
+ WHERE source_id = %s AND target_id = %s AND edge_type = %s
+ """,
+ (source_id, target_id, type),
+ )
+ finally:
+ self._put_conn(conn)
+
+ def edge_exists(self, source_id: str, target_id: str, type: str) -> bool:
+ """Check if edge exists."""
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ cur.execute(
+ f"""
+ SELECT 1 FROM {self.schema}.edges
+ WHERE source_id = %s AND target_id = %s AND edge_type = %s
+ LIMIT 1
+ """,
+ (source_id, target_id, type),
+ )
+ return cur.fetchone() is not None
+ finally:
+ self._put_conn(conn)
+
+ # =========================================================================
+ # Graph Queries
+ # =========================================================================
+
+ def get_neighbors(
+ self, id: str, type: str, direction: Literal["in", "out", "both"] = "out"
+ ) -> list[str]:
+ """Get neighboring node IDs."""
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ if direction == "out":
+ cur.execute(
+ f"""
+ SELECT target_id FROM {self.schema}.edges
+ WHERE source_id = %s AND edge_type = %s
+ """,
+ (id, type),
+ )
+ elif direction == "in":
+ cur.execute(
+ f"""
+ SELECT source_id FROM {self.schema}.edges
+ WHERE target_id = %s AND edge_type = %s
+ """,
+ (id, type),
+ )
+ else: # both
+ cur.execute(
+ f"""
+ SELECT target_id FROM {self.schema}.edges WHERE source_id = %s AND edge_type = %s
+ UNION
+ SELECT source_id FROM {self.schema}.edges WHERE target_id = %s AND edge_type = %s
+ """,
+ (id, type, id, type),
+ )
+ return [row[0] for row in cur.fetchall()]
+ finally:
+ self._put_conn(conn)
+
+ def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]:
+ """Get path between nodes using recursive CTE."""
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ cur.execute(
+ f"""
+ WITH RECURSIVE path AS (
+ SELECT source_id, target_id, ARRAY[source_id] as nodes, 1 as depth
+ FROM {self.schema}.edges
+ WHERE source_id = %s
+ UNION ALL
+ SELECT e.source_id, e.target_id, p.nodes || e.source_id, p.depth + 1
+ FROM {self.schema}.edges e
+ JOIN path p ON e.source_id = p.target_id
+ WHERE p.depth < %s AND NOT e.source_id = ANY(p.nodes)
+ )
+ SELECT nodes || target_id as full_path
+ FROM path
+ WHERE target_id = %s
+ ORDER BY depth
+ LIMIT 1
+ """,
+ (source_id, max_depth, target_id),
+ )
+ row = cur.fetchone()
+ return row[0] if row else []
+ finally:
+ self._put_conn(conn)
+
+ def get_subgraph(self, center_id: str, depth: int = 2) -> list[str]:
+ """Get subgraph around center node."""
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ cur.execute(
+ f"""
+ WITH RECURSIVE subgraph AS (
+ SELECT %s::text as node_id, 0 as level
+ UNION
+ SELECT CASE WHEN e.source_id = s.node_id THEN e.target_id ELSE e.source_id END,
+ s.level + 1
+ FROM {self.schema}.edges e
+ JOIN subgraph s ON (e.source_id = s.node_id OR e.target_id = s.node_id)
+ WHERE s.level < %s
+ )
+ SELECT DISTINCT node_id FROM subgraph
+ """,
+ (center_id, depth),
+ )
+ return [row[0] for row in cur.fetchall()]
+ finally:
+ self._put_conn(conn)
+
+ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
+ """Get ordered chain following relationship type."""
+ return self.get_neighbors(id, type, "out")
+
+ # =========================================================================
+ # Search Operations
+ # =========================================================================
+
+ def search_by_embedding(
+ self,
+ vector: list[float],
+ top_k: int = 5,
+ scope: str | None = None,
+ status: str | None = None,
+ threshold: float | None = None,
+ search_filter: dict | None = None,
+ user_name: str | None = None,
+ filter: dict | None = None,
+ knowledgebase_ids: list[str] | None = None,
+ **kwargs,
+ ) -> list[dict]:
+ """Search nodes by vector similarity using pgvector."""
+ user_name = user_name or self.user_name
+
+ # Build WHERE clause
+ conditions = ["embedding IS NOT NULL"]
+ params = []
+
+ if user_name:
+ conditions.append("user_name = %s")
+ params.append(user_name)
+
+ if scope:
+ conditions.append("properties->>'memory_type' = %s")
+ params.append(scope)
+
+ if status:
+ conditions.append("properties->>'status' = %s")
+ params.append(status)
+ else:
+ conditions.append(
+ "(properties->>'status' = 'activated' OR properties->>'status' IS NULL)"
+ )
+
+ if search_filter:
+ for k, v in search_filter.items():
+ conditions.append(f"properties->>'{k}' = %s")
+ params.append(str(v))
+
+ where_clause = " AND ".join(conditions)
+
+ # pgvector cosine distance: 1 - (a <=> b) gives similarity score
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ cur.execute(
+ f"""
+ SELECT id, 1 - (embedding <=> %s::vector) as score
+ FROM {self.schema}.memories
+ WHERE {where_clause}
+ ORDER BY embedding <=> %s::vector
+ LIMIT %s
+ """,
+ (vector, *params, vector, top_k),
+ )
+
+ results = []
+ for row in cur.fetchall():
+ score = float(row[1])
+ if threshold is None or score >= threshold:
+ results.append({"id": row[0], "score": score})
+ return results
+ finally:
+ self._put_conn(conn)
+
+ def get_by_metadata(
+ self,
+ filters: list[dict[str, Any]],
+ status: str | None = None,
+ user_name: str | None = None,
+ filter: dict | None = None,
+ knowledgebase_ids: list[str] | None = None,
+ user_name_flag: bool = True,
+ ) -> list[str]:
+ """Get node IDs matching metadata filters."""
+ user_name = user_name or self.user_name
+
+ conditions = []
+ params = []
+
+ if user_name_flag and user_name:
+ conditions.append("user_name = %s")
+ params.append(user_name)
+
+ if status:
+ conditions.append("properties->>'status' = %s")
+ params.append(status)
+
+ for f in filters:
+ field = f["field"]
+ op = f.get("op", "=")
+ value = f["value"]
+
+ if op == "=":
+ conditions.append(f"properties->>'{field}' = %s")
+ params.append(str(value))
+ elif op == "in":
+ placeholders = ",".join(["%s"] * len(value))
+ conditions.append(f"properties->>'{field}' IN ({placeholders})")
+ params.extend([str(v) for v in value])
+ elif op in (">", ">=", "<", "<="):
+ conditions.append(f"(properties->>'{field}')::numeric {op} %s")
+ params.append(value)
+ elif op == "contains":
+ conditions.append(f"properties->'{field}' @> %s::jsonb")
+ params.append(json.dumps([value]))
+
+ where_clause = " AND ".join(conditions) if conditions else "TRUE"
+
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ cur.execute(
+ f"""
+ SELECT id FROM {self.schema}.memories
+ WHERE {where_clause}
+ """,
+ params,
+ )
+ return [row[0] for row in cur.fetchall()]
+ finally:
+ self._put_conn(conn)
+
+ def get_all_memory_items(
+ self,
+ scope: str,
+ include_embedding: bool = False,
+ status: str | None = None,
+ filter: dict | None = None,
+ knowledgebase_ids: list[str] | None = None,
+ **kwargs,
+ ) -> list[dict]:
+ """Get all memory items of a specific type."""
+ user_name = kwargs.get("user_name") or self.user_name
+
+ conditions = ["properties->>'memory_type' = %s", "user_name = %s"]
+ params = [scope, user_name]
+
+ if status:
+ conditions.append("properties->>'status' = %s")
+ params.append(status)
+
+ where_clause = " AND ".join(conditions)
+
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ cols = "id, memory, properties, created_at, updated_at"
+ if include_embedding:
+ cols += ", embedding"
+ cur.execute(
+ f"""
+ SELECT {cols} FROM {self.schema}.memories
+ WHERE {where_clause}
+ """,
+ params,
+ )
+ return [self._parse_row(row, include_embedding) for row in cur.fetchall()]
+ finally:
+ self._put_conn(conn)
+
+ def get_structure_optimization_candidates(
+ self, scope: str, include_embedding: bool = False
+ ) -> list[dict]:
+ """Find isolated nodes (no edges)."""
+ user_name = self.user_name
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ cols = "m.id, m.memory, m.properties, m.created_at, m.updated_at"
+ cur.execute(
+ f"""
+ SELECT {cols}
+ FROM {self.schema}.memories m
+ LEFT JOIN {self.schema}.edges e1 ON m.id = e1.source_id
+ LEFT JOIN {self.schema}.edges e2 ON m.id = e2.target_id
+ WHERE m.properties->>'memory_type' = %s
+ AND m.user_name = %s
+ AND m.properties->>'status' = 'activated'
+ AND e1.id IS NULL
+ AND e2.id IS NULL
+ """,
+ (scope, user_name),
+ )
+ return [self._parse_row(row, False) for row in cur.fetchall()]
+ finally:
+ self._put_conn(conn)
+
+ # =========================================================================
+ # Maintenance
+ # =========================================================================
+
+ def deduplicate_nodes(self) -> None:
+ """Not implemented - handled at application level."""
+
+ def get_grouped_counts(
+ self,
+ group_fields: list[str],
+ where_clause: str = "",
+ params: dict[str, Any] | None = None,
+ user_name: str | None = None,
+ ) -> list[dict[str, Any]]:
+ """
+ Count nodes grouped by specified fields.
+
+ Args:
+ group_fields: Fields to group by, e.g., ["memory_type", "status"]
+ where_clause: Extra WHERE condition
+ params: Parameters for WHERE clause
+ user_name: User to filter by
+
+ Returns:
+ list[dict]: e.g., [{'memory_type': 'WorkingMemory', 'count': 10}, ...]
+ """
+ user_name = user_name or self.user_name
+ if not group_fields:
+ raise ValueError("group_fields cannot be empty")
+
+ # Build SELECT and GROUP BY clauses
+ # Fields come from JSONB properties column
+ select_fields = ", ".join([f"properties->>'{field}' AS {field}" for field in group_fields])
+ group_by = ", ".join([f"properties->>'{field}'" for field in group_fields])
+
+ # Build WHERE clause
+ conditions = ["user_name = %s"]
+ query_params = [user_name]
+
+ if where_clause:
+ # Parse simple where clause format
+ where_clause = where_clause.strip()
+ if where_clause.upper().startswith("WHERE"):
+ where_clause = where_clause[5:].strip()
+ if where_clause:
+ conditions.append(where_clause)
+ if params:
+ query_params.extend(params.values())
+
+ where_sql = " AND ".join(conditions)
+
+ query = f"""
+ SELECT {select_fields}, COUNT(*) AS count
+ FROM {self.schema}.memories
+ WHERE {where_sql}
+ GROUP BY {group_by}
+ """
+
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ cur.execute(query, query_params)
+ results = []
+ for row in cur.fetchall():
+ result = {}
+ for i, field in enumerate(group_fields):
+ result[field] = row[i]
+ result["count"] = row[len(group_fields)]
+ results.append(result)
+ return results
+ finally:
+ self._put_conn(conn)
+
+ def detect_conflicts(self) -> list[tuple[str, str]]:
+ """Not implemented."""
+ return []
+
+ def merge_nodes(self, id1: str, id2: str) -> str:
+ """Not implemented."""
+ raise NotImplementedError
+
+ def clear(self, user_name: str | None = None) -> None:
+ """Clear all data for user."""
+ user_name = user_name or self.user_name
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ # Get all node IDs for user
+ cur.execute(
+ f"""
+ SELECT id FROM {self.schema}.memories WHERE user_name = %s
+ """,
+ (user_name,),
+ )
+ ids = [row[0] for row in cur.fetchall()]
+
+ if ids:
+ # Delete edges
+ cur.execute(
+ f"""
+ DELETE FROM {self.schema}.edges
+ WHERE source_id = ANY(%s) OR target_id = ANY(%s)
+ """,
+ (ids, ids),
+ )
+
+ # Delete nodes
+ cur.execute(
+ f"""
+ DELETE FROM {self.schema}.memories WHERE user_name = %s
+ """,
+ (user_name,),
+ )
+ logger.info(f"Cleared all data for user {user_name}")
+ finally:
+ self._put_conn(conn)
+
+ def export_graph(self, include_embedding: bool = False, **kwargs) -> dict[str, Any]:
+ """Export all data."""
+ user_name = kwargs.get("user_name") or self.user_name
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ # Get nodes
+ cols = "id, memory, properties, created_at, updated_at"
+ if include_embedding:
+ cols += ", embedding"
+ cur.execute(
+ f"""
+ SELECT {cols} FROM {self.schema}.memories
+ WHERE user_name = %s
+ ORDER BY created_at DESC
+ """,
+ (user_name,),
+ )
+ nodes = [self._parse_row(row, include_embedding) for row in cur.fetchall()]
+
+ # Get edges
+ node_ids = [n["id"] for n in nodes]
+ if node_ids:
+ cur.execute(
+ f"""
+ SELECT source_id, target_id, edge_type
+ FROM {self.schema}.edges
+ WHERE source_id = ANY(%s) OR target_id = ANY(%s)
+ """,
+ (node_ids, node_ids),
+ )
+ edges = [
+ {"source": row[0], "target": row[1], "type": row[2]}
+ for row in cur.fetchall()
+ ]
+ else:
+ edges = []
+
+ return {
+ "nodes": nodes,
+ "edges": edges,
+ "total_nodes": len(nodes),
+ "total_edges": len(edges),
+ }
+ finally:
+ self._put_conn(conn)
+
+ def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None:
+ """Import graph data."""
+ user_name = user_name or self.user_name
+
+ for node in data.get("nodes", []):
+ self.add_node(
+ id=node["id"],
+ memory=node.get("memory", ""),
+ metadata=node.get("metadata", {}),
+ user_name=user_name,
+ )
+
+ for edge in data.get("edges", []):
+ self.add_edge(
+ source_id=edge["source"],
+ target_id=edge["target"],
+ type=edge["type"],
+ )
+
+ def close(self):
+ """Close connection pool."""
+ if not self._pool_closed:
+ self._pool_closed = True
+ self.pool.closeall()
diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py
index e38318a64..6c6d1821f 100644
--- a/src/memos/mem_feedback/feedback.py
+++ b/src/memos/mem_feedback/feedback.py
@@ -235,20 +235,16 @@ def _single_add_operation(
to_add_memory.metadata.tags = new_memory_item.metadata.tags
to_add_memory.memory = new_memory_item.memory
to_add_memory.metadata.embedding = new_memory_item.metadata.embedding
-
to_add_memory.metadata.user_id = new_memory_item.metadata.user_id
- to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = (
- datetime.now().isoformat()
- )
- to_add_memory.metadata.background = new_memory_item.metadata.background
else:
to_add_memory = new_memory_item.model_copy(deep=True)
- to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = (
- datetime.now().isoformat()
- )
- to_add_memory.metadata.background = new_memory_item.metadata.background
- to_add_memory.id = ""
+ to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = (
+ datetime.now().isoformat()
+ )
+ to_add_memory.metadata.background = new_memory_item.metadata.background
+ to_add_memory.metadata.sources = []
+
added_ids = self._retry_db_operation(
lambda: self.memory_manager.add([to_add_memory], user_name=user_name, use_batch=False)
)
@@ -626,10 +622,39 @@ def _info_comparison(self, memory: TextualMemoryItem, _info: dict, include_keys:
def _retrieve(self, query: str, info=None, top_k=20, user_name=None):
"""Retrieve memory items"""
- retrieved_mems = self.searcher.search(
- query, info=info, user_name=user_name, top_k=top_k, full_recall=True
+
+ def check_has_edges(mem_item: TextualMemoryItem) -> tuple[TextualMemoryItem, bool]:
+ """Check if a memory item has edges."""
+ edges = self.searcher.graph_store.get_edges(mem_item.id, user_name=user_name)
+ return (mem_item, len(edges) == 0)
+
+ text_mems = self.searcher.search(
+ query,
+ info=info,
+ memory_type="AllSummaryMemory",
+ user_name=user_name,
+ top_k=top_k,
+ full_recall=True,
)
- retrieved_mems = [item[0] for item in retrieved_mems if float(item[1]) > 0.01]
+ text_mems = [item[0] for item in text_mems if float(item[1]) > 0.01]
+
+ # Memory with edges is not modified by feedback
+ retrieved_mems = []
+ with ContextThreadPoolExecutor(max_workers=10) as executor:
+ futures = {executor.submit(check_has_edges, item): item for item in text_mems}
+ for future in concurrent.futures.as_completed(futures):
+ try:
+ mem_item, has_no_edges = future.result()
+ if has_no_edges:
+ retrieved_mems.append(mem_item)
+ except Exception as e:
+ logger.error(f"[0107 Feedback Core: _retrieve] Error checking edges: {e}")
+
+ if len(retrieved_mems) < len(text_mems):
+ logger.info(
+ f"[0107 Feedback Core: _retrieve] {len(text_mems) - len(retrieved_mems)} "
+ f"text memories are not modified by feedback due to edges."
+ )
if self.pref_feedback:
pref_info = {}
@@ -924,7 +949,7 @@ def process_keyword_replace(
)
must_part = f"{' & '.join(queries)}" if len(queries) > 1 else queries[0]
- retrieved_ids = self.graph_store.seach_by_keywords_tfidf(
+ retrieved_ids = self.graph_store.search_by_keywords_tfidf(
[must_part], user_name=user_name, filter=filter_dict
)
if len(retrieved_ids) < 1:
@@ -932,7 +957,7 @@ def process_keyword_replace(
queries, top_k=100, user_name=user_name, filter=filter_dict
)
else:
- retrieved_ids = self.graph_store.seach_by_keywords_like(
+ retrieved_ids = self.graph_store.search_by_keywords_like(
f"%{original_word}%", user_name=user_name, filter=filter_dict
)
@@ -1165,7 +1190,7 @@ def process_feedback(
info,
**kwargs,
)
- done, pending = concurrent.futures.wait([answer_future, core_future], timeout=30)
+ _done, pending = concurrent.futures.wait([answer_future, core_future], timeout=30)
for fut in pending:
fut.cancel()
try:
diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py
index 77a5e70c9..b2c74c384 100644
--- a/src/memos/mem_os/product.py
+++ b/src/memos/mem_os/product.py
@@ -798,12 +798,14 @@ def run_async_in_thread():
)
# Add exception handling for the background task
task.add_done_callback(
- lambda t: logger.error(
- f"Error in background post-chat processing for user {user_id}: {t.exception()}",
- exc_info=True,
+ lambda t: (
+ logger.error(
+ f"Error in background post-chat processing for user {user_id}: {t.exception()}",
+ exc_info=True,
+ )
+ if t.exception()
+ else None
)
- if t.exception()
- else None
)
except RuntimeError:
# No event loop, run in a new thread with context propagation
diff --git a/src/memos/mem_os/product_server.py b/src/memos/mem_os/product_server.py
index 758f2794d..80aefea85 100644
--- a/src/memos/mem_os/product_server.py
+++ b/src/memos/mem_os/product_server.py
@@ -437,12 +437,14 @@ def run_async_in_thread():
)
# Add exception handling for the background task
task.add_done_callback(
- lambda t: logger.error(
- f"Error in background post-chat processing for user {user_id}: {t.exception()}",
- exc_info=True,
+ lambda t: (
+ logger.error(
+ f"Error in background post-chat processing for user {user_id}: {t.exception()}",
+ exc_info=True,
+ )
+ if t.exception()
+ else None
)
- if t.exception()
- else None
)
except RuntimeError:
# No event loop, run in a new thread with context propagation
diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py
index f6a016556..8b0968ca1 100644
--- a/src/memos/mem_reader/multi_modal_struct.py
+++ b/src/memos/mem_reader/multi_modal_struct.py
@@ -3,7 +3,7 @@
import re
import traceback
-from typing import Any
+from typing import TYPE_CHECKING, Any
from memos import log
from memos.configs.mem_reader import MultiModalStructMemReaderConfig
@@ -20,6 +20,10 @@
from memos.utils import timed
+if TYPE_CHECKING:
+ from memos.types.general_types import UserContext
+
+
logger = log.get_logger(__name__)
@@ -108,10 +112,10 @@ def _create_chunk_item(chunk):
)
return split_item
- # Use thread pool to parallel process chunks
+ # Use thread pool to parallel process chunks, but keep the original order
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(_create_chunk_item, chunk) for chunk in chunks]
- for future in concurrent.futures.as_completed(futures):
+ for future in futures:
split_item = future.result()
if split_item is not None:
split_items.append(split_item)
@@ -146,26 +150,33 @@ def _concat_multi_modal_memories(
parallel_chunking = True
if parallel_chunking:
- # parallel chunk large memory items
+ # parallel chunk large memory items, but keep the original order
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
- future_to_item = {
- executor.submit(self._split_large_memory_item, item, max_tokens): item
- for item in all_memory_items
- if (item.memory or "") and self._count_tokens(item.memory) > max_tokens
- }
- processed_items.extend(
- [
- item
- for item in all_memory_items
- if not (
- (item.memory or "") and self._count_tokens(item.memory) > max_tokens
- )
- ]
- )
- # collect split items from futures
- for future in concurrent.futures.as_completed(future_to_item):
- split_items = future.result()
- processed_items.extend(split_items)
+ # Create a list to hold futures with their original index
+ futures = []
+ for idx, item in enumerate(all_memory_items):
+ if (item.memory or "") and self._count_tokens(item.memory) > max_tokens:
+ future = executor.submit(self._split_large_memory_item, item, max_tokens)
+ futures.append(
+ (idx, future, True)
+ ) # True indicates this item needs splitting
+ else:
+ futures.append((idx, item, False)) # False indicates no splitting needed
+
+ # Process results in original order
+ temp_results = [None] * len(all_memory_items)
+ for idx, future_or_item, needs_splitting in futures:
+ if needs_splitting:
+ # Wait for the future to complete and get the split items
+ split_items = future_or_item.result()
+ temp_results[idx] = split_items
+ else:
+ # No splitting needed, use the original item
+ temp_results[idx] = [future_or_item]
+
+ # Flatten the results while preserving order
+ for items in temp_results:
+ processed_items.extend(items)
else:
# serial chunk large memory items
for item in all_memory_items:
@@ -277,6 +288,7 @@ def _build_window_from_items(
# Collect all memory texts and sources
memory_texts = []
all_sources = []
+ seen_content = set() # Track seen source content to avoid duplicates
roles = set()
aggregated_file_ids: list[str] = []
@@ -290,8 +302,18 @@ def _build_window_from_items(
item_sources = [item_sources]
for source in item_sources:
- # Add source to all_sources
- all_sources.append(source)
+ # Get content from source for deduplication
+ source_content = None
+ if isinstance(source, dict):
+ source_content = source.get("content", "")
+ else:
+ source_content = getattr(source, "content", "") or ""
+
+ # Only add if content is different (empty content is considered unique)
+ content_key = source_content if source_content else None
+ if content_key and content_key not in seen_content:
+ seen_content.add(content_key)
+ all_sources.append(source)
# Extract role from source
if hasattr(source, "role") and source.role:
@@ -453,7 +475,10 @@ def _determine_prompt_type(self, sources: list) -> str:
source_role = source.get("role")
if source_role in {"user", "assistant", "system", "tool"}:
prompt_type = "chat"
-
+ if hasattr(source, "type"):
+ source_type = source.type
+ if source_type == "file":
+ prompt_type = "doc"
return prompt_type
def _get_maybe_merged_memory(
@@ -630,11 +655,14 @@ def _process_string_fine(
) -> list[TextualMemoryItem]:
"""
Process fast mode memory items through LLM to generate fine mode memories.
+ Where fast_memory_items are raw chunk memory items, not the final memory items.
"""
if not fast_memory_items:
return []
- def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]:
+ def _process_one_item(
+ fast_item: TextualMemoryItem, chunk_idx: int, total_chunks: int
+ ) -> list[TextualMemoryItem]:
"""Process a single fast memory item and return a list of fine items."""
fine_items: list[TextualMemoryItem] = []
@@ -660,6 +688,12 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]:
if file_ids:
extra_kwargs["file_ids"] = file_ids
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ if user_context:
+ extra_kwargs["manager_user_id"] = user_context.manager_user_id
+ extra_kwargs["project_id"] = user_context.project_id
+
# Determine prompt type based on sources
prompt_type = self._determine_prompt_type(sources)
@@ -732,12 +766,40 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]:
except Exception as e:
logger.error(f"[MultiModalFine] parse error: {e}")
+ # save rawfile node
+ if self.save_rawfile and prompt_type == "doc" and len(fine_items) > 0:
+ rawfile_chunk = mem_str
+ file_info = fine_items[0].metadata.sources[0].file_info
+ source = self.multi_modal_parser.file_content_parser.create_source(
+ message={"file": file_info},
+ info=info_per_item,
+ chunk_index=chunk_idx,
+ chunk_total=total_chunks,
+ chunk_content="",
+ )
+ rawfile_node = self._make_memory_item(
+ value=rawfile_chunk,
+ info=info_per_item,
+ memory_type="RawFileMemory",
+ tags=[
+ "mode:fine",
+ "multimodal:file",
+ f"chunk:{chunk_idx + 1}/{total_chunks}",
+ ],
+ sources=[source],
+ )
+ rawfile_node.metadata.summary_ids = [mem_node.id for mem_node in fine_items]
+ fine_items.append(rawfile_node)
return fine_items
fine_memory_items: list[TextualMemoryItem] = []
+ total_chunks_len = len(fast_memory_items)
with ContextThreadPoolExecutor(max_workers=30) as executor:
- futures = [executor.submit(_process_one_item, item) for item in fast_memory_items]
+ futures = [
+ executor.submit(_process_one_item, item, idx, total_chunks_len)
+ for idx, item in enumerate[TextualMemoryItem](fast_memory_items)
+ ]
for future in concurrent.futures.as_completed(futures):
try:
@@ -747,6 +809,63 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]:
except Exception as e:
logger.error(f"[MultiModalFine] worker error: {e}")
+ # related preceding and following rawfilememories
+ fine_memory_items = self._relate_preceding_following_rawfile_memories(fine_memory_items)
+ return fine_memory_items
+
+ def _relate_preceding_following_rawfile_memories(
+ self, fine_memory_items: list[TextualMemoryItem]
+ ) -> list[TextualMemoryItem]:
+ """
+ Relate RawFileMemory items to each other by setting preceding_id and following_id.
+ """
+ # Filter RawFileMemory items and track their original positions
+ rawfile_items_with_pos = []
+ for idx, item in enumerate[TextualMemoryItem](fine_memory_items):
+ if (
+ hasattr(item.metadata, "memory_type")
+ and item.metadata.memory_type == "RawFileMemory"
+ ):
+ rawfile_items_with_pos.append((idx, item))
+
+ if len(rawfile_items_with_pos) <= 1:
+ return fine_memory_items
+
+ def get_chunk_idx(item_with_pos) -> int:
+ """Extract chunk_idx from item's source metadata."""
+ _, item = item_with_pos
+ if item.metadata.sources and len(item.metadata.sources) > 0:
+ source = item.metadata.sources[0]
+ # Handle both SourceMessage object and dict
+ if isinstance(source, dict):
+ file_info = source.get("file_info")
+ if file_info and isinstance(file_info, dict):
+ chunk_idx = file_info.get("chunk_index")
+ if chunk_idx is not None:
+ return chunk_idx
+ else:
+ # SourceMessage object
+ file_info = getattr(source, "file_info", None)
+ if file_info and isinstance(file_info, dict):
+ chunk_idx = file_info.get("chunk_index")
+ if chunk_idx is not None:
+ return chunk_idx
+ return float("inf")
+
+ # Sort items by chunk_index
+ sorted_rawfile_items_with_pos = sorted(rawfile_items_with_pos, key=get_chunk_idx)
+
+ # Relate adjacent items
+ for i in range(len(sorted_rawfile_items_with_pos) - 1):
+ _, current_item = sorted_rawfile_items_with_pos[i]
+ _, next_item = sorted_rawfile_items_with_pos[i + 1]
+ current_item.metadata.following_id = next_item.id
+ next_item.metadata.preceding_id = current_item.id
+
+ # Replace sorted items back to original positions in fine_memory_items
+ for orig_idx, item in sorted_rawfile_items_with_pos:
+ fine_memory_items[orig_idx] = item
+
return fine_memory_items
def _get_llm_tool_trajectory_response(self, mem_str: str) -> dict:
@@ -775,6 +894,11 @@ def _process_tool_trajectory_fine(
fine_memory_items = []
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ manager_user_id = user_context.manager_user_id if user_context else None
+ project_id = user_context.project_id if user_context else None
+
for fast_item in fast_memory_items:
# Extract memory text (string content)
mem_str = fast_item.memory or ""
@@ -801,6 +925,8 @@ def _process_tool_trajectory_fine(
correctness=m.get("correctness", ""),
experience=m.get("experience", ""),
tool_used_status=m.get("tool_used_status", []),
+ manager_user_id=manager_user_id,
+ project_id=project_id,
)
fine_memory_items.append(node)
except Exception as e:
@@ -831,8 +957,9 @@ def _process_multi_modal_data(
if isinstance(scene_data_info, list):
# Parse each message in the list
all_memory_items = []
- # Use thread pool to parse each message in parallel
+ # Use thread pool to parse each message in parallel, but keep the original order
with ContextThreadPoolExecutor(max_workers=30) as executor:
+ # submit tasks and keep the original order
futures = [
executor.submit(
self.multi_modal_parser.parse,
@@ -844,7 +971,8 @@ def _process_multi_modal_data(
)
for msg in scene_data_info
]
- for future in concurrent.futures.as_completed(futures):
+ # collect results in original order
+ for future in futures:
try:
items = future.result()
all_memory_items.extend(items)
diff --git a/src/memos/mem_reader/read_multi_modal/assistant_parser.py b/src/memos/mem_reader/read_multi_modal/assistant_parser.py
index 89d4fec7f..bac9deaad 100644
--- a/src/memos/mem_reader/read_multi_modal/assistant_parser.py
+++ b/src/memos/mem_reader/read_multi_modal/assistant_parser.py
@@ -2,7 +2,7 @@
import json
-from typing import Any
+from typing import TYPE_CHECKING, Any
from memos.embedders.base import BaseEmbedder
from memos.llms.base import BaseLLM
@@ -18,6 +18,10 @@
from .utils import detect_lang
+if TYPE_CHECKING:
+ from memos.types.general_types import UserContext
+
+
logger = get_logger(__name__)
@@ -281,6 +285,11 @@ def parse_fast(
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ manager_user_id = user_context.manager_user_id if user_context else None
+ project_id = user_context.project_id if user_context else None
+
# Create memory item (equivalent to _make_memory_item)
memory_item = TextualMemoryItem(
memory=line,
@@ -298,6 +307,8 @@ def parse_fast(
confidence=0.99,
type="fact",
info=info_,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
),
)
diff --git a/src/memos/mem_reader/read_multi_modal/base.py b/src/memos/mem_reader/read_multi_modal/base.py
index 95d427864..737a3fe1e 100644
--- a/src/memos/mem_reader/read_multi_modal/base.py
+++ b/src/memos/mem_reader/read_multi_modal/base.py
@@ -15,6 +15,7 @@
TextualMemoryItem,
TreeNodeTextualMemoryMetadata,
)
+from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
from memos.utils import timed
from .utils import detect_lang, get_text_splitter
@@ -90,6 +91,7 @@ def __init__(self, embedder, llm=None):
"""
self.embedder = embedder
self.llm = llm
+ self.tokenizer = FastTokenizer(use_jieba=True, use_stopwords=True)
@abstractmethod
def create_source(
diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py
index fbc704d0b..2b49d63ba 100644
--- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py
+++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py
@@ -5,7 +5,7 @@
import re
import tempfile
-from typing import Any
+from typing import TYPE_CHECKING, Any
from tqdm import tqdm
@@ -34,6 +34,10 @@
from memos.types.openai_chat_completion_types import File
+if TYPE_CHECKING:
+ from memos.types.general_types import UserContext
+
+
logger = get_logger(__name__)
# Prompt dictionary for doc processing (shared by simple_struct and file_content_parser)
@@ -465,6 +469,11 @@ def parse_fast(
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ manager_user_id = user_context.manager_user_id if user_context else None
+ project_id = user_context.project_id if user_context else None
+
# For file content parts, default to LongTermMemory
# (since we don't have role information at this level)
memory_type = "LongTermMemory"
@@ -509,6 +518,8 @@ def parse_fast(
type="fact",
info=info_,
file_ids=file_ids,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
),
)
memory_items.append(memory_item)
@@ -541,6 +552,8 @@ def parse_fast(
type="fact",
info=info_,
file_ids=file_ids,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
),
)
memory_items.append(memory_item)
@@ -658,6 +671,12 @@ def parse_fine(
info_ = info.copy()
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
+
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ manager_user_id = user_context.manager_user_id if user_context else None
+ project_id = user_context.project_id if user_context else None
+
if file_id:
info_["file_id"] = file_id
file_ids = [file_id] if file_id else []
@@ -716,6 +735,8 @@ def _make_memory_item(
type="fact",
info=info_,
file_ids=file_ids,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
),
)
@@ -724,7 +745,7 @@ def _make_fallback(
chunk_idx: int, chunk_text: str, reason: str = "raw"
) -> TextualMemoryItem:
"""Create fallback memory item with raw chunk text."""
- return _make_memory_item(
+ raw_chunk_mem = _make_memory_item(
value=chunk_text,
tags=[
"mode:fine",
@@ -735,6 +756,11 @@ def _make_fallback(
chunk_idx=chunk_idx,
chunk_content=chunk_text,
)
+ tags_list = self.tokenizer.tokenize_mixed(raw_chunk_mem.metadata.key)
+ tags_list = [tag for tag in tags_list if len(tag) > 1]
+ tags_list = sorted(tags_list, key=len, reverse=True)
+ raw_chunk_mem.metadata.tags.extend(tags_list[:5])
+ return raw_chunk_mem
# Handle empty chunks case
if not valid_chunks:
@@ -781,10 +807,36 @@ def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem:
logger.warning(f"[FileContentParser] Fallback to raw for chunk {chunk_idx}")
return _make_fallback(chunk_idx, chunk_text)
+ def _relate_chunks(items: list[TextualMemoryItem]) -> None:
+ """
+ Relate chunks to each other.
+ """
+ if len(items) <= 1:
+ return []
+
+ def get_chunk_idx(item: TextualMemoryItem) -> int:
+ """Extract chunk_idx from item's source metadata."""
+ if item.metadata.sources and len(item.metadata.sources) > 0:
+ source = item.metadata.sources[0]
+ if source.file_info and isinstance(source.file_info, dict):
+ chunk_idx = source.file_info.get("chunk_index")
+ if chunk_idx is not None:
+ return chunk_idx
+ return float("inf")
+
+ sorted_items = sorted(items, key=get_chunk_idx)
+
+ # Relate adjacent items
+ for i in range(len(sorted_items) - 1):
+ sorted_items[i].metadata.following_id = sorted_items[i + 1].id
+ sorted_items[i + 1].metadata.preceding_id = sorted_items[i].id
+ return sorted_items
+
# Process chunks concurrently with progress bar
memory_items = []
chunk_map = dict(valid_chunks)
total_chunks = len(valid_chunks)
+ fallback_count = 0
logger.info(f"[FileContentParser] Processing {total_chunks} chunks with LLM...")
@@ -802,20 +854,53 @@ def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem:
chunk_idx = futures[future]
try:
node = future.result()
- if node:
- memory_items.append(node)
+ memory_items.append(node)
+
+ # Check if this node is a fallback by checking tags
+ is_fallback = any(tag.startswith("fallback:") for tag in node.metadata.tags)
+ if is_fallback:
+ fallback_count += 1
+
+ # save raw file
+ node_id = node.id
+ if node.memory != node.metadata.sources[0].content:
+ chunk_node = _make_memory_item(
+ value=node.metadata.sources[0].content,
+ mem_type="RawFileMemory",
+ tags=[
+ "mode:fine",
+ "multimodal:file",
+ f"chunk:{chunk_idx + 1}/{total_chunks}",
+ ],
+ chunk_idx=chunk_idx,
+ chunk_content="",
+ )
+ chunk_node.metadata.summary_ids = [node_id]
+ memory_items.append(chunk_node)
+
except Exception as e:
tqdm.write(f"[ERROR] Chunk {chunk_idx} failed: {e}")
logger.error(f"[FileContentParser] Future failed for chunk {chunk_idx}: {e}")
# Create fallback for failed future
if chunk_idx in chunk_map:
+ fallback_count += 1
memory_items.append(
_make_fallback(chunk_idx, chunk_map[chunk_idx], "error")
)
+ fallback_percentage = (fallback_count / total_chunks * 100) if total_chunks > 0 else 0.0
logger.info(
- f"[FileContentParser] Completed processing {len(memory_items)}/{total_chunks} chunks"
+ f"[FileContentParser] Completed processing {len(memory_items)}/{total_chunks} chunks, "
+ f"fallback count: {fallback_count}/{total_chunks} ({fallback_percentage:.1f}%)"
)
+ rawfile_items = [
+ memory for memory in memory_items if memory.metadata.memory_type == "RawFileMemory"
+ ]
+ mem_items = [
+ memory for memory in memory_items if memory.metadata.memory_type != "RawFileMemory"
+ ]
+ related_rawfile_items = _relate_chunks(rawfile_items)
+ memory_items = mem_items + related_rawfile_items
return memory_items or [
_make_memory_item(
diff --git a/src/memos/mem_reader/read_multi_modal/image_parser.py b/src/memos/mem_reader/read_multi_modal/image_parser.py
index 9322b9bc9..97400ca26 100644
--- a/src/memos/mem_reader/read_multi_modal/image_parser.py
+++ b/src/memos/mem_reader/read_multi_modal/image_parser.py
@@ -3,7 +3,7 @@
import json
import re
-from typing import Any
+from typing import TYPE_CHECKING, Any
from memos.embedders.base import BaseEmbedder
from memos.llms.base import BaseLLM
@@ -20,6 +20,10 @@
from .utils import detect_lang
+if TYPE_CHECKING:
+ from memos.types.general_types import UserContext
+
+
logger = get_logger(__name__)
@@ -212,6 +216,7 @@ def parse_fine(
key=_derive_key(summary),
sources=[source],
background=summary,
+ **kwargs,
)
)
return memory_items
@@ -252,6 +257,7 @@ def parse_fine(
key=key if key else _derive_key(value),
sources=[source],
background=background,
+ **kwargs,
)
memory_items.append(memory_item)
except Exception as e:
@@ -273,6 +279,7 @@ def parse_fine(
key=_derive_key(fallback_value),
sources=[source],
background="Image processing encountered an error.",
+ **kwargs,
)
]
@@ -333,12 +340,18 @@ def _create_memory_item(
key: str,
sources: list[SourceMessage],
background: str = "",
+ **kwargs,
) -> TextualMemoryItem:
"""Create a TextualMemoryItem with the given parameters."""
info_ = info.copy()
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ manager_user_id = user_context.manager_user_id if user_context else None
+ project_id = user_context.project_id if user_context else None
+
return TextualMemoryItem(
memory=value,
metadata=TreeNodeTextualMemoryMetadata(
@@ -355,5 +368,7 @@ def _create_memory_item(
confidence=0.99,
type="fact",
info=info_,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
),
)
diff --git a/src/memos/mem_reader/read_multi_modal/string_parser.py b/src/memos/mem_reader/read_multi_modal/string_parser.py
index b6e18fda3..220cf6e58 100644
--- a/src/memos/mem_reader/read_multi_modal/string_parser.py
+++ b/src/memos/mem_reader/read_multi_modal/string_parser.py
@@ -3,7 +3,7 @@
Handles simple string messages that need to be converted to memory items.
"""
-from typing import Any
+from typing import TYPE_CHECKING, Any
from memos.embedders.base import BaseEmbedder
from memos.llms.base import BaseLLM
@@ -17,6 +17,10 @@
from .base import BaseMessageParser, _add_lang_to_source, _derive_key
+if TYPE_CHECKING:
+ from memos.types.general_types import UserContext
+
+
logger = get_logger(__name__)
@@ -92,6 +96,11 @@ def parse_fast(
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ manager_user_id = user_context.manager_user_id if user_context else None
+ project_id = user_context.project_id if user_context else None
+
# For string messages, default to LongTermMemory
memory_type = "LongTermMemory"
@@ -120,6 +129,8 @@ def parse_fast(
confidence=0.99,
type="fact",
info=info_,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
),
)
memory_items.append(memory_item)
diff --git a/src/memos/mem_reader/read_multi_modal/system_parser.py b/src/memos/mem_reader/read_multi_modal/system_parser.py
index 03a49afd8..74545ceee 100644
--- a/src/memos/mem_reader/read_multi_modal/system_parser.py
+++ b/src/memos/mem_reader/read_multi_modal/system_parser.py
@@ -6,7 +6,7 @@
import re
import uuid
-from typing import Any
+from typing import TYPE_CHECKING, Any
from memos.embedders.base import BaseEmbedder
from memos.llms.base import BaseLLM
@@ -21,6 +21,10 @@
from .base import BaseMessageParser, _add_lang_to_source
+if TYPE_CHECKING:
+ from memos.types.general_types import UserContext
+
+
logger = get_logger(__name__)
@@ -242,6 +246,11 @@ def format_tool_schema_readable(tool_schema):
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ manager_user_id = user_context.manager_user_id if user_context else None
+ project_id = user_context.project_id if user_context else None
+
# Split parsed text into chunks
content_chunks = self._split_text(msg_line)
@@ -260,6 +269,8 @@ def format_tool_schema_readable(tool_schema):
tags=["mode:fast"],
sources=[source],
info=info_,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
),
)
memory_items.append(memory_item)
@@ -294,6 +305,11 @@ def parse_fine(
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ manager_user_id = user_context.manager_user_id if user_context else None
+ project_id = user_context.project_id if user_context else None
+
# Deduplicate tool schemas based on memory content
# Use hash as key for efficiency, but store original string to handle collisions
seen_memories = {} # hash -> memory_str mapping
@@ -321,6 +337,8 @@ def parse_fine(
status="activated",
embedding=self.embedder.embed([json.dumps(schema, ensure_ascii=False)])[0],
info=info_,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
),
)
for schema in unique_schemas
diff --git a/src/memos/mem_reader/read_multi_modal/text_content_parser.py b/src/memos/mem_reader/read_multi_modal/text_content_parser.py
index 549f74852..9fdcf8c58 100644
--- a/src/memos/mem_reader/read_multi_modal/text_content_parser.py
+++ b/src/memos/mem_reader/read_multi_modal/text_content_parser.py
@@ -4,7 +4,7 @@
Text content parts are typically used in user/assistant messages with multimodal content.
"""
-from typing import Any
+from typing import TYPE_CHECKING, Any
from memos.embedders.base import BaseEmbedder
from memos.llms.base import BaseLLM
@@ -19,6 +19,10 @@
from .base import BaseMessageParser, _add_lang_to_source, _derive_key
+if TYPE_CHECKING:
+ from memos.types.general_types import UserContext
+
+
logger = get_logger(__name__)
@@ -92,6 +96,11 @@ def parse_fast(
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ manager_user_id = user_context.manager_user_id if user_context else None
+ project_id = user_context.project_id if user_context else None
+
# For text content parts, default to LongTermMemory
# (since we don't have role information at this level)
memory_type = "LongTermMemory"
@@ -113,6 +122,8 @@ def parse_fast(
confidence=0.99,
type="fact",
info=info_,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
),
)
diff --git a/src/memos/mem_reader/read_multi_modal/tool_parser.py b/src/memos/mem_reader/read_multi_modal/tool_parser.py
index caf5ffaa6..4718f87ba 100644
--- a/src/memos/mem_reader/read_multi_modal/tool_parser.py
+++ b/src/memos/mem_reader/read_multi_modal/tool_parser.py
@@ -2,7 +2,7 @@
import json
-from typing import Any
+from typing import TYPE_CHECKING, Any
from memos.embedders.base import BaseEmbedder
from memos.llms.base import BaseLLM
@@ -18,6 +18,10 @@
from .utils import detect_lang
+if TYPE_CHECKING:
+ from memos.types.general_types import UserContext
+
+
logger = get_logger(__name__)
@@ -179,6 +183,11 @@ def parse_fast(
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ manager_user_id = user_context.manager_user_id if user_context else None
+ project_id = user_context.project_id if user_context else None
+
content_chunks = self._split_text(line)
memory_items = []
for _chunk_idx, chunk_text in enumerate(content_chunks):
@@ -195,6 +204,8 @@ def parse_fast(
tags=["mode:fast"],
sources=sources,
info=info_,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
),
)
memory_items.append(memory_item)
diff --git a/src/memos/mem_reader/read_multi_modal/user_parser.py b/src/memos/mem_reader/read_multi_modal/user_parser.py
index 1ab48c82e..2e5ea6eae 100644
--- a/src/memos/mem_reader/read_multi_modal/user_parser.py
+++ b/src/memos/mem_reader/read_multi_modal/user_parser.py
@@ -1,6 +1,6 @@
"""Parser for user messages."""
-from typing import Any
+from typing import TYPE_CHECKING, Any
from memos.embedders.base import BaseEmbedder
from memos.llms.base import BaseLLM
@@ -16,6 +16,10 @@
from .utils import detect_lang
+if TYPE_CHECKING:
+ from memos.types.general_types import UserContext
+
+
logger = get_logger(__name__)
@@ -64,6 +68,10 @@ def create_source(
part_type = part.get("type", "")
if part_type == "text":
text_contents.append(part.get("text", ""))
+ if part_type == "file":
+ file_info = part.get("file", {})
+ file_data = file_info.get("file_data", "")
+ text_contents.append(file_data)
# Detect overall language from all text content
overall_lang = "en"
@@ -183,6 +191,11 @@ def parse_fast(
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ manager_user_id = user_context.manager_user_id if user_context else None
+ project_id = user_context.project_id if user_context else None
+
# Create memory item (equivalent to _make_memory_item)
memory_item = TextualMemoryItem(
memory=line,
@@ -200,6 +213,8 @@ def parse_fast(
confidence=0.99,
type="fact",
info=info_,
+ manager_user_id=manager_user_id,
+ project_id=project_id,
),
)
diff --git a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py
index 6bd18808d..fa799e759 100644
--- a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py
+++ b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py
@@ -1,3 +1,4 @@
+import copy
import json
import os
import shutil
@@ -7,7 +8,9 @@
from concurrent.futures import as_completed
from datetime import datetime
from pathlib import Path
-from typing import Any
+from typing import TYPE_CHECKING, Any
+
+from dotenv import load_dotenv
from memos.context.context import ContextThreadPoolExecutor
from memos.dependency import require_python_package
@@ -19,19 +22,225 @@
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
from memos.templates.skill_mem_prompt import (
+ OTHERS_GENERATION_PROMPT,
+ OTHERS_GENERATION_PROMPT_ZH,
+ SCRIPT_GENERATION_PROMPT,
SKILL_MEMORY_EXTRACTION_PROMPT,
+ SKILL_MEMORY_EXTRACTION_PROMPT_MD,
+ SKILL_MEMORY_EXTRACTION_PROMPT_MD_ZH,
SKILL_MEMORY_EXTRACTION_PROMPT_ZH,
TASK_CHUNKING_PROMPT,
TASK_CHUNKING_PROMPT_ZH,
TASK_QUERY_REWRITE_PROMPT,
TASK_QUERY_REWRITE_PROMPT_ZH,
+ TOOL_GENERATION_PROMPT,
)
from memos.types import MessageList
+load_dotenv()
+
+if TYPE_CHECKING:
+ from memos.types.general_types import UserContext
+
+
logger = get_logger(__name__)
+def _generate_content_by_llm(llm: BaseLLM, prompt_template: str, **kwargs) -> Any:
+ """Generate content using LLM."""
+ try:
+ prompt = prompt_template.format(**kwargs)
+ response = llm.generate([{"role": "user", "content": prompt}])
+ if "json" in prompt_template.lower():
+ response = response.replace("```json", "").replace("```", "").strip()
+ return json.loads(response)
+ return response.strip()
+ except Exception as e:
+ logger.warning(f"[PROCESS_SKILLS] LLM generation failed: {e}")
+ return {} if "json" in prompt_template.lower() else ""
+
+
+def _batch_extract_skills(
+ task_chunks: dict[str, MessageList],
+ related_memories_map: dict[str, list[TextualMemoryItem]],
+ llm: BaseLLM,
+ chat_history: MessageList,
+) -> list[tuple[dict[str, Any], str, MessageList]]:
+ """Phase 1: Batch extract base skill structures from all task chunks in parallel."""
+ results = []
+ with ContextThreadPoolExecutor(max_workers=min(5, len(task_chunks))) as executor:
+ futures = {
+ executor.submit(
+ _extract_skill_memory_by_llm_md,
+ messages=messages,
+ old_memories=related_memories_map.get(task_type, []),
+ llm=llm,
+ chat_history=chat_history,
+ ): task_type
+ for task_type, messages in task_chunks.items()
+ }
+
+ for future in as_completed(futures):
+ task_type = futures[future]
+ try:
+ skill_memory = future.result()
+ if skill_memory:
+ results.append((skill_memory, task_type, task_chunks.get(task_type, [])))
+ except Exception as e:
+ logger.warning(
+ f"[PROCESS_SKILLS] Error extracting skill memory for task '{task_type}': {e}"
+ )
+ return results
+
+
+def _batch_generate_skill_details(
+ raw_skills_data: list[tuple[dict[str, Any], str, MessageList]],
+ related_skill_memories_map: dict[str, list[TextualMemoryItem]],
+ llm: BaseLLM,
+) -> list[dict[str, Any]]:
+ """Phase 2: Batch generate details (scripts, tools, others, examples) for all skills in parallel."""
+ generation_tasks = []
+
+ # Helper to create task objects
+ def create_task(skill_mem, gen_type, prompt, requirements, context, **kwargs):
+ return {
+ "type": gen_type,
+ "skill_memory": skill_mem,
+ "func": _generate_content_by_llm,
+ "args": (llm, prompt),
+ "kwargs": {"requirements": requirements, "context": context, **kwargs},
+ }
+
+ # 1. Collect all generation tasks from all skills
+ for skill_memory, task_type, messages in raw_skills_data:
+ messages_context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
+
+ # Script
+ script_req = copy.deepcopy(skill_memory.get("scripts"))
+ if script_req:
+ generation_tasks.append(
+ create_task(
+ skill_memory, "scripts", SCRIPT_GENERATION_PROMPT, script_req, messages_context
+ )
+ )
+ # TODO Add loop verification after code completion to ensure the generated script meets requirements
+ else:
+ skill_memory["scripts"] = {}
+
+ # Tool
+ tool_req = skill_memory.get("tool")
+ if tool_req:
+ # Extract available tool schemas from related memories
+ tool_memories = [
+ memory
+ for memory in related_skill_memories_map.get(task_type, [])
+ if memory.metadata.memory_type == "ToolSchemaMemory"
+ ]
+ tool_schemas_list = [memory.memory for memory in tool_memories]
+
+ tool_schemas_str = (
+ "\n\n".join(
+ [
+ f"Tool Schema {i + 1}:\n{schema}"
+ for i, schema in enumerate(tool_schemas_list)
+ ]
+ )
+ if tool_schemas_list
+ else "No specific tool schemas available."
+ )
+
+ generation_tasks.append(
+ create_task(
+ skill_memory,
+ "tool",
+ TOOL_GENERATION_PROMPT,
+ tool_req,
+ messages_context,
+ tool_schemas=tool_schemas_str,
+ )
+ )
+ else:
+ skill_memory["tool"] = {}
+
+ lang = detect_lang(messages_context)
+ others_req = skill_memory.get("others")
+ if others_req and isinstance(others_req, dict):
+ for filename, summary in others_req.items():
+ generation_tasks.append(
+ {
+ "type": "others",
+ "skill_memory": skill_memory,
+ "key": filename,
+ "func": _generate_content_by_llm,
+ "args": (
+ llm,
+ OTHERS_GENERATION_PROMPT_ZH
+ if lang == "zh"
+ else OTHERS_GENERATION_PROMPT,
+ ),
+ "kwargs": {
+ "filename": filename,
+ "summary": summary,
+ "context": messages_context,
+ },
+ }
+ )
+ else:
+ skill_memory["others"] = {}
+
+ if not generation_tasks:
+ return [item[0] for item in raw_skills_data]
+
+ # 2. Execute all tasks in parallel
+ with ContextThreadPoolExecutor(max_workers=min(len(generation_tasks), 5)) as executor:
+ futures = {
+ executor.submit(t["func"], *t["args"], **t["kwargs"]): t for t in generation_tasks
+ }
+
+ for future in as_completed(futures):
+ task_info = futures[future]
+ try:
+ result = future.result()
+ if not result:
+ continue
+
+ skill_mem = task_info["skill_memory"]
+
+ if task_info["type"] == "scripts":
+ if isinstance(result, dict):
+ # Combine code with script_req
+ try:
+ skill_mem["scripts"] = {
+ filename: f"# {abstract}:\n{code}"
+ for abstract, (filename, code) in zip(
+ script_req, result.items(), strict=False
+ )
+ }
+ except ValueError:
+ logger.warning(
+ f"[PROCESS_SKILLS] Invalid script generation result: {result}"
+ )
+ skill_mem["scripts"] = {}
+
+ elif task_info["type"] == "tool":
+ skill_mem["tool"] = result
+
+ elif task_info["type"] == "others":
+ if "others" not in skill_mem or not isinstance(skill_mem["others"], dict):
+ skill_mem["others"] = {}
+ skill_mem["others"][task_info["key"]] = (
+ f"# {task_info['kwargs']['summary']}\n{result}"
+ )
+
+ except Exception as e:
+ logger.warning(
+ f"[PROCESS_SKILLS] Error in generation task {task_info['type']}: {e}"
+ )
+
+ return [item[0] for item in raw_skills_data]
+
+
def add_id_to_mysql(memory_id: str, mem_cube_id: str):
"""Add id to mysql, will deprecate this function in the future"""
# TODO: tmp function, deprecate soon
@@ -263,6 +472,135 @@ def _extract_skill_memory_by_llm(
return None
+def _extract_skill_memory_by_llm_md(
+ messages: MessageList,
+ old_memories: list[TextualMemoryItem],
+ llm: BaseLLM,
+ chat_history: MessageList,
+ chat_history_max_length: int = 5000,
+) -> dict[str, Any]:
+ old_memories_dict = [memory.model_dump() for memory in old_memories]
+ old_memories_context = {}
+ old_skill_content = []
+ seen_messages = set()
+
+ for mem in old_memories_dict:
+ if mem["metadata"]["memory_type"] == "SkillMemory":
+ old_skill_content.append(
+ {
+ "id": mem["id"],
+ "name": mem["metadata"]["name"],
+ "description": mem["metadata"]["description"],
+ "procedure": mem["metadata"]["procedure"],
+ "experience": mem["metadata"]["experience"],
+ "preference": mem["metadata"]["preference"],
+ "examples": mem["metadata"]["examples"],
+ "others": mem["metadata"].get("others"), # TODO: maybe remove, too long
+ }
+ )
+ else:
+ # Filter and deduplicate messages
+ unique_messages = []
+ for item in mem["metadata"]["sources"]:
+ message_content = f"{item['role']}: {item['content']}"
+ if message_content not in seen_messages:
+ seen_messages.add(message_content)
+ unique_messages.append(message_content)
+
+ if unique_messages:
+ old_memories_context.setdefault(mem["metadata"]["memory_type"], []).extend(
+ unique_messages
+ )
+
+ # Prepare current conversation context
+ messages_context = "\n".join(
+ [f"{message['role']}: {message['content']}" for message in messages]
+ )
+
+ # Prepare history context
+ chat_history_context = "\n".join(
+ [f"{history['role']}: {history['content']}" for history in chat_history]
+ )
+ chat_history_context = chat_history_context[-chat_history_max_length:]
+
+ # Prepare prompt
+ lang = detect_lang(messages_context)
+
+ # Prepare old memories context
+ old_skill_content = (
+ "已有技能列表: \n"
+ if lang == "zh"
+ else "Exist Skill Schemas: \n" + json.dumps(old_skill_content, ensure_ascii=False, indent=2)
+ if old_skill_content
+ else ""
+ )
+
+ old_memories_context = (
+ "相关历史对话:\n"
+ if lang == "zh"
+ else "Relevant Context:\n"
+ + "\n".join([f"{k}:\n{v}" for k, v in old_memories_context.items()])
+ )
+
+ template = (
+ SKILL_MEMORY_EXTRACTION_PROMPT_MD_ZH if lang == "zh" else SKILL_MEMORY_EXTRACTION_PROMPT_MD
+ )
+ prompt_content = (
+ template.replace("{old_memories}", old_memories_context + old_skill_content)
+ .replace("{messages}", messages_context)
+ .replace("{chat_history}", chat_history_context)
+ )
+
+ prompt = [{"role": "user", "content": prompt_content}]
+ logger.info(f"[Skill Memory]: _extract_skill_memory_by_llm_md: Prompt {prompt_content}")
+
+ # Call LLM to extract skill memory with retry logic
+ for attempt in range(3):
+ try:
+ # Only pass model_name_or_path if SKILLS_LLM is set
+ skills_llm = os.getenv("SKILLS_LLM", None)
+ llm_kwargs = {"model_name_or_path": skills_llm} if skills_llm else {}
+ response_text = llm.generate(prompt, **llm_kwargs)
+ # Clean up response (remove Markdown code blocks if present)
+ logger.info(f"[Skill Memory]: response_text {response_text}")
+ response_text = response_text.strip()
+ response_text = response_text.replace("```json", "").replace("```", "").strip()
+
+ # Parse JSON response
+ skill_memory = json.loads(response_text)
+
+ # If LLM returns null (parsed as None), log and return None
+ if skill_memory is None:
+ logger.info(
+ "[PROCESS_SKILLS] No skill memory extracted from conversation (LLM returned null)"
+ )
+ return None
+ # If no old skill content, set update to False (for llm hallucination)
+ if not old_skill_content:
+ skill_memory["old_memory_id"] = ""
+ skill_memory["update"] = False
+
+ return skill_memory
+
+ except json.JSONDecodeError as e:
+ logger.warning(f"[PROCESS_SKILLS] JSON decode failed (attempt {attempt + 1}): {e}")
+ logger.debug(f"[PROCESS_SKILLS] Response text: {response_text}")
+ if attempt == 2:
+ logger.warning("[PROCESS_SKILLS] Failed to parse skill memory after 3 retries")
+ return None
+ except Exception as e:
+ logger.warning(
+ f"[PROCESS_SKILLS] LLM skill memory extraction failed (attempt {attempt + 1}): {e}"
+ )
+ if attempt == 2:
+ logger.warning(
+ "[PROCESS_SKILLS] LLM skill memory extraction failed after 3 retries"
+ )
+ return None
+
+ return None
+
+
def _recall_related_skill_memories(
task_type: str,
messages: MessageList,
@@ -276,7 +614,7 @@ def _recall_related_skill_memories(
related_skill_memories = searcher.search(
query,
top_k=5,
- memory_type="SkillMemory",
+ memory_type="All",
info=info,
include_skill_memory=True,
user_name=mem_cube_id,
@@ -329,42 +667,88 @@ def _rewrite_query(task_type: str, messages: MessageList, llm: BaseLLM, rewrite_
import_name="alibabacloud_oss_v2",
install_command="pip install alibabacloud-oss-v2",
)
-def _upload_skills_to_oss(local_file_path: str, oss_file_path: str, client: Any) -> str:
- import alibabacloud_oss_v2 as oss
-
- result = client.put_object_from_file(
- request=oss.PutObjectRequest(
- bucket=os.getenv("OSS_BUCKET_NAME"),
- key=oss_file_path,
- ),
- filepath=local_file_path,
- )
+def _upload_skills(
+ skills_repo_backend: str,
+ skills_oss_dir: dict[str, Any] | None,
+ local_tmp_file_path: str,
+ local_save_file_path: str,
+ client: Any,
+ user_id: str,
+) -> str:
+ if skills_repo_backend == "OSS":
+ zip_filename = Path(local_tmp_file_path).name
+ oss_path = (Path(skills_oss_dir) / user_id / zip_filename).as_posix()
+
+ import alibabacloud_oss_v2 as oss
+
+ result = client.put_object_from_file(
+ request=oss.PutObjectRequest(
+ bucket=os.getenv("OSS_BUCKET_NAME"),
+ key=oss_path,
+ ),
+ filepath=local_tmp_file_path,
+ )
- if result.status_code != 200:
- logger.warning("[PROCESS_SKILLS] Failed to upload skill to OSS")
- return ""
+ if result.status_code != 200:
+ logger.warning("[PROCESS_SKILLS] Failed to upload skill to OSS")
+ return ""
+
+ # Construct and return the URL
+ bucket_name = os.getenv("OSS_BUCKET_NAME")
+ endpoint = os.getenv("OSS_ENDPOINT").replace("https://", "").replace("http://", "")
+ url = f"https://{bucket_name}.{endpoint}/{oss_path}"
+ return url
+ else:
+ import sys
+
+ args = sys.argv
+ port = (
+ int(args[args.index("--port") + 1])
+ if "--port" in args and args.index("--port") + 1 < len(args)
+ else "8000"
+ )
- # Construct and return the URL
- bucket_name = os.getenv("OSS_BUCKET_NAME")
- endpoint = os.getenv("OSS_ENDPOINT").replace("https://", "").replace("http://", "")
- url = f"https://{bucket_name}.{endpoint}/{oss_file_path}"
- return url
+ zip_path = str(local_tmp_file_path)
+ os.makedirs(local_save_file_path, exist_ok=True)
+ file_name = os.path.basename(zip_path)
+ target_full_path = os.path.join(local_save_file_path, file_name)
+ shutil.copy2(zip_path, target_full_path)
+ return f"http://localhost:{port}/download/{file_name}"
@require_python_package(
import_name="alibabacloud_oss_v2",
install_command="pip install alibabacloud-oss-v2",
)
-def _delete_skills_from_oss(oss_file_path: str, client: Any) -> Any:
- import alibabacloud_oss_v2 as oss
-
- result = client.delete_object(
- oss.DeleteObjectRequest(
- bucket=os.getenv("OSS_BUCKET_NAME"),
- key=oss_file_path,
+def _delete_skills(
+ skills_repo_backend: str,
+ zip_filename: str,
+ client: Any,
+ skills_oss_dir: dict[str, Any] | None,
+ local_save_file_path: str,
+ user_id: str,
+) -> Any:
+ if skills_repo_backend == "OSS":
+ old_path = (Path(skills_oss_dir) / user_id / zip_filename).as_posix()
+ import alibabacloud_oss_v2 as oss
+
+ return client.delete_object(
+ oss.DeleteObjectRequest(
+ bucket=os.getenv("OSS_BUCKET_NAME"),
+ key=old_path,
+ )
)
- )
- return result
+ else:
+ target_full_path = os.path.join(local_save_file_path, zip_filename)
+ target_path = Path(target_full_path)
+ try:
+ if target_path.is_file():
+ target_path.unlink()
+ logger.info(f"Local file {target_path} successfully deleted")
+ else:
+ logger.info(f"Local file {target_path} does not exist, no need to delete")
+ except Exception as e:
+ logger.warning(f"Error deleting local file: {e}")
def _write_skills_to_file(
@@ -374,7 +758,7 @@ def _write_skills_to_file(
skill_name = skill_memory.get("name", "unnamed_skill").replace(" ", "_").lower()
# Create tmp directory for user if it doesn't exist
- tmp_dir = Path(skills_dir_config["skills_local_dir"]) / user_id
+ tmp_dir = Path(skills_dir_config["skills_local_tmp_dir"]) / user_id
tmp_dir.mkdir(parents=True, exist_ok=True)
# Create skill directory directly in tmp_dir
@@ -388,6 +772,11 @@ def _write_skills_to_file(
---
"""
+ # Add trigger
+ trigger = skill_memory.get("trigger", "")
+ if trigger:
+ skill_md_content += f"\n## Trigger\n{trigger}\n"
+
# Add Procedure section only if present
procedure = skill_memory.get("procedure", "")
if procedure and procedure.strip():
@@ -422,6 +811,10 @@ def _write_skills_to_file(
for script_name in scripts:
skill_md_content += f"- `./scripts/{script_name}`\n"
+ tool_usage = skill_memory.get("tool", "")
+ if tool_usage:
+ skill_md_content += f"\n## Tool Usage\n{tool_usage}\n"
+
# Add others - handle both inline content and separate markdown files
others = skill_memory.get("others")
if others and isinstance(others, dict):
@@ -447,7 +840,7 @@ def _write_skills_to_file(
skill_md_content += "\n## Additional Information\n"
skill_md_content += "\nSee also:\n"
for md_filename in md_files:
- skill_md_content += f"- [{md_filename}](./{md_filename})\n"
+ skill_md_content += f"- [{md_filename}](./reference/{md_filename})\n"
# Write SKILL.md file
skill_md_path = skill_dir / "SKILL.md"
@@ -458,7 +851,9 @@ def _write_skills_to_file(
if others and isinstance(others, dict):
for key, value in others.items():
if key.endswith(".md"):
- md_file_path = skill_dir / key
+ md_file_dir = skill_dir / "reference"
+ md_file_dir.mkdir(parents=True, exist_ok=True)
+ md_file_path = md_file_dir / key
with open(md_file_path, "w", encoding="utf-8") as f:
f.write(value)
@@ -494,12 +889,20 @@ def _write_skills_to_file(
def create_skill_memory_item(
- skill_memory: dict[str, Any], info: dict[str, Any], embedder: BaseEmbedder | None = None
+ skill_memory: dict[str, Any],
+ info: dict[str, Any],
+ embedder: BaseEmbedder | None = None,
+ **kwargs: Any,
) -> TextualMemoryItem:
info_ = info.copy()
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
+ # Extract manager_user_id and project_id from user_context
+ user_context: UserContext | None = kwargs.get("user_context")
+ manager_user_id = user_context.manager_user_id if user_context else None
+ project_id = user_context.project_id if user_context else None
+
# Use description as the memory content
memory_content = skill_memory.get("description", "")
@@ -509,7 +912,7 @@ def create_skill_memory_item(
session_id=session_id,
memory_type="SkillMemory",
status="activated",
- tags=skill_memory.get("tags", []),
+ tags=skill_memory.get("tags") or skill_memory.get("trigger", []),
key=skill_memory.get("name", ""),
sources=[],
usage=[],
@@ -530,6 +933,8 @@ def create_skill_memory_item(
scripts=skill_memory.get("scripts"),
others=skill_memory.get("others"),
url=skill_memory.get("url", ""),
+ manager_user_id=manager_user_id,
+ project_id=project_id,
)
# If this is an update, use the old memory ID
@@ -544,6 +949,52 @@ def create_skill_memory_item(
return TextualMemoryItem(id=item_id, memory=memory_content, metadata=metadata)
+def _skill_init(skills_repo_backend, oss_config, skills_dir_config):
+ if skills_repo_backend == "OSS":
+ # Validate required configurations
+ if not oss_config:
+ logger.warning(
+ "[PROCESS_SKILLS] OSS configuration is required for skill memory processing"
+ )
+ return None, None, False
+
+ if not skills_dir_config:
+ logger.warning(
+ "[PROCESS_SKILLS] Skills directory configuration is required for skill memory processing"
+ )
+ return None, None, False
+
+ # Validate skills_dir has required keys
+ required_keys = ["skills_local_tmp_dir", "skills_local_dir", "skills_oss_dir"]
+ missing_keys = [key for key in required_keys if key not in skills_dir_config]
+ if missing_keys:
+ logger.warning(
+ f"[PROCESS_SKILLS] Skills directory configuration missing required keys: {', '.join(missing_keys)}"
+ )
+ return None, None, False
+
+ oss_client = create_oss_client(oss_config)
+ if not oss_client:
+ logger.warning("[PROCESS_SKILLS] Failed to create OSS client")
+ return None, None, False
+ return oss_client, missing_keys, True
+ else:
+ return None, None, True
+
+
+def _get_skill_file_storage_location() -> str:
+ # SKILLS_REPO_BACKEND: Skill file storage location OSS/LOCAL
+ allowed_backends = {"OSS", "LOCAL"}
+ raw_backend = os.getenv("SKILLS_REPO_BACKEND")
+ if raw_backend in allowed_backends:
+ return raw_backend
+ else:
+ logger.warning(
+ "Environment variable [SKILLS_REPO_BACKEND] is invalid, using LOCAL to store skill",
+ )
+ return "LOCAL"
+
+
def process_skill_memory_fine(
fast_memory_items: list[TextualMemoryItem],
info: dict[str, Any],
@@ -554,17 +1005,12 @@ def process_skill_memory_fine(
rewrite_query: bool = True,
oss_config: dict[str, Any] | None = None,
skills_dir_config: dict[str, Any] | None = None,
+ complete_skill_memory: bool = True,
**kwargs,
) -> list[TextualMemoryItem]:
- # Validate required configurations
- if not oss_config:
- logger.warning("[PROCESS_SKILLS] OSS configuration is required for skill memory processing")
- return []
-
- if not skills_dir_config:
- logger.warning(
- "[PROCESS_SKILLS] Skills directory configuration is required for skill memory processing"
- )
+ skills_repo_backend = _get_skill_file_storage_location()
+ oss_client, missing_keys, flag = _skill_init(skills_repo_backend, oss_config, skills_dir_config)
+ if not flag:
return []
chat_history = kwargs.get("chat_history")
@@ -572,20 +1018,6 @@ def process_skill_memory_fine(
chat_history = []
logger.warning("[PROCESS_SKILLS] History is None in Skills")
- # Validate skills_dir has required keys
- required_keys = ["skills_local_dir", "skills_oss_dir"]
- missing_keys = [key for key in required_keys if key not in skills_dir_config]
- if missing_keys:
- logger.warning(
- f"[PROCESS_SKILLS] Skills directory configuration missing required keys: {', '.join(missing_keys)}"
- )
- return []
-
- oss_client = create_oss_client(oss_config)
- if not oss_client:
- logger.warning("[PROCESS_SKILLS] Failed to create OSS client")
- return []
-
messages = _reconstruct_messages_from_memory_items(fast_memory_items)
chat_history, messages = _preprocess_extract_messages(chat_history, messages)
@@ -627,26 +1059,56 @@ def process_skill_memory_fine(
)
related_skill_memories_by_task[task_name] = []
- skill_memories = []
- with ContextThreadPoolExecutor(max_workers=5) as executor:
- futures = {
- executor.submit(
- _extract_skill_memory_by_llm,
- messages,
- related_skill_memories_by_task.get(task_type, []),
- llm,
- chat_history,
- ): task_type
- for task_type, messages in task_chunks.items()
- }
- for future in as_completed(futures):
- try:
- skill_memory = future.result()
- if skill_memory: # Only add non-None results
- skill_memories.append(skill_memory)
- except Exception as e:
- logger.warning(f"[PROCESS_SKILLS] Error extracting skill memory: {e}")
- continue
+ def _simple_extract():
+ # simple extract skill memory, only one stage
+ memories = []
+ with ContextThreadPoolExecutor(max_workers=min(5, len(task_chunks))) as executor:
+ futures = {
+ executor.submit(
+ _extract_skill_memory_by_llm,
+ messages=chunk_messages,
+ # Filter only SkillMemory types
+ old_memories=[
+ item
+ for item in related_skill_memories_by_task.get(task_type, [])
+ if item and getattr(item.metadata, "memory_type", "") == "SkillMemory"
+ ],
+ llm=llm,
+ chat_history=chat_history,
+ ): task_type
+ for task_type, chunk_messages in task_chunks.items()
+ }
+
+ for future in as_completed(futures):
+ task_type = futures[future]
+ try:
+ skill_memory = future.result()
+ if skill_memory:
+ memories.append(skill_memory)
+ except Exception as e:
+ logger.warning(
+ f"[PROCESS_SKILLS] _simple_extract: Error processing task '{task_type}': {e}"
+ )
+ return memories
+
+ def _full_extract():
+ # full extract skill memory, include two stage
+ raw_extraction_results = _batch_extract_skills(
+ task_chunks=task_chunks,
+ related_memories_map=related_skill_memories_by_task,
+ llm=llm,
+ chat_history=chat_history,
+ )
+ if not raw_extraction_results:
+ return []
+ return _batch_generate_skill_details(
+ raw_skills_data=raw_extraction_results,
+ related_skill_memories_map=related_skill_memories_by_task,
+ llm=llm,
+ )
+
+ # Execute both parts in parallel
+ skill_memories = _simple_extract() if not complete_skill_memory else _full_extract()
# write skills to file and get zip paths
skill_memory_with_paths = []
@@ -684,23 +1146,27 @@ def process_skill_memory_fine(
old_memory = old_memories_map.get(old_memory_id)
if old_memory:
- # Get old OSS path from the old memory's metadata
- old_oss_path = getattr(old_memory.metadata, "url", None)
+ # Get old path from the old memory's metadata
+ old_path = getattr(old_memory.metadata, "url", None)
- if old_oss_path:
+ if old_path:
try:
# delete old skill from OSS
- zip_filename = Path(old_oss_path).name
- old_oss_path = (
- Path(skills_dir_config["skills_oss_dir"]) / user_id / zip_filename
- ).as_posix()
- _delete_skills_from_oss(old_oss_path, oss_client)
+ zip_filename = Path(old_path).name
+ _delete_skills(
+ skills_repo_backend=skills_repo_backend,
+ zip_filename=zip_filename,
+ client=oss_client,
+ skills_oss_dir=skills_dir_config["skills_oss_dir"],
+ local_save_file_path=skills_dir_config["skills_local_dir"],
+ user_id=user_id,
+ )
logger.info(
- f"[PROCESS_SKILLS] Deleted old skill from OSS: {old_oss_path}"
+ f"[PROCESS_SKILLS] Deleted old skill from {skills_repo_backend}: {old_path}"
)
except Exception as e:
logger.warning(
- f"[PROCESS_SKILLS] Failed to delete old skill from OSS: {e}"
+ f"[PROCESS_SKILLS] Failed to delete old skill from {skills_repo_backend}: {e}"
)
# delete old skill from graph db
@@ -710,24 +1176,23 @@ def process_skill_memory_fine(
f"[PROCESS_SKILLS] Deleted old skill from graph db: {old_memory_id}"
)
- # Upload new skill to OSS
+ # Upload new skill
# Use the same filename as the local zip file
- zip_filename = Path(zip_path).name
- oss_path = (
- Path(skills_dir_config["skills_oss_dir"]) / user_id / zip_filename
- ).as_posix()
-
- # _upload_skills_to_oss returns the URL
- url = _upload_skills_to_oss(
- local_file_path=str(zip_path), oss_file_path=oss_path, client=oss_client
+ url = _upload_skills(
+ skills_repo_backend=skills_repo_backend,
+ skills_oss_dir=skills_dir_config["skills_oss_dir"],
+ local_tmp_file_path=zip_path,
+ local_save_file_path=skills_dir_config["skills_local_dir"],
+ client=oss_client,
+ user_id=user_id,
)
# Set URL directly to skill_memory
skill_memory["url"] = url
- logger.info(f"[PROCESS_SKILLS] Uploaded skill to OSS: {url}")
+ logger.info(f"[PROCESS_SKILLS] Uploaded skill to {skills_repo_backend}: {url}")
except Exception as e:
- logger.warning(f"[PROCESS_SKILLS] Error uploading skill to OSS: {e}")
+ logger.warning(f"[PROCESS_SKILLS] Error uploading skill to {skills_repo_backend}: {e}")
skill_memory["url"] = "" # Set to empty string if upload fails
finally:
# Clean up local files after upload
@@ -748,7 +1213,7 @@ def process_skill_memory_fine(
skill_memory_items = []
for skill_memory in skill_memories:
try:
- memory_item = create_skill_memory_item(skill_memory, info, embedder)
+ memory_item = create_skill_memory_item(skill_memory, info, embedder, **kwargs)
skill_memory_items.append(memory_item)
except Exception as e:
logger.warning(f"[PROCESS_SKILLS] Error creating skill memory item: {e}")
diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py
index 2c4fee853..ceaf28bfa 100644
--- a/src/memos/mem_reader/simple_struct.py
+++ b/src/memos/mem_reader/simple_struct.py
@@ -176,6 +176,7 @@ def __init__(self, config: SimpleStructMemReaderConfig):
self.llm = LLMFactory.from_config(config.llm)
self.embedder = EmbedderFactory.from_config(config.embedder)
self.chunker = ChunkerFactory.from_config(config.chunker)
+ self.save_rawfile = self.chunker.config.save_rawfile
self.memory_max_length = 8000
# Use token-based windowing; default to ~5000 tokens if not configured
self.chat_window_max_tokens = getattr(self.config, "chat_window_max_tokens", 1024)
diff --git a/src/memos/mem_scheduler/base_mixins/__init__.py b/src/memos/mem_scheduler/base_mixins/__init__.py
new file mode 100644
index 000000000..7e01cffc0
--- /dev/null
+++ b/src/memos/mem_scheduler/base_mixins/__init__.py
@@ -0,0 +1,10 @@
+from .memory_ops import BaseSchedulerMemoryMixin
+from .queue_ops import BaseSchedulerQueueMixin
+from .web_log_ops import BaseSchedulerWebLogMixin
+
+
+__all__ = [
+ "BaseSchedulerMemoryMixin",
+ "BaseSchedulerQueueMixin",
+ "BaseSchedulerWebLogMixin",
+]
diff --git a/src/memos/mem_scheduler/base_mixins/memory_ops.py b/src/memos/mem_scheduler/base_mixins/memory_ops.py
new file mode 100644
index 000000000..35e095422
--- /dev/null
+++ b/src/memos/mem_scheduler/base_mixins/memory_ops.py
@@ -0,0 +1,227 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
+from memos.mem_scheduler.utils.filter_utils import transform_name_to_key
+from memos.memories.textual.naive import NaiveTextMemory
+from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
+
+
+if TYPE_CHECKING:
+ from memos.types.general_types import MemCubeID, UserID
+
+
+logger = get_logger(__name__)
+
+
+class BaseSchedulerMemoryMixin:
+ def transform_working_memories_to_monitors(
+ self, query_keywords, memories: list[TextualMemoryItem]
+ ) -> list[MemoryMonitorItem]:
+ result = []
+ mem_length = len(memories)
+ for idx, mem in enumerate(memories):
+ text_mem = mem.memory
+ mem_key = transform_name_to_key(name=text_mem)
+
+ keywords_score = 0
+ if query_keywords and text_mem:
+ for keyword, count in query_keywords.items():
+ keyword_count = text_mem.count(keyword)
+ if keyword_count > 0:
+ keywords_score += keyword_count * count
+ logger.debug(
+ "Matched keyword '%s' %s times, added %s to keywords_score",
+ keyword,
+ keyword_count,
+ keywords_score,
+ )
+
+ sorting_score = mem_length - idx
+
+ mem_monitor = MemoryMonitorItem(
+ memory_text=text_mem,
+ tree_memory_item=mem,
+ tree_memory_item_mapping_key=mem_key,
+ sorting_score=sorting_score,
+ keywords_score=keywords_score,
+ recording_count=1,
+ )
+ result.append(mem_monitor)
+
+ logger.info("Transformed %s memories to monitors", len(result))
+ return result
+
+ def replace_working_memory(
+ self,
+ user_id: UserID | str,
+ mem_cube_id: MemCubeID | str,
+ mem_cube,
+ original_memory: list[TextualMemoryItem],
+ new_memory: list[TextualMemoryItem],
+ ) -> None | list[TextualMemoryItem]:
+ text_mem_base = mem_cube.text_mem
+ if isinstance(text_mem_base, TreeTextMemory):
+ query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id]
+ query_db_manager.sync_with_orm()
+
+ query_history = query_db_manager.obj.get_queries_with_timesort()
+
+ original_count = len(original_memory)
+ filtered_original_memory = []
+ for origin_mem in original_memory:
+ if "mode:fast" not in origin_mem.metadata.tags:
+ filtered_original_memory.append(origin_mem)
+ else:
+ logger.debug(
+ "Filtered out memory - ID: %s, Tags: %s",
+ getattr(origin_mem, "id", "unknown"),
+ origin_mem.metadata.tags,
+ )
+ filtered_count = original_count - len(filtered_original_memory)
+ remaining_count = len(filtered_original_memory)
+
+ logger.info(
+ "Filtering complete. Removed %s memories with tag 'mode:fast'. Remaining memories: %s",
+ filtered_count,
+ remaining_count,
+ )
+ original_memory = filtered_original_memory
+
+ memories_with_new_order, rerank_success_flag = (
+ self.retriever.process_and_rerank_memories(
+ queries=query_history,
+ original_memory=original_memory,
+ new_memory=new_memory,
+ top_k=self.top_k,
+ )
+ )
+
+ logger.info("Filtering memories based on query history: %s queries", len(query_history))
+ filtered_memories, filter_success_flag = self.retriever.filter_unrelated_memories(
+ query_history=query_history,
+ memories=memories_with_new_order,
+ )
+
+ if filter_success_flag:
+ logger.info(
+ "Memory filtering completed successfully. Filtered from %s to %s memories",
+ len(memories_with_new_order),
+ len(filtered_memories),
+ )
+ memories_with_new_order = filtered_memories
+ else:
+ logger.warning(
+ "Memory filtering failed - keeping all memories as fallback. Original count: %s",
+ len(memories_with_new_order),
+ )
+
+ query_keywords = query_db_manager.obj.get_keywords_collections()
+ logger.info(
+ "Processing %s memories with %s query keywords",
+ len(memories_with_new_order),
+ len(query_keywords),
+ )
+ new_working_memory_monitors = self.transform_working_memories_to_monitors(
+ query_keywords=query_keywords,
+ memories=memories_with_new_order,
+ )
+
+ if not rerank_success_flag:
+ for one in new_working_memory_monitors:
+ one.sorting_score = 0
+
+ logger.info("update %s working_memory_monitors", len(new_working_memory_monitors))
+ self.monitor.update_working_memory_monitors(
+ new_working_memory_monitors=new_working_memory_monitors,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ )
+
+ mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][
+ mem_cube_id
+ ].obj.get_sorted_mem_monitors(reverse=True)
+ new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors]
+
+ text_mem_base.replace_working_memory(memories=new_working_memories)
+
+ logger.info(
+ "The working memory has been replaced with %s new memories.",
+ len(memories_with_new_order),
+ )
+ self.log_working_memory_replacement(
+ original_memory=original_memory,
+ new_memory=new_working_memories,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ log_func_callback=self._submit_web_logs,
+ )
+ elif isinstance(text_mem_base, NaiveTextMemory):
+ logger.info(
+ "NaiveTextMemory: Updating working memory monitors with %s candidates.",
+ len(new_memory),
+ )
+
+ query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id]
+ query_db_manager.sync_with_orm()
+ query_keywords = query_db_manager.obj.get_keywords_collections()
+
+ new_working_memory_monitors = self.transform_working_memories_to_monitors(
+ query_keywords=query_keywords,
+ memories=new_memory,
+ )
+
+ self.monitor.update_working_memory_monitors(
+ new_working_memory_monitors=new_working_memory_monitors,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ )
+ memories_with_new_order = new_memory
+ else:
+ logger.error("memory_base is not supported")
+ memories_with_new_order = new_memory
+
+ return memories_with_new_order
+
+ def update_activation_memory(
+ self,
+ new_memories: list[str | TextualMemoryItem],
+ label: str,
+ user_id: UserID | str,
+ mem_cube_id: MemCubeID | str,
+ mem_cube,
+ ) -> None:
+ if hasattr(self, "activation_memory_manager") and self.activation_memory_manager:
+ self.activation_memory_manager.update_activation_memory(
+ new_memories=new_memories,
+ label=label,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ )
+ else:
+ logger.warning("Activation memory manager not initialized")
+
+ def update_activation_memory_periodically(
+ self,
+ interval_seconds: int,
+ label: str,
+ user_id: UserID | str,
+ mem_cube_id: MemCubeID | str,
+ mem_cube,
+ ):
+ if hasattr(self, "activation_memory_manager") and self.activation_memory_manager:
+ self.activation_memory_manager.update_activation_memory_periodically(
+ interval_seconds=interval_seconds,
+ label=label,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ )
+ else:
+ logger.warning("Activation memory manager not initialized")
diff --git a/src/memos/mem_scheduler/base_mixins/queue_ops.py b/src/memos/mem_scheduler/base_mixins/queue_ops.py
new file mode 100644
index 000000000..590189c24
--- /dev/null
+++ b/src/memos/mem_scheduler/base_mixins/queue_ops.py
@@ -0,0 +1,425 @@
+from __future__ import annotations
+
+import multiprocessing
+import time
+
+from contextlib import suppress
+from datetime import datetime, timezone
+from typing import TYPE_CHECKING
+
+from memos.context.context import (
+ ContextThread,
+ RequestContext,
+ get_current_context,
+ get_current_trace_id,
+ set_request_context,
+)
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.general_schemas import STARTUP_BY_PROCESS
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.schemas.task_schemas import TaskPriorityLevel
+from memos.mem_scheduler.utils.db_utils import get_utc_now
+from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
+from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso
+
+
+logger = get_logger(__name__)
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+
+class BaseSchedulerQueueMixin:
+ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]):
+ if isinstance(messages, ScheduleMessageItem):
+ messages = [messages]
+
+ if not messages:
+ return
+
+ current_trace_id = get_current_trace_id()
+
+ immediate_msgs: list[ScheduleMessageItem] = []
+ queued_msgs: list[ScheduleMessageItem] = []
+
+ for msg in messages:
+ if current_trace_id:
+ msg.trace_id = current_trace_id
+
+ with suppress(Exception):
+ self.metrics.task_enqueued(user_id=msg.user_id, task_type=msg.label)
+
+ if getattr(msg, "timestamp", None) is None:
+ msg.timestamp = get_utc_now()
+
+ if self.status_tracker:
+ try:
+ self.status_tracker.task_submitted(
+ task_id=msg.item_id,
+ user_id=msg.user_id,
+ task_type=msg.label,
+ mem_cube_id=msg.mem_cube_id,
+ business_task_id=msg.task_id,
+ )
+ except Exception:
+ logger.warning("status_tracker.task_submitted failed", exc_info=True)
+
+ if self.disabled_handlers and msg.label in self.disabled_handlers:
+ logger.info("Skipping disabled handler: %s - %s", msg.label, msg.content)
+ continue
+
+ task_priority = self.orchestrator.get_task_priority(task_label=msg.label)
+ if task_priority == TaskPriorityLevel.LEVEL_1:
+ immediate_msgs.append(msg)
+ else:
+ queued_msgs.append(msg)
+
+ if immediate_msgs:
+ for m in immediate_msgs:
+ emit_monitor_event(
+ "enqueue",
+ m,
+ {
+ "enqueue_ts": to_iso(getattr(m, "timestamp", None)),
+ "event_duration_ms": 0,
+ "total_duration_ms": 0,
+ },
+ )
+
+ for m in immediate_msgs:
+ try:
+ now = time.time()
+ enqueue_ts_obj = getattr(m, "timestamp", None)
+ enqueue_epoch = None
+ if isinstance(enqueue_ts_obj, int | float):
+ enqueue_epoch = float(enqueue_ts_obj)
+ elif hasattr(enqueue_ts_obj, "timestamp"):
+ dt = enqueue_ts_obj
+ if dt.tzinfo is None:
+ dt = dt.replace(tzinfo=timezone.utc)
+ enqueue_epoch = dt.timestamp()
+
+ queue_wait_ms = None
+ if enqueue_epoch is not None:
+ queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000
+
+ object.__setattr__(m, "_dequeue_ts", now)
+ emit_monitor_event(
+ "dequeue",
+ m,
+ {
+ "enqueue_ts": to_iso(enqueue_ts_obj),
+ "dequeue_ts": datetime.fromtimestamp(now, tz=timezone.utc).isoformat(),
+ "queue_wait_ms": queue_wait_ms,
+ "event_duration_ms": queue_wait_ms,
+ "total_duration_ms": queue_wait_ms,
+ },
+ )
+ self.metrics.task_dequeued(user_id=m.user_id, task_type=m.label)
+ except Exception:
+ logger.debug("Failed to emit dequeue for immediate task", exc_info=True)
+
+ user_cube_groups = group_messages_by_user_and_mem_cube(immediate_msgs)
+ for user_id, cube_groups in user_cube_groups.items():
+ for mem_cube_id, user_cube_msgs in cube_groups.items():
+ label_groups: dict[str, list[ScheduleMessageItem]] = {}
+ for m in user_cube_msgs:
+ label_groups.setdefault(m.label, []).append(m)
+
+ for label, msgs_by_label in label_groups.items():
+ handler = self.dispatcher.handlers.get(
+ label, self.dispatcher._default_message_handler
+ )
+ self.dispatcher.execute_task(
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ task_label=label,
+ msgs=msgs_by_label,
+ handler_call_back=handler,
+ )
+
+ if queued_msgs:
+ self.memos_message_queue.submit_messages(messages=queued_msgs)
+
+ def _message_consumer(self) -> None:
+ while self._running:
+ try:
+ if self.enable_parallel_dispatch and self.dispatcher:
+ running_tasks = self.dispatcher.get_running_task_count()
+ if running_tasks >= self.dispatcher.max_workers:
+ time.sleep(self._consume_interval)
+ continue
+
+ messages = self.memos_message_queue.get_messages(batch_size=self.consume_batch)
+
+ if messages:
+ now = time.time()
+ for msg in messages:
+ prev_context = get_current_context()
+ try:
+ msg_context = RequestContext(
+ trace_id=msg.trace_id,
+ user_name=msg.user_name,
+ )
+ set_request_context(msg_context)
+
+ enqueue_ts_obj = getattr(msg, "timestamp", None)
+ enqueue_epoch = None
+ if isinstance(enqueue_ts_obj, int | float):
+ enqueue_epoch = float(enqueue_ts_obj)
+ elif hasattr(enqueue_ts_obj, "timestamp"):
+ dt = enqueue_ts_obj
+ if dt.tzinfo is None:
+ dt = dt.replace(tzinfo=timezone.utc)
+ enqueue_epoch = dt.timestamp()
+
+ queue_wait_ms = None
+ if enqueue_epoch is not None:
+ queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000
+
+ object.__setattr__(msg, "_dequeue_ts", now)
+ emit_monitor_event(
+ "dequeue",
+ msg,
+ {
+ "enqueue_ts": to_iso(enqueue_ts_obj),
+ "dequeue_ts": datetime.fromtimestamp(
+ now, tz=timezone.utc
+ ).isoformat(),
+ "queue_wait_ms": queue_wait_ms,
+ "event_duration_ms": queue_wait_ms,
+ "total_duration_ms": queue_wait_ms,
+ },
+ )
+ self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label)
+ finally:
+ set_request_context(prev_context)
+ try:
+ with suppress(Exception):
+ if messages:
+ self.dispatcher.on_messages_enqueued(messages)
+
+ self.dispatcher.dispatch(messages)
+ except Exception as e:
+ logger.error("Error dispatching messages: %s", e)
+
+ time.sleep(self._consume_interval)
+
+ except Exception as e:
+ if "No messages available in Redis queue" not in str(e):
+ logger.error("Unexpected error in message consumer: %s", e, exc_info=True)
+ time.sleep(self._consume_interval)
+
+ def _monitor_loop(self):
+ while self._running:
+ try:
+ q_sizes = self.memos_message_queue.qsize()
+
+ if not isinstance(q_sizes, dict):
+ continue
+
+ for stream_key, queue_length in q_sizes.items():
+ if stream_key == "total_size":
+ continue
+
+ parts = stream_key.split(":")
+ if len(parts) >= 3:
+ user_id = parts[-3]
+ self.metrics.update_queue_length(queue_length, user_id)
+ else:
+ if ":" not in stream_key:
+ self.metrics.update_queue_length(queue_length, stream_key)
+
+ except Exception as e:
+ logger.error("Error in metrics monitor loop: %s", e, exc_info=True)
+
+ time.sleep(15)
+
+ def start(self) -> None:
+ if self.enable_parallel_dispatch:
+ logger.info(
+ "Initializing dispatcher thread pool with %s workers",
+ self.thread_pool_max_workers,
+ )
+
+ self.start_consumer()
+ self.start_background_monitor()
+
+ def start_background_monitor(self):
+ if self._monitor_thread and self._monitor_thread.is_alive():
+ return
+ self._monitor_thread = ContextThread(
+ target=self._monitor_loop, daemon=True, name="SchedulerMetricsMonitor"
+ )
+ self._monitor_thread.start()
+ logger.info("Scheduler metrics monitor thread started.")
+
+ def start_consumer(self) -> None:
+ if self._running:
+ logger.warning("Memory Scheduler consumer is already running")
+ return
+
+ self._running = True
+
+ if self.scheduler_startup_mode == STARTUP_BY_PROCESS:
+ self._consumer_process = multiprocessing.Process(
+ target=self._message_consumer,
+ daemon=True,
+ name="MessageConsumerProcess",
+ )
+ self._consumer_process.start()
+ logger.info("Message consumer process started")
+ else:
+ self._consumer_thread = ContextThread(
+ target=self._message_consumer,
+ daemon=True,
+ name="MessageConsumerThread",
+ )
+ self._consumer_thread.start()
+ logger.info("Message consumer thread started")
+
+ def stop_consumer(self) -> None:
+ if not self._running:
+ logger.warning("Memory Scheduler consumer is not running")
+ return
+
+ self._running = False
+
+ if self.scheduler_startup_mode == STARTUP_BY_PROCESS and self._consumer_process:
+ if self._consumer_process.is_alive():
+ self._consumer_process.join(timeout=5.0)
+ if self._consumer_process.is_alive():
+ logger.warning("Consumer process did not stop gracefully, terminating...")
+ self._consumer_process.terminate()
+ self._consumer_process.join(timeout=2.0)
+ if self._consumer_process.is_alive():
+ logger.error("Consumer process could not be terminated")
+ else:
+ logger.info("Consumer process terminated")
+ else:
+ logger.info("Consumer process stopped")
+ self._consumer_process = None
+ elif self._consumer_thread and self._consumer_thread.is_alive():
+ self._consumer_thread.join(timeout=5.0)
+ if self._consumer_thread.is_alive():
+ logger.warning("Consumer thread did not stop gracefully")
+ else:
+ logger.info("Consumer thread stopped")
+ self._consumer_thread = None
+
+ logger.info("Memory Scheduler consumer stopped")
+
+ def stop(self) -> None:
+ if not self._running:
+ logger.warning("Memory Scheduler is not running")
+ return
+
+ self.stop_consumer()
+
+ if self._monitor_thread:
+ self._monitor_thread.join(timeout=2.0)
+
+ if self.dispatcher:
+ logger.info("Shutting down dispatcher...")
+ self.dispatcher.shutdown()
+
+ if self.dispatcher_monitor:
+ logger.info("Shutting down monitor...")
+ self.dispatcher_monitor.stop()
+
+ @property
+ def handlers(self) -> dict[str, Callable]:
+ if not self.dispatcher:
+ logger.warning("Dispatcher is not initialized, returning empty handlers dict")
+ return {}
+
+ return self.dispatcher.handlers
+
+ def register_handlers(
+ self,
+ handlers: dict[
+ str,
+ Callable[[list[ScheduleMessageItem]], None]
+ | tuple[
+ Callable[[list[ScheduleMessageItem]], None], TaskPriorityLevel | None, int | None
+ ],
+ ],
+ ) -> None:
+ if not self.dispatcher:
+ logger.warning("Dispatcher is not initialized, cannot register handlers")
+ return
+
+ self.dispatcher.register_handlers(handlers)
+
+ def unregister_handlers(self, labels: list[str]) -> dict[str, bool]:
+ if not self.dispatcher:
+ logger.warning("Dispatcher is not initialized, cannot unregister handlers")
+ return dict.fromkeys(labels, False)
+
+ return self.dispatcher.unregister_handlers(labels)
+
+ def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, dict]:
+ if not self.dispatcher:
+ logger.warning("Dispatcher is not initialized, returning empty tasks dict")
+ return {}
+
+ running_tasks = self.dispatcher.get_running_tasks(filter_func=filter_func)
+
+ result = {}
+ for task_id, task_item in running_tasks.items():
+ result[task_id] = {
+ "item_id": task_item.item_id,
+ "user_id": task_item.user_id,
+ "mem_cube_id": task_item.mem_cube_id,
+ "task_info": task_item.task_info,
+ "task_name": task_item.task_name,
+ "start_time": task_item.start_time,
+ "end_time": task_item.end_time,
+ "status": task_item.status,
+ "result": task_item.result,
+ "error_message": task_item.error_message,
+ "messages": task_item.messages,
+ }
+
+ return result
+
+ def get_tasks_status(self):
+ return self.task_schedule_monitor.get_tasks_status()
+
+ def print_tasks_status(self, tasks_status: dict | None = None) -> None:
+ self.task_schedule_monitor.print_tasks_status(tasks_status=tasks_status)
+
+ def _gather_queue_stats(self) -> dict:
+ memos_message_queue = self.memos_message_queue.memos_message_queue
+ stats: dict[str, int | float | str] = {}
+ stats["use_redis_queue"] = bool(self.use_redis_queue)
+ if not self.use_redis_queue:
+ try:
+ stats["qsize"] = int(memos_message_queue.qsize())
+ except Exception:
+ stats["qsize"] = -1
+ try:
+ stats["unfinished_tasks"] = int(
+ getattr(memos_message_queue, "unfinished_tasks", 0) or 0
+ )
+ except Exception:
+ stats["unfinished_tasks"] = -1
+ stats["maxsize"] = int(self.max_internal_message_queue_size)
+ try:
+ maxsize = int(self.max_internal_message_queue_size) or 1
+ qsize = int(stats.get("qsize", 0))
+ stats["utilization"] = min(1.0, max(0.0, qsize / maxsize))
+ except Exception:
+ stats["utilization"] = 0.0
+ try:
+ d_stats = self.dispatcher.stats()
+ stats.update(
+ {
+ "running": int(d_stats.get("running", 0)),
+ "inflight": int(d_stats.get("inflight", 0)),
+ "handlers": int(d_stats.get("handlers", 0)),
+ }
+ )
+ except Exception:
+ stats.update({"running": 0, "inflight": 0, "handlers": 0})
+ return stats
diff --git a/src/memos/mem_scheduler/base_mixins/web_log_ops.py b/src/memos/mem_scheduler/base_mixins/web_log_ops.py
new file mode 100644
index 000000000..64b5348d3
--- /dev/null
+++ b/src/memos/mem_scheduler/base_mixins/web_log_ops.py
@@ -0,0 +1,110 @@
+from __future__ import annotations
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem
+from memos.mem_scheduler.schemas.task_schemas import (
+ ADD_TASK_LABEL,
+ ANSWER_TASK_LABEL,
+ MEM_ARCHIVE_TASK_LABEL,
+ MEM_ORGANIZE_TASK_LABEL,
+ MEM_UPDATE_TASK_LABEL,
+ QUERY_TASK_LABEL,
+)
+
+
+logger = get_logger(__name__)
+
+
+class BaseSchedulerWebLogMixin:
+ def _submit_web_logs(
+ self,
+ messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem],
+ additional_log_info: str | None = None,
+ ) -> None:
+ if isinstance(messages, ScheduleLogForWebItem):
+ messages = [messages]
+
+ for message in messages:
+ if self.rabbitmq_config is None:
+ return
+ try:
+ logger.info(
+ "[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish %s",
+ message.model_dump_json(indent=2),
+ )
+ self.rabbitmq_publish_message(message=message.to_dict())
+ logger.info(
+ "[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched item_id=%s task_id=%s label=%s",
+ message.item_id,
+ message.task_id,
+ message.label,
+ )
+ except Exception as e:
+ logger.error(
+ "[DIAGNOSTIC] base_scheduler._submit_web_logs failed: %s",
+ e,
+ exc_info=True,
+ )
+
+ logger.debug(
+ "%s submitted. %s in queue. additional_log_info: %s",
+ len(messages),
+ self._web_log_message_queue.qsize(),
+ additional_log_info,
+ )
+
+ def get_web_log_messages(self) -> list[dict]:
+ raw_items: list[ScheduleLogForWebItem] = []
+ while True:
+ try:
+ raw_items.append(self._web_log_message_queue.get_nowait())
+ except Exception:
+ break
+
+ def _map_label(label: str) -> str:
+ mapping = {
+ QUERY_TASK_LABEL: "addMessage",
+ ANSWER_TASK_LABEL: "addMessage",
+ ADD_TASK_LABEL: "addMemory",
+ MEM_UPDATE_TASK_LABEL: "updateMemory",
+ MEM_ORGANIZE_TASK_LABEL: "mergeMemory",
+ MEM_ARCHIVE_TASK_LABEL: "archiveMemory",
+ }
+ return mapping.get(label, label)
+
+ def _normalize_item(item: ScheduleLogForWebItem) -> dict:
+ data = item.to_dict()
+ data["label"] = _map_label(data.get("label"))
+ memcube_content = getattr(item, "memcube_log_content", None) or []
+ metadata = getattr(item, "metadata", None) or []
+
+ memcube_name = getattr(item, "memcube_name", None)
+ if not memcube_name and hasattr(self, "_map_memcube_name"):
+ memcube_name = self._map_memcube_name(item.mem_cube_id)
+ data["memcube_name"] = memcube_name
+
+ memory_len = getattr(item, "memory_len", None)
+ if memory_len is None:
+ if data["label"] == "mergeMemory":
+ memory_len = len([c for c in memcube_content if c.get("type") != "postMerge"])
+ elif memcube_content:
+ memory_len = len(memcube_content)
+ else:
+ memory_len = 1 if item.log_content else 0
+
+ data["memcube_log_content"] = memcube_content
+ data["memory_len"] = memory_len
+
+ def _with_memory_time(meta: dict) -> dict:
+ enriched = dict(meta)
+ if "memory_time" not in enriched:
+ enriched["memory_time"] = enriched.get("updated_at") or enriched.get(
+ "update_at"
+ )
+ return enriched
+
+ data["metadata"] = [_with_memory_time(m) for m in metadata]
+ data["log_title"] = ""
+ return data
+
+ return [_normalize_item(it) for it in raw_items]
diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py
index 5ab524128..7c26336ed 100644
--- a/src/memos/mem_scheduler/base_scheduler.py
+++ b/src/memos/mem_scheduler/base_scheduler.py
@@ -1,33 +1,27 @@
-import multiprocessing
+from __future__ import annotations
+
import os
import threading
-import time
-from collections.abc import Callable
-from contextlib import suppress
-from datetime import datetime, timezone
from pathlib import Path
-from typing import TYPE_CHECKING, Union
-
-from sqlalchemy.engine import Engine
+from typing import TYPE_CHECKING
from memos.configs.mem_scheduler import AuthConfig, BaseSchedulerConfig
-from memos.context.context import (
- ContextThread,
- RequestContext,
- get_current_context,
- get_current_trace_id,
- set_request_context,
-)
-from memos.llms.base import BaseLLM
from memos.log import get_logger
-from memos.mem_cube.base import BaseMemCube
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_feedback.simple_feedback import SimpleMemFeedback
+from memos.mem_scheduler.base_mixins import (
+ BaseSchedulerMemoryMixin,
+ BaseSchedulerQueueMixin,
+ BaseSchedulerWebLogMixin,
+)
from memos.mem_scheduler.general_modules.init_components_for_scheduler import init_components
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule
+from memos.mem_scheduler.memory_manage_modules.activation_memory_manager import (
+ ActivationMemoryManager,
+)
+from memos.mem_scheduler.memory_manage_modules.post_processor import MemoryPostProcessor
from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever
+from memos.mem_scheduler.memory_manage_modules.search_service import SchedulerSearchService
from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor
from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
from memos.mem_scheduler.monitors.task_schedule_monitor import TaskScheduleMonitor
@@ -42,58 +36,44 @@
DEFAULT_THREAD_POOL_MAX_WORKERS,
DEFAULT_TOP_K,
DEFAULT_USE_REDIS_QUEUE,
- STARTUP_BY_PROCESS,
TreeTextMemory_SEARCH_METHOD,
)
-from memos.mem_scheduler.schemas.message_schemas import (
- ScheduleLogForWebItem,
- ScheduleMessageItem,
-)
-from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
-from memos.mem_scheduler.schemas.task_schemas import (
- ADD_TASK_LABEL,
- ANSWER_TASK_LABEL,
- MEM_ARCHIVE_TASK_LABEL,
- MEM_ORGANIZE_TASK_LABEL,
- MEM_UPDATE_TASK_LABEL,
- QUERY_TASK_LABEL,
- TaskPriorityLevel,
-)
from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher
from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator
from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue
from memos.mem_scheduler.utils import metrics
-from memos.mem_scheduler.utils.db_utils import get_utc_now
-from memos.mem_scheduler.utils.filter_utils import (
- transform_name_to_key,
-)
-from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
-from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule
from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule
-from memos.memories.activation.kv import KVCacheMemory
-from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory
-from memos.memories.textual.naive import NaiveTextMemory
-from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
-from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
-from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE
-from memos.types.general_types import (
- MemCubeID,
- UserID,
-)
if TYPE_CHECKING:
import redis
+ from sqlalchemy.engine import Engine
+
+ from memos.llms.base import BaseLLM
+ from memos.mem_cube.base import BaseMemCube
+ from memos.mem_feedback.simple_feedback import SimpleMemFeedback
+ from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem
+ from memos.memories.textual.item import TextualMemoryItem
+ from memos.memories.textual.tree import TreeTextMemory
+ from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
from memos.reranker.http_bge import HTTPBGEReranker
+ from memos.types.general_types import MemCubeID, UserID
logger = get_logger(__name__)
-class BaseScheduler(RabbitMQSchedulerModule, RedisSchedulerModule, SchedulerLoggerModule):
+class BaseScheduler(
+ RabbitMQSchedulerModule,
+ RedisSchedulerModule,
+ SchedulerLoggerModule,
+ BaseSchedulerWebLogMixin,
+ BaseSchedulerMemoryMixin,
+ BaseSchedulerQueueMixin,
+):
"""Base class for all mem_scheduler."""
def __init__(self, config: BaseSchedulerConfig):
@@ -144,6 +124,9 @@ def __init__(self, config: BaseSchedulerConfig):
self.orchestrator = SchedulerOrchestrator()
self.searcher: Searcher | None = None
+ self.search_service: SchedulerSearchService | None = None
+ self.post_processor: MemoryPostProcessor | None = None
+ self.activation_memory_manager: ActivationMemoryManager | None = None
self.retriever: SchedulerRetriever | None = None
self.db_engine: Engine | None = None
self.monitor: SchedulerGeneralMonitor | None = None
@@ -213,13 +196,16 @@ def init_mem_cube(
self.searcher = searcher
self.feedback_server = feedback_server
+ # Initialize search service with the searcher
+ self.search_service = SchedulerSearchService(searcher=self.searcher)
+
def initialize_modules(
self,
chat_llm: BaseLLM,
process_llm: BaseLLM | None = None,
db_engine: Engine | None = None,
mem_reader=None,
- redis_client: Union["redis.Redis", None] = None,
+ redis_client: redis.Redis | None = None,
):
if process_llm is None:
process_llm = chat_llm
@@ -243,6 +229,18 @@ def initialize_modules(
self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config)
self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config)
+ # Initialize post-processor for memory enhancement and filtering
+ self.post_processor = MemoryPostProcessor(
+ process_llm=self.process_llm, config=self.config
+ )
+
+ self.activation_memory_manager = ActivationMemoryManager(
+ act_mem_dump_path=self.act_mem_dump_path,
+ monitor=self.monitor,
+ log_func_callback=self._submit_web_logs,
+ log_activation_memory_update_func=self.log_activation_memory_update,
+ )
+
if mem_reader:
self.mem_reader = mem_reader
@@ -391,188 +389,7 @@ def mem_cubes(self, value: dict[str, BaseMemCube]) -> None:
f"Failed to initialize current_mem_cube from mem_cubes: {e}", exc_info=True
)
- def transform_working_memories_to_monitors(
- self, query_keywords, memories: list[TextualMemoryItem]
- ) -> list[MemoryMonitorItem]:
- """
- Convert a list of TextualMemoryItem objects into MemoryMonitorItem objects
- with importance scores based on keyword matching.
-
- Args:
- memories: List of TextualMemoryItem objects to be transformed.
-
- Returns:
- List of MemoryMonitorItem objects with computed importance scores.
- """
-
- result = []
- mem_length = len(memories)
- for idx, mem in enumerate(memories):
- text_mem = mem.memory
- mem_key = transform_name_to_key(name=text_mem)
-
- # Calculate importance score based on keyword matches
- keywords_score = 0
- if query_keywords and text_mem:
- for keyword, count in query_keywords.items():
- keyword_count = text_mem.count(keyword)
- if keyword_count > 0:
- keywords_score += keyword_count * count
- logger.debug(
- f"Matched keyword '{keyword}' {keyword_count} times, added {keywords_score} to keywords_score"
- )
-
- # rank score
- sorting_score = mem_length - idx
-
- mem_monitor = MemoryMonitorItem(
- memory_text=text_mem,
- tree_memory_item=mem,
- tree_memory_item_mapping_key=mem_key,
- sorting_score=sorting_score,
- keywords_score=keywords_score,
- recording_count=1,
- )
- result.append(mem_monitor)
-
- logger.info(f"Transformed {len(result)} memories to monitors")
- return result
-
- def replace_working_memory(
- self,
- user_id: UserID | str,
- mem_cube_id: MemCubeID | str,
- mem_cube: GeneralMemCube,
- original_memory: list[TextualMemoryItem],
- new_memory: list[TextualMemoryItem],
- ) -> None | list[TextualMemoryItem]:
- """Replace working memory with new memories after reranking."""
- text_mem_base = mem_cube.text_mem
- if isinstance(text_mem_base, TreeTextMemory):
- text_mem_base: TreeTextMemory = text_mem_base
-
- # process rerank memories with llm
- query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id]
- # Sync with database to get latest query history
- query_db_manager.sync_with_orm()
-
- query_history = query_db_manager.obj.get_queries_with_timesort()
-
- original_count = len(original_memory)
- # Filter out memories tagged with "mode:fast"
- filtered_original_memory = []
- for origin_mem in original_memory:
- if "mode:fast" not in origin_mem.metadata.tags:
- filtered_original_memory.append(origin_mem)
- else:
- logger.debug(
- f"Filtered out memory - ID: {getattr(origin_mem, 'id', 'unknown')}, Tags: {origin_mem.metadata.tags}"
- )
- # Calculate statistics
- filtered_count = original_count - len(filtered_original_memory)
- remaining_count = len(filtered_original_memory)
-
- logger.info(
- f"Filtering complete. Removed {filtered_count} memories with tag 'mode:fast'. Remaining memories: {remaining_count}"
- )
- original_memory = filtered_original_memory
-
- memories_with_new_order, rerank_success_flag = (
- self.retriever.process_and_rerank_memories(
- queries=query_history,
- original_memory=original_memory,
- new_memory=new_memory,
- top_k=self.top_k,
- )
- )
-
- # Filter completely unrelated memories according to query_history
- logger.info(f"Filtering memories based on query history: {len(query_history)} queries")
- filtered_memories, filter_success_flag = self.retriever.filter_unrelated_memories(
- query_history=query_history,
- memories=memories_with_new_order,
- )
-
- if filter_success_flag:
- logger.info(
- f"Memory filtering completed successfully. "
- f"Filtered from {len(memories_with_new_order)} to {len(filtered_memories)} memories"
- )
- memories_with_new_order = filtered_memories
- else:
- logger.warning(
- "Memory filtering failed - keeping all memories as fallback. "
- f"Original count: {len(memories_with_new_order)}"
- )
-
- # Update working memory monitors
- query_keywords = query_db_manager.obj.get_keywords_collections()
- logger.info(
- f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords"
- )
- new_working_memory_monitors = self.transform_working_memories_to_monitors(
- query_keywords=query_keywords,
- memories=memories_with_new_order,
- )
-
- if not rerank_success_flag:
- for one in new_working_memory_monitors:
- one.sorting_score = 0
-
- logger.info(f"update {len(new_working_memory_monitors)} working_memory_monitors")
- self.monitor.update_working_memory_monitors(
- new_working_memory_monitors=new_working_memory_monitors,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- )
-
- mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][
- mem_cube_id
- ].obj.get_sorted_mem_monitors(reverse=True)
- new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors]
-
- text_mem_base.replace_working_memory(memories=new_working_memories)
-
- logger.info(
- f"The working memory has been replaced with {len(memories_with_new_order)} new memories."
- )
- self.log_working_memory_replacement(
- original_memory=original_memory,
- new_memory=new_working_memories,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- log_func_callback=self._submit_web_logs,
- )
- elif isinstance(text_mem_base, NaiveTextMemory):
- # For NaiveTextMemory, we populate the monitors with the new candidates so activation memory can pick them up
- logger.info(
- f"NaiveTextMemory: Updating working memory monitors with {len(new_memory)} candidates."
- )
-
- # Use query keywords if available, otherwise just basic monitoring
- query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id]
- query_db_manager.sync_with_orm()
- query_keywords = query_db_manager.obj.get_keywords_collections()
-
- new_working_memory_monitors = self.transform_working_memories_to_monitors(
- query_keywords=query_keywords,
- memories=new_memory,
- )
-
- self.monitor.update_working_memory_monitors(
- new_working_memory_monitors=new_working_memory_monitors,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- )
- memories_with_new_order = new_memory
- else:
- logger.error("memory_base is not supported")
- memories_with_new_order = new_memory
-
- return memories_with_new_order
+ # Methods moved to mixins in mem_scheduler.base_mixins.
def update_activation_memory(
self,
@@ -580,80 +397,22 @@ def update_activation_memory(
label: str,
user_id: UserID | str,
mem_cube_id: MemCubeID | str,
- mem_cube: GeneralMemCube,
+ mem_cube: BaseMemCube,
) -> None:
"""
Update activation memory by extracting KVCacheItems from new_memory (list of str),
add them to a KVCacheMemory instance, and dump to disk.
"""
- if len(new_memories) == 0:
- logger.error("update_activation_memory: new_memory is empty.")
- return
- if isinstance(new_memories[0], TextualMemoryItem):
- new_text_memories = [mem.memory for mem in new_memories]
- elif isinstance(new_memories[0], str):
- new_text_memories = new_memories
- else:
- logger.error("Not Implemented.")
- return
-
- try:
- if isinstance(mem_cube.act_mem, VLLMKVCacheMemory):
- act_mem: VLLMKVCacheMemory = mem_cube.act_mem
- elif isinstance(mem_cube.act_mem, KVCacheMemory):
- act_mem: KVCacheMemory = mem_cube.act_mem
- else:
- logger.error("Not Implemented.")
- return
-
- new_text_memory = MEMORY_ASSEMBLY_TEMPLATE.format(
- memory_text="".join(
- [
- f"{i + 1}. {sentence.strip()}\n"
- for i, sentence in enumerate(new_text_memories)
- if sentence.strip() # Skip empty strings
- ]
- )
- )
-
- # huggingface or vllm kv cache
- original_cache_items: list[VLLMKVCacheItem] = act_mem.get_all()
- original_text_memories = []
- if len(original_cache_items) > 0:
- pre_cache_item: VLLMKVCacheItem = original_cache_items[-1]
- original_text_memories = pre_cache_item.records.text_memories
- original_composed_text_memory = pre_cache_item.records.composed_text_memory
- if original_composed_text_memory == new_text_memory:
- logger.warning(
- "Skipping memory update - new composition matches existing cache: %s",
- new_text_memory[:50] + "..."
- if len(new_text_memory) > 50
- else new_text_memory,
- )
- return
- act_mem.delete_all()
-
- cache_item = act_mem.extract(new_text_memory)
- cache_item.records.text_memories = new_text_memories
- cache_item.records.timestamp = get_utc_now()
-
- act_mem.add([cache_item])
- act_mem.dump(self.act_mem_dump_path)
-
- self.log_activation_memory_update(
- original_text_memories=original_text_memories,
- new_text_memories=new_text_memories,
+ if self.activation_memory_manager:
+ self.activation_memory_manager.update_activation_memory(
+ new_memories=new_memories,
label=label,
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
- log_func_callback=self._submit_web_logs,
)
-
- except Exception as e:
- logger.error(f"MOS-based activation memory update failed: {e}", exc_info=True)
- # Re-raise the exception if it's critical for the operation
- # For now, we'll continue execution but this should be reviewed
+ else:
+ logger.warning("Activation memory manager not initialized")
def update_activation_memory_periodically(
self,
@@ -661,659 +420,15 @@ def update_activation_memory_periodically(
label: str,
user_id: UserID | str,
mem_cube_id: MemCubeID | str,
- mem_cube: GeneralMemCube,
+ mem_cube: BaseMemCube,
):
- try:
- if (
- self.monitor.last_activation_mem_update_time == datetime.min
- or self.monitor.timed_trigger(
- last_time=self.monitor.last_activation_mem_update_time,
- interval_seconds=interval_seconds,
- )
- ):
- logger.info(
- f"Updating activation memory for user {user_id} and mem_cube {mem_cube_id}"
- )
-
- if (
- user_id not in self.monitor.working_memory_monitors
- or mem_cube_id not in self.monitor.working_memory_monitors[user_id]
- or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories)
- == 0
- ):
- logger.warning(
- "No memories found in working_memory_monitors, activation memory update is skipped"
- )
- return
-
- self.monitor.update_activation_memory_monitors(
- user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube
- )
-
- # Sync with database to get latest activation memories
- activation_db_manager = self.monitor.activation_memory_monitors[user_id][
- mem_cube_id
- ]
- activation_db_manager.sync_with_orm()
- new_activation_memories = [
- m.memory_text for m in activation_db_manager.obj.memories
- ]
-
- logger.info(
- f"Collected {len(new_activation_memories)} new memory entries for processing"
- )
- # Print the content of each new activation memory
- for i, memory in enumerate(new_activation_memories[:5], 1):
- logger.info(
- f"Part of New Activation Memorires | {i}/{len(new_activation_memories)}: {memory[:20]}"
- )
-
- self.update_activation_memory(
- new_memories=new_activation_memories,
- label=label,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- )
-
- self.monitor.last_activation_mem_update_time = get_utc_now()
-
- logger.debug(
- f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}"
- )
-
- else:
- logger.info(
- f"Skipping update - {interval_seconds} second interval not yet reached. "
- f"Last update time is {self.monitor.last_activation_mem_update_time} and now is "
- f"{get_utc_now()}"
- )
- except Exception as e:
- logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True)
-
- def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]):
- """Submit messages for processing, with priority-aware dispatch.
-
- - LEVEL_1 tasks dispatch immediately to the appropriate handler.
- - Lower-priority tasks are enqueued via the configured message queue.
- """
- if isinstance(messages, ScheduleMessageItem):
- messages = [messages]
-
- if not messages:
- return
-
- current_trace_id = get_current_trace_id()
-
- immediate_msgs: list[ScheduleMessageItem] = []
- queued_msgs: list[ScheduleMessageItem] = []
-
- for msg in messages:
- # propagate request trace_id when available so monitor logs align with request logs
- if current_trace_id:
- msg.trace_id = current_trace_id
-
- # basic metrics and status tracking
- with suppress(Exception):
- self.metrics.task_enqueued(user_id=msg.user_id, task_type=msg.label)
-
- # ensure timestamp exists for monitoring
- if getattr(msg, "timestamp", None) is None:
- msg.timestamp = get_utc_now()
-
- if self.status_tracker:
- try:
- self.status_tracker.task_submitted(
- task_id=msg.item_id,
- user_id=msg.user_id,
- task_type=msg.label,
- mem_cube_id=msg.mem_cube_id,
- business_task_id=msg.task_id,
- )
- except Exception:
- logger.warning("status_tracker.task_submitted failed", exc_info=True)
-
- # honor disabled handlers
- if self.disabled_handlers and msg.label in self.disabled_handlers:
- logger.info(f"Skipping disabled handler: {msg.label} - {msg.content}")
- continue
-
- # decide priority path
- task_priority = self.orchestrator.get_task_priority(task_label=msg.label)
- if task_priority == TaskPriorityLevel.LEVEL_1:
- immediate_msgs.append(msg)
- else:
- queued_msgs.append(msg)
-
- # Dispatch high-priority tasks immediately
- if immediate_msgs:
- # emit enqueue events for consistency
- for m in immediate_msgs:
- emit_monitor_event(
- "enqueue",
- m,
- {
- "enqueue_ts": to_iso(getattr(m, "timestamp", None)),
- "event_duration_ms": 0,
- "total_duration_ms": 0,
- },
- )
-
- # simulate dequeue for immediately dispatched messages so monitor logs stay complete
- for m in immediate_msgs:
- try:
- now = time.time()
- enqueue_ts_obj = getattr(m, "timestamp", None)
- enqueue_epoch = None
- if isinstance(enqueue_ts_obj, int | float):
- enqueue_epoch = float(enqueue_ts_obj)
- elif hasattr(enqueue_ts_obj, "timestamp"):
- dt = enqueue_ts_obj
- if dt.tzinfo is None:
- dt = dt.replace(tzinfo=timezone.utc)
- enqueue_epoch = dt.timestamp()
-
- queue_wait_ms = None
- if enqueue_epoch is not None:
- queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000
-
- object.__setattr__(m, "_dequeue_ts", now)
- emit_monitor_event(
- "dequeue",
- m,
- {
- "enqueue_ts": to_iso(enqueue_ts_obj),
- "dequeue_ts": datetime.fromtimestamp(now, tz=timezone.utc).isoformat(),
- "queue_wait_ms": queue_wait_ms,
- "event_duration_ms": queue_wait_ms,
- "total_duration_ms": queue_wait_ms,
- },
- )
- self.metrics.task_dequeued(user_id=m.user_id, task_type=m.label)
- except Exception:
- logger.debug("Failed to emit dequeue for immediate task", exc_info=True)
-
- user_cube_groups = group_messages_by_user_and_mem_cube(immediate_msgs)
- for user_id, cube_groups in user_cube_groups.items():
- for mem_cube_id, user_cube_msgs in cube_groups.items():
- label_groups: dict[str, list[ScheduleMessageItem]] = {}
- for m in user_cube_msgs:
- label_groups.setdefault(m.label, []).append(m)
-
- for label, msgs_by_label in label_groups.items():
- handler = self.dispatcher.handlers.get(
- label, self.dispatcher._default_message_handler
- )
- self.dispatcher.execute_task(
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- task_label=label,
- msgs=msgs_by_label,
- handler_call_back=handler,
- )
-
- # Enqueue lower-priority tasks
- if queued_msgs:
- self.memos_message_queue.submit_messages(messages=queued_msgs)
-
- def _submit_web_logs(
- self,
- messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem],
- additional_log_info: str | None = None,
- ) -> None:
- """Submit log messages to the web log queue and optionally to RabbitMQ.
-
- Args:
- messages: Single log message or list of log messages
- """
- if isinstance(messages, ScheduleLogForWebItem):
- messages = [messages] # transform single message to list
-
- for message in messages:
- if self.rabbitmq_config is None:
- return
- try:
- # Always call publish; the publisher now caches when offline and flushes after reconnect
- logger.info(
- f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message.model_dump_json(indent=2)}"
- )
- self.rabbitmq_publish_message(message=message.to_dict())
- logger.info(
- "[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched "
- "item_id=%s task_id=%s label=%s",
- message.item_id,
- message.task_id,
- message.label,
- )
- except Exception as e:
- logger.error(
- f"[DIAGNOSTIC] base_scheduler._submit_web_logs failed: {e}", exc_info=True
- )
-
- logger.debug(
- f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue. additional_log_info: {additional_log_info}"
- )
-
- def get_web_log_messages(self) -> list[dict]:
- """
- Retrieve structured log messages from the queue and return JSON-serializable dicts.
- """
- raw_items: list[ScheduleLogForWebItem] = []
- while True:
- try:
- raw_items.append(self._web_log_message_queue.get_nowait())
- except Exception:
- break
-
- def _map_label(label: str) -> str:
- mapping = {
- QUERY_TASK_LABEL: "addMessage",
- ANSWER_TASK_LABEL: "addMessage",
- ADD_TASK_LABEL: "addMemory",
- MEM_UPDATE_TASK_LABEL: "updateMemory",
- MEM_ORGANIZE_TASK_LABEL: "mergeMemory",
- MEM_ARCHIVE_TASK_LABEL: "archiveMemory",
- }
- return mapping.get(label, label)
-
- def _normalize_item(item: ScheduleLogForWebItem) -> dict:
- data = item.to_dict()
- data["label"] = _map_label(data.get("label"))
- memcube_content = getattr(item, "memcube_log_content", None) or []
- metadata = getattr(item, "metadata", None) or []
-
- memcube_name = getattr(item, "memcube_name", None)
- if not memcube_name and hasattr(self, "_map_memcube_name"):
- memcube_name = self._map_memcube_name(item.mem_cube_id)
- data["memcube_name"] = memcube_name
-
- memory_len = getattr(item, "memory_len", None)
- if memory_len is None:
- if data["label"] == "mergeMemory":
- memory_len = len([c for c in memcube_content if c.get("type") != "postMerge"])
- elif memcube_content:
- memory_len = len(memcube_content)
- else:
- memory_len = 1 if item.log_content else 0
-
- data["memcube_log_content"] = memcube_content
- data["memory_len"] = memory_len
-
- def _with_memory_time(meta: dict) -> dict:
- enriched = dict(meta)
- if "memory_time" not in enriched:
- enriched["memory_time"] = enriched.get("updated_at") or enriched.get(
- "update_at"
- )
- return enriched
-
- data["metadata"] = [_with_memory_time(m) for m in metadata]
- data["log_title"] = ""
- return data
-
- return [_normalize_item(it) for it in raw_items]
-
- def _message_consumer(self) -> None:
- """
- Continuously checks the queue for messages and dispatches them.
-
- Runs in a dedicated thread to process messages at regular intervals.
- For Redis queue, this method starts the Redis listener.
- """
-
- # Original local queue logic
- while self._running: # Use a running flag for graceful shutdown
- try:
- # Check dispatcher thread pool status to avoid overloading
- if self.enable_parallel_dispatch and self.dispatcher:
- running_tasks = self.dispatcher.get_running_task_count()
- if running_tasks >= self.dispatcher.max_workers:
- # Thread pool is full, wait and retry
- time.sleep(self._consume_interval)
- continue
-
- # Get messages in batches based on consume_batch setting
-
- messages = self.memos_message_queue.get_messages(batch_size=self.consume_batch)
-
- if messages:
- now = time.time()
- for msg in messages:
- prev_context = get_current_context()
- try:
- # Set context for this message
- msg_context = RequestContext(
- trace_id=msg.trace_id,
- user_name=msg.user_name,
- )
- set_request_context(msg_context)
-
- enqueue_ts_obj = getattr(msg, "timestamp", None)
- enqueue_epoch = None
- if isinstance(enqueue_ts_obj, int | float):
- enqueue_epoch = float(enqueue_ts_obj)
- elif hasattr(enqueue_ts_obj, "timestamp"):
- dt = enqueue_ts_obj
- if dt.tzinfo is None:
- dt = dt.replace(tzinfo=timezone.utc)
- enqueue_epoch = dt.timestamp()
-
- queue_wait_ms = None
- if enqueue_epoch is not None:
- queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000
-
- # Avoid pydantic field enforcement by using object.__setattr__
- object.__setattr__(msg, "_dequeue_ts", now)
- emit_monitor_event(
- "dequeue",
- msg,
- {
- "enqueue_ts": to_iso(enqueue_ts_obj),
- "dequeue_ts": datetime.fromtimestamp(
- now, tz=timezone.utc
- ).isoformat(),
- "queue_wait_ms": queue_wait_ms,
- "event_duration_ms": queue_wait_ms,
- "total_duration_ms": queue_wait_ms,
- },
- )
- self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label)
- finally:
- # Restore the prior context of the consumer thread
- set_request_context(prev_context)
- try:
- import contextlib
-
- with contextlib.suppress(Exception):
- if messages:
- self.dispatcher.on_messages_enqueued(messages)
-
- self.dispatcher.dispatch(messages)
- except Exception as e:
- logger.error(f"Error dispatching messages: {e!s}")
-
- # Sleep briefly to prevent busy waiting
- time.sleep(self._consume_interval) # Adjust interval as needed
-
- except Exception as e:
- # Don't log error for "No messages available in Redis queue" as it's expected
- if "No messages available in Redis queue" not in str(e):
- logger.error(f"Unexpected error in message consumer: {e!s}", exc_info=True)
- time.sleep(self._consume_interval) # Prevent tight error loops
-
- def _monitor_loop(self):
- while self._running:
- try:
- q_sizes = self.memos_message_queue.qsize()
-
- if not isinstance(q_sizes, dict):
- continue
-
- for stream_key, queue_length in q_sizes.items():
- # Skip aggregate keys like 'total_size'
- if stream_key == "total_size":
- continue
-
- # Key format: ...:{user_id}:{mem_cube_id}:{task_label}
- # We want to extract user_id, which is the 3rd component from the end.
- parts = stream_key.split(":")
- if len(parts) >= 3:
- user_id = parts[-3]
- self.metrics.update_queue_length(queue_length, user_id)
- else:
- # Fallback for unexpected key formats (e.g. legacy or testing)
- # Try to use the key itself if it looks like a user_id (no colons)
- # or just log a warning?
- # For now, let's assume if it's not total_size and short, it might be a direct user_id key
- # (though that shouldn't happen with current queue implementations)
- if ":" not in stream_key:
- self.metrics.update_queue_length(queue_length, stream_key)
-
- except Exception as e:
- logger.error(f"Error in metrics monitor loop: {e}", exc_info=True)
-
- time.sleep(15) # 每 15 秒采样一次
-
- def start(self) -> None:
- """
- Start the message consumer thread/process and initialize dispatcher resources.
-
- Initializes and starts:
- 1. Message consumer thread or process (based on startup_mode)
- 2. Dispatcher thread pool (if parallel dispatch enabled)
- """
- # Initialize dispatcher resources
- if self.enable_parallel_dispatch:
- logger.info(
- f"Initializing dispatcher thread pool with {self.thread_pool_max_workers} workers"
- )
-
- self.start_consumer()
- self.start_background_monitor()
-
- def start_background_monitor(self):
- if self._monitor_thread and self._monitor_thread.is_alive():
- return
- self._monitor_thread = ContextThread(
- target=self._monitor_loop, daemon=True, name="SchedulerMetricsMonitor"
- )
- self._monitor_thread.start()
- logger.info("Scheduler metrics monitor thread started.")
-
- def start_consumer(self) -> None:
- """
- Start only the message consumer thread/process.
-
- This method can be used to restart the consumer after it has been stopped
- with stop_consumer(), without affecting other scheduler components.
- """
- if self._running:
- logger.warning("Memory Scheduler consumer is already running")
- return
-
- # Start consumer based on startup mode
- self._running = True
-
- if self.scheduler_startup_mode == STARTUP_BY_PROCESS:
- # Start consumer process
- self._consumer_process = multiprocessing.Process(
- target=self._message_consumer,
- daemon=True,
- name="MessageConsumerProcess",
+ if self.activation_memory_manager:
+ self.activation_memory_manager.update_activation_memory_periodically(
+ interval_seconds=interval_seconds,
+ label=label,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
)
- self._consumer_process.start()
- logger.info("Message consumer process started")
else:
- # Default to thread mode
- self._consumer_thread = ContextThread(
- target=self._message_consumer,
- daemon=True,
- name="MessageConsumerThread",
- )
- self._consumer_thread.start()
- logger.info("Message consumer thread started")
-
- def stop_consumer(self) -> None:
- """Stop only the message consumer thread/process gracefully.
-
- This method stops the consumer without affecting other components like
- dispatcher or monitors. Useful when you want to pause message processing
- while keeping other scheduler components running.
- """
- if not self._running:
- logger.warning("Memory Scheduler consumer is not running")
- return
-
- # Signal consumer thread/process to stop
- self._running = False
-
- # Wait for consumer thread or process
- if self.scheduler_startup_mode == STARTUP_BY_PROCESS and self._consumer_process:
- if self._consumer_process.is_alive():
- self._consumer_process.join(timeout=5.0)
- if self._consumer_process.is_alive():
- logger.warning("Consumer process did not stop gracefully, terminating...")
- self._consumer_process.terminate()
- self._consumer_process.join(timeout=2.0)
- if self._consumer_process.is_alive():
- logger.error("Consumer process could not be terminated")
- else:
- logger.info("Consumer process terminated")
- else:
- logger.info("Consumer process stopped")
- self._consumer_process = None
- elif self._consumer_thread and self._consumer_thread.is_alive():
- self._consumer_thread.join(timeout=5.0)
- if self._consumer_thread.is_alive():
- logger.warning("Consumer thread did not stop gracefully")
- else:
- logger.info("Consumer thread stopped")
- self._consumer_thread = None
-
- logger.info("Memory Scheduler consumer stopped")
-
- def stop(self) -> None:
- """Stop all scheduler components gracefully.
-
- 1. Stops message consumer thread/process
- 2. Shuts down dispatcher thread pool
- 3. Cleans up resources
- """
- if not self._running:
- logger.warning("Memory Scheduler is not running")
- return
-
- # Stop consumer first
- self.stop_consumer()
-
- if self._monitor_thread:
- self._monitor_thread.join(timeout=2.0)
-
- # Shutdown dispatcher
- if self.dispatcher:
- logger.info("Shutting down dispatcher...")
- self.dispatcher.shutdown()
-
- # Shutdown dispatcher_monitor
- if self.dispatcher_monitor:
- logger.info("Shutting down monitor...")
- self.dispatcher_monitor.stop()
-
- @property
- def handlers(self) -> dict[str, Callable]:
- """
- Access the dispatcher's handlers dictionary.
-
- Returns:
- dict[str, Callable]: Dictionary mapping labels to handler functions
- """
- if not self.dispatcher:
- logger.warning("Dispatcher is not initialized, returning empty handlers dict")
- return {}
-
- return self.dispatcher.handlers
-
- def register_handlers(
- self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]]
- ) -> None:
- """
- Bulk register multiple handlers from a dictionary.
-
- Args:
- handlers: Dictionary mapping labels to handler functions
- Format: {label: handler_callable}
- """
- if not self.dispatcher:
- logger.warning("Dispatcher is not initialized, cannot register handlers")
- return
-
- self.dispatcher.register_handlers(handlers)
-
- def unregister_handlers(self, labels: list[str]) -> dict[str, bool]:
- """
- Unregister handlers from the dispatcher by their labels.
-
- Args:
- labels: List of labels to unregister handlers for
-
- Returns:
- dict[str, bool]: Dictionary mapping each label to whether it was successfully unregistered
- """
- if not self.dispatcher:
- logger.warning("Dispatcher is not initialized, cannot unregister handlers")
- return dict.fromkeys(labels, False)
-
- return self.dispatcher.unregister_handlers(labels)
-
- def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, dict]:
- if not self.dispatcher:
- logger.warning("Dispatcher is not initialized, returning empty tasks dict")
- return {}
-
- running_tasks = self.dispatcher.get_running_tasks(filter_func=filter_func)
-
- # Convert RunningTaskItem objects to dictionaries for easier consumption
- result = {}
- for task_id, task_item in running_tasks.items():
- result[task_id] = {
- "item_id": task_item.item_id,
- "user_id": task_item.user_id,
- "mem_cube_id": task_item.mem_cube_id,
- "task_info": task_item.task_info,
- "task_name": task_item.task_name,
- "start_time": task_item.start_time,
- "end_time": task_item.end_time,
- "status": task_item.status,
- "result": task_item.result,
- "error_message": task_item.error_message,
- "messages": task_item.messages,
- }
-
- return result
-
- def get_tasks_status(self):
- """Delegate status collection to TaskScheduleMonitor."""
- return self.task_schedule_monitor.get_tasks_status()
-
- def print_tasks_status(self, tasks_status: dict | None = None) -> None:
- """Delegate pretty printing to TaskScheduleMonitor."""
- self.task_schedule_monitor.print_tasks_status(tasks_status=tasks_status)
-
- def _gather_queue_stats(self) -> dict:
- """Collect queue/dispatcher stats for reporting."""
- memos_message_queue = self.memos_message_queue.memos_message_queue
- stats: dict[str, int | float | str] = {}
- stats["use_redis_queue"] = bool(self.use_redis_queue)
- # local queue metrics
- if not self.use_redis_queue:
- try:
- stats["qsize"] = int(memos_message_queue.qsize())
- except Exception:
- stats["qsize"] = -1
- # unfinished_tasks if available
- try:
- stats["unfinished_tasks"] = int(
- getattr(memos_message_queue, "unfinished_tasks", 0) or 0
- )
- except Exception:
- stats["unfinished_tasks"] = -1
- stats["maxsize"] = int(self.max_internal_message_queue_size)
- try:
- maxsize = int(self.max_internal_message_queue_size) or 1
- qsize = int(stats.get("qsize", 0))
- stats["utilization"] = min(1.0, max(0.0, qsize / maxsize))
- except Exception:
- stats["utilization"] = 0.0
- # dispatcher stats
- try:
- d_stats = self.dispatcher.stats()
- stats.update(
- {
- "running": int(d_stats.get("running", 0)),
- "inflight": int(d_stats.get("inflight", 0)),
- "handlers": int(d_stats.get("handlers", 0)),
- }
- )
- except Exception:
- stats.update({"running": 0, "inflight": 0, "handlers": 0})
- return stats
+ logger.warning("Activation memory manager not initialized")
diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
index 903088a4c..b103acf3a 100644
--- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
+++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
@@ -61,9 +61,11 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
"neo4j": APIConfig.get_neo4j_config(user_id=user_id),
"nebular": APIConfig.get_nebular_config(user_id=user_id),
"polardb": APIConfig.get_polardb_config(user_id=user_id),
+ "postgres": APIConfig.get_postgres_config(user_id=user_id),
}
- graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower()
+ # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars
+ graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower()
return GraphDBConfigFactory.model_validate(
{
"backend": graph_db_backend,
diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py
index 74e50a514..6d6f38d95 100644
--- a/src/memos/mem_scheduler/general_scheduler.py
+++ b/src/memos/mem_scheduler/general_scheduler.py
@@ -1,49 +1,16 @@
-import concurrent.futures
-import contextlib
-import json
-import traceback
+from __future__ import annotations
-from memos.configs.mem_scheduler import GeneralSchedulerConfig
-from memos.context.context import ContextThreadPoolExecutor
-from memos.log import get_logger
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_scheduler.base_scheduler import BaseScheduler
-from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
-from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem
-from memos.mem_scheduler.schemas.task_schemas import (
- ADD_TASK_LABEL,
- ANSWER_TASK_LABEL,
- DEFAULT_MAX_QUERY_KEY_WORDS,
- LONG_TERM_MEMORY_TYPE,
- MEM_FEEDBACK_TASK_LABEL,
- MEM_ORGANIZE_TASK_LABEL,
- MEM_READ_TASK_LABEL,
- MEM_UPDATE_TASK_LABEL,
- NOT_APPLICABLE_TYPE,
- PREF_ADD_TASK_LABEL,
- QUERY_TASK_LABEL,
- USER_INPUT_TYPE,
-)
-from memos.mem_scheduler.utils.filter_utils import (
- is_all_chinese,
- is_all_english,
- transform_name_to_key,
-)
-from memos.mem_scheduler.utils.misc_utils import (
- group_messages_by_user_and_mem_cube,
- is_cloud_env,
-)
-from memos.memories.textual.item import TextualMemoryItem
-from memos.memories.textual.naive import NaiveTextMemory
-from memos.memories.textual.preference import PreferenceTextMemory
-from memos.memories.textual.tree import TreeTextMemory
-from memos.types import (
- MemCubeID,
- UserID,
-)
+from typing import TYPE_CHECKING
-logger = get_logger(__name__)
+if TYPE_CHECKING:
+ from memos.configs.mem_scheduler import GeneralSchedulerConfig
+from memos.mem_scheduler.base_scheduler import BaseScheduler
+from memos.mem_scheduler.task_schedule_modules.handlers import (
+ SchedulerHandlerContext,
+ SchedulerHandlerRegistry,
+ SchedulerHandlerServices,
+)
class GeneralScheduler(BaseScheduler):
@@ -53,1447 +20,29 @@ def __init__(self, config: GeneralSchedulerConfig):
self.query_key_words_limit = self.config.get("query_key_words_limit", 20)
- # register handlers
- handlers = {
- QUERY_TASK_LABEL: self._query_message_consumer,
- ANSWER_TASK_LABEL: self._answer_message_consumer,
- MEM_UPDATE_TASK_LABEL: self._memory_update_consumer,
- ADD_TASK_LABEL: self._add_message_consumer,
- MEM_READ_TASK_LABEL: self._mem_read_message_consumer,
- MEM_ORGANIZE_TASK_LABEL: self._mem_reorganize_message_consumer,
- PREF_ADD_TASK_LABEL: self._pref_add_message_consumer,
- MEM_FEEDBACK_TASK_LABEL: self._mem_feedback_message_consumer,
- }
- self.dispatcher.register_handlers(handlers)
-
- def long_memory_update_process(
- self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem]
- ):
- mem_cube = self.mem_cube
-
- # update query monitors
- for msg in messages:
- self.monitor.register_query_monitor_if_not_exists(
- user_id=user_id, mem_cube_id=mem_cube_id
- )
-
- query = msg.content
- query_keywords = self.monitor.extract_query_keywords(query=query)
- logger.info(
- f'Extracted keywords "{query_keywords}" from query "{query}" for user_id={user_id}'
- )
-
- if len(query_keywords) == 0:
- stripped_query = query.strip()
- # Determine measurement method based on language
- if is_all_english(stripped_query):
- words = stripped_query.split() # Word count for English
- elif is_all_chinese(stripped_query):
- words = stripped_query # Character count for Chinese
- else:
- logger.debug(
- f"Mixed-language memory, using character count: {stripped_query[:50]}..."
- )
- words = stripped_query # Default to character count
-
- query_keywords = list(set(words[: self.query_key_words_limit]))
- logger.error(
- f"Keyword extraction failed for query '{query}' (user_id={user_id}). Using fallback keywords: {query_keywords[:10]}... (truncated)",
- exc_info=True,
- )
-
- item = QueryMonitorItem(
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- query_text=query,
- keywords=query_keywords,
- max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS,
- )
-
- query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id]
- query_db_manager.obj.put(item=item)
- # Sync with database after adding new item
- query_db_manager.sync_with_orm()
- logger.debug(
- f"Queries in monitor for user_id={user_id}, mem_cube_id={mem_cube_id}: {query_db_manager.obj.get_queries_with_timesort()}"
- )
-
- queries = [msg.content for msg in messages]
-
- # recall
- cur_working_memory, new_candidates = self.process_session_turn(
- queries=queries,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- top_k=self.top_k,
- )
- logger.info(
- # Build the candidate preview string outside the f-string to avoid backslashes in expression
- f"[long_memory_update_process] Processed {len(queries)} queries {queries} and retrieved {len(new_candidates)} "
- f"new candidate memories for user_id={user_id}: "
- + ("\n- " + "\n- ".join([f"{one.id}: {one.memory}" for one in new_candidates]))
- )
-
- # rerank
- new_order_working_memory = self.replace_working_memory(
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- original_memory=cur_working_memory,
- new_memory=new_candidates,
- )
- logger.debug(
- f"[long_memory_update_process] Final working memory size: {len(new_order_working_memory)} memories for user_id={user_id}"
- )
-
- old_memory_texts = "\n- " + "\n- ".join(
- [f"{one.id}: {one.memory}" for one in cur_working_memory]
- )
- new_memory_texts = "\n- " + "\n- ".join(
- [f"{one.id}: {one.memory}" for one in new_order_working_memory]
- )
-
- logger.info(
- f"[long_memory_update_process] For user_id='{user_id}', mem_cube_id='{mem_cube_id}': "
- f"Scheduler replaced working memory based on query history {queries}. "
- f"Old working memory ({len(cur_working_memory)} items): {old_memory_texts}. "
- f"New working memory ({len(new_order_working_memory)} items): {new_memory_texts}."
- )
-
- # update activation memories
- logger.debug(
- f"Activation memory update {'enabled' if self.enable_activation_memory else 'disabled'} "
- f"(interval: {self.monitor.act_mem_update_interval}s)"
- )
- if self.enable_activation_memory:
- self.update_activation_memory_periodically(
- interval_seconds=self.monitor.act_mem_update_interval,
- label=QUERY_TASK_LABEL,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=self.mem_cube,
- )
-
- def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- logger.info(f"Messages {messages} assigned to {ADD_TASK_LABEL} handler.")
- # Process the query in a session turn
- grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)
-
- self.validate_schedule_messages(messages=messages, label=ADD_TASK_LABEL)
- try:
- for user_id in grouped_messages:
- for mem_cube_id in grouped_messages[user_id]:
- batch = grouped_messages[user_id][mem_cube_id]
- if not batch:
- continue
-
- # Process each message in the batch
- for msg in batch:
- prepared_add_items, prepared_update_items_with_original = (
- self.log_add_messages(msg=msg)
- )
- logger.info(
- f"prepared_add_items: {prepared_add_items};\n prepared_update_items_with_original: {prepared_update_items_with_original}"
- )
- # Conditional Logging: Knowledge Base (Cloud Service) vs. Playground/Default
- cloud_env = is_cloud_env()
-
- if cloud_env:
- self.send_add_log_messages_to_cloud_env(
- msg, prepared_add_items, prepared_update_items_with_original
- )
- else:
- self.send_add_log_messages_to_local_env(
- msg, prepared_add_items, prepared_update_items_with_original
- )
-
- except Exception as e:
- logger.error(f"Error: {e}", exc_info=True)
-
- def _memory_update_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- logger.info(f"Messages {messages} assigned to {MEM_UPDATE_TASK_LABEL} handler.")
-
- grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)
-
- self.validate_schedule_messages(messages=messages, label=MEM_UPDATE_TASK_LABEL)
-
- for user_id in grouped_messages:
- for mem_cube_id in grouped_messages[user_id]:
- batch = grouped_messages[user_id][mem_cube_id]
- if not batch:
- continue
- # Process the whole batch once; no need to iterate per message
- self.long_memory_update_process(
- user_id=user_id, mem_cube_id=mem_cube_id, messages=batch
- )
-
- def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- """
- Process and handle query trigger messages from the queue.
-
- Args:
- messages: List of query messages to process
- """
- logger.info(f"Messages {messages} assigned to {QUERY_TASK_LABEL} handler.")
-
- grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)
-
- self.validate_schedule_messages(messages=messages, label=QUERY_TASK_LABEL)
-
- mem_update_messages = []
- for user_id in grouped_messages:
- for mem_cube_id in grouped_messages[user_id]:
- batch = grouped_messages[user_id][mem_cube_id]
- if not batch:
- continue
-
- for msg in batch:
- try:
- event = self.create_event_log(
- label="addMessage",
- from_memory_type=USER_INPUT_TYPE,
- to_memory_type=NOT_APPLICABLE_TYPE,
- user_id=msg.user_id,
- mem_cube_id=msg.mem_cube_id,
- mem_cube=self.mem_cube,
- memcube_log_content=[
- {
- "content": f"[User] {msg.content}",
- "ref_id": msg.item_id,
- "role": "user",
- }
- ],
- metadata=[],
- memory_len=1,
- memcube_name=self._map_memcube_name(msg.mem_cube_id),
- )
- event.task_id = msg.task_id
- self._submit_web_logs([event])
- except Exception:
- logger.exception("Failed to record addMessage log for query")
- # Re-submit the message with label changed to mem_update
- update_msg = ScheduleMessageItem(
- user_id=msg.user_id,
- mem_cube_id=msg.mem_cube_id,
- label=MEM_UPDATE_TASK_LABEL,
- content=msg.content,
- session_id=msg.session_id,
- user_name=msg.user_name,
- info=msg.info,
- task_id=msg.task_id,
- )
- mem_update_messages.append(update_msg)
-
- self.submit_messages(messages=mem_update_messages)
-
- def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- """
- Process and handle answer trigger messages from the queue.
-
- Args:
- messages: List of answer messages to process
- """
- logger.info(f"Messages {messages} assigned to {ANSWER_TASK_LABEL} handler.")
- grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)
-
- self.validate_schedule_messages(messages=messages, label=ANSWER_TASK_LABEL)
-
- for user_id in grouped_messages:
- for mem_cube_id in grouped_messages[user_id]:
- batch = grouped_messages[user_id][mem_cube_id]
- if not batch:
- continue
- try:
- for msg in batch:
- event = self.create_event_log(
- label="addMessage",
- from_memory_type=USER_INPUT_TYPE,
- to_memory_type=NOT_APPLICABLE_TYPE,
- user_id=msg.user_id,
- mem_cube_id=msg.mem_cube_id,
- mem_cube=self.mem_cube,
- memcube_log_content=[
- {
- "content": f"[Assistant] {msg.content}",
- "ref_id": msg.item_id,
- "role": "assistant",
- }
- ],
- metadata=[],
- memory_len=1,
- memcube_name=self._map_memcube_name(msg.mem_cube_id),
- )
- event.task_id = msg.task_id
- self._submit_web_logs([event])
- except Exception:
- logger.exception("Failed to record addMessage log for answer")
-
- def log_add_messages(self, msg: ScheduleMessageItem):
- try:
- userinput_memory_ids = json.loads(msg.content)
- except Exception as e:
- logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True)
- userinput_memory_ids = []
-
- # Prepare data for both logging paths, fetching original content for updates
- prepared_add_items = []
- prepared_update_items_with_original = []
- missing_ids: list[str] = []
-
- for memory_id in userinput_memory_ids:
- try:
- # This mem_item represents the NEW content that was just added/processed
- mem_item: TextualMemoryItem | None = None
- mem_item = self.mem_cube.text_mem.get(
- memory_id=memory_id, user_name=msg.mem_cube_id
- )
- if mem_item is None:
- raise ValueError(f"Memory {memory_id} not found after retries")
- # Check if a memory with the same key already exists (determining if it's an update)
- key = getattr(mem_item.metadata, "key", None) or transform_name_to_key(
- name=mem_item.memory
- )
- exists = False
- original_content = None
- original_item_id = None
-
- # Only check graph_store if a key exists and the text_mem has a graph_store
- if key and hasattr(self.mem_cube.text_mem, "graph_store"):
- candidates = self.mem_cube.text_mem.graph_store.get_by_metadata(
- [
- {"field": "key", "op": "=", "value": key},
- {
- "field": "memory_type",
- "op": "=",
- "value": mem_item.metadata.memory_type,
- },
- ]
- )
- if candidates:
- exists = True
- original_item_id = candidates[0]
- # Crucial step: Fetch the original content for updates
- # This `get` is for the *existing* memory that will be updated
- original_mem_item = self.mem_cube.text_mem.get(
- memory_id=original_item_id, user_name=msg.mem_cube_id
- )
- original_content = original_mem_item.memory
-
- if exists:
- prepared_update_items_with_original.append(
- {
- "new_item": mem_item,
- "original_content": original_content,
- "original_item_id": original_item_id,
- }
- )
- else:
- prepared_add_items.append(mem_item)
-
- except Exception:
- missing_ids.append(memory_id)
- logger.debug(
- f"This MemoryItem {memory_id} has already been deleted or an error occurred during preparation."
- )
-
- if missing_ids:
- content_preview = (
- msg.content[:200] + "..."
- if isinstance(msg.content, str) and len(msg.content) > 200
- else msg.content
- )
- logger.warning(
- "Missing TextualMemoryItem(s) during add log preparation. "
- "memory_ids=%s user_id=%s mem_cube_id=%s task_id=%s item_id=%s redis_msg_id=%s label=%s stream_key=%s content_preview=%s",
- missing_ids,
- msg.user_id,
- msg.mem_cube_id,
- msg.task_id,
- msg.item_id,
- getattr(msg, "redis_message_id", ""),
- msg.label,
- getattr(msg, "stream_key", ""),
- content_preview,
- )
-
- if not prepared_add_items and not prepared_update_items_with_original:
- logger.warning(
- "No add/update items prepared; skipping addMemory/knowledgeBaseUpdate logs. "
- "user_id=%s mem_cube_id=%s task_id=%s item_id=%s redis_msg_id=%s label=%s stream_key=%s missing_ids=%s",
- msg.user_id,
- msg.mem_cube_id,
- msg.task_id,
- msg.item_id,
- getattr(msg, "redis_message_id", ""),
- msg.label,
- getattr(msg, "stream_key", ""),
- missing_ids,
- )
- return prepared_add_items, prepared_update_items_with_original
-
- def send_add_log_messages_to_local_env(
- self, msg: ScheduleMessageItem, prepared_add_items, prepared_update_items_with_original
- ):
- # Existing: Playground/Default Logging
- # Reconstruct add_content/add_meta/update_content/update_meta from prepared_items
- # This ensures existing logging path continues to work with pre-existing data structures
- add_content_legacy: list[dict] = []
- add_meta_legacy: list[dict] = []
- update_content_legacy: list[dict] = []
- update_meta_legacy: list[dict] = []
-
- for item in prepared_add_items:
- key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory)
- add_content_legacy.append({"content": f"{key}: {item.memory}", "ref_id": item.id})
- add_meta_legacy.append(
- {
- "ref_id": item.id,
- "id": item.id,
- "key": item.metadata.key,
- "memory": item.memory,
- "memory_type": item.metadata.memory_type,
- "status": item.metadata.status,
- "confidence": item.metadata.confidence,
- "tags": item.metadata.tags,
- "updated_at": getattr(item.metadata, "updated_at", None)
- or getattr(item.metadata, "update_at", None),
- }
- )
-
- for item_data in prepared_update_items_with_original:
- item = item_data["new_item"]
- key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory)
- update_content_legacy.append({"content": f"{key}: {item.memory}", "ref_id": item.id})
- update_meta_legacy.append(
- {
- "ref_id": item.id,
- "id": item.id,
- "key": item.metadata.key,
- "memory": item.memory,
- "memory_type": item.metadata.memory_type,
- "status": item.metadata.status,
- "confidence": item.metadata.confidence,
- "tags": item.metadata.tags,
- "updated_at": getattr(item.metadata, "updated_at", None)
- or getattr(item.metadata, "update_at", None),
- }
- )
-
- events = []
- if add_content_legacy:
- event = self.create_event_log(
- label="addMemory",
- from_memory_type=USER_INPUT_TYPE,
- to_memory_type=LONG_TERM_MEMORY_TYPE,
- user_id=msg.user_id,
- mem_cube_id=msg.mem_cube_id,
- mem_cube=self.mem_cube,
- memcube_log_content=add_content_legacy,
- metadata=add_meta_legacy,
- memory_len=len(add_content_legacy),
- memcube_name=self._map_memcube_name(msg.mem_cube_id),
- )
- event.task_id = msg.task_id
- events.append(event)
- if update_content_legacy:
- event = self.create_event_log(
- label="updateMemory",
- from_memory_type=LONG_TERM_MEMORY_TYPE,
- to_memory_type=LONG_TERM_MEMORY_TYPE,
- user_id=msg.user_id,
- mem_cube_id=msg.mem_cube_id,
- mem_cube=self.mem_cube,
- memcube_log_content=update_content_legacy,
- metadata=update_meta_legacy,
- memory_len=len(update_content_legacy),
- memcube_name=self._map_memcube_name(msg.mem_cube_id),
- )
- event.task_id = msg.task_id
- events.append(event)
- logger.info(f"send_add_log_messages_to_local_env: {len(events)}")
- if events:
- self._submit_web_logs(events, additional_log_info="send_add_log_messages_to_cloud_env")
-
- def send_add_log_messages_to_cloud_env(
- self, msg: ScheduleMessageItem, prepared_add_items, prepared_update_items_with_original
- ):
- """
- Cloud logging path for add/update events.
- """
- kb_log_content: list[dict] = []
- info = msg.info or {}
-
- # Process added items
- for item in prepared_add_items:
- metadata = getattr(item, "metadata", None)
- file_ids = getattr(metadata, "file_ids", None) if metadata else None
- source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None
- kb_log_content.append(
- {
- "log_source": "KNOWLEDGE_BASE_LOG",
- "trigger_source": info.get("trigger_source", "Messages"),
- "operation": "ADD",
- "memory_id": item.id,
- "content": item.memory,
- "original_content": None,
- "source_doc_id": source_doc_id,
- }
- )
-
- # Process updated items
- for item_data in prepared_update_items_with_original:
- item = item_data["new_item"]
- metadata = getattr(item, "metadata", None)
- file_ids = getattr(metadata, "file_ids", None) if metadata else None
- source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None
- kb_log_content.append(
- {
- "log_source": "KNOWLEDGE_BASE_LOG",
- "trigger_source": info.get("trigger_source", "Messages"),
- "operation": "UPDATE",
- "memory_id": item.id,
- "content": item.memory,
- "original_content": item_data.get("original_content"),
- "source_doc_id": source_doc_id,
- }
- )
-
- if kb_log_content:
- logger.info(
- f"[DIAGNOSTIC] general_scheduler.send_add_log_messages_to_cloud_env: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: {msg.user_id}, mem_cube_id: {msg.mem_cube_id}, task_id: {msg.task_id}. KB content: {json.dumps(kb_log_content, indent=2)}"
- )
- event = self.create_event_log(
- label="knowledgeBaseUpdate",
- from_memory_type=USER_INPUT_TYPE,
- to_memory_type=LONG_TERM_MEMORY_TYPE,
- user_id=msg.user_id,
- mem_cube_id=msg.mem_cube_id,
- mem_cube=self.mem_cube,
- memcube_log_content=kb_log_content,
- metadata=None,
- memory_len=len(kb_log_content),
- memcube_name=self._map_memcube_name(msg.mem_cube_id),
- )
- event.log_content = f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
- event.task_id = msg.task_id
- self._submit_web_logs([event])
-
- def _mem_feedback_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- try:
- if not messages:
- return
- message = messages[0]
- mem_cube = self.mem_cube
-
- user_id = message.user_id
- mem_cube_id = message.mem_cube_id
- content = message.content
-
- try:
- feedback_data = json.loads(content) if isinstance(content, str) else content
- if not isinstance(feedback_data, dict):
- logger.error(
- f"Failed to decode feedback_data or it is not a dict: {feedback_data}"
- )
- return
- except json.JSONDecodeError:
- logger.error(f"Invalid JSON content for feedback message: {content}", exc_info=True)
- return
-
- task_id = feedback_data.get("task_id") or message.task_id
- feedback_result = self.feedback_server.process_feedback(
- user_id=user_id,
- user_name=mem_cube_id,
- session_id=feedback_data.get("session_id"),
- chat_history=feedback_data.get("history", []),
- retrieved_memory_ids=feedback_data.get("retrieved_memory_ids", []),
- feedback_content=feedback_data.get("feedback_content"),
- feedback_time=feedback_data.get("feedback_time"),
- task_id=task_id,
- info=feedback_data.get("info", None),
- )
-
- logger.info(
- f"Successfully processed feedback for user_id={user_id}, mem_cube_id={mem_cube_id}"
- )
-
- cloud_env = is_cloud_env()
- if cloud_env:
- record = feedback_result.get("record") if isinstance(feedback_result, dict) else {}
- add_records = record.get("add") if isinstance(record, dict) else []
- update_records = record.get("update") if isinstance(record, dict) else []
-
- def _extract_fields(mem_item):
- mem_id = (
- getattr(mem_item, "id", None)
- if not isinstance(mem_item, dict)
- else mem_item.get("id")
- )
- mem_memory = (
- getattr(mem_item, "memory", None)
- if not isinstance(mem_item, dict)
- else mem_item.get("memory") or mem_item.get("text")
- )
- if mem_memory is None and isinstance(mem_item, dict):
- mem_memory = mem_item.get("text")
- original_content = (
- getattr(mem_item, "origin_memory", None)
- if not isinstance(mem_item, dict)
- else mem_item.get("origin_memory")
- or mem_item.get("old_memory")
- or mem_item.get("original_content")
- )
- source_doc_id = None
- if isinstance(mem_item, dict):
- source_doc_id = mem_item.get("source_doc_id", None)
-
- return mem_id, mem_memory, original_content, source_doc_id
-
- kb_log_content: list[dict] = []
-
- for mem_item in add_records or []:
- mem_id, mem_memory, _, source_doc_id = _extract_fields(mem_item)
- if mem_id and mem_memory:
- kb_log_content.append(
- {
- "log_source": "KNOWLEDGE_BASE_LOG",
- "trigger_source": "Feedback",
- "operation": "ADD",
- "memory_id": mem_id,
- "content": mem_memory,
- "original_content": None,
- "source_doc_id": source_doc_id,
- }
- )
- else:
- logger.warning(
- "Skipping malformed feedback add item. user_id=%s mem_cube_id=%s task_id=%s item=%s",
- user_id,
- mem_cube_id,
- task_id,
- mem_item,
- stack_info=True,
- )
-
- for mem_item in update_records or []:
- mem_id, mem_memory, original_content, source_doc_id = _extract_fields(mem_item)
- if mem_id and mem_memory:
- kb_log_content.append(
- {
- "log_source": "KNOWLEDGE_BASE_LOG",
- "trigger_source": "Feedback",
- "operation": "UPDATE",
- "memory_id": mem_id,
- "content": mem_memory,
- "original_content": original_content,
- "source_doc_id": source_doc_id,
- }
- )
- else:
- logger.warning(
- "Skipping malformed feedback update item. user_id=%s mem_cube_id=%s task_id=%s item=%s",
- user_id,
- mem_cube_id,
- task_id,
- mem_item,
- stack_info=True,
- )
-
- logger.info(f"[Feedback Scheduler] kb_log_content: {kb_log_content!s}")
- if kb_log_content:
- logger.info(
- "[DIAGNOSTIC] general_scheduler._mem_feedback_message_consumer: Creating knowledgeBaseUpdate event for feedback. user_id=%s mem_cube_id=%s task_id=%s items=%s",
- user_id,
- mem_cube_id,
- task_id,
- len(kb_log_content),
- )
- event = self.create_event_log(
- label="knowledgeBaseUpdate",
- from_memory_type=USER_INPUT_TYPE,
- to_memory_type=LONG_TERM_MEMORY_TYPE,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- memcube_log_content=kb_log_content,
- metadata=None,
- memory_len=len(kb_log_content),
- memcube_name=self._map_memcube_name(mem_cube_id),
- )
- event.log_content = (
- f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
- )
- event.task_id = task_id
- self._submit_web_logs([event])
- else:
- logger.warning(
- "No valid feedback content generated for web log. user_id=%s mem_cube_id=%s task_id=%s",
- user_id,
- mem_cube_id,
- task_id,
- stack_info=True,
- )
- else:
- logger.info(
- "Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)",
- cloud_env,
- )
-
- except Exception as e:
- logger.error(f"Error processing feedbackMemory message: {e}", exc_info=True)
-
- def _mem_read_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- logger.info(
- f"[DIAGNOSTIC] general_scheduler._mem_read_message_consumer called. Received messages: {[msg.model_dump_json(indent=2) for msg in messages]}"
- )
- logger.info(f"Messages {messages} assigned to {MEM_READ_TASK_LABEL} handler.")
-
- def process_message(message: ScheduleMessageItem):
- try:
- user_id = message.user_id
- mem_cube_id = message.mem_cube_id
- mem_cube = self.mem_cube
- if mem_cube is None:
- logger.error(
- f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing",
- stack_info=True,
- )
- return
-
- content = message.content
- user_name = message.user_name
- info = message.info or {}
- chat_history = message.chat_history
-
- # Parse the memory IDs from content
- mem_ids = json.loads(content) if isinstance(content, str) else content
- if not mem_ids:
- return
-
- logger.info(
- f"Processing mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}"
- )
-
- # Get the text memory from the mem_cube
- text_mem = mem_cube.text_mem
- if not isinstance(text_mem, TreeTextMemory):
- logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}")
- return
-
- # Use mem_reader to process the memories
- self._process_memories_with_reader(
- mem_ids=mem_ids,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- text_mem=text_mem,
- user_name=user_name,
- custom_tags=info.get("custom_tags", None),
- task_id=message.task_id,
- info=info,
- chat_history=chat_history,
- )
-
- logger.info(
- f"Successfully processed mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}"
- )
-
- except Exception as e:
- logger.error(f"Error processing mem_read message: {e}", stack_info=True)
-
- with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor:
- futures = [executor.submit(process_message, msg) for msg in messages]
- for future in concurrent.futures.as_completed(futures):
- try:
- future.result()
- except Exception as e:
- logger.error(f"Thread task failed: {e}", stack_info=True)
-
- def _process_memories_with_reader(
- self,
- mem_ids: list[str],
- user_id: str,
- mem_cube_id: str,
- text_mem: TreeTextMemory,
- user_name: str,
- custom_tags: list[str] | None = None,
- task_id: str | None = None,
- info: dict | None = None,
- chat_history: list | None = None,
- ) -> None:
- logger.info(
- f"[DIAGNOSTIC] general_scheduler._process_memories_with_reader called. mem_ids: {mem_ids}, user_id: {user_id}, mem_cube_id: {mem_cube_id}, task_id: {task_id}"
- )
- """
- Process memories using mem_reader for enhanced memory processing.
-
- Args:
- mem_ids: List of memory IDs to process
- user_id: User ID
- mem_cube_id: Memory cube ID
- text_mem: Text memory instance
- custom_tags: Optional list of custom tags for memory processing
- """
- kb_log_content: list[dict] = []
- try:
- # Get the mem_reader from the parent MOSCore
- if not hasattr(self, "mem_reader") or self.mem_reader is None:
- logger.warning(
- "mem_reader not available in scheduler, skipping enhanced processing"
- )
- return
-
- # Get the original memory items
- memory_items = []
- for mem_id in mem_ids:
- try:
- memory_item = text_mem.get(mem_id, user_name=user_name)
- memory_items.append(memory_item)
- except Exception as e:
- logger.warning(
- f"[_process_memories_with_reader] Failed to get memory {mem_id}: {e}"
- )
- continue
-
- if not memory_items:
- logger.warning("No valid memory items found for processing")
- return
-
- # parse working_binding ids from the *original* memory_items (the raw items created in /add)
- # these still carry metadata.background with "[working_binding:...]" so we can know
- # which WorkingMemory clones should be cleaned up later.
- from memos.memories.textual.tree_text_memory.organize.manager import (
- extract_working_binding_ids,
- )
-
- bindings_to_delete = extract_working_binding_ids(memory_items)
- logger.info(
- f"Extracted {len(bindings_to_delete)} working_binding ids to cleanup: {list(bindings_to_delete)}"
- )
-
- # Use mem_reader to process the memories
- logger.info(f"Processing {len(memory_items)} memories with mem_reader")
-
- # Extract memories using mem_reader
- try:
- processed_memories = self.mem_reader.fine_transfer_simple_mem(
- memory_items,
- type="chat",
- custom_tags=custom_tags,
- user_name=user_name,
- chat_history=chat_history,
- )
- except Exception as e:
- logger.warning(f"{e}: Fail to transfer mem: {memory_items}")
- processed_memories = []
-
- if processed_memories and len(processed_memories) > 0:
- # Flatten the results (mem_reader returns list of lists)
- flattened_memories = []
- for memory_list in processed_memories:
- flattened_memories.extend(memory_list)
-
- logger.info(f"mem_reader processed {len(flattened_memories)} enhanced memories")
-
- # Add the enhanced memories back to the memory system
- if flattened_memories:
- enhanced_mem_ids = text_mem.add(flattened_memories, user_name=user_name)
- logger.info(
- f"Added {len(enhanced_mem_ids)} enhanced memories: {enhanced_mem_ids}"
- )
-
- # Mark merged_from memories as archived when provided in memory metadata
- if self.mem_reader.graph_db:
- for memory in flattened_memories:
- merged_from = (memory.metadata.info or {}).get("merged_from")
- if merged_from:
- old_ids = (
- merged_from
- if isinstance(merged_from, (list | tuple | set))
- else [merged_from]
- )
- for old_id in old_ids:
- try:
- self.mem_reader.graph_db.update_node(
- str(old_id), {"status": "archived"}, user_name=user_name
- )
- logger.info(
- f"[Scheduler] Archived merged_from memory: {old_id}"
- )
- except Exception as e:
- logger.warning(
- f"[Scheduler] Failed to archive merged_from memory {old_id}: {e}"
- )
- else:
- # Check if any memory has merged_from but graph_db is unavailable
- has_merged_from = any(
- (m.metadata.info or {}).get("merged_from") for m in flattened_memories
- )
- if has_merged_from:
- logger.warning(
- "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving."
- )
-
- # LOGGING BLOCK START
- # This block is replicated from _add_message_consumer to ensure consistent logging
- cloud_env = is_cloud_env()
- if cloud_env:
- # New: Knowledge Base Logging (Cloud Service)
- kb_log_content = []
- for item in flattened_memories:
- metadata = getattr(item, "metadata", None)
- file_ids = getattr(metadata, "file_ids", None) if metadata else None
- source_doc_id = (
- file_ids[0] if isinstance(file_ids, list) and file_ids else None
- )
- kb_log_content.append(
- {
- "log_source": "KNOWLEDGE_BASE_LOG",
- "trigger_source": info.get("trigger_source", "Messages")
- if info
- else "Messages",
- "operation": "ADD",
- "memory_id": item.id,
- "content": item.memory,
- "original_content": None,
- "source_doc_id": source_doc_id,
- }
- )
- if kb_log_content:
- logger.info(
- f"[DIAGNOSTIC] general_scheduler._process_memories_with_reader: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: {user_id}, mem_cube_id: {mem_cube_id}, task_id: {task_id}. KB content: {json.dumps(kb_log_content, indent=2)}"
- )
- event = self.create_event_log(
- label="knowledgeBaseUpdate",
- from_memory_type=USER_INPUT_TYPE,
- to_memory_type=LONG_TERM_MEMORY_TYPE,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=self.mem_cube,
- memcube_log_content=kb_log_content,
- metadata=None,
- memory_len=len(kb_log_content),
- memcube_name=self._map_memcube_name(mem_cube_id),
- )
- event.log_content = (
- f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
- )
- event.task_id = task_id
- self._submit_web_logs([event])
- else:
- # Existing: Playground/Default Logging
- add_content_legacy: list[dict] = []
- add_meta_legacy: list[dict] = []
- for item_id, item in zip(
- enhanced_mem_ids, flattened_memories, strict=False
- ):
- key = getattr(item.metadata, "key", None) or transform_name_to_key(
- name=item.memory
- )
- add_content_legacy.append(
- {"content": f"{key}: {item.memory}", "ref_id": item_id}
- )
- add_meta_legacy.append(
- {
- "ref_id": item_id,
- "id": item_id,
- "key": item.metadata.key,
- "memory": item.memory,
- "memory_type": item.metadata.memory_type,
- "status": item.metadata.status,
- "confidence": item.metadata.confidence,
- "tags": item.metadata.tags,
- "updated_at": getattr(item.metadata, "updated_at", None)
- or getattr(item.metadata, "update_at", None),
- }
- )
- if add_content_legacy:
- event = self.create_event_log(
- label="addMemory",
- from_memory_type=USER_INPUT_TYPE,
- to_memory_type=LONG_TERM_MEMORY_TYPE,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=self.mem_cube,
- memcube_log_content=add_content_legacy,
- metadata=add_meta_legacy,
- memory_len=len(add_content_legacy),
- memcube_name=self._map_memcube_name(mem_cube_id),
- )
- event.task_id = task_id
- self._submit_web_logs([event])
- # LOGGING BLOCK END
- else:
- logger.info("No enhanced memories generated by mem_reader")
- else:
- logger.info("mem_reader returned no processed memories")
-
- # build full delete list:
- # - original raw mem_ids (temporary fast memories)
- # - any bound working memories referenced by the enhanced memories
- delete_ids = list(mem_ids)
- if bindings_to_delete:
- delete_ids.extend(list(bindings_to_delete))
- # deduplicate
- delete_ids = list(dict.fromkeys(delete_ids))
- if delete_ids:
- try:
- text_mem.delete(delete_ids, user_name=user_name)
- logger.info(
- f"Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}"
- )
- except Exception as e:
- logger.warning(f"Failed to delete some mem_ids {delete_ids}: {e}")
- else:
- logger.info("No mem_ids to delete (nothing to cleanup)")
-
- text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name)
- logger.info("Remove and Refresh Memories")
- logger.debug(f"Finished add {user_id} memory: {mem_ids}")
-
- except Exception as exc:
- logger.error(
- f"Error in _process_memories_with_reader: {traceback.format_exc()}", exc_info=True
- )
- with contextlib.suppress(Exception):
- cloud_env = is_cloud_env()
- if cloud_env:
- if not kb_log_content:
- trigger_source = (
- info.get("trigger_source", "Messages") if info else "Messages"
- )
- kb_log_content = [
- {
- "log_source": "KNOWLEDGE_BASE_LOG",
- "trigger_source": trigger_source,
- "operation": "ADD",
- "memory_id": mem_id,
- "content": None,
- "original_content": None,
- "source_doc_id": None,
- }
- for mem_id in mem_ids
- ]
- event = self.create_event_log(
- label="knowledgeBaseUpdate",
- from_memory_type=USER_INPUT_TYPE,
- to_memory_type=LONG_TERM_MEMORY_TYPE,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=self.mem_cube,
- memcube_log_content=kb_log_content,
- metadata=None,
- memory_len=len(kb_log_content),
- memcube_name=self._map_memcube_name(mem_cube_id),
- )
- event.log_content = f"Knowledge Base Memory Update failed: {exc!s}"
- event.task_id = task_id
- event.status = "failed"
- self._submit_web_logs([event])
-
- def _mem_reorganize_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- logger.info(f"Messages {messages} assigned to {MEM_ORGANIZE_TASK_LABEL} handler.")
-
- def process_message(message: ScheduleMessageItem):
- try:
- user_id = message.user_id
- mem_cube_id = message.mem_cube_id
- mem_cube = self.mem_cube
- if mem_cube is None:
- logger.warning(
- f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing"
- )
- return
- content = message.content
- user_name = message.user_name
-
- # Parse the memory IDs from content
- mem_ids = json.loads(content) if isinstance(content, str) else content
- if not mem_ids:
- return
-
- logger.info(
- f"Processing mem_reorganize for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}"
- )
-
- # Get the text memory from the mem_cube
- text_mem = mem_cube.text_mem
- if not isinstance(text_mem, TreeTextMemory):
- logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}")
- return
-
- # Use mem_reader to process the memories
- self._process_memories_with_reorganize(
- mem_ids=mem_ids,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- text_mem=text_mem,
- user_name=user_name,
- )
-
- with contextlib.suppress(Exception):
- mem_items: list[TextualMemoryItem] = []
- for mid in mem_ids:
- with contextlib.suppress(Exception):
- mem_items.append(text_mem.get(mid, user_name=user_name))
- if len(mem_items) > 1:
- keys: list[str] = []
- memcube_content: list[dict] = []
- meta: list[dict] = []
- merged_target_ids: set[str] = set()
- with contextlib.suppress(Exception):
- if hasattr(text_mem, "graph_store"):
- for mid in mem_ids:
- edges = text_mem.graph_store.get_edges(
- mid, type="MERGED_TO", direction="OUT"
- )
- for edge in edges:
- target = (
- edge.get("to") or edge.get("dst") or edge.get("target")
- )
- if target:
- merged_target_ids.add(target)
- for item in mem_items:
- key = getattr(
- getattr(item, "metadata", {}), "key", None
- ) or transform_name_to_key(getattr(item, "memory", ""))
- keys.append(key)
- memcube_content.append(
- {"content": key or "(no key)", "ref_id": item.id, "type": "merged"}
- )
- meta.append(
- {
- "ref_id": item.id,
- "id": item.id,
- "key": key,
- "memory": item.memory,
- "memory_type": item.metadata.memory_type,
- "status": item.metadata.status,
- "confidence": item.metadata.confidence,
- "tags": item.metadata.tags,
- "updated_at": getattr(item.metadata, "updated_at", None)
- or getattr(item.metadata, "update_at", None),
- }
- )
- combined_key = keys[0] if keys else ""
- post_ref_id = None
- post_meta = {
- "ref_id": None,
- "id": None,
- "key": None,
- "memory": None,
- "memory_type": None,
- "status": None,
- "confidence": None,
- "tags": None,
- "updated_at": None,
- }
- if merged_target_ids:
- post_ref_id = next(iter(merged_target_ids))
- with contextlib.suppress(Exception):
- merged_item = text_mem.get(post_ref_id, user_name=user_name)
- combined_key = (
- getattr(getattr(merged_item, "metadata", {}), "key", None)
- or combined_key
- )
- post_meta = {
- "ref_id": post_ref_id,
- "id": post_ref_id,
- "key": getattr(
- getattr(merged_item, "metadata", {}), "key", None
- ),
- "memory": getattr(merged_item, "memory", None),
- "memory_type": getattr(
- getattr(merged_item, "metadata", {}), "memory_type", None
- ),
- "status": getattr(
- getattr(merged_item, "metadata", {}), "status", None
- ),
- "confidence": getattr(
- getattr(merged_item, "metadata", {}), "confidence", None
- ),
- "tags": getattr(
- getattr(merged_item, "metadata", {}), "tags", None
- ),
- "updated_at": getattr(
- getattr(merged_item, "metadata", {}), "updated_at", None
- )
- or getattr(
- getattr(merged_item, "metadata", {}), "update_at", None
- ),
- }
- if not post_ref_id:
- import hashlib
-
- post_ref_id = f"merge-{hashlib.md5(''.join(sorted(mem_ids)).encode()).hexdigest()}"
- post_meta["ref_id"] = post_ref_id
- post_meta["id"] = post_ref_id
- if not post_meta.get("key"):
- post_meta["key"] = combined_key
- if not keys:
- keys = [item.id for item in mem_items]
- memcube_content.append(
- {
- "content": combined_key if combined_key else "(no key)",
- "ref_id": post_ref_id,
- "type": "postMerge",
- }
- )
- meta.append(post_meta)
- event = self.create_event_log(
- label="mergeMemory",
- from_memory_type=LONG_TERM_MEMORY_TYPE,
- to_memory_type=LONG_TERM_MEMORY_TYPE,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- memcube_log_content=memcube_content,
- metadata=meta,
- memory_len=len(keys),
- memcube_name=self._map_memcube_name(mem_cube_id),
- )
- self._submit_web_logs([event])
-
- logger.info(
- f"Successfully processed mem_reorganize for user_id={user_id}, mem_cube_id={mem_cube_id}"
- )
-
- except Exception as e:
- logger.error(f"Error processing mem_reorganize message: {e}", exc_info=True)
-
- with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor:
- futures = [executor.submit(process_message, msg) for msg in messages]
- for future in concurrent.futures.as_completed(futures):
- try:
- future.result()
- except Exception as e:
- logger.error(f"Thread task failed: {e}", exc_info=True)
-
- def _process_memories_with_reorganize(
- self,
- mem_ids: list[str],
- user_id: str,
- mem_cube_id: str,
- mem_cube: GeneralMemCube,
- text_mem: TreeTextMemory,
- user_name: str,
- ) -> None:
- """
- Process memories using mem_reorganize for enhanced memory processing.
-
- Args:
- mem_ids: List of memory IDs to process
- user_id: User ID
- mem_cube_id: Memory cube ID
- mem_cube: Memory cube instance
- text_mem: Text memory instance
- """
- try:
- # Get the mem_reader from the parent MOSCore
- if not hasattr(self, "mem_reader") or self.mem_reader is None:
- logger.warning(
- "mem_reader not available in scheduler, skipping enhanced processing"
- )
- return
-
- # Get the original memory items
- memory_items = []
- for mem_id in mem_ids:
- try:
- memory_item = text_mem.get(mem_id, user_name=user_name)
- memory_items.append(memory_item)
- except Exception as e:
- logger.warning(f"Failed to get memory {mem_id}: {e}|{traceback.format_exc()}")
- continue
-
- if not memory_items:
- logger.warning("No valid memory items found for processing")
- return
-
- # Use mem_reader to process the memories
- logger.info(f"Processing {len(memory_items)} memories with mem_reader")
- text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name)
- logger.info("Remove and Refresh Memories")
- logger.debug(f"Finished add {user_id} memory: {mem_ids}")
-
- except Exception:
- logger.error(
- f"Error in _process_memories_with_reorganize: {traceback.format_exc()}",
- exc_info=True,
- )
-
- def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- logger.info(f"Messages {messages} assigned to {PREF_ADD_TASK_LABEL} handler.")
-
- def process_message(message: ScheduleMessageItem):
- try:
- mem_cube = self.mem_cube
- if mem_cube is None:
- logger.warning(
- f"mem_cube is None for user_id={message.user_id}, mem_cube_id={message.mem_cube_id}, skipping processing"
- )
- return
-
- user_id = message.user_id
- session_id = message.session_id
- mem_cube_id = message.mem_cube_id
- content = message.content
- messages_list = json.loads(content)
- info = message.info or {}
-
- logger.info(f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}")
-
- # Get the preference memory from the mem_cube
- pref_mem = mem_cube.pref_mem
- if pref_mem is None:
- logger.warning(
- f"Preference memory not initialized for mem_cube_id={mem_cube_id}, "
- f"skipping pref_add processing"
- )
- return
- if not isinstance(pref_mem, PreferenceTextMemory):
- logger.error(
- f"Expected PreferenceTextMemory but got {type(pref_mem).__name__} "
- f"for mem_cube_id={mem_cube_id}"
- )
- return
-
- # Use pref_mem.get_memory to process the memories
- pref_memories = pref_mem.get_memory(
- messages_list,
- type="chat",
- info={
- **info,
- "user_id": user_id,
- "session_id": session_id,
- "mem_cube_id": mem_cube_id,
- },
- )
- # Add pref_mem to vector db
- pref_ids = pref_mem.add(pref_memories)
-
- logger.info(
- f"Successfully processed and add preferences for user_id={user_id}, mem_cube_id={mem_cube_id}, pref_ids={pref_ids}"
- )
-
- except Exception as e:
- logger.error(f"Error processing pref_add message: {e}", exc_info=True)
-
- with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor:
- futures = [executor.submit(process_message, msg) for msg in messages]
- for future in concurrent.futures.as_completed(futures):
- try:
- future.result()
- except Exception as e:
- logger.error(f"Thread task failed: {e}", exc_info=True)
-
- def process_session_turn(
- self,
- queries: str | list[str],
- user_id: UserID | str,
- mem_cube_id: MemCubeID | str,
- mem_cube: GeneralMemCube,
- top_k: int = 10,
- ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]] | None:
- """
- Process a dialog turn:
- - If q_list reaches window size, trigger retrieval;
- - Immediately switch to the new memory if retrieval is triggered.
- """
-
- text_mem_base = mem_cube.text_mem
- if not isinstance(text_mem_base, TreeTextMemory):
- if isinstance(text_mem_base, NaiveTextMemory):
- logger.debug(
- f"NaiveTextMemory used for mem_cube_id={mem_cube_id}, processing session turn with simple search."
- )
- # Treat NaiveTextMemory similar to TreeTextMemory but with simpler logic
- # We will perform retrieval to get "working memory" candidates for activation memory
- # But we won't have a distinct "current working memory"
- cur_working_memory = []
- else:
- logger.warning(
- f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} "
- f"for mem_cube_id={mem_cube_id}, user_id={user_id}. "
- f"text_mem_base value: {text_mem_base}"
- )
- return [], []
- else:
- cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory(
- user_name=mem_cube_id
- )
- cur_working_memory = cur_working_memory[:top_k]
-
- logger.info(
- f"[process_session_turn] Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}"
- )
-
- text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
- intent_result = self.monitor.detect_intent(
- q_list=queries, text_working_memory=text_working_memory
- )
-
- time_trigger_flag = False
- if self.monitor.timed_trigger(
- last_time=self.monitor.last_query_consume_time,
- interval_seconds=self.monitor.query_trigger_interval,
- ):
- time_trigger_flag = True
-
- if (not intent_result["trigger_retrieval"]) and (not time_trigger_flag):
- logger.info(
- f"[process_session_turn] Query schedule not triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. Intent_result: {intent_result}"
- )
- return
- elif (not intent_result["trigger_retrieval"]) and time_trigger_flag:
- logger.info(
- f"[process_session_turn] Query schedule forced to trigger due to time ticker for user_id={user_id}, mem_cube_id={mem_cube_id}"
- )
- intent_result["trigger_retrieval"] = True
- intent_result["missing_evidences"] = queries
- else:
- logger.info(
- f"[process_session_turn] Query schedule triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. "
- f"Missing evidences: {intent_result['missing_evidences']}"
- )
-
- missing_evidences = intent_result["missing_evidences"]
- num_evidence = len(missing_evidences)
- k_per_evidence = max(1, top_k // max(1, num_evidence))
- new_candidates = []
- for item in missing_evidences:
- logger.info(
- f"[process_session_turn] Searching for missing evidence: '{item}' with top_k={k_per_evidence} for user_id={user_id}"
- )
-
- search_args = {}
- if isinstance(text_mem_base, NaiveTextMemory):
- # NaiveTextMemory doesn't support complex search args usually, but let's see
- # self.retriever.search calls mem_cube.text_mem.search
- # NaiveTextMemory.search takes query and top_k
- # SchedulerRetriever.search handles method dispatch
- # For NaiveTextMemory, we might need to bypass retriever or extend it
- # But let's try calling naive memory directly if retriever fails or doesn't support it
- try:
- results = text_mem_base.search(query=item, top_k=k_per_evidence)
- except Exception as e:
- logger.warning(f"NaiveTextMemory search failed: {e}")
- results = []
- else:
- results: list[TextualMemoryItem] = self.retriever.search(
- query=item,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- top_k=k_per_evidence,
- method=self.search_method,
- search_args=search_args,
- )
-
- logger.info(
- f"[process_session_turn] Search results for missing evidence '{item}': "
- + ("\n- " + "\n- ".join([f"{one.id}: {one.memory}" for one in results]))
- )
- new_candidates.extend(results)
- return cur_working_memory, new_candidates
+ services = SchedulerHandlerServices(
+ validate_messages=self.validate_schedule_messages,
+ submit_messages=self.submit_messages,
+ create_event_log=self.create_event_log,
+ submit_web_logs=self._submit_web_logs,
+ map_memcube_name=self._map_memcube_name,
+ update_activation_memory_periodically=self.update_activation_memory_periodically,
+ replace_working_memory=self.replace_working_memory,
+ transform_working_memories_to_monitors=self.transform_working_memories_to_monitors,
+ log_working_memory_replacement=self.log_working_memory_replacement,
+ )
+ scheduler_context = SchedulerHandlerContext(
+ get_mem_cube=lambda: self.mem_cube,
+ get_monitor=lambda: self.monitor,
+ get_retriever=lambda: self.retriever,
+ get_mem_reader=lambda: self.mem_reader,
+ get_feedback_server=lambda: self.feedback_server,
+ get_search_method=lambda: self.search_method,
+ get_top_k=lambda: self.top_k,
+ get_enable_activation_memory=lambda: self.enable_activation_memory,
+ get_query_key_words_limit=lambda: self.query_key_words_limit,
+ services=services,
+ )
+
+ self._handler_registry = SchedulerHandlerRegistry(scheduler_context)
+ self.register_handlers(self._handler_registry.build_dispatch_map())
diff --git a/src/memos/mem_scheduler/memory_manage_modules/activation_memory_manager.py b/src/memos/mem_scheduler/memory_manage_modules/activation_memory_manager.py
new file mode 100644
index 000000000..589d0e421
--- /dev/null
+++ b/src/memos/mem_scheduler/memory_manage_modules/activation_memory_manager.py
@@ -0,0 +1,186 @@
+from collections.abc import Callable
+from datetime import datetime
+
+from memos.log import get_logger
+from memos.mem_cube.general import GeneralMemCube
+from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
+from memos.mem_scheduler.utils.db_utils import get_utc_now
+from memos.memories.activation.kv import KVCacheMemory
+from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory
+from memos.memories.textual.tree import TextualMemoryItem
+from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE
+from memos.types.general_types import MemCubeID, UserID
+
+
+logger = get_logger(__name__)
+
+
+class ActivationMemoryManager:
+ def __init__(
+ self,
+ act_mem_dump_path: str,
+ monitor: SchedulerGeneralMonitor,
+ log_func_callback: Callable,
+ log_activation_memory_update_func: Callable,
+ ):
+ self.act_mem_dump_path = act_mem_dump_path
+ self.monitor = monitor
+ self.log_func_callback = log_func_callback
+ self.log_activation_memory_update_func = log_activation_memory_update_func
+
+ def update_activation_memory(
+ self,
+ new_memories: list[str | TextualMemoryItem],
+ label: str,
+ user_id: UserID | str,
+ mem_cube_id: MemCubeID | str,
+ mem_cube: GeneralMemCube,
+ ) -> None:
+ """
+ Update activation memory by extracting KVCacheItems from new_memory (list of str),
+ add them to a KVCacheMemory instance, and dump to disk.
+ """
+ if len(new_memories) == 0:
+ logger.error("update_activation_memory: new_memory is empty.")
+ return
+ if isinstance(new_memories[0], TextualMemoryItem):
+ new_text_memories = [mem.memory for mem in new_memories]
+ elif isinstance(new_memories[0], str):
+ new_text_memories = new_memories
+ else:
+ logger.error("Not Implemented.")
+ return
+
+ try:
+ if isinstance(mem_cube.act_mem, VLLMKVCacheMemory):
+ act_mem: VLLMKVCacheMemory = mem_cube.act_mem
+ elif isinstance(mem_cube.act_mem, KVCacheMemory):
+ act_mem: KVCacheMemory = mem_cube.act_mem
+ else:
+ logger.error("Not Implemented.")
+ return
+
+ new_text_memory = MEMORY_ASSEMBLY_TEMPLATE.format(
+ memory_text="".join(
+ [
+ f"{i + 1}. {sentence.strip()}\n"
+ for i, sentence in enumerate(new_text_memories)
+ if sentence.strip() # Skip empty strings
+ ]
+ )
+ )
+
+ # huggingface or vllm kv cache
+ original_cache_items: list[VLLMKVCacheItem] = act_mem.get_all()
+ original_text_memories = []
+ if len(original_cache_items) > 0:
+ pre_cache_item: VLLMKVCacheItem = original_cache_items[-1]
+ original_text_memories = pre_cache_item.records.text_memories
+ original_composed_text_memory = pre_cache_item.records.composed_text_memory
+ if original_composed_text_memory == new_text_memory:
+ logger.warning(
+ "Skipping memory update - new composition matches existing cache: %s",
+ new_text_memory[:50] + "..."
+ if len(new_text_memory) > 50
+ else new_text_memory,
+ )
+ return
+ act_mem.delete_all()
+
+ cache_item = act_mem.extract(new_text_memory)
+ cache_item.records.text_memories = new_text_memories
+ cache_item.records.timestamp = get_utc_now()
+
+ act_mem.add([cache_item])
+ act_mem.dump(self.act_mem_dump_path)
+
+ self.log_activation_memory_update_func(
+ original_text_memories=original_text_memories,
+ new_text_memories=new_text_memories,
+ label=label,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ log_func_callback=self.log_func_callback,
+ )
+
+ except Exception as e:
+ logger.error(f"MOS-based activation memory update failed: {e}", exc_info=True)
+ # Re-raise the exception if it's critical for the operation
+ # For now, we'll continue execution but this should be reviewed
+
+ def update_activation_memory_periodically(
+ self,
+ interval_seconds: int,
+ label: str,
+ user_id: UserID | str,
+ mem_cube_id: MemCubeID | str,
+ mem_cube: GeneralMemCube,
+ ):
+ try:
+ if (
+ self.monitor.last_activation_mem_update_time == datetime.min
+ or self.monitor.timed_trigger(
+ last_time=self.monitor.last_activation_mem_update_time,
+ interval_seconds=interval_seconds,
+ )
+ ):
+ logger.info(
+ f"Updating activation memory for user {user_id} and mem_cube {mem_cube_id}"
+ )
+
+ if (
+ user_id not in self.monitor.working_memory_monitors
+ or mem_cube_id not in self.monitor.working_memory_monitors[user_id]
+ or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories)
+ == 0
+ ):
+ logger.warning(
+ "No memories found in working_memory_monitors, activation memory update is skipped"
+ )
+ return
+
+ self.monitor.update_activation_memory_monitors(
+ user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube
+ )
+
+ # Sync with database to get latest activation memories
+ activation_db_manager = self.monitor.activation_memory_monitors[user_id][
+ mem_cube_id
+ ]
+ activation_db_manager.sync_with_orm()
+ new_activation_memories = [
+ m.memory_text for m in activation_db_manager.obj.memories
+ ]
+
+ logger.info(
+ f"Collected {len(new_activation_memories)} new memory entries for processing"
+ )
+ # Print the content of each new activation memory
+ for i, memory in enumerate(new_activation_memories[:5], 1):
+ logger.info(
+ f"Part of New Activation Memorires | {i}/{len(new_activation_memories)}: {memory[:20]}"
+ )
+
+ self.update_activation_memory(
+ new_memories=new_activation_memories,
+ label=label,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ )
+
+ self.monitor.last_activation_mem_update_time = get_utc_now()
+
+ logger.debug(
+ f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}"
+ )
+
+ else:
+ logger.info(
+ f"Skipping update - {interval_seconds} second interval not yet reached. "
+ f"Last update time is {self.monitor.last_activation_mem_update_time} and now is "
+ f"{get_utc_now()}"
+ )
+ except Exception as e:
+ logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True)
diff --git a/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py
new file mode 100644
index 000000000..98125c13b
--- /dev/null
+++ b/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py
@@ -0,0 +1,282 @@
+from __future__ import annotations
+
+import time
+
+from typing import TYPE_CHECKING
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.general_schemas import (
+ DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
+ DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
+)
+from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer
+from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata
+from memos.types.general_types import FINE_STRATEGY, FineStrategy
+
+
+logger = get_logger(__name__)
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+
+class EnhancementPipeline:
+ def __init__(self, process_llm, config, build_prompt: Callable[..., str]):
+ self.process_llm = process_llm
+ self.config = config
+ self.build_prompt = build_prompt
+ self.batch_size: int | None = getattr(
+ config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE
+ )
+ self.retries: int = getattr(
+ config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES
+ )
+
+ def evaluate_memory_answer_ability(
+ self, query: str, memory_texts: list[str], top_k: int | None = None
+ ) -> bool:
+ limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts
+ prompt = self.build_prompt(
+ template_name="memory_answer_ability_evaluation",
+ query=query,
+ memory_list="\n".join([f"- {memory}" for memory in limited_memories])
+ if limited_memories
+ else "No memories available",
+ )
+
+ response = self.process_llm.generate([{"role": "user", "content": prompt}])
+
+ try:
+ result = extract_json_obj(response)
+
+ if "result" in result:
+ logger.info(
+ "Answerability: result=%s; reason=%s; evaluated=%s",
+ result["result"],
+ result.get("reason", "n/a"),
+ len(limited_memories),
+ )
+ return result["result"]
+ logger.warning("Answerability: invalid LLM JSON structure; payload=%s", result)
+ return False
+
+ except Exception as e:
+ logger.error("Answerability: parse failed; err=%s; raw=%s...", e, str(response)[:200])
+ return False
+
+ def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[str]) -> str:
+ if len(query_history) == 1:
+ query_history = query_history[0]
+ else:
+ query_history = (
+ [f"[{i}] {query}" for i, query in enumerate(query_history)]
+ if len(query_history) > 1
+ else query_history[0]
+ )
+ if FINE_STRATEGY == FineStrategy.REWRITE:
+ text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(batch_texts)])
+ prompt_name = "memory_rewrite_enhancement"
+ else:
+ text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)])
+ prompt_name = "memory_recreate_enhancement"
+ return self.build_prompt(
+ prompt_name,
+ query_history=query_history,
+ memories=text_memories,
+ )
+
+ def _process_enhancement_batch(
+ self,
+ batch_index: int,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ retries: int,
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ attempt = 0
+ text_memories = [one.memory for one in memories]
+
+ prompt = self._build_enhancement_prompt(
+ query_history=query_history, batch_texts=text_memories
+ )
+
+ llm_response = None
+ while attempt <= max(0, retries) + 1:
+ try:
+ llm_response = self.process_llm.generate([{"role": "user", "content": prompt}])
+ processed_text_memories = extract_list_items_in_answer(llm_response)
+ if len(processed_text_memories) > 0:
+ enhanced_memories = []
+ user_id = memories[0].metadata.user_id
+ if FINE_STRATEGY == FineStrategy.RECREATE:
+ for new_mem in processed_text_memories:
+ enhanced_memories.append(
+ TextualMemoryItem(
+ memory=new_mem,
+ metadata=TextualMemoryMetadata(
+ user_id=user_id, memory_type="LongTermMemory"
+ ),
+ )
+ )
+ elif FINE_STRATEGY == FineStrategy.REWRITE:
+
+ def _parse_index_and_text(s: str) -> tuple[int | None, str]:
+ import re
+
+ s = (s or "").strip()
+ m = re.match(r"^\s*\[(\d+)\]\s*(.+)$", s)
+ if m:
+ return int(m.group(1)), m.group(2).strip()
+ m = re.match(r"^\s*(\d+)\s*[:\-\)]\s*(.+)$", s)
+ if m:
+ return int(m.group(1)), m.group(2).strip()
+ return None, s
+
+ idx_to_original = dict(enumerate(memories))
+ for j, item in enumerate(processed_text_memories):
+ idx, new_text = _parse_index_and_text(item)
+ if idx is not None and idx in idx_to_original:
+ orig = idx_to_original[idx]
+ else:
+ orig = memories[j] if j < len(memories) else None
+ if not orig:
+ continue
+ enhanced_memories.append(
+ TextualMemoryItem(
+ id=orig.id,
+ memory=new_text,
+ metadata=orig.metadata,
+ )
+ )
+ else:
+ logger.error("Fine search strategy %s not exists", FINE_STRATEGY)
+
+ logger.info(
+ "[enhance_memories_with_query] done | Strategy=%s | prompt=%s | llm_response=%s",
+ FINE_STRATEGY,
+ prompt,
+ llm_response,
+ )
+ return enhanced_memories, True
+ raise ValueError(
+ "Fail to run memory enhancement; retry "
+ f"{attempt}/{max(1, retries) + 1}; "
+ f"processed_text_memories: {processed_text_memories}"
+ )
+ except Exception as e:
+ attempt += 1
+ time.sleep(1)
+ logger.debug(
+ "[enhance_memories_with_query][batch=%s] retry %s/%s failed: %s",
+ batch_index,
+ attempt,
+ max(1, retries) + 1,
+ e,
+ )
+ logger.error(
+ "Fail to run memory enhancement; prompt: %s;\n llm_response: %s",
+ prompt,
+ llm_response,
+ exc_info=True,
+ )
+ return memories, False
+
+ @staticmethod
+ def _split_batches(
+ memories: list[TextualMemoryItem], batch_size: int
+ ) -> list[tuple[int, int, list[TextualMemoryItem]]]:
+ batches: list[tuple[int, int, list[TextualMemoryItem]]] = []
+ start = 0
+ n = len(memories)
+ while start < n:
+ end = min(start + batch_size, n)
+ batches.append((start, end, memories[start:end]))
+ start = end
+ return batches
+
+ def recall_for_missing_memories(self, query: str, memories: list[str]) -> tuple[str, bool]:
+ text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(memories)])
+
+ prompt = self.build_prompt(
+ template_name="enlarge_recall",
+ query=query,
+ memories_inline=text_memories,
+ )
+ llm_response = self.process_llm.generate([{"role": "user", "content": prompt}])
+
+ json_result: dict = extract_json_obj(llm_response)
+
+ logger.info(
+ "[recall_for_missing_memories] done | prompt=%s | llm_response=%s",
+ prompt,
+ llm_response,
+ )
+
+ hint = json_result.get("hint", "")
+ if len(hint) == 0:
+ return hint, False
+ return hint, json_result.get("trigger_recall", False)
+
+ def enhance_memories_with_query(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ if not memories:
+ logger.warning("[Enhance] skipped (no memories to process)")
+ return memories, True
+
+ batch_size = self.batch_size
+ retries = self.retries
+ num_of_memories = len(memories)
+ try:
+ if batch_size is None or num_of_memories <= batch_size:
+ enhanced_memories, success_flag = self._process_enhancement_batch(
+ batch_index=0,
+ query_history=query_history,
+ memories=memories,
+ retries=retries,
+ )
+
+ all_success = success_flag
+ else:
+ batches = self._split_batches(memories=memories, batch_size=batch_size)
+
+ all_success = True
+ failed_batches = 0
+ from concurrent.futures import as_completed
+
+ from memos.context.context import ContextThreadPoolExecutor
+
+ with ContextThreadPoolExecutor(max_workers=len(batches)) as executor:
+ future_map = {
+ executor.submit(
+ self._process_enhancement_batch, bi, query_history, texts, retries
+ ): (bi, s, e)
+ for bi, (s, e, texts) in enumerate(batches)
+ }
+ enhanced_memories = []
+ for fut in as_completed(future_map):
+ _bi, _s, _e = future_map[fut]
+
+ batch_memories, ok = fut.result()
+ enhanced_memories.extend(batch_memories)
+ if not ok:
+ all_success = False
+ failed_batches += 1
+ logger.info(
+ "[Enhance] multi-batch done | batches=%s | enhanced=%s | failed_batches=%s | success=%s",
+ len(batches),
+ len(enhanced_memories),
+ failed_batches,
+ all_success,
+ )
+
+ except Exception as e:
+ logger.error("[Enhance] fatal error: %s", e, exc_info=True)
+ all_success = False
+ enhanced_memories = memories
+
+ if len(enhanced_memories) == 0:
+ enhanced_memories = []
+ logger.error("[Enhance] fatal error: enhanced_memories is empty", exc_info=True)
+ return enhanced_memories, all_success
diff --git a/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py
new file mode 100644
index 000000000..315f821a9
--- /dev/null
+++ b/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py
@@ -0,0 +1,29 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from memos.mem_scheduler.memory_manage_modules.memory_filter import MemoryFilter
+
+
+if TYPE_CHECKING:
+ from memos.memories.textual.tree import TextualMemoryItem
+
+
+class FilterPipeline:
+ def __init__(self, process_llm, config):
+ self.memory_filter = MemoryFilter(process_llm=process_llm, config=config)
+
+ def filter_unrelated_memories(
+ self, query_history: list[str], memories: list[TextualMemoryItem]
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ return self.memory_filter.filter_unrelated_memories(query_history, memories)
+
+ def filter_redundant_memories(
+ self, query_history: list[str], memories: list[TextualMemoryItem]
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ return self.memory_filter.filter_redundant_memories(query_history, memories)
+
+ def filter_unrelated_and_redundant_memories(
+ self, query_history: list[str], memories: list[TextualMemoryItem]
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ return self.memory_filter.filter_unrelated_and_redundant_memories(query_history, memories)
diff --git a/src/memos/mem_scheduler/memory_manage_modules/post_processor.py b/src/memos/mem_scheduler/memory_manage_modules/post_processor.py
new file mode 100644
index 000000000..28dc22925
--- /dev/null
+++ b/src/memos/mem_scheduler/memory_manage_modules/post_processor.py
@@ -0,0 +1,307 @@
+"""
+Memory Post-Processor - Handles post-retrieval memory filtering and reranking.
+
+This module provides post-processing operations for retrieved memories,
+including filtering and reranking operations specific to the scheduler's needs.
+
+Note: Memory enhancement operations (enhance_memories_with_query, recall_for_missing_memories)
+have been moved to AdvancedSearcher for better architectural separation.
+"""
+
+from memos.configs.mem_scheduler import BaseSchedulerConfig
+from memos.llms.base import BaseLLM
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
+from memos.mem_scheduler.schemas.general_schemas import (
+ DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
+ DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
+)
+from memos.mem_scheduler.utils.filter_utils import (
+ filter_too_short_memories,
+ filter_vector_based_similar_memories,
+ transform_name_to_key,
+)
+from memos.mem_scheduler.utils.misc_utils import extract_json_obj
+from memos.memories.textual.item import TextualMemoryItem
+
+from .memory_filter import MemoryFilter
+
+
+logger = get_logger(__name__)
+
+
+class MemoryPostProcessor(BaseSchedulerModule):
+ """
+ Post-processor for retrieved memories.
+
+ This class handles scheduler-specific post-retrieval operations:
+ - Memory filtering: Remove unrelated or redundant memories
+ - Memory reranking: Reorder memories by relevance
+ - Memory evaluation: Assess memory's ability to answer queries
+
+ Design principles:
+ - Single Responsibility: Only handles filtering/reranking, not enhancement or retrieval
+ - Composable: Can be used independently or chained together
+ - Testable: Each operation can be tested in isolation
+
+ Note: Memory enhancement operations have been moved to AdvancedSearcher.
+
+ Usage:
+ processor = MemoryPostProcessor(process_llm=llm, config=config)
+
+ # Filter out unrelated memories
+ filtered, _ = processor.filter_unrelated_memories(
+ query_history=["What is Python?"],
+ memories=raw_memories
+ )
+
+ # Rerank memories by relevance
+ reranked, _ = processor.process_and_rerank_memories(
+ queries=["What is Python?"],
+ original_memory=filtered,
+ new_memory=[],
+ top_k=10
+ )
+ """
+
+ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
+ """
+ Initialize the post-processor.
+
+ Args:
+ process_llm: LLM instance for enhancement and filtering operations
+ config: Scheduler configuration containing batch sizes and retry settings
+ """
+ super().__init__()
+
+ # Core dependencies
+ self.process_llm = process_llm
+ self.config = config
+ self.memory_filter = MemoryFilter(process_llm=process_llm, config=config)
+
+ # Configuration
+ self.filter_similarity_threshold = 0.75
+ self.filter_min_length_threshold = 6
+
+ # NOTE: Config keys still use "scheduler_retriever_*" prefix for backward compatibility
+ # TODO: Consider renaming to "post_processor_*" in future config refactor
+ self.batch_size: int | None = getattr(
+ config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE
+ )
+ self.retries: int = getattr(
+ config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES
+ )
+
+ def evaluate_memory_answer_ability(
+ self, query: str, memory_texts: list[str], top_k: int | None = None
+ ) -> bool:
+ """
+ Evaluate whether the given memories can answer the query.
+
+ This method uses LLM to assess if the provided memories contain
+ sufficient information to answer the given query.
+
+ Args:
+ query: The query to be answered
+ memory_texts: List of memory text strings
+ top_k: Optional limit on number of memories to consider
+
+ Returns:
+ Boolean indicating whether memories can answer the query
+ """
+ limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts
+
+ # Build prompt using the template
+ prompt = self.build_prompt(
+ template_name="memory_answer_ability_evaluation",
+ query=query,
+ memory_list="\n".join([f"- {memory}" for memory in limited_memories])
+ if limited_memories
+ else "No memories available",
+ )
+
+ # Use the process LLM to generate response
+ response = self.process_llm.generate([{"role": "user", "content": prompt}])
+
+ try:
+ result = extract_json_obj(response)
+
+ # Validate response structure
+ if "result" in result:
+ logger.info(
+ f"[Answerability] result={result['result']}; "
+ f"reason={result.get('reason', 'n/a')}; "
+ f"evaluated={len(limited_memories)}"
+ )
+ return result["result"]
+ else:
+ logger.warning(f"[Answerability] invalid LLM JSON structure; payload={result}")
+ return False
+
+ except Exception as e:
+ logger.error(f"[Answerability] parse failed; err={e}; raw={str(response)[:200]}...")
+ return False
+
+ def rerank_memories(
+ self, queries: list[str], original_memories: list[str], top_k: int
+ ) -> tuple[list[str], bool]:
+ """
+ Rerank memories based on relevance to given queries using LLM.
+
+ Args:
+ queries: List of query strings to determine relevance
+ original_memories: List of memory strings to be reranked
+ top_k: Number of top memories to return after reranking
+
+ Returns:
+ Tuple of (reranked_memories, success_flag)
+ - reranked_memories: List of reranked memory strings (length <= top_k)
+ - success_flag: True if reranking succeeded
+
+ Note:
+ If LLM reranking fails, falls back to original order (truncated to top_k)
+ """
+ logger.info(f"Starting memory reranking for {len(original_memories)} memories")
+
+ # Build LLM prompt for memory reranking
+ prompt = self.build_prompt(
+ "memory_reranking",
+ queries=[f"[0] {queries[0]}"],
+ current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)],
+ )
+ logger.debug(f"Generated reranking prompt: {prompt[:200]}...")
+
+ # Get LLM response
+ response = self.process_llm.generate([{"role": "user", "content": prompt}])
+ logger.debug(f"Received LLM response: {response[:200]}...")
+
+ try:
+ # Parse JSON response
+ response = extract_json_obj(response)
+ new_order = response["new_order"][:top_k]
+ text_memories_with_new_order = [original_memories[idx] for idx in new_order]
+ logger.info(
+ f"Successfully reranked memories. Returning top {len(text_memories_with_new_order)} items; "
+ f"Ranking reasoning: {response['reasoning']}"
+ )
+ success_flag = True
+ except Exception as e:
+ logger.error(
+ f"Failed to rerank memories with LLM. Exception: {e}. Raw response: {response} ",
+ exc_info=True,
+ )
+ text_memories_with_new_order = original_memories[:top_k]
+ success_flag = False
+
+ return text_memories_with_new_order, success_flag
+
+ def process_and_rerank_memories(
+ self,
+ queries: list[str],
+ original_memory: list[TextualMemoryItem],
+ new_memory: list[TextualMemoryItem],
+ top_k: int = 10,
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ """
+ Process and rerank memory items by combining, filtering, and reranking.
+
+ This is a higher-level method that combines multiple post-processing steps:
+ 1. Merge original and new memories
+ 2. Apply similarity filtering
+ 3. Apply length filtering
+ 4. Remove duplicates
+ 5. Rerank by relevance
+
+ Args:
+ queries: List of query strings to rerank memories against
+ original_memory: List of original TextualMemoryItem objects
+ new_memory: List of new TextualMemoryItem objects to merge
+ top_k: Maximum number of memories to return after reranking
+
+ Returns:
+ Tuple of (reranked_memories, success_flag)
+ - reranked_memories: List of reranked TextualMemoryItem objects
+ - success_flag: True if reranking succeeded
+ """
+ # Combine original and new memories
+ combined_memory = original_memory + new_memory
+
+ # Create mapping from normalized text to memory objects
+ memory_map = {
+ transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory
+ }
+
+ # Extract text representations
+ combined_text_memory = [m.memory for m in combined_memory]
+
+ # Apply similarity filter
+ filtered_combined_text_memory = filter_vector_based_similar_memories(
+ text_memories=combined_text_memory,
+ similarity_threshold=self.filter_similarity_threshold,
+ )
+
+ # Apply length filter
+ filtered_combined_text_memory = filter_too_short_memories(
+ text_memories=filtered_combined_text_memory,
+ min_length_threshold=self.filter_min_length_threshold,
+ )
+
+ # Remove duplicates (preserving order)
+ unique_memory = list(dict.fromkeys(filtered_combined_text_memory))
+
+ # Rerank memories
+ text_memories_with_new_order, success_flag = self.rerank_memories(
+ queries=queries,
+ original_memories=unique_memory,
+ top_k=top_k,
+ )
+
+ # Map reranked texts back to memory objects
+ memories_with_new_order = []
+ for text in text_memories_with_new_order:
+ normalized_text = transform_name_to_key(name=text)
+ if normalized_text in memory_map:
+ memories_with_new_order.append(memory_map[normalized_text])
+ else:
+ logger.warning(
+ f"Memory text not found in memory map. text: {text};\n"
+ f"Keys of memory_map: {memory_map.keys()}"
+ )
+
+ return memories_with_new_order, success_flag
+
+ def filter_unrelated_memories(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ """
+ Filter out memories unrelated to the query history.
+
+ Delegates to MemoryFilter for the actual filtering logic.
+ """
+ return self.memory_filter.filter_unrelated_memories(query_history, memories)
+
+ def filter_redundant_memories(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ """
+ Filter out redundant memories from the list.
+
+ Delegates to MemoryFilter for the actual filtering logic.
+ """
+ return self.memory_filter.filter_redundant_memories(query_history, memories)
+
+ def filter_unrelated_and_redundant_memories(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ """
+ Filter out both unrelated and redundant memories using LLM analysis.
+
+ Delegates to MemoryFilter for the actual filtering logic.
+ """
+ return self.memory_filter.filter_unrelated_and_redundant_memories(query_history, memories)
diff --git a/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py
new file mode 100644
index 000000000..0e347df6a
--- /dev/null
+++ b/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py
@@ -0,0 +1,115 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from memos.log import get_logger
+from memos.mem_scheduler.utils.filter_utils import (
+ filter_too_short_memories,
+ filter_vector_based_similar_memories,
+ transform_name_to_key,
+)
+from memos.mem_scheduler.utils.misc_utils import extract_json_obj
+
+
+if TYPE_CHECKING:
+ from memos.memories.textual.item import TextualMemoryItem
+
+
+logger = get_logger(__name__)
+
+
+class RerankPipeline:
+ def __init__(
+ self,
+ process_llm,
+ similarity_threshold: float,
+ min_length_threshold: int,
+ build_prompt,
+ ):
+ self.process_llm = process_llm
+ self.filter_similarity_threshold = similarity_threshold
+ self.filter_min_length_threshold = min_length_threshold
+ self.build_prompt = build_prompt
+
+ def rerank_memories(
+ self, queries: list[str], original_memories: list[str], top_k: int
+ ) -> tuple[list[str], bool]:
+ logger.info("Starting memory reranking for %s memories", len(original_memories))
+
+ prompt = self.build_prompt(
+ "memory_reranking",
+ queries=[f"[0] {queries[0]}"],
+ current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)],
+ )
+ logger.debug("Generated reranking prompt: %s...", prompt[:200])
+
+ response = self.process_llm.generate([{"role": "user", "content": prompt}])
+ logger.debug("Received LLM response: %s...", response[:200])
+
+ try:
+ response = extract_json_obj(response)
+ new_order = response["new_order"][:top_k]
+ text_memories_with_new_order = [original_memories[idx] for idx in new_order]
+ logger.info(
+ "Successfully reranked memories. Returning top %s items; Ranking reasoning: %s",
+ len(text_memories_with_new_order),
+ response["reasoning"],
+ )
+ success_flag = True
+ except Exception as e:
+ logger.error(
+ "Failed to rerank memories with LLM. Exception: %s. Raw response: %s ",
+ e,
+ response,
+ exc_info=True,
+ )
+ text_memories_with_new_order = original_memories[:top_k]
+ success_flag = False
+ return text_memories_with_new_order, success_flag
+
+ def process_and_rerank_memories(
+ self,
+ queries: list[str],
+ original_memory: list[TextualMemoryItem],
+ new_memory: list[TextualMemoryItem],
+ top_k: int = 10,
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ combined_memory = original_memory + new_memory
+
+ memory_map = {
+ transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory
+ }
+
+ combined_text_memory = [m.memory for m in combined_memory]
+
+ filtered_combined_text_memory = filter_vector_based_similar_memories(
+ text_memories=combined_text_memory,
+ similarity_threshold=self.filter_similarity_threshold,
+ )
+
+ filtered_combined_text_memory = filter_too_short_memories(
+ text_memories=filtered_combined_text_memory,
+ min_length_threshold=self.filter_min_length_threshold,
+ )
+
+ unique_memory = list(dict.fromkeys(filtered_combined_text_memory))
+
+ text_memories_with_new_order, success_flag = self.rerank_memories(
+ queries=queries,
+ original_memories=unique_memory,
+ top_k=top_k,
+ )
+
+ memories_with_new_order = []
+ for text in text_memories_with_new_order:
+ normalized_text = transform_name_to_key(name=text)
+ if normalized_text in memory_map:
+ memories_with_new_order.append(memory_map[normalized_text])
+ else:
+ logger.warning(
+ "Memory text not found in memory map. text: %s;\nKeys of memory_map: %s",
+ text,
+ memory_map.keys(),
+ )
+
+ return memories_with_new_order, success_flag
diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py
index f205766f0..3e849f470 100644
--- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py
+++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py
@@ -1,457 +1,99 @@
-import time
+from __future__ import annotations
-from concurrent.futures import as_completed
+from typing import TYPE_CHECKING
-from memos.configs.mem_scheduler import BaseSchedulerConfig
-from memos.context.context import ContextThreadPoolExecutor
-from memos.llms.base import BaseLLM
from memos.log import get_logger
-from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
-from memos.mem_scheduler.schemas.general_schemas import (
- DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
- DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
- TreeTextMemory_FINE_SEARCH_METHOD,
- TreeTextMemory_SEARCH_METHOD,
-)
-from memos.mem_scheduler.utils.filter_utils import (
- filter_too_short_memories,
- filter_vector_based_similar_memories,
- transform_name_to_key,
-)
-from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer
-from memos.memories.textual.item import TextualMemoryMetadata
-from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
-from memos.types.general_types import (
- FINE_STRATEGY,
- FineStrategy,
- SearchMode,
-)
+from memos.mem_scheduler.memory_manage_modules.enhancement_pipeline import EnhancementPipeline
+from memos.mem_scheduler.memory_manage_modules.filter_pipeline import FilterPipeline
+from memos.mem_scheduler.memory_manage_modules.rerank_pipeline import RerankPipeline
+from memos.mem_scheduler.memory_manage_modules.search_pipeline import SearchPipeline
-# Extract JSON response
-from .memory_filter import MemoryFilter
+
+if TYPE_CHECKING:
+ from memos.memories.textual.item import TextualMemoryItem
logger = get_logger(__name__)
class SchedulerRetriever(BaseSchedulerModule):
- def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
+ def __init__(self, process_llm, config):
super().__init__()
- # hyper-parameters
self.filter_similarity_threshold = 0.75
self.filter_min_length_threshold = 6
- self.memory_filter = MemoryFilter(process_llm=process_llm, config=config)
self.process_llm = process_llm
self.config = config
- # Configure enhancement batching & retries from config with safe defaults
- self.batch_size: int | None = getattr(
- config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE
+ self.search_pipeline = SearchPipeline()
+ self.enhancement_pipeline = EnhancementPipeline(
+ process_llm=process_llm,
+ config=config,
+ build_prompt=self.build_prompt,
)
- self.retries: int = getattr(
- config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES
+ self.rerank_pipeline = RerankPipeline(
+ process_llm=process_llm,
+ similarity_threshold=self.filter_similarity_threshold,
+ min_length_threshold=self.filter_min_length_threshold,
+ build_prompt=self.build_prompt,
)
+ self.filter_pipeline = FilterPipeline(process_llm=process_llm, config=config)
+ self.memory_filter = self.filter_pipeline.memory_filter
def evaluate_memory_answer_ability(
self, query: str, memory_texts: list[str], top_k: int | None = None
) -> bool:
- limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts
- # Build prompt using the template
- prompt = self.build_prompt(
- template_name="memory_answer_ability_evaluation",
- query=query,
- memory_list="\n".join([f"- {memory}" for memory in limited_memories])
- if limited_memories
- else "No memories available",
- )
-
- # Use the process LLM to generate response
- response = self.process_llm.generate([{"role": "user", "content": prompt}])
-
- try:
- result = extract_json_obj(response)
-
- # Validate response structure
- if "result" in result:
- logger.info(
- f"Answerability: result={result['result']}; reason={result.get('reason', 'n/a')}; evaluated={len(limited_memories)}"
- )
- return result["result"]
- else:
- logger.warning(f"Answerability: invalid LLM JSON structure; payload={result}")
- return False
-
- except Exception as e:
- logger.error(f"Answerability: parse failed; err={e}; raw={str(response)[:200]}...")
- # Fallback: return False if we can't determine answer ability
- return False
-
- # ---------------------- Enhancement helpers ----------------------
- def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[str]) -> str:
- if len(query_history) == 1:
- query_history = query_history[0]
- else:
- query_history = (
- [f"[{i}] {query}" for i, query in enumerate(query_history)]
- if len(query_history) > 1
- else query_history[0]
- )
- # Include numbering for rewrite mode to help LLM reference original memory IDs
- if FINE_STRATEGY == FineStrategy.REWRITE:
- text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(batch_texts)])
- prompt_name = "memory_rewrite_enhancement"
- else:
- text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)])
- prompt_name = "memory_recreate_enhancement"
- return self.build_prompt(
- prompt_name,
- query_history=query_history,
- memories=text_memories,
- )
-
- def _process_enhancement_batch(
- self,
- batch_index: int,
- query_history: list[str],
- memories: list[TextualMemoryItem],
- retries: int,
- ) -> tuple[list[TextualMemoryItem], bool]:
- attempt = 0
- text_memories = [one.memory for one in memories]
-
- prompt = self._build_enhancement_prompt(
- query_history=query_history, batch_texts=text_memories
- )
-
- llm_response = None
- while attempt <= max(0, retries) + 1:
- try:
- llm_response = self.process_llm.generate([{"role": "user", "content": prompt}])
- processed_text_memories = extract_list_items_in_answer(llm_response)
- if len(processed_text_memories) > 0:
- # create new
- enhanced_memories = []
- user_id = memories[0].metadata.user_id
- if FINE_STRATEGY == FineStrategy.RECREATE:
- for new_mem in processed_text_memories:
- enhanced_memories.append(
- TextualMemoryItem(
- memory=new_mem,
- metadata=TextualMemoryMetadata(
- user_id=user_id, memory_type="LongTermMemory"
- ), # TODO add memory_type
- )
- )
- elif FINE_STRATEGY == FineStrategy.REWRITE:
- # Parse index from each processed line and rewrite corresponding original memory
- def _parse_index_and_text(s: str) -> tuple[int | None, str]:
- import re
-
- s = (s or "").strip()
- # Preferred: [index] text
- m = re.match(r"^\s*\[(\d+)\]\s*(.+)$", s)
- if m:
- return int(m.group(1)), m.group(2).strip()
- # Fallback: index: text or index - text
- m = re.match(r"^\s*(\d+)\s*[:\-\)]\s*(.+)$", s)
- if m:
- return int(m.group(1)), m.group(2).strip()
- return None, s
-
- idx_to_original = dict(enumerate(memories))
- for j, item in enumerate(processed_text_memories):
- idx, new_text = _parse_index_and_text(item)
- if idx is not None and idx in idx_to_original:
- orig = idx_to_original[idx]
- else:
- # Fallback: align by order if index missing/invalid
- orig = memories[j] if j < len(memories) else None
- if not orig:
- continue
- enhanced_memories.append(
- TextualMemoryItem(
- id=orig.id,
- memory=new_text,
- metadata=orig.metadata,
- )
- )
- else:
- logger.error(f"Fine search strategy {FINE_STRATEGY} not exists")
-
- logger.info(
- f"[enhance_memories_with_query] ✅ done | Strategy={FINE_STRATEGY} | prompt={prompt} | llm_response={llm_response}"
- )
- return enhanced_memories, True
- else:
- raise ValueError(
- f"Fail to run memory enhancement; retry {attempt}/{max(1, retries) + 1}; processed_text_memories: {processed_text_memories}"
- )
- except Exception as e:
- attempt += 1
- time.sleep(1)
- logger.debug(
- f"[enhance_memories_with_query][batch={batch_index}] 🔁 retry {attempt}/{max(1, retries) + 1} failed: {e}"
- )
- logger.error(
- f"Fail to run memory enhancement; prompt: {prompt};\n llm_response: {llm_response}",
- exc_info=True,
- )
- return memories, False
-
- @staticmethod
- def _split_batches(
- memories: list[TextualMemoryItem], batch_size: int
- ) -> list[tuple[int, int, list[TextualMemoryItem]]]:
- batches: list[tuple[int, int, list[TextualMemoryItem]]] = []
- start = 0
- n = len(memories)
- while start < n:
- end = min(start + batch_size, n)
- batches.append((start, end, memories[start:end]))
- start = end
- return batches
-
- def recall_for_missing_memories(
- self,
- query: str,
- memories: list[str],
- ) -> tuple[str, bool]:
- text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(memories)])
-
- prompt = self.build_prompt(
- template_name="enlarge_recall",
+ return self.enhancement_pipeline.evaluate_memory_answer_ability(
query=query,
- memories_inline=text_memories,
- )
- llm_response = self.process_llm.generate([{"role": "user", "content": prompt}])
-
- json_result: dict = extract_json_obj(llm_response)
-
- logger.info(
- f"[recall_for_missing_memories] ✅ done | prompt={prompt} | llm_response={llm_response}"
+ memory_texts=memory_texts,
+ top_k=top_k,
)
- hint = json_result.get("hint", "")
- if len(hint) == 0:
- return hint, False
- return hint, json_result.get("trigger_recall", False)
-
def search(
self,
query: str,
user_id: str,
mem_cube_id: str,
- mem_cube: GeneralMemCube,
+ mem_cube,
top_k: int,
- method: str = TreeTextMemory_SEARCH_METHOD,
+ method: str,
search_args: dict | None = None,
) -> list[TextualMemoryItem]:
- """Search in text memory with the given query.
-
- Args:
- query: The search query string
- top_k: Number of top results to return
- method: Search method to use
-
- Returns:
- Search results or None if not implemented
- """
- text_mem_base = mem_cube.text_mem
- # Normalize default for mutable argument
- search_args = search_args or {}
- try:
- if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]:
- assert isinstance(text_mem_base, TreeTextMemory)
- session_id = search_args.get("session_id", "default_session")
- target_session_id = session_id
- search_priority = (
- {"session_id": target_session_id} if "session_id" in search_args else None
- )
- search_filter = search_args.get("filter")
- search_source = search_args.get("source")
- plugin = bool(search_source is not None and search_source == "plugin")
- user_name = search_args.get("user_name", mem_cube_id)
- internet_search = search_args.get("internet_search", False)
- chat_history = search_args.get("chat_history")
- search_tool_memory = search_args.get("search_tool_memory", False)
- tool_mem_top_k = search_args.get("tool_mem_top_k", 6)
- playground_search_goal_parser = search_args.get(
- "playground_search_goal_parser", False
- )
-
- info = search_args.get(
- "info",
- {
- "user_id": user_id,
- "session_id": target_session_id,
- "chat_history": chat_history,
- },
- )
-
- results_long_term = mem_cube.text_mem.search(
- query=query,
- user_name=user_name,
- top_k=top_k,
- mode=SearchMode.FAST,
- manual_close_internet=not internet_search,
- memory_type="LongTermMemory",
- search_filter=search_filter,
- search_priority=search_priority,
- info=info,
- plugin=plugin,
- search_tool_memory=search_tool_memory,
- tool_mem_top_k=tool_mem_top_k,
- playground_search_goal_parser=playground_search_goal_parser,
- )
-
- results_user = mem_cube.text_mem.search(
- query=query,
- user_name=user_name,
- top_k=top_k,
- mode=SearchMode.FAST,
- manual_close_internet=not internet_search,
- memory_type="UserMemory",
- search_filter=search_filter,
- search_priority=search_priority,
- info=info,
- plugin=plugin,
- search_tool_memory=search_tool_memory,
- tool_mem_top_k=tool_mem_top_k,
- playground_search_goal_parser=playground_search_goal_parser,
- )
- results = results_long_term + results_user
- else:
- raise NotImplementedError(str(type(text_mem_base)))
- except Exception as e:
- logger.error(f"Fail to search. The exeption is {e}.", exc_info=True)
- results = []
- return results
+ return self.search_pipeline.search(
+ query=query,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ top_k=top_k,
+ method=method,
+ search_args=search_args,
+ )
def enhance_memories_with_query(
self,
query_history: list[str],
memories: list[TextualMemoryItem],
- ) -> (list[TextualMemoryItem], bool):
- """
- Enhance memories by adding context and making connections to better answer queries.
-
- Args:
- query_history: List of user queries in chronological order
- memories: List of memory items to enhance
-
- Returns:
- Tuple of (enhanced_memories, success_flag)
- """
- if not memories:
- logger.warning("[Enhance] ⚠️ skipped (no memories to process)")
- return memories, True
-
- batch_size = self.batch_size
- retries = self.retries
- num_of_memories = len(memories)
- try:
- # no parallel
- if batch_size is None or num_of_memories <= batch_size:
- # Single batch path with retry
- enhanced_memories, success_flag = self._process_enhancement_batch(
- batch_index=0,
- query_history=query_history,
- memories=memories,
- retries=retries,
- )
-
- all_success = success_flag
- else:
- # parallel running batches
- # Split into batches preserving order
- batches = self._split_batches(memories=memories, batch_size=batch_size)
-
- # Process batches concurrently
- all_success = True
- failed_batches = 0
- with ContextThreadPoolExecutor(max_workers=len(batches)) as executor:
- future_map = {
- executor.submit(
- self._process_enhancement_batch, bi, query_history, texts, retries
- ): (bi, s, e)
- for bi, (s, e, texts) in enumerate(batches)
- }
- enhanced_memories = []
- for fut in as_completed(future_map):
- bi, s, e = future_map[fut]
-
- batch_memories, ok = fut.result()
- enhanced_memories.extend(batch_memories)
- if not ok:
- all_success = False
- failed_batches += 1
- logger.info(
- f"[Enhance] ✅ multi-batch done | batches={len(batches)} | enhanced={len(enhanced_memories)} |"
- f" failed_batches={failed_batches} | success={all_success}"
- )
-
- except Exception as e:
- logger.error(f"[Enhance] ❌ fatal error: {e}", exc_info=True)
- all_success = False
- enhanced_memories = memories
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ return self.enhancement_pipeline.enhance_memories_with_query(
+ query_history=query_history,
+ memories=memories,
+ )
- if len(enhanced_memories) == 0:
- enhanced_memories = []
- logger.error("[Enhance] ❌ fatal error: enhanced_memories is empty", exc_info=True)
- return enhanced_memories, all_success
+ def recall_for_missing_memories(self, query: str, memories: list[str]) -> tuple[str, bool]:
+ return self.enhancement_pipeline.recall_for_missing_memories(
+ query=query,
+ memories=memories,
+ )
def rerank_memories(
self, queries: list[str], original_memories: list[str], top_k: int
- ) -> (list[str], bool):
- """
- Rerank memories based on relevance to given queries using LLM.
-
- Args:
- queries: List of query strings to determine relevance
- original_memories: List of memory strings to be reranked
- top_k: Number of top memories to return after reranking
-
- Returns:
- List of reranked memory strings (length <= top_k)
-
- Note:
- If LLM reranking fails, falls back to original order (truncated to top_k)
- """
-
- logger.info(f"Starting memory reranking for {len(original_memories)} memories")
-
- # Build LLM prompt for memory reranking
- prompt = self.build_prompt(
- "memory_reranking",
- queries=[f"[0] {queries[0]}"],
- current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)],
+ ) -> tuple[list[str], bool]:
+ return self.rerank_pipeline.rerank_memories(
+ queries=queries,
+ original_memories=original_memories,
+ top_k=top_k,
)
- logger.debug(f"Generated reranking prompt: {prompt[:200]}...") # Log first 200 chars
-
- # Get LLM response
- response = self.process_llm.generate([{"role": "user", "content": prompt}])
- logger.debug(f"Received LLM response: {response[:200]}...") # Log first 200 chars
-
- try:
- # Parse JSON response
- response = extract_json_obj(response)
- new_order = response["new_order"][:top_k]
- text_memories_with_new_order = [original_memories[idx] for idx in new_order]
- logger.info(
- f"Successfully reranked memories. Returning top {len(text_memories_with_new_order)} items;"
- f"Ranking reasoning: {response['reasoning']}"
- )
- success_flag = True
- except Exception as e:
- logger.error(
- f"Failed to rerank memories with LLM. Exception: {e}. Raw response: {response} ",
- exc_info=True,
- )
- text_memories_with_new_order = original_memories[:top_k]
- success_flag = False
- return text_memories_with_new_order, success_flag
def process_and_rerank_memories(
self,
@@ -459,89 +101,40 @@ def process_and_rerank_memories(
original_memory: list[TextualMemoryItem],
new_memory: list[TextualMemoryItem],
top_k: int = 10,
- ) -> list[TextualMemoryItem] | None:
- """
- Process and rerank memory items by combining original and new memories,
- applying filters, and then reranking based on relevance to queries.
-
- Args:
- queries: List of query strings to rerank memories against
- original_memory: List of original TextualMemoryItem objects
- new_memory: List of new TextualMemoryItem objects to merge
- top_k: Maximum number of memories to return after reranking
-
- Returns:
- List of reranked TextualMemoryItem objects, or None if processing fails
- """
- # Combine original and new memories into a single list
- combined_memory = original_memory + new_memory
-
- # Create a mapping from normalized text to memory objects
- memory_map = {
- transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory
- }
-
- # Extract normalized text representations from all memory items
- combined_text_memory = [m.memory for m in combined_memory]
-
- # Apply similarity filter to remove overly similar memories
- filtered_combined_text_memory = filter_vector_based_similar_memories(
- text_memories=combined_text_memory,
- similarity_threshold=self.filter_similarity_threshold,
- )
-
- # Apply length filter to remove memories that are too short
- filtered_combined_text_memory = filter_too_short_memories(
- text_memories=filtered_combined_text_memory,
- min_length_threshold=self.filter_min_length_threshold,
- )
-
- # Ensure uniqueness of memory texts using dictionary keys (preserves order)
- unique_memory = list(dict.fromkeys(filtered_combined_text_memory))
-
- # Rerank the filtered memories based on relevance to the queries
- text_memories_with_new_order, success_flag = self.rerank_memories(
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ return self.rerank_pipeline.process_and_rerank_memories(
queries=queries,
- original_memories=unique_memory,
+ original_memory=original_memory,
+ new_memory=new_memory,
top_k=top_k,
)
- # Map reranked text entries back to their original memory objects
- memories_with_new_order = []
- for text in text_memories_with_new_order:
- normalized_text = transform_name_to_key(name=text)
- if normalized_text in memory_map: # Ensure correct key matching
- memories_with_new_order.append(memory_map[normalized_text])
- else:
- logger.warning(
- f"Memory text not found in memory map. text: {text};\n"
- f"Keys of memory_map: {memory_map.keys()}"
- )
-
- return memories_with_new_order, success_flag
-
def filter_unrelated_memories(
self,
query_history: list[str],
memories: list[TextualMemoryItem],
- ) -> (list[TextualMemoryItem], bool):
- return self.memory_filter.filter_unrelated_memories(query_history, memories)
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ return self.filter_pipeline.filter_unrelated_memories(
+ query_history=query_history,
+ memories=memories,
+ )
def filter_redundant_memories(
self,
query_history: list[str],
memories: list[TextualMemoryItem],
- ) -> (list[TextualMemoryItem], bool):
- return self.memory_filter.filter_redundant_memories(query_history, memories)
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ return self.filter_pipeline.filter_redundant_memories(
+ query_history=query_history,
+ memories=memories,
+ )
def filter_unrelated_and_redundant_memories(
self,
query_history: list[str],
memories: list[TextualMemoryItem],
- ) -> (list[TextualMemoryItem], bool):
- """
- Filter out both unrelated and redundant memories using LLM analysis.
-
- This method delegates to the MemoryFilter class.
- """
- return self.memory_filter.filter_unrelated_and_redundant_memories(query_history, memories)
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ return self.filter_pipeline.filter_unrelated_and_redundant_memories(
+ query_history=query_history,
+ memories=memories,
+ )
diff --git a/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py
new file mode 100644
index 000000000..a346622c5
--- /dev/null
+++ b/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py
@@ -0,0 +1,94 @@
+from __future__ import annotations
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.general_schemas import (
+ TreeTextMemory_FINE_SEARCH_METHOD,
+ TreeTextMemory_SEARCH_METHOD,
+)
+from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
+from memos.types.general_types import SearchMode
+
+
+logger = get_logger(__name__)
+
+
+class SearchPipeline:
+ def search(
+ self,
+ query: str,
+ user_id: str,
+ mem_cube_id: str,
+ mem_cube,
+ top_k: int,
+ method: str = TreeTextMemory_SEARCH_METHOD,
+ search_args: dict | None = None,
+ ) -> list[TextualMemoryItem]:
+ text_mem_base = mem_cube.text_mem
+ search_args = search_args or {}
+ try:
+ if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]:
+ assert isinstance(text_mem_base, TreeTextMemory)
+ session_id = search_args.get("session_id", "default_session")
+ target_session_id = session_id
+ search_priority = (
+ {"session_id": target_session_id} if "session_id" in search_args else None
+ )
+ search_filter = search_args.get("filter")
+ search_source = search_args.get("source")
+ plugin = bool(search_source is not None and search_source == "plugin")
+ user_name = search_args.get("user_name", mem_cube_id)
+ internet_search = search_args.get("internet_search", False)
+ chat_history = search_args.get("chat_history")
+ search_tool_memory = search_args.get("search_tool_memory", False)
+ tool_mem_top_k = search_args.get("tool_mem_top_k", 6)
+ playground_search_goal_parser = search_args.get(
+ "playground_search_goal_parser", False
+ )
+
+ info = search_args.get(
+ "info",
+ {
+ "user_id": user_id,
+ "session_id": target_session_id,
+ "chat_history": chat_history,
+ },
+ )
+
+ results_long_term = mem_cube.text_mem.search(
+ query=query,
+ user_name=user_name,
+ top_k=top_k,
+ mode=SearchMode.FAST,
+ manual_close_internet=not internet_search,
+ memory_type="LongTermMemory",
+ search_filter=search_filter,
+ search_priority=search_priority,
+ info=info,
+ plugin=plugin,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
+ playground_search_goal_parser=playground_search_goal_parser,
+ )
+
+ results_user = mem_cube.text_mem.search(
+ query=query,
+ user_name=user_name,
+ top_k=top_k,
+ mode=SearchMode.FAST,
+ manual_close_internet=not internet_search,
+ memory_type="UserMemory",
+ search_filter=search_filter,
+ search_priority=search_priority,
+ info=info,
+ plugin=plugin,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
+ playground_search_goal_parser=playground_search_goal_parser,
+ )
+ results = results_long_term + results_user
+ else:
+ raise NotImplementedError(str(type(text_mem_base)))
+ except Exception as e:
+ logger.error("Fail to search. The exception is %s.", e, exc_info=True)
+ results = []
+ return results
diff --git a/src/memos/mem_scheduler/memory_manage_modules/search_service.py b/src/memos/mem_scheduler/memory_manage_modules/search_service.py
new file mode 100644
index 000000000..3f00372bd
--- /dev/null
+++ b/src/memos/mem_scheduler/memory_manage_modules/search_service.py
@@ -0,0 +1,297 @@
+"""
+Scheduler Search Service - Unified search interface for the scheduler.
+
+This module provides a clean abstraction over the Searcher class,
+adapting it for scheduler-specific use cases while maintaining compatibility.
+"""
+
+from memos.log import get_logger
+from memos.mem_cube.general import GeneralMemCube
+from memos.memories.textual.item import TextualMemoryItem
+from memos.memories.textual.tree import TreeTextMemory
+from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
+from memos.types.general_types import SearchMode
+
+
+logger = get_logger(__name__)
+
+
+class SchedulerSearchService:
+ """
+ Unified search service for the scheduler.
+
+ This service provides a clean interface for memory search operations,
+ delegating to the Searcher class while handling scheduler-specific
+ parameter adaptations.
+
+ Design principles:
+ - Single Responsibility: Only handles search coordination
+ - Dependency Injection: Searcher is injected, not created
+ - Fail-safe: Falls back to direct text_mem.search() if Searcher unavailable
+
+ Usage:
+ service = SchedulerSearchService(searcher=searcher)
+ results = service.search(
+ query="user query",
+ user_id="user_123",
+ mem_cube=mem_cube,
+ top_k=10
+ )
+ """
+
+ def __init__(self, searcher: Searcher | None = None):
+ """
+ Initialize the search service.
+
+ Args:
+ searcher: Optional Searcher instance. If None, will fall back to
+ direct mem_cube.text_mem.search() calls.
+ """
+ self.searcher = searcher
+
+ def search(
+ self,
+ query: str,
+ user_id: str,
+ mem_cube: GeneralMemCube,
+ top_k: int,
+ mode: SearchMode = SearchMode.FAST,
+ search_filter: dict | None = None,
+ search_priority: dict | None = None,
+ session_id: str = "default_session",
+ internet_search: bool = False,
+ chat_history: list | None = None,
+ plugin: bool = False,
+ search_tool_memory: bool = False,
+ tool_mem_top_k: int = 6,
+ playground_search_goal_parser: bool = False,
+ mem_cube_id: str | None = None,
+ ) -> list[TextualMemoryItem]:
+ """
+ Search for memories across both LongTermMemory and UserMemory.
+
+ This method provides a unified interface for memory search, automatically
+ handling the search across different memory types and merging results.
+
+ Args:
+ query: The search query string
+ user_id: User identifier
+ mem_cube: Memory cube instance containing text memory
+ top_k: Number of top results to return per memory type
+ mode: Search mode (FAST or FINE)
+ search_filter: Optional metadata filters for search results
+ search_priority: Optional metadata priority for search results
+ session_id: Session identifier for session-scoped search
+ internet_search: Whether to enable internet search
+ chat_history: Chat history for context
+ plugin: Whether this is a plugin-initiated search
+ search_tool_memory: Whether to search tool memory
+ tool_mem_top_k: Top-k for tool memory search
+ playground_search_goal_parser: Whether to use playground goal parser
+ mem_cube_id: Memory cube identifier (defaults to user_id if not provided)
+
+ Returns:
+ List of TextualMemoryItem objects sorted by relevance
+
+ Raises:
+ Exception: Propagates exceptions from underlying search implementations
+ """
+ mem_cube_id = mem_cube_id or user_id
+ user_name = mem_cube_id
+ text_mem_base = mem_cube.text_mem
+
+ # Build info dict for tracking
+ info = {
+ "user_id": user_id,
+ "session_id": session_id,
+ "chat_history": chat_history,
+ }
+
+ try:
+ if self.searcher:
+ # Use injected Searcher (preferred path)
+ results = self._search_with_searcher(
+ query=query,
+ user_name=user_name,
+ top_k=top_k,
+ mode=mode,
+ search_filter=search_filter,
+ search_priority=search_priority,
+ info=info,
+ internet_search=internet_search,
+ plugin=plugin,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
+ playground_search_goal_parser=playground_search_goal_parser,
+ )
+ logger.info(
+ f"[SchedulerSearchService] Searched via Searcher: "
+ f"query='{query}' results={len(results)}"
+ )
+ else:
+ # Fallback: Direct text_mem.search() call
+ results = self._search_with_text_mem(
+ text_mem_base=text_mem_base,
+ query=query,
+ user_name=user_name,
+ top_k=top_k,
+ mode=mode,
+ search_filter=search_filter,
+ search_priority=search_priority,
+ info=info,
+ internet_search=internet_search,
+ plugin=plugin,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
+ playground_search_goal_parser=playground_search_goal_parser,
+ )
+ logger.info(
+ f"[SchedulerSearchService] Searched via text_mem (fallback): "
+ f"query='{query}' results={len(results)}"
+ )
+
+ return results
+
+ except Exception as e:
+ logger.error(
+ f"[SchedulerSearchService] Search failed for query='{query}': {e}",
+ exc_info=True,
+ )
+ return []
+
+ def _search_with_searcher(
+ self,
+ query: str,
+ user_name: str,
+ top_k: int,
+ mode: SearchMode,
+ search_filter: dict | None,
+ search_priority: dict | None,
+ info: dict,
+ internet_search: bool,
+ plugin: bool,
+ search_tool_memory: bool,
+ tool_mem_top_k: int,
+ playground_search_goal_parser: bool,
+ ) -> list[TextualMemoryItem]:
+ """
+ Search using the injected Searcher instance.
+
+ IMPORTANT: This method searches "All" memory types in a single call to avoid
+ the bug where calling search() twice (for LongTermMemory and UserMemory separately)
+ would return 2*top_k results due to Searcher.search() applying deduplication and
+ top_k limiting on each call.
+
+ This ensures the final result is properly deduplicated and limited to top_k items.
+ """
+ # Preserve original internet search setting
+ original_manual_close = getattr(self.searcher, "manual_close_internet", None)
+
+ try:
+ # Configure internet search
+ if original_manual_close is not None:
+ self.searcher.manual_close_internet = not internet_search
+
+ # Search LongTermMemory
+ results_long_term = self.searcher.search(
+ query=query,
+ user_name=user_name,
+ top_k=top_k,
+ mode=mode,
+ memory_type="LongTermMemory",
+ search_filter=search_filter,
+ search_priority=search_priority,
+ info=info,
+ plugin=plugin,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
+ playground_search_goal_parser=playground_search_goal_parser,
+ )
+
+ # Search UserMemory
+ results_user = self.searcher.search(
+ query=query,
+ user_name=user_name,
+ top_k=top_k,
+ mode=mode,
+ memory_type="UserMemory",
+ search_filter=search_filter,
+ search_priority=search_priority,
+ info=info,
+ plugin=plugin,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
+ playground_search_goal_parser=playground_search_goal_parser,
+ )
+
+ return results_long_term + results_user
+
+ finally:
+ # Restore original setting
+ if original_manual_close is not None:
+ self.searcher.manual_close_internet = original_manual_close
+
+ def _search_with_text_mem(
+ self,
+ text_mem_base: TreeTextMemory,
+ query: str,
+ user_name: str,
+ top_k: int,
+ mode: SearchMode,
+ search_filter: dict | None,
+ search_priority: dict | None,
+ info: dict,
+ internet_search: bool,
+ plugin: bool,
+ search_tool_memory: bool,
+ tool_mem_top_k: int,
+ playground_search_goal_parser: bool,
+ ) -> list[TextualMemoryItem]:
+ """
+ Fallback: Search using direct text_mem.search() calls.
+
+ This is used when no Searcher instance is available, providing
+ backward compatibility with the original implementation.
+
+ NOTE: TreeTextMemory.search() with memory_type="All" will internally
+ search both LongTermMemory and UserMemory and properly merge results.
+ """
+ assert isinstance(text_mem_base, TreeTextMemory), (
+ f"Fallback search requires TreeTextMemory, got {type(text_mem_base)}"
+ )
+
+ # Search LongTermMemory
+ results_long_term = text_mem_base.search(
+ query=query,
+ user_name=user_name,
+ top_k=top_k,
+ mode=mode,
+ manual_close_internet=not internet_search,
+ memory_type="LongTermMemory",
+ search_filter=search_filter,
+ search_priority=search_priority,
+ info=info,
+ plugin=plugin,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
+ playground_search_goal_parser=playground_search_goal_parser,
+ )
+
+ # Search UserMemory
+ results_user = text_mem_base.search(
+ query=query,
+ user_name=user_name,
+ top_k=top_k,
+ mode=mode,
+ manual_close_internet=not internet_search,
+ memory_type="UserMemory",
+ search_filter=search_filter,
+ search_priority=search_priority,
+ info=info,
+ plugin=plugin,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
+ playground_search_goal_parser=playground_search_goal_parser,
+ )
+
+ return results_long_term + results_user
diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py
index 497d19ac6..d6b566dfe 100644
--- a/src/memos/mem_scheduler/optimized_scheduler.py
+++ b/src/memos/mem_scheduler/optimized_scheduler.py
@@ -19,6 +19,7 @@
from memos.mem_scheduler.utils.db_utils import get_utc_now
from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
+from memos.search import build_search_context, search_text_memories
from memos.types import (
MemCubeID,
SearchMode,
@@ -104,29 +105,14 @@ def search_memories(
mem_cube: NaiveMemCube,
mode: SearchMode,
):
- """Fine search memories function copied from server_router to avoid circular import"""
- target_session_id = search_req.session_id
- if not target_session_id:
- target_session_id = "default_session"
- search_priority = {"session_id": search_req.session_id} if search_req.session_id else None
- search_filter = search_req.filter
-
- # Create MemCube and perform search
- search_results = mem_cube.text_mem.search(
- query=search_req.query,
- user_name=user_context.mem_cube_id,
- top_k=search_req.top_k,
+ """Shared text-memory search via centralized search service."""
+ return search_text_memories(
+ text_mem=mem_cube.text_mem,
+ search_req=search_req,
+ user_context=user_context,
mode=mode,
- manual_close_internet=not search_req.internet_search,
- search_filter=search_filter,
- search_priority=search_priority,
- info={
- "user_id": search_req.user_id,
- "session_id": target_session_id,
- "chat_history": search_req.chat_history,
- },
+ include_embedding=(search_req.dedup == "mmr"),
)
- return search_results
def mix_search_memories(
self,
@@ -157,19 +143,13 @@ def mix_search_memories(
]
# Get mem_cube for fast search
- target_session_id = search_req.session_id
- if not target_session_id:
- target_session_id = "default_session"
- search_priority = {"session_id": search_req.session_id} if search_req.session_id else None
- search_filter = search_req.filter
+ search_ctx = build_search_context(search_req=search_req)
+ search_priority = search_ctx.search_priority
+ search_filter = search_ctx.search_filter
# Rerank Memories - reranker expects TextualMemoryItem objects
- info = {
- "user_id": search_req.user_id,
- "session_id": target_session_id,
- "chat_history": search_req.chat_history,
- }
+ info = search_ctx.info
raw_retrieved_memories = self.searcher.retrieve(
query=search_req.query,
diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py
index c7f270f19..c11d30470 100644
--- a/src/memos/mem_scheduler/schemas/message_schemas.py
+++ b/src/memos/mem_scheduler/schemas/message_schemas.py
@@ -1,3 +1,5 @@
+import json
+
from datetime import datetime
from typing import Any
from uuid import uuid4
@@ -9,6 +11,7 @@
from memos.log import get_logger
from memos.mem_scheduler.general_modules.misc import DictConversionMixin
from memos.mem_scheduler.utils.db_utils import get_utc_now
+from memos.types.general_types import UserContext
from .general_schemas import NOT_INITIALIZED
@@ -55,6 +58,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin):
description="Optional business-level task ID. Multiple items can share the same task_id.",
)
chat_history: list | None = Field(default=None, description="user chat history")
+ user_context: UserContext | None = Field(default=None, description="user context")
# Pydantic V2 model configuration
model_config = ConfigDict(
@@ -79,7 +83,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin):
def to_dict(self) -> dict:
"""Convert model to dictionary suitable for Redis Stream"""
- return {
+ raw = {
"item_id": self.item_id,
"user_id": self.user_id,
"cube_id": self.mem_cube_id,
@@ -91,22 +95,65 @@ def to_dict(self) -> dict:
"user_name": self.user_name,
"task_id": self.task_id if self.task_id is not None else "",
"chat_history": self.chat_history if self.chat_history is not None else [],
+ "user_context": self.user_context.model_dump(exclude_none=True)
+ if self.user_context
+ else None,
}
+ return {key: self._serialize_redis_value(value) for key, value in raw.items()}
+
+ @staticmethod
+ def _serialize_redis_value(value: Any) -> Any:
+ if value is None:
+ return ""
+ if isinstance(value, list | dict):
+ return json.dumps(value, ensure_ascii=False)
+ return value
@classmethod
def from_dict(cls, data: dict) -> "ScheduleMessageItem":
"""Create model from Redis Stream dictionary"""
+
+ def _decode(val: Any) -> Any:
+ if isinstance(val, bytes | bytearray):
+ return val.decode("utf-8")
+ return val
+
+ raw_chat_history = _decode(data.get("chat_history"))
+ if isinstance(raw_chat_history, str):
+ if raw_chat_history:
+ try:
+ chat_history = json.loads(raw_chat_history)
+ except Exception:
+ chat_history = None
+ else:
+ chat_history = None
+ else:
+ chat_history = raw_chat_history
+
+ raw_user_context = _decode(data.get("user_context"))
+ if isinstance(raw_user_context, str):
+ if raw_user_context:
+ try:
+ raw_user_context = json.loads(raw_user_context)
+ except Exception:
+ raw_user_context = None
+ else:
+ raw_user_context = None
+
+ raw_timestamp = _decode(data.get("timestamp"))
+ timestamp = datetime.fromisoformat(raw_timestamp) if raw_timestamp else get_utc_now()
return cls(
- item_id=data.get("item_id", str(uuid4())),
- user_id=data["user_id"],
- mem_cube_id=data["cube_id"],
- trace_id=data.get("trace_id", generate_trace_id()),
- label=data["label"],
- content=data["content"],
- timestamp=datetime.fromisoformat(data["timestamp"]),
- user_name=data.get("user_name"),
- task_id=data.get("task_id"),
- chat_history=data.get("chat_history"),
+ item_id=_decode(data.get("item_id", str(uuid4()))),
+ user_id=_decode(data["user_id"]),
+ mem_cube_id=_decode(data["cube_id"]),
+ trace_id=_decode(data.get("trace_id", generate_trace_id())),
+ label=_decode(data["label"]),
+ content=_decode(data["content"]),
+ timestamp=timestamp,
+ user_name=_decode(data.get("user_name")),
+ task_id=_decode(data.get("task_id")),
+ chat_history=chat_history,
+ user_context=UserContext.model_validate(raw_user_context) if raw_user_context else None,
)
diff --git a/src/memos/mem_scheduler/task_schedule_modules/base_handler.py b/src/memos/mem_scheduler/task_schedule_modules/base_handler.py
new file mode 100644
index 000000000..603b038e1
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/base_handler.py
@@ -0,0 +1,68 @@
+from __future__ import annotations
+
+from abc import abstractmethod
+from typing import TYPE_CHECKING
+
+from memos.log import get_logger
+from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
+
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+ from memos.mem_scheduler.task_schedule_modules.context import SchedulerHandlerContext
+
+
+logger = get_logger(__name__)
+
+
+class BaseSchedulerHandler:
+ def __init__(self, scheduler_context: SchedulerHandlerContext) -> None:
+ self.scheduler_context = scheduler_context
+
+ @property
+ @abstractmethod
+ def expected_task_label(self) -> str:
+ """The expected task label for this handler."""
+ ...
+
+ def validate_and_log_messages(self, messages: list[ScheduleMessageItem], label: str) -> None:
+ logger.info(f"Messages {messages} assigned to {label} handler.")
+ self.scheduler_context.services.validate_messages(messages=messages, label=label)
+
+ def handle_exception(self, e: Exception, message: str = "Error processing messages") -> None:
+ logger.error(f"{message}: {e}", exc_info=True)
+
+ def process_grouped_messages(
+ self,
+ messages: list[ScheduleMessageItem],
+ message_handler: Callable[[str, str, list[ScheduleMessageItem]], None],
+ ) -> None:
+ grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)
+ for user_id, user_batches in grouped_messages.items():
+ for mem_cube_id, batch in user_batches.items():
+ if not batch:
+ continue
+ try:
+ message_handler(user_id, mem_cube_id, batch)
+ except Exception as e:
+ self.handle_exception(
+ e, f"Error processing batch for user {user_id}, mem_cube {mem_cube_id}"
+ )
+
+ @abstractmethod
+ def batch_handler(
+ self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]
+ ) -> None: ...
+
+ def __call__(self, messages: list[ScheduleMessageItem]) -> None:
+ """
+ Process the messages.
+ """
+ self.validate_and_log_messages(messages=messages, label=self.expected_task_label)
+
+ self.process_grouped_messages(
+ messages=messages,
+ message_handler=self.batch_handler,
+ )
diff --git a/src/memos/mem_scheduler/task_schedule_modules/context.py b/src/memos/mem_scheduler/task_schedule_modules/context.py
new file mode 100644
index 000000000..d5c1ea9af
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/context.py
@@ -0,0 +1,42 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any
+
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+ from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
+ from memos.memories.textual.item import TextualMemoryItem
+
+
+@dataclass(frozen=True)
+class SchedulerHandlerServices:
+ validate_messages: Callable[[list[ScheduleMessageItem], str], None]
+ submit_messages: Callable[[list[ScheduleMessageItem]], None]
+ create_event_log: Callable[..., Any]
+ submit_web_logs: Callable[..., None]
+ map_memcube_name: Callable[[str], str]
+ update_activation_memory_periodically: Callable[..., None]
+ replace_working_memory: Callable[
+ [str, str, Any, list[TextualMemoryItem], list[TextualMemoryItem]],
+ list[TextualMemoryItem] | None,
+ ]
+ transform_working_memories_to_monitors: Callable[..., list[MemoryMonitorItem]]
+ log_working_memory_replacement: Callable[..., None]
+
+
+@dataclass(frozen=True)
+class SchedulerHandlerContext:
+ get_mem_cube: Callable[[], Any]
+ get_monitor: Callable[[], Any]
+ get_retriever: Callable[[], Any]
+ get_mem_reader: Callable[[], Any]
+ get_feedback_server: Callable[[], Any]
+ get_search_method: Callable[[], str]
+ get_top_k: Callable[[], int]
+ get_enable_activation_memory: Callable[[], bool]
+ get_query_key_words_limit: Callable[[], int]
+ services: SchedulerHandlerServices
diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
index 2099da5a1..74ab15209 100644
--- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
+++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
@@ -20,7 +20,7 @@
DEFAULT_STOP_WAIT,
)
from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem, ScheduleMessageItem
-from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem
+from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem, TaskPriorityLevel
from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue
@@ -428,34 +428,70 @@ def get_running_task_count(self) -> int:
with self._task_lock:
return len(self._running_tasks)
- def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]):
+ def register_handler(
+ self,
+ label: str,
+ handler: Callable[[list[ScheduleMessageItem]], None],
+ priority: TaskPriorityLevel | None = None,
+ min_idle_ms: int | None = None,
+ ):
"""
Register a handler function for a specific message label.
Args:
label: Message label to handle
handler: Callable that processes messages of this label
+ priority: Optional priority level for the task
+ min_idle_ms: Optional minimum idle time for task claiming
"""
self.handlers[label] = handler
+ if self.orchestrator:
+ self.orchestrator.set_task_config(
+ task_label=label, priority=priority, min_idle_ms=min_idle_ms
+ )
def register_handlers(
- self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]]
+ self,
+ handlers: dict[
+ str,
+ Callable[[list[ScheduleMessageItem]], None]
+ | tuple[
+ Callable[[list[ScheduleMessageItem]], None], TaskPriorityLevel | None, int | None
+ ],
+ ],
) -> None:
"""
Bulk register multiple handlers from a dictionary.
Args:
- handlers: Dictionary mapping labels to handler functions
- Format: {label: handler_callable}
+ handlers: Dictionary where key is label and value is either:
+ - handler_callable
+ - tuple(handler_callable, priority, min_idle_ms)
"""
- for label, handler in handlers.items():
+ for label, value in handlers.items():
if not isinstance(label, str):
logger.error(f"Invalid label type: {type(label)}. Expected str.")
continue
+
+ if isinstance(value, tuple):
+ if len(value) != 3:
+ logger.error(
+ f"Invalid handler tuple for label '{label}'. Expected (handler, priority, min_idle_ms)."
+ )
+ continue
+ handler, priority, min_idle_ms = value
+ else:
+ handler = value
+ priority = None
+ min_idle_ms = None
+
if not callable(handler):
logger.error(f"Handler for label '{label}' is not callable.")
continue
- self.register_handler(label=label, handler=handler)
+
+ self.register_handler(
+ label=label, handler=handler, priority=priority, min_idle_ms=min_idle_ms
+ )
logger.info(f"Registered {len(handlers)} handlers in bulk")
def unregister_handler(self, label: str) -> bool:
@@ -470,6 +506,8 @@ def unregister_handler(self, label: str) -> bool:
"""
if label in self.handlers:
del self.handlers[label]
+ if self.orchestrator:
+ self.orchestrator.remove_task_config(label)
logger.info(f"Unregistered handler for label: {label}")
return True
else:
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/__init__.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/__init__.py
new file mode 100644
index 000000000..e5700e641
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/__init__.py
@@ -0,0 +1,12 @@
+from memos.mem_scheduler.task_schedule_modules.context import (
+ SchedulerHandlerContext,
+ SchedulerHandlerServices,
+)
+from memos.mem_scheduler.task_schedule_modules.registry import SchedulerHandlerRegistry
+
+
+__all__ = [
+ "SchedulerHandlerContext",
+ "SchedulerHandlerRegistry",
+ "SchedulerHandlerServices",
+]
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py
new file mode 100644
index 000000000..63718fd92
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py
@@ -0,0 +1,299 @@
+from __future__ import annotations
+
+import json
+
+from typing import TYPE_CHECKING
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.task_schemas import (
+ ADD_TASK_LABEL,
+ LONG_TERM_MEMORY_TYPE,
+ USER_INPUT_TYPE,
+)
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler
+from memos.mem_scheduler.utils.filter_utils import transform_name_to_key
+from memos.mem_scheduler.utils.misc_utils import is_cloud_env
+
+
+if TYPE_CHECKING:
+ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+ from memos.memories.textual.item import TextualMemoryItem
+
+
+logger = get_logger(__name__)
+
+
+class AddMessageHandler(BaseSchedulerHandler):
+ @property
+ def expected_task_label(self) -> str:
+ return ADD_TASK_LABEL
+
+ def batch_handler(
+ self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]
+ ) -> None:
+ for msg in batch:
+ prepared_add_items, prepared_update_items_with_original = self.log_add_messages(msg=msg)
+ logger.info(
+ "prepared_add_items: %s;\n prepared_update_items_with_original: %s",
+ prepared_add_items,
+ prepared_update_items_with_original,
+ )
+ cloud_env = is_cloud_env()
+
+ if cloud_env:
+ self.send_add_log_messages_to_cloud_env(
+ msg, prepared_add_items, prepared_update_items_with_original
+ )
+ else:
+ self.send_add_log_messages_to_local_env(
+ msg, prepared_add_items, prepared_update_items_with_original
+ )
+
+ def log_add_messages(self, msg: ScheduleMessageItem):
+ try:
+ userinput_memory_ids = json.loads(msg.content)
+ except Exception as e:
+ logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True)
+ userinput_memory_ids = []
+
+ prepared_add_items = []
+ prepared_update_items_with_original = []
+ missing_ids: list[str] = []
+
+ mem_cube = self.scheduler_context.get_mem_cube()
+
+ for memory_id in userinput_memory_ids:
+ try:
+ mem_item: TextualMemoryItem | None = None
+ mem_item = mem_cube.text_mem.get(memory_id=memory_id, user_name=msg.mem_cube_id)
+ if mem_item is None:
+ raise ValueError(f"Memory {memory_id} not found after retries")
+ key = getattr(mem_item.metadata, "key", None) or transform_name_to_key(
+ name=mem_item.memory
+ )
+ exists = False
+ original_content = None
+ original_item_id = None
+
+ if key and hasattr(mem_cube.text_mem, "graph_store"):
+ candidates = mem_cube.text_mem.graph_store.get_by_metadata(
+ [
+ {"field": "key", "op": "=", "value": key},
+ {
+ "field": "memory_type",
+ "op": "=",
+ "value": mem_item.metadata.memory_type,
+ },
+ ]
+ )
+ if candidates:
+ exists = True
+ original_item_id = candidates[0]
+ original_mem_item = mem_cube.text_mem.get(
+ memory_id=original_item_id, user_name=msg.mem_cube_id
+ )
+ original_content = original_mem_item.memory
+
+ if exists:
+ prepared_update_items_with_original.append(
+ {
+ "new_item": mem_item,
+ "original_content": original_content,
+ "original_item_id": original_item_id,
+ }
+ )
+ else:
+ prepared_add_items.append(mem_item)
+
+ except Exception:
+ missing_ids.append(memory_id)
+ logger.debug(
+ "This MemoryItem %s has already been deleted or an error occurred during preparation.",
+ memory_id,
+ )
+
+ if missing_ids:
+ content_preview = (
+ msg.content[:200] + "..."
+ if isinstance(msg.content, str) and len(msg.content) > 200
+ else msg.content
+ )
+ logger.warning(
+ "Missing TextualMemoryItem(s) during add log preparation. "
+ "memory_ids=%s user_id=%s mem_cube_id=%s task_id=%s item_id=%s redis_msg_id=%s label=%s stream_key=%s content_preview=%s",
+ missing_ids,
+ msg.user_id,
+ msg.mem_cube_id,
+ msg.task_id,
+ msg.item_id,
+ getattr(msg, "redis_message_id", ""),
+ msg.label,
+ getattr(msg, "stream_key", ""),
+ content_preview,
+ )
+
+ if not prepared_add_items and not prepared_update_items_with_original:
+ logger.warning(
+ "No add/update items prepared; skipping addMemory/knowledgeBaseUpdate logs. "
+ "user_id=%s mem_cube_id=%s task_id=%s item_id=%s redis_msg_id=%s label=%s stream_key=%s missing_ids=%s",
+ msg.user_id,
+ msg.mem_cube_id,
+ msg.task_id,
+ msg.item_id,
+ getattr(msg, "redis_message_id", ""),
+ msg.label,
+ getattr(msg, "stream_key", ""),
+ missing_ids,
+ )
+ return prepared_add_items, prepared_update_items_with_original
+
+ def send_add_log_messages_to_local_env(
+ self,
+ msg: ScheduleMessageItem,
+ prepared_add_items,
+ prepared_update_items_with_original,
+ ) -> None:
+ add_content_legacy: list[dict] = []
+ add_meta_legacy: list[dict] = []
+ update_content_legacy: list[dict] = []
+ update_meta_legacy: list[dict] = []
+
+ for item in prepared_add_items:
+ key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory)
+ add_content_legacy.append({"content": f"{key}: {item.memory}", "ref_id": item.id})
+ add_meta_legacy.append(
+ {
+ "ref_id": item.id,
+ "id": item.id,
+ "key": item.metadata.key,
+ "memory": item.memory,
+ "memory_type": item.metadata.memory_type,
+ "status": item.metadata.status,
+ "confidence": item.metadata.confidence,
+ "tags": item.metadata.tags,
+ "updated_at": getattr(item.metadata, "updated_at", None)
+ or getattr(item.metadata, "update_at", None),
+ }
+ )
+
+ for item_data in prepared_update_items_with_original:
+ item = item_data["new_item"]
+ key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory)
+ update_content_legacy.append({"content": f"{key}: {item.memory}", "ref_id": item.id})
+ update_meta_legacy.append(
+ {
+ "ref_id": item.id,
+ "id": item.id,
+ "key": item.metadata.key,
+ "memory": item.memory,
+ "memory_type": item.metadata.memory_type,
+ "status": item.metadata.status,
+ "confidence": item.metadata.confidence,
+ "tags": item.metadata.tags,
+ "updated_at": getattr(item.metadata, "updated_at", None)
+ or getattr(item.metadata, "update_at", None),
+ }
+ )
+
+ events = []
+ if add_content_legacy:
+ event = self.scheduler_context.services.create_event_log(
+ label="addMemory",
+ from_memory_type=USER_INPUT_TYPE,
+ to_memory_type=LONG_TERM_MEMORY_TYPE,
+ user_id=msg.user_id,
+ mem_cube_id=msg.mem_cube_id,
+ mem_cube=self.scheduler_context.get_mem_cube(),
+ memcube_log_content=add_content_legacy,
+ metadata=add_meta_legacy,
+ memory_len=len(add_content_legacy),
+ memcube_name=self.scheduler_context.services.map_memcube_name(msg.mem_cube_id),
+ )
+ event.task_id = msg.task_id
+ events.append(event)
+ if update_content_legacy:
+ event = self.scheduler_context.services.create_event_log(
+ label="updateMemory",
+ from_memory_type=LONG_TERM_MEMORY_TYPE,
+ to_memory_type=LONG_TERM_MEMORY_TYPE,
+ user_id=msg.user_id,
+ mem_cube_id=msg.mem_cube_id,
+ mem_cube=self.scheduler_context.get_mem_cube(),
+ memcube_log_content=update_content_legacy,
+ metadata=update_meta_legacy,
+ memory_len=len(update_content_legacy),
+ memcube_name=self.scheduler_context.services.map_memcube_name(msg.mem_cube_id),
+ )
+ event.task_id = msg.task_id
+ events.append(event)
+ logger.info("send_add_log_messages_to_local_env: %s", len(events))
+ if events:
+ self.scheduler_context.services.submit_web_logs(
+ events, additional_log_info="send_add_log_messages_to_cloud_env"
+ )
+
+ def send_add_log_messages_to_cloud_env(
+ self,
+ msg: ScheduleMessageItem,
+ prepared_add_items,
+ prepared_update_items_with_original,
+ ) -> None:
+ kb_log_content: list[dict] = []
+ info = msg.info or {}
+
+ for item in prepared_add_items:
+ metadata = getattr(item, "metadata", None)
+ file_ids = getattr(metadata, "file_ids", None) if metadata else None
+ source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None
+ kb_log_content.append(
+ {
+ "log_source": "KNOWLEDGE_BASE_LOG",
+ "trigger_source": info.get("trigger_source", "Messages"),
+ "operation": "ADD",
+ "memory_id": item.id,
+ "content": item.memory,
+ "original_content": None,
+ "source_doc_id": source_doc_id,
+ }
+ )
+
+ for item_data in prepared_update_items_with_original:
+ item = item_data["new_item"]
+ metadata = getattr(item, "metadata", None)
+ file_ids = getattr(metadata, "file_ids", None) if metadata else None
+ source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None
+ kb_log_content.append(
+ {
+ "log_source": "KNOWLEDGE_BASE_LOG",
+ "trigger_source": info.get("trigger_source", "Messages"),
+ "operation": "UPDATE",
+ "memory_id": item.id,
+ "content": item.memory,
+ "original_content": item_data.get("original_content"),
+ "source_doc_id": source_doc_id,
+ }
+ )
+
+ if kb_log_content:
+ logger.info(
+ "[DIAGNOSTIC] add_handler.send_add_log_messages_to_cloud_env: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: %s, mem_cube_id: %s, task_id: %s. KB content: %s",
+ msg.user_id,
+ msg.mem_cube_id,
+ msg.task_id,
+ json.dumps(kb_log_content, indent=2),
+ )
+ event = self.scheduler_context.services.create_event_log(
+ label="knowledgeBaseUpdate",
+ from_memory_type=USER_INPUT_TYPE,
+ to_memory_type=LONG_TERM_MEMORY_TYPE,
+ user_id=msg.user_id,
+ mem_cube_id=msg.mem_cube_id,
+ mem_cube=self.scheduler_context.get_mem_cube(),
+ memcube_log_content=kb_log_content,
+ metadata=None,
+ memory_len=len(kb_log_content),
+ memcube_name=self.scheduler_context.services.map_memcube_name(msg.mem_cube_id),
+ )
+ event.log_content = f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
+ event.task_id = msg.task_id
+ self.scheduler_context.services.submit_web_logs([event])
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/answer_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/answer_handler.py
new file mode 100644
index 000000000..8ca56f859
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/answer_handler.py
@@ -0,0 +1,48 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.task_schemas import (
+ ANSWER_TASK_LABEL,
+ NOT_APPLICABLE_TYPE,
+ USER_INPUT_TYPE,
+)
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler
+
+
+logger = get_logger(__name__)
+
+if TYPE_CHECKING:
+ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+
+
+class AnswerMessageHandler(BaseSchedulerHandler):
+ @property
+ def expected_task_label(self) -> str:
+ return ANSWER_TASK_LABEL
+
+ def batch_handler(
+ self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]
+ ) -> None:
+ for msg in batch:
+ event = self.scheduler_context.services.create_event_log(
+ label="addMessage",
+ from_memory_type=USER_INPUT_TYPE,
+ to_memory_type=NOT_APPLICABLE_TYPE,
+ user_id=msg.user_id,
+ mem_cube_id=msg.mem_cube_id,
+ mem_cube=self.scheduler_context.get_mem_cube(),
+ memcube_log_content=[
+ {
+ "content": f"[Assistant] {msg.content}",
+ "ref_id": msg.item_id,
+ "role": "assistant",
+ }
+ ],
+ metadata=[],
+ memory_len=1,
+ memcube_name=self.scheduler_context.services.map_memcube_name(msg.mem_cube_id),
+ )
+ event.task_id = msg.task_id
+ self.scheduler_context.services.submit_web_logs([event])
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/feedback_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/feedback_handler.py
new file mode 100644
index 000000000..173d37b50
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/feedback_handler.py
@@ -0,0 +1,196 @@
+from __future__ import annotations
+
+import json
+
+from typing import TYPE_CHECKING
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.task_schemas import (
+ LONG_TERM_MEMORY_TYPE,
+ MEM_FEEDBACK_TASK_LABEL,
+ USER_INPUT_TYPE,
+)
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler
+from memos.mem_scheduler.utils.misc_utils import is_cloud_env
+
+
+logger = get_logger(__name__)
+
+if TYPE_CHECKING:
+ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+
+
+class FeedbackMessageHandler(BaseSchedulerHandler):
+ @property
+ def expected_task_label(self) -> str:
+ return MEM_FEEDBACK_TASK_LABEL
+
+ def batch_handler(
+ self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]
+ ) -> None:
+ for message in batch:
+ try:
+ self.process_single_feedback(message)
+ except Exception as e:
+ logger.error(
+ "Error processing feedbackMemory message: %s",
+ e,
+ exc_info=True,
+ )
+
+ def process_single_feedback(self, message: ScheduleMessageItem) -> None:
+ mem_cube = self.scheduler_context.get_mem_cube()
+
+ user_id = message.user_id
+ mem_cube_id = message.mem_cube_id
+ content = message.content
+
+ try:
+ feedback_data = json.loads(content) if isinstance(content, str) else content
+ if not isinstance(feedback_data, dict):
+ logger.error(
+ "Failed to decode feedback_data or it is not a dict: %s", feedback_data
+ )
+ return
+ except json.JSONDecodeError:
+ logger.error("Invalid JSON content for feedback message: %s", content, exc_info=True)
+ return
+
+ task_id = feedback_data.get("task_id") or message.task_id
+ feedback_result = self.scheduler_context.get_feedback_server().process_feedback(
+ user_id=user_id,
+ user_name=mem_cube_id,
+ session_id=feedback_data.get("session_id"),
+ chat_history=feedback_data.get("history", []),
+ retrieved_memory_ids=feedback_data.get("retrieved_memory_ids", []),
+ feedback_content=feedback_data.get("feedback_content"),
+ feedback_time=feedback_data.get("feedback_time"),
+ task_id=task_id,
+ info=feedback_data.get("info", None),
+ )
+
+ logger.info(
+ "Successfully processed feedback for user_id=%s, mem_cube_id=%s",
+ user_id,
+ mem_cube_id,
+ )
+
+ cloud_env = is_cloud_env()
+ if cloud_env:
+ record = feedback_result.get("record") if isinstance(feedback_result, dict) else {}
+ add_records = record.get("add") if isinstance(record, dict) else []
+ update_records = record.get("update") if isinstance(record, dict) else []
+
+ def _extract_fields(mem_item):
+ mem_id = (
+ getattr(mem_item, "id", None)
+ if not isinstance(mem_item, dict)
+ else mem_item.get("id")
+ )
+ mem_memory = (
+ getattr(mem_item, "memory", None)
+ if not isinstance(mem_item, dict)
+ else mem_item.get("memory") or mem_item.get("text")
+ )
+ if mem_memory is None and isinstance(mem_item, dict):
+ mem_memory = mem_item.get("text")
+ original_content = (
+ getattr(mem_item, "origin_memory", None)
+ if not isinstance(mem_item, dict)
+ else mem_item.get("origin_memory")
+ or mem_item.get("old_memory")
+ or mem_item.get("original_content")
+ )
+ source_doc_id = None
+ if isinstance(mem_item, dict):
+ source_doc_id = mem_item.get("source_doc_id", None)
+
+ return mem_id, mem_memory, original_content, source_doc_id
+
+ kb_log_content: list[dict] = []
+
+ for mem_item in add_records or []:
+ mem_id, mem_memory, _, source_doc_id = _extract_fields(mem_item)
+ if mem_id and mem_memory:
+ kb_log_content.append(
+ {
+ "log_source": "KNOWLEDGE_BASE_LOG",
+ "trigger_source": "Feedback",
+ "operation": "ADD",
+ "memory_id": mem_id,
+ "content": mem_memory,
+ "original_content": None,
+ "source_doc_id": source_doc_id,
+ }
+ )
+ else:
+ logger.warning(
+ "Skipping malformed feedback add item. user_id=%s mem_cube_id=%s task_id=%s item=%s",
+ user_id,
+ mem_cube_id,
+ task_id,
+ mem_item,
+ stack_info=True,
+ )
+
+ for mem_item in update_records or []:
+ mem_id, mem_memory, original_content, source_doc_id = _extract_fields(mem_item)
+ if mem_id and mem_memory:
+ kb_log_content.append(
+ {
+ "log_source": "KNOWLEDGE_BASE_LOG",
+ "trigger_source": "Feedback",
+ "operation": "UPDATE",
+ "memory_id": mem_id,
+ "content": mem_memory,
+ "original_content": original_content,
+ "source_doc_id": source_doc_id,
+ }
+ )
+ else:
+ logger.warning(
+ "Skipping malformed feedback update item. user_id=%s mem_cube_id=%s task_id=%s item=%s",
+ user_id,
+ mem_cube_id,
+ task_id,
+ mem_item,
+ stack_info=True,
+ )
+
+ logger.info("[Feedback Scheduler] kb_log_content: %s", kb_log_content)
+ if kb_log_content:
+ logger.info(
+ "[DIAGNOSTIC] feedback_handler: Creating knowledgeBaseUpdate event for feedback. user_id=%s mem_cube_id=%s task_id=%s items=%s",
+ user_id,
+ mem_cube_id,
+ task_id,
+ len(kb_log_content),
+ )
+ event = self.scheduler_context.services.create_event_log(
+ label="knowledgeBaseUpdate",
+ from_memory_type=USER_INPUT_TYPE,
+ to_memory_type=LONG_TERM_MEMORY_TYPE,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ memcube_log_content=kb_log_content,
+ metadata=None,
+ memory_len=len(kb_log_content),
+ memcube_name=self.scheduler_context.services.map_memcube_name(mem_cube_id),
+ )
+ event.log_content = f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
+ event.task_id = task_id
+ self.scheduler_context.services.submit_web_logs([event])
+ else:
+ logger.warning(
+ "No valid feedback content generated for web log. user_id=%s mem_cube_id=%s task_id=%s",
+ user_id,
+ mem_cube_id,
+ task_id,
+ stack_info=True,
+ )
+ else:
+ logger.info(
+ "Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)",
+ cloud_env,
+ )
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py
new file mode 100644
index 000000000..5d86c5589
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py
@@ -0,0 +1,409 @@
+from __future__ import annotations
+
+import concurrent.futures
+import contextlib
+import json
+import traceback
+
+from typing import TYPE_CHECKING
+
+from memos.context.context import ContextThreadPoolExecutor
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.task_schemas import (
+ LONG_TERM_MEMORY_TYPE,
+ MEM_READ_TASK_LABEL,
+ USER_INPUT_TYPE,
+)
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler
+from memos.mem_scheduler.utils.filter_utils import transform_name_to_key
+from memos.mem_scheduler.utils.misc_utils import is_cloud_env
+from memos.memories.textual.tree import TreeTextMemory
+
+
+logger = get_logger(__name__)
+
+if TYPE_CHECKING:
+ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+ from memos.types.general_types import UserContext
+
+
+class MemReadMessageHandler(BaseSchedulerHandler):
+ @property
+ def expected_task_label(self) -> str:
+ return MEM_READ_TASK_LABEL
+
+ def batch_handler(
+ self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]
+ ) -> None:
+ logger.info(
+ "[DIAGNOSTIC] mem_read_handler batch_handler called. Batch size: %s", len(batch)
+ )
+
+ with ContextThreadPoolExecutor(max_workers=min(8, len(batch))) as executor:
+ futures = [executor.submit(self.process_message, msg) for msg in batch]
+ for future in concurrent.futures.as_completed(futures):
+ try:
+ future.result()
+ except Exception as e:
+ logger.error("Thread task failed: %s", e, stack_info=True)
+
+ def process_message(self, message: ScheduleMessageItem):
+ try:
+ user_id = message.user_id
+ mem_cube_id = message.mem_cube_id
+ mem_cube = self.scheduler_context.get_mem_cube()
+ if mem_cube is None:
+ logger.error(
+ "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing",
+ user_id,
+ mem_cube_id,
+ stack_info=True,
+ )
+ return
+
+ content = message.content
+ user_name = message.user_name
+ info = message.info or {}
+ chat_history = message.chat_history
+ user_context = message.user_context
+
+ mem_ids = json.loads(content) if isinstance(content, str) else content
+ if not mem_ids:
+ return
+
+ logger.info(
+ "Processing mem_read for user_id=%s, mem_cube_id=%s, mem_ids=%s",
+ user_id,
+ mem_cube_id,
+ mem_ids,
+ )
+
+ text_mem = mem_cube.text_mem
+ if not isinstance(text_mem, TreeTextMemory):
+ logger.error("Expected TreeTextMemory but got %s", type(text_mem).__name__)
+ return
+
+ self._process_memories_with_reader(
+ mem_ids=mem_ids,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ text_mem=text_mem,
+ user_name=user_name,
+ custom_tags=info.get("custom_tags", None),
+ task_id=message.task_id,
+ info=info,
+ chat_history=chat_history,
+ user_context=user_context,
+ )
+
+ logger.info(
+ "Successfully processed mem_read for user_id=%s, mem_cube_id=%s",
+ user_id,
+ mem_cube_id,
+ )
+
+ except Exception as e:
+ logger.error("Error processing mem_read message: %s", e, stack_info=True)
+
+ def _process_memories_with_reader(
+ self,
+ mem_ids: list[str],
+ user_id: str,
+ mem_cube_id: str,
+ text_mem: TreeTextMemory,
+ user_name: str,
+ custom_tags: list[str] | None = None,
+ task_id: str | None = None,
+ info: dict | None = None,
+ chat_history: list | None = None,
+ user_context: UserContext | None = None,
+ ) -> None:
+ logger.info(
+ "[DIAGNOSTIC] mem_read_handler._process_memories_with_reader called. mem_ids: %s, user_id: %s, mem_cube_id: %s, task_id: %s",
+ mem_ids,
+ user_id,
+ mem_cube_id,
+ task_id,
+ )
+ kb_log_content: list[dict] = []
+ try:
+ mem_reader = self.scheduler_context.get_mem_reader()
+ if mem_reader is None:
+ logger.warning(
+ "mem_reader not available in scheduler, skipping enhanced processing"
+ )
+ return
+
+ # Get the original fast memory (raw chunk) items
+ memory_items = []
+ for mem_id in mem_ids:
+ try:
+ memory_item = text_mem.get(mem_id, user_name=user_name)
+ memory_items.append(memory_item)
+ except Exception as e:
+ logger.warning(
+ "[_process_memories_with_reader] Failed to get memory %s: %s", mem_id, e
+ )
+ continue
+
+ if not memory_items:
+ logger.warning("No valid memory items found for processing")
+ return
+
+ from memos.memories.textual.tree_text_memory.organize.manager import (
+ extract_working_binding_ids,
+ )
+
+ bindings_to_delete = extract_working_binding_ids(memory_items)
+ logger.info(
+ "Extracted %s working_binding ids to cleanup: %s",
+ len(bindings_to_delete),
+ list(bindings_to_delete),
+ )
+
+ logger.info("Processing %s memories with mem_reader", len(memory_items))
+
+ try:
+ processed_memories = mem_reader.fine_transfer_simple_mem(
+ memory_items,
+ type="chat",
+ custom_tags=custom_tags,
+ user_name=user_name,
+ chat_history=chat_history,
+ user_context=user_context,
+ )
+ except Exception as e:
+ logger.warning("%s: Fail to transfer mem: %s", e, memory_items)
+ processed_memories = []
+
+ if processed_memories and len(processed_memories) > 0:
+ flattened_memories = []
+ for memory_list in processed_memories:
+ flattened_memories.extend(memory_list)
+
+ logger.info("mem_reader processed %s enhanced memories", len(flattened_memories))
+
+ if flattened_memories:
+ mem_group = [
+ memory
+ for memory in flattened_memories
+ if memory.metadata.memory_type != "RawFileMemory"
+ ]
+ enhanced_mem_ids = text_mem.add(mem_group, user_name=user_name)
+ logger.info(
+ "Added %s enhanced memories: %s",
+ len(enhanced_mem_ids),
+ enhanced_mem_ids,
+ )
+
+ # add raw file nodes and edges
+ if mem_reader.save_rawfile:
+ raw_file_mem_group = [
+ memory
+ for memory in flattened_memories
+ if memory.metadata.memory_type == "RawFileMemory"
+ ]
+ text_mem.add_rawfile_nodes_n_edges(
+ raw_file_mem_group,
+ enhanced_mem_ids,
+ user_id=user_id,
+ user_name=user_name,
+ )
+ logger.info("Added %s Rawfile memories.", len(raw_file_mem_group))
+
+ # Mark merged_from memories as archived when provided in memory metadata
+ summary_memories = [
+ memory
+ for memory in flattened_memories
+ if memory.metadata.memory_type != "RawFileMemory"
+ ]
+ if mem_reader.graph_db:
+ for memory in summary_memories:
+ merged_from = (memory.metadata.info or {}).get("merged_from")
+ if merged_from:
+ old_ids = (
+ merged_from
+ if isinstance(merged_from, (list | tuple | set))
+ else [merged_from]
+ )
+ for old_id in old_ids:
+ try:
+ mem_reader.graph_db.update_node(
+ str(old_id), {"status": "archived"}, user_name=user_name
+ )
+ logger.info(
+ "[Scheduler] Archived merged_from memory: %s",
+ old_id,
+ )
+ except Exception as e:
+ logger.warning(
+ "[Scheduler] Failed to archive merged_from memory %s: %s",
+ old_id,
+ e,
+ )
+ else:
+ has_merged_from = any(
+ (m.metadata.info or {}).get("merged_from") for m in summary_memories
+ )
+ if has_merged_from:
+ logger.warning(
+ "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving."
+ )
+
+ cloud_env = is_cloud_env()
+ if cloud_env:
+ kb_log_content = []
+ for item in flattened_memories:
+ metadata = getattr(item, "metadata", None)
+ file_ids = getattr(metadata, "file_ids", None) if metadata else None
+ source_doc_id = (
+ file_ids[0] if isinstance(file_ids, list) and file_ids else None
+ )
+ kb_log_content.append(
+ {
+ "log_source": "KNOWLEDGE_BASE_LOG",
+ "trigger_source": info.get("trigger_source", "Messages")
+ if info
+ else "Messages",
+ "operation": "ADD",
+ "memory_id": item.id,
+ "content": item.memory,
+ "original_content": None,
+ "source_doc_id": source_doc_id,
+ }
+ )
+ if kb_log_content:
+ logger.info(
+ "[DIAGNOSTIC] mem_read_handler: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: %s, mem_cube_id: %s, task_id: %s. KB content: %s",
+ user_id,
+ mem_cube_id,
+ task_id,
+ json.dumps(kb_log_content, indent=2),
+ )
+ event = self.scheduler_context.services.create_event_log(
+ label="knowledgeBaseUpdate",
+ from_memory_type=USER_INPUT_TYPE,
+ to_memory_type=LONG_TERM_MEMORY_TYPE,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=self.scheduler_context.get_mem_cube(),
+ memcube_log_content=kb_log_content,
+ metadata=None,
+ memory_len=len(kb_log_content),
+ memcube_name=self.scheduler_context.services.map_memcube_name(
+ mem_cube_id
+ ),
+ )
+ event.log_content = (
+ f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
+ )
+ event.task_id = task_id
+ self.scheduler_context.services.submit_web_logs([event])
+ else:
+ add_content_legacy: list[dict] = []
+ add_meta_legacy: list[dict] = []
+ for item_id, item in zip(
+ enhanced_mem_ids, flattened_memories, strict=False
+ ):
+ key = getattr(item.metadata, "key", None) or transform_name_to_key(
+ name=item.memory
+ )
+ add_content_legacy.append(
+ {"content": f"{key}: {item.memory}", "ref_id": item_id}
+ )
+ add_meta_legacy.append(
+ {
+ "ref_id": item_id,
+ "id": item_id,
+ "key": item.metadata.key,
+ "memory": item.memory,
+ "memory_type": item.metadata.memory_type,
+ "status": item.metadata.status,
+ "confidence": item.metadata.confidence,
+ "tags": item.metadata.tags,
+ "updated_at": getattr(item.metadata, "updated_at", None)
+ or getattr(item.metadata, "update_at", None),
+ }
+ )
+ if add_content_legacy:
+ event = self.scheduler_context.services.create_event_log(
+ label="addMemory",
+ from_memory_type=USER_INPUT_TYPE,
+ to_memory_type=LONG_TERM_MEMORY_TYPE,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=self.scheduler_context.get_mem_cube(),
+ memcube_log_content=add_content_legacy,
+ metadata=add_meta_legacy,
+ memory_len=len(add_content_legacy),
+ memcube_name=self.scheduler_context.services.map_memcube_name(
+ mem_cube_id
+ ),
+ )
+ event.task_id = task_id
+ self.scheduler_context.services.submit_web_logs([event])
+ else:
+ logger.info("No enhanced memories generated by mem_reader")
+ else:
+ logger.info("mem_reader returned no processed memories")
+
+ delete_ids = list(mem_ids)
+ if bindings_to_delete:
+ delete_ids.extend(list(bindings_to_delete))
+ delete_ids = list(dict.fromkeys(delete_ids))
+ if delete_ids:
+ try:
+ text_mem.delete(delete_ids, user_name=user_name)
+ logger.info(
+ "Delete raw/working mem_ids: %s for user_name: %s", delete_ids, user_name
+ )
+ except Exception as e:
+ logger.warning("Failed to delete some mem_ids %s: %s", delete_ids, e)
+ else:
+ logger.info("No mem_ids to delete (nothing to cleanup)")
+
+ text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name)
+ logger.info("Remove and Refresh Memories")
+ logger.debug("Finished add %s memory: %s", user_id, mem_ids)
+
+ except Exception as exc:
+ logger.error(
+ "Error in _process_memories_with_reader: %s",
+ traceback.format_exc(),
+ exc_info=True,
+ )
+ with contextlib.suppress(Exception):
+ cloud_env = is_cloud_env()
+ if cloud_env:
+ if not kb_log_content:
+ trigger_source = (
+ info.get("trigger_source", "Messages") if info else "Messages"
+ )
+ kb_log_content = [
+ {
+ "log_source": "KNOWLEDGE_BASE_LOG",
+ "trigger_source": trigger_source,
+ "operation": "ADD",
+ "memory_id": mem_id,
+ "content": None,
+ "original_content": None,
+ "source_doc_id": None,
+ }
+ for mem_id in mem_ids
+ ]
+ event = self.scheduler_context.services.create_event_log(
+ label="knowledgeBaseUpdate",
+ from_memory_type=USER_INPUT_TYPE,
+ to_memory_type=LONG_TERM_MEMORY_TYPE,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=self.scheduler_context.get_mem_cube(),
+ memcube_log_content=kb_log_content,
+ metadata=None,
+ memory_len=len(kb_log_content),
+ memcube_name=self.scheduler_context.services.map_memcube_name(mem_cube_id),
+ )
+ event.log_content = f"Knowledge Base Memory Update failed: {exc!s}"
+ event.task_id = task_id
+ event.status = "failed"
+ self.scheduler_context.services.submit_web_logs([event])
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_reorganize_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_reorganize_handler.py
new file mode 100644
index 000000000..1d1dbe566
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_reorganize_handler.py
@@ -0,0 +1,252 @@
+from __future__ import annotations
+
+import concurrent.futures
+import contextlib
+import json
+import traceback
+
+from typing import TYPE_CHECKING
+
+from memos.context.context import ContextThreadPoolExecutor
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.task_schemas import (
+ LONG_TERM_MEMORY_TYPE,
+ MEM_ORGANIZE_TASK_LABEL,
+)
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler
+from memos.mem_scheduler.utils.filter_utils import transform_name_to_key
+from memos.memories.textual.tree import TreeTextMemory
+
+
+logger = get_logger(__name__)
+
+if TYPE_CHECKING:
+ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+ from memos.memories.textual.item import TextualMemoryItem
+
+
+class MemReorganizeMessageHandler(BaseSchedulerHandler):
+ @property
+ def expected_task_label(self) -> str:
+ return MEM_ORGANIZE_TASK_LABEL
+
+ def batch_handler(
+ self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]
+ ) -> None:
+ with ContextThreadPoolExecutor(max_workers=min(8, len(batch))) as executor:
+ futures = [executor.submit(self.process_message, msg) for msg in batch]
+ for future in concurrent.futures.as_completed(futures):
+ try:
+ future.result()
+ except Exception as e:
+ logger.error("Thread task failed: %s", e, exc_info=True)
+
+ def process_message(self, message: ScheduleMessageItem):
+ try:
+ user_id = message.user_id
+ mem_cube_id = message.mem_cube_id
+ mem_cube = self.scheduler_context.get_mem_cube()
+ if mem_cube is None:
+ logger.warning(
+ "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing",
+ user_id,
+ mem_cube_id,
+ )
+ return
+ content = message.content
+ user_name = message.user_name
+
+ mem_ids = json.loads(content) if isinstance(content, str) else content
+ if not mem_ids:
+ return
+
+ logger.info(
+ "Processing mem_reorganize for user_id=%s, mem_cube_id=%s, mem_ids=%s",
+ user_id,
+ mem_cube_id,
+ mem_ids,
+ )
+
+ text_mem = mem_cube.text_mem
+ if not isinstance(text_mem, TreeTextMemory):
+ logger.error("Expected TreeTextMemory but got %s", type(text_mem).__name__)
+ return
+
+ self._process_memories_with_reorganize(
+ mem_ids=mem_ids,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ text_mem=text_mem,
+ user_name=user_name,
+ )
+
+ with contextlib.suppress(Exception):
+ mem_items: list[TextualMemoryItem] = []
+ for mid in mem_ids:
+ with contextlib.suppress(Exception):
+ mem_items.append(text_mem.get(mid, user_name=user_name))
+ if len(mem_items) > 1:
+ keys: list[str] = []
+ memcube_content: list[dict] = []
+ meta: list[dict] = []
+ merged_target_ids: set[str] = set()
+ with contextlib.suppress(Exception):
+ if hasattr(text_mem, "graph_store"):
+ for mid in mem_ids:
+ edges = text_mem.graph_store.get_edges(
+ mid, type="MERGED_TO", direction="OUT"
+ )
+ for edge in edges:
+ target = edge.get("to") or edge.get("dst") or edge.get("target")
+ if target:
+ merged_target_ids.add(target)
+ for item in mem_items:
+ key = getattr(
+ getattr(item, "metadata", {}), "key", None
+ ) or transform_name_to_key(getattr(item, "memory", ""))
+ keys.append(key)
+ memcube_content.append(
+ {"content": key or "(no key)", "ref_id": item.id, "type": "merged"}
+ )
+ meta.append(
+ {
+ "ref_id": item.id,
+ "id": item.id,
+ "key": key,
+ "memory": item.memory,
+ "memory_type": item.metadata.memory_type,
+ "status": item.metadata.status,
+ "confidence": item.metadata.confidence,
+ "tags": item.metadata.tags,
+ "updated_at": getattr(item.metadata, "updated_at", None)
+ or getattr(item.metadata, "update_at", None),
+ }
+ )
+ combined_key = keys[0] if keys else ""
+ post_ref_id = None
+ post_meta = {
+ "ref_id": None,
+ "id": None,
+ "key": None,
+ "memory": None,
+ "memory_type": None,
+ "status": None,
+ "confidence": None,
+ "tags": None,
+ "updated_at": None,
+ }
+ if merged_target_ids:
+ post_ref_id = next(iter(merged_target_ids))
+ with contextlib.suppress(Exception):
+ merged_item = text_mem.get(post_ref_id, user_name=user_name)
+ combined_key = (
+ getattr(getattr(merged_item, "metadata", {}), "key", None)
+ or combined_key
+ )
+ post_meta = {
+ "ref_id": post_ref_id,
+ "id": post_ref_id,
+ "key": getattr(getattr(merged_item, "metadata", {}), "key", None),
+ "memory": getattr(merged_item, "memory", None),
+ "memory_type": getattr(
+ getattr(merged_item, "metadata", {}), "memory_type", None
+ ),
+ "status": getattr(
+ getattr(merged_item, "metadata", {}), "status", None
+ ),
+ "confidence": getattr(
+ getattr(merged_item, "metadata", {}), "confidence", None
+ ),
+ "tags": getattr(getattr(merged_item, "metadata", {}), "tags", None),
+ "updated_at": getattr(
+ getattr(merged_item, "metadata", {}), "updated_at", None
+ )
+ or getattr(getattr(merged_item, "metadata", {}), "update_at", None),
+ }
+ if not post_ref_id:
+ import hashlib
+
+ post_ref_id = (
+ "merge-" + hashlib.md5("".join(sorted(mem_ids)).encode()).hexdigest()
+ )
+ post_meta["ref_id"] = post_ref_id
+ post_meta["id"] = post_ref_id
+ if not post_meta.get("key"):
+ post_meta["key"] = combined_key
+ if not keys:
+ keys = [item.id for item in mem_items]
+ memcube_content.append(
+ {
+ "content": combined_key if combined_key else "(no key)",
+ "ref_id": post_ref_id,
+ "type": "postMerge",
+ }
+ )
+ meta.append(post_meta)
+ event = self.scheduler_context.services.create_event_log(
+ label="mergeMemory",
+ from_memory_type=LONG_TERM_MEMORY_TYPE,
+ to_memory_type=LONG_TERM_MEMORY_TYPE,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ memcube_log_content=memcube_content,
+ metadata=meta,
+ memory_len=len(keys),
+ memcube_name=self.scheduler_context.services.map_memcube_name(mem_cube_id),
+ )
+ self.scheduler_context.services.submit_web_logs([event])
+
+ logger.info(
+ "Successfully processed mem_reorganize for user_id=%s, mem_cube_id=%s",
+ user_id,
+ mem_cube_id,
+ )
+
+ except Exception as e:
+ logger.error("Error processing mem_reorganize message: %s", e, exc_info=True)
+
+ def _process_memories_with_reorganize(
+ self,
+ mem_ids: list[str],
+ user_id: str,
+ mem_cube_id: str,
+ mem_cube,
+ text_mem: TreeTextMemory,
+ user_name: str,
+ ) -> None:
+ try:
+ mem_reader = self.scheduler_context.get_mem_reader()
+ if mem_reader is None:
+ logger.warning(
+ "mem_reader not available in scheduler, skipping enhanced processing"
+ )
+ return
+
+ memory_items = []
+ for mem_id in mem_ids:
+ try:
+ memory_item = text_mem.get(mem_id, user_name=user_name)
+ memory_items.append(memory_item)
+ except Exception as e:
+ logger.warning(
+ "Failed to get memory %s: %s|%s", mem_id, e, traceback.format_exc()
+ )
+ continue
+
+ if not memory_items:
+ logger.warning("No valid memory items found for processing")
+ return
+
+ logger.info("Processing %s memories with mem_reader", len(memory_items))
+ text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name)
+ logger.info("Remove and Refresh Memories")
+ logger.debug("Finished add %s memory: %s", user_id, mem_ids)
+
+ except Exception:
+ logger.error(
+ "Error in _process_memories_with_reorganize: %s",
+ traceback.format_exc(),
+ exc_info=True,
+ )
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/memory_update_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/memory_update_handler.py
new file mode 100644
index 000000000..a8968e878
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/memory_update_handler.py
@@ -0,0 +1,277 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem
+from memos.mem_scheduler.schemas.task_schemas import (
+ DEFAULT_MAX_QUERY_KEY_WORDS,
+ MEM_UPDATE_TASK_LABEL,
+ QUERY_TASK_LABEL,
+)
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler
+from memos.mem_scheduler.utils.filter_utils import is_all_chinese, is_all_english
+from memos.memories.textual.naive import NaiveTextMemory
+from memos.memories.textual.tree import TreeTextMemory
+
+
+logger = get_logger(__name__)
+
+if TYPE_CHECKING:
+ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+ from memos.memories.textual.item import TextualMemoryItem
+ from memos.types import MemCubeID, UserID
+
+
+class MemoryUpdateHandler(BaseSchedulerHandler):
+ @property
+ def expected_task_label(self) -> str:
+ return MEM_UPDATE_TASK_LABEL
+
+ def batch_handler(
+ self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]
+ ) -> None:
+ self.long_memory_update_process(user_id=user_id, mem_cube_id=mem_cube_id, messages=batch)
+
+ def long_memory_update_process(
+ self,
+ user_id: str,
+ mem_cube_id: str,
+ messages: list[ScheduleMessageItem],
+ ) -> None:
+ mem_cube = self.scheduler_context.get_mem_cube()
+ monitor = self.scheduler_context.get_monitor()
+
+ query_key_words_limit = self.scheduler_context.get_query_key_words_limit()
+
+ for msg in messages:
+ monitor.register_query_monitor_if_not_exists(user_id=user_id, mem_cube_id=mem_cube_id)
+
+ query = msg.content
+ query_keywords = monitor.extract_query_keywords(query=query)
+ logger.info(
+ 'Extracted keywords "%s" from query "%s" for user_id=%s',
+ query_keywords,
+ query,
+ user_id,
+ )
+
+ if len(query_keywords) == 0:
+ stripped_query = query.strip()
+ if is_all_english(stripped_query):
+ words = stripped_query.split()
+ elif is_all_chinese(stripped_query):
+ words = stripped_query
+ else:
+ logger.debug(
+ "Mixed-language memory, using character count: %s...",
+ stripped_query[:50],
+ )
+ words = stripped_query
+
+ query_keywords = list(set(words[:query_key_words_limit]))
+ logger.error(
+ "Keyword extraction failed for query '%s' (user_id=%s). Using fallback keywords: %s... (truncated)",
+ query,
+ user_id,
+ query_keywords[:10],
+ exc_info=True,
+ )
+
+ item = QueryMonitorItem(
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ query_text=query,
+ keywords=query_keywords,
+ max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS,
+ )
+
+ query_db_manager = monitor.query_monitors[user_id][mem_cube_id]
+ query_db_manager.obj.put(item=item)
+ query_db_manager.sync_with_orm()
+ logger.debug(
+ "Queries in monitor for user_id=%s, mem_cube_id=%s: %s",
+ user_id,
+ mem_cube_id,
+ query_db_manager.obj.get_queries_with_timesort(),
+ )
+
+ queries = [msg.content for msg in messages]
+
+ cur_working_memory, new_candidates = self.process_session_turn(
+ queries=queries,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ top_k=self.scheduler_context.get_top_k(),
+ )
+ logger.info(
+ "[long_memory_update_process] Processed %s queries %s and retrieved %s new candidate memories for user_id=%s: "
+ + ("\n- " + "\n- ".join([f"{one.id}: {one.memory}" for one in new_candidates])),
+ len(queries),
+ queries,
+ len(new_candidates),
+ user_id,
+ )
+
+ new_order_working_memory = self.scheduler_context.services.replace_working_memory(
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ original_memory=cur_working_memory,
+ new_memory=new_candidates,
+ )
+ logger.debug(
+ "[long_memory_update_process] Final working memory size: %s memories for user_id=%s",
+ len(new_order_working_memory),
+ user_id,
+ )
+
+ old_memory_texts = "\n- " + "\n- ".join(
+ [f"{one.id}: {one.memory}" for one in cur_working_memory]
+ )
+ new_memory_texts = "\n- " + "\n- ".join(
+ [f"{one.id}: {one.memory}" for one in new_order_working_memory]
+ )
+
+ logger.info(
+ "[long_memory_update_process] For user_id='%s', mem_cube_id='%s': "
+ "Scheduler replaced working memory based on query history %s. "
+ "Old working memory (%s items): %s. "
+ "New working memory (%s items): %s.",
+ user_id,
+ mem_cube_id,
+ queries,
+ len(cur_working_memory),
+ old_memory_texts,
+ len(new_order_working_memory),
+ new_memory_texts,
+ )
+
+ logger.debug(
+ "Activation memory update %s (interval: %ss)",
+ "enabled" if self.scheduler_context.get_enable_activation_memory() else "disabled",
+ monitor.act_mem_update_interval,
+ )
+ if self.scheduler_context.get_enable_activation_memory():
+ self.scheduler_context.services.update_activation_memory_periodically(
+ interval_seconds=monitor.act_mem_update_interval,
+ label=QUERY_TASK_LABEL,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ )
+
+ def process_session_turn(
+ self,
+ queries: str | list[str],
+ user_id: UserID | str,
+ mem_cube_id: MemCubeID | str,
+ mem_cube,
+ top_k: int = 10,
+ ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]] | None:
+ text_mem_base = mem_cube.text_mem
+ if not isinstance(text_mem_base, TreeTextMemory):
+ if isinstance(text_mem_base, NaiveTextMemory):
+ logger.debug(
+ "NaiveTextMemory used for mem_cube_id=%s, processing session turn with simple search.",
+ mem_cube_id,
+ )
+ cur_working_memory = []
+ else:
+ logger.warning(
+ "Not implemented! Expected TreeTextMemory but got %s for mem_cube_id=%s, user_id=%s. text_mem_base value: %s",
+ type(text_mem_base).__name__,
+ mem_cube_id,
+ user_id,
+ text_mem_base,
+ )
+ return [], []
+ else:
+ cur_working_memory = text_mem_base.get_working_memory(user_name=mem_cube_id)
+ cur_working_memory = cur_working_memory[:top_k]
+
+ logger.info(
+ "[process_session_turn] Processing %s queries for user_id=%s, mem_cube_id=%s",
+ len(queries),
+ user_id,
+ mem_cube_id,
+ )
+
+ text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
+ monitor = self.scheduler_context.get_monitor()
+ intent_result = monitor.detect_intent(
+ q_list=queries, text_working_memory=text_working_memory
+ )
+
+ time_trigger_flag = False
+ if monitor.timed_trigger(
+ last_time=monitor.last_query_consume_time,
+ interval_seconds=monitor.query_trigger_interval,
+ ):
+ time_trigger_flag = True
+
+ if (not intent_result["trigger_retrieval"]) and (not time_trigger_flag):
+ logger.info(
+ "[process_session_turn] Query schedule not triggered for user_id=%s, mem_cube_id=%s. Intent_result: %s",
+ user_id,
+ mem_cube_id,
+ intent_result,
+ )
+ return
+ if (not intent_result["trigger_retrieval"]) and time_trigger_flag:
+ logger.info(
+ "[process_session_turn] Query schedule forced to trigger due to time ticker for user_id=%s, mem_cube_id=%s",
+ user_id,
+ mem_cube_id,
+ )
+ intent_result["trigger_retrieval"] = True
+ intent_result["missing_evidences"] = queries
+ else:
+ logger.info(
+ "[process_session_turn] Query schedule triggered for user_id=%s, mem_cube_id=%s. Missing evidences: %s",
+ user_id,
+ mem_cube_id,
+ intent_result["missing_evidences"],
+ )
+
+ missing_evidences = intent_result["missing_evidences"]
+ num_evidence = len(missing_evidences)
+ k_per_evidence = max(1, top_k // max(1, num_evidence))
+ new_candidates: list[TextualMemoryItem] = []
+ retriever = self.scheduler_context.get_retriever()
+ search_method = self.scheduler_context.get_search_method()
+
+ for item in missing_evidences:
+ logger.info(
+ "[process_session_turn] Searching for missing evidence: '%s' with top_k=%s for user_id=%s",
+ item,
+ k_per_evidence,
+ user_id,
+ )
+
+ search_args = {}
+ if isinstance(text_mem_base, NaiveTextMemory):
+ try:
+ results = text_mem_base.search(query=item, top_k=k_per_evidence)
+ except Exception as e:
+ logger.warning("NaiveTextMemory search failed: %s", e)
+ results = []
+ else:
+ results = retriever.search(
+ query=item,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ top_k=k_per_evidence,
+ method=search_method,
+ search_args=search_args,
+ )
+
+ logger.info(
+ "[process_session_turn] Search results for missing evidence '%s': \n- %s",
+ item,
+ "\n- ".join([f"{one.id}: {one.memory}" for one in results]),
+ )
+ new_candidates.extend(results)
+ return cur_working_memory, new_candidates
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py
new file mode 100644
index 000000000..b7dd2fa4c
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py
@@ -0,0 +1,94 @@
+from __future__ import annotations
+
+import concurrent.futures
+import json
+
+from typing import TYPE_CHECKING
+
+from memos.context.context import ContextThreadPoolExecutor
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.task_schemas import PREF_ADD_TASK_LABEL
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler
+from memos.memories.textual.preference import PreferenceTextMemory
+
+
+logger = get_logger(__name__)
+
+if TYPE_CHECKING:
+ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+
+
+class PrefAddMessageHandler(BaseSchedulerHandler):
+ @property
+ def expected_task_label(self) -> str:
+ return PREF_ADD_TASK_LABEL
+
+ def batch_handler(
+ self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]
+ ) -> None:
+ with ContextThreadPoolExecutor(max_workers=min(8, len(batch))) as executor:
+ futures = [executor.submit(self.process_message, msg) for msg in batch]
+ for future in concurrent.futures.as_completed(futures):
+ try:
+ future.result()
+ except Exception as e:
+ logger.error("Thread task failed: %s", e, exc_info=True)
+
+ def process_message(self, message: ScheduleMessageItem):
+ try:
+ mem_cube = self.scheduler_context.get_mem_cube()
+ if mem_cube is None:
+ logger.warning(
+ "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing",
+ message.user_id,
+ message.mem_cube_id,
+ )
+ return
+
+ user_id = message.user_id
+ session_id = message.session_id
+ mem_cube_id = message.mem_cube_id
+ content = message.content
+ messages_list = json.loads(content)
+ user_context = message.user_context
+ info = message.info or {}
+
+ logger.info("Processing pref_add for user_id=%s, mem_cube_id=%s", user_id, mem_cube_id)
+
+ pref_mem = mem_cube.pref_mem
+ if pref_mem is None:
+ logger.warning(
+ "Preference memory not initialized for mem_cube_id=%s, skipping pref_add processing",
+ mem_cube_id,
+ )
+ return
+ if not isinstance(pref_mem, PreferenceTextMemory):
+ logger.error(
+ "Expected PreferenceTextMemory but got %s for mem_cube_id=%s",
+ type(pref_mem).__name__,
+ mem_cube_id,
+ )
+ return
+
+ pref_memories = pref_mem.get_memory(
+ messages_list,
+ type="chat",
+ info={
+ **info,
+ "user_id": user_id,
+ "session_id": session_id,
+ "mem_cube_id": mem_cube_id,
+ },
+ user_context=user_context,
+ )
+ pref_ids = pref_mem.add(pref_memories)
+
+ logger.info(
+ "Successfully processed and add preferences for user_id=%s, mem_cube_id=%s, pref_ids=%s",
+ user_id,
+ mem_cube_id,
+ pref_ids,
+ )
+
+ except Exception as e:
+ logger.error("Error processing pref_add message: %s", e, exc_info=True)
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/query_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/query_handler.py
new file mode 100644
index 000000000..29f0a88fd
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/query_handler.py
@@ -0,0 +1,64 @@
+from __future__ import annotations
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.schemas.task_schemas import (
+ MEM_UPDATE_TASK_LABEL,
+ NOT_APPLICABLE_TYPE,
+ QUERY_TASK_LABEL,
+ USER_INPUT_TYPE,
+)
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler
+
+
+logger = get_logger(__name__)
+
+
+class QueryMessageHandler(BaseSchedulerHandler):
+ @property
+ def expected_task_label(self) -> str:
+ return QUERY_TASK_LABEL
+
+ def batch_handler(
+ self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]
+ ) -> None:
+ mem_update_messages: list[ScheduleMessageItem] = []
+ for msg in batch:
+ try:
+ event = self.scheduler_context.services.create_event_log(
+ label="addMessage",
+ from_memory_type=USER_INPUT_TYPE,
+ to_memory_type=NOT_APPLICABLE_TYPE,
+ user_id=msg.user_id,
+ mem_cube_id=msg.mem_cube_id,
+ mem_cube=self.scheduler_context.get_mem_cube(),
+ memcube_log_content=[
+ {
+ "content": f"[User] {msg.content}",
+ "ref_id": msg.item_id,
+ "role": "user",
+ }
+ ],
+ metadata=[],
+ memory_len=1,
+ memcube_name=self.scheduler_context.services.map_memcube_name(msg.mem_cube_id),
+ )
+ event.task_id = msg.task_id
+ self.scheduler_context.services.submit_web_logs([event])
+ except Exception:
+ logger.exception("Failed to record addMessage log for query")
+
+ update_msg = ScheduleMessageItem(
+ user_id=msg.user_id,
+ mem_cube_id=msg.mem_cube_id,
+ label=MEM_UPDATE_TASK_LABEL,
+ content=msg.content,
+ session_id=msg.session_id,
+ user_name=msg.user_name,
+ info=msg.info,
+ task_id=msg.task_id,
+ )
+ mem_update_messages.append(update_msg)
+
+ if mem_update_messages:
+ self.scheduler_context.services.submit_messages(messages=mem_update_messages)
diff --git a/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py b/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py
index cb5a49421..af46b3dcd 100644
--- a/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py
+++ b/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py
@@ -17,11 +17,8 @@
from memos.log import get_logger
from memos.mem_scheduler.schemas.task_schemas import (
- ADD_TASK_LABEL,
- ANSWER_TASK_LABEL,
DEFAULT_PENDING_CLAIM_MIN_IDLE_MS,
PREF_ADD_TASK_LABEL,
- QUERY_TASK_LABEL,
TaskPriorityLevel,
)
from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule
@@ -38,11 +35,7 @@ def __init__(self):
"""
# Cache of fetched messages grouped by (user_id, mem_cube_id, task_label)
self._cache = None
- self.tasks_priorities = {
- ADD_TASK_LABEL: TaskPriorityLevel.LEVEL_1,
- QUERY_TASK_LABEL: TaskPriorityLevel.LEVEL_1,
- ANSWER_TASK_LABEL: TaskPriorityLevel.LEVEL_1,
- }
+ self.tasks_priorities = {}
# Per-task minimum idle time (ms) before claiming pending messages
# Default fallback handled in `get_task_idle_min`.
@@ -54,6 +47,37 @@ def __init__(self):
def get_stream_priorities(self) -> None | dict:
return None
+ def set_task_config(
+ self,
+ task_label: str,
+ priority: TaskPriorityLevel | None = None,
+ min_idle_ms: int | None = None,
+ ):
+ """
+ Dynamically register or update task configuration.
+
+ Args:
+ task_label: The label of the task.
+ priority: The priority level of the task.
+ min_idle_ms: The minimum idle time (ms) for claiming pending messages.
+ """
+ if priority is not None:
+ self.tasks_priorities[task_label] = priority
+ if min_idle_ms is not None:
+ self.tasks_min_idle_ms[task_label] = min_idle_ms
+
+ def remove_task_config(self, task_label: str):
+ """
+ Remove task configuration for a specific label.
+
+ Args:
+ task_label: The label of the task to remove configuration for.
+ """
+ if task_label in self.tasks_priorities:
+ del self.tasks_priorities[task_label]
+ if task_label in self.tasks_min_idle_ms:
+ del self.tasks_min_idle_ms[task_label]
+
def get_task_priority(self, task_label: str):
return self.tasks_priorities.get(task_label, TaskPriorityLevel.LEVEL_3)
diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
index d570dccdd..dc7d86752 100644
--- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
+++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
@@ -718,10 +718,10 @@ def _claim_pending_messages(
justid=False,
)
if len(claimed_result) == 2:
- next_id, claimed = claimed_result
- deleted_ids = []
+ _next_id, claimed = claimed_result
+ _deleted_ids = []
elif len(claimed_result) == 3:
- next_id, claimed, deleted_ids = claimed_result
+ _next_id, claimed, _deleted_ids = claimed_result
else:
raise ValueError(
f"Unexpected xautoclaim response length: {len(claimed_result)}"
@@ -745,10 +745,10 @@ def _claim_pending_messages(
justid=False,
)
if len(claimed_result) == 2:
- next_id, claimed = claimed_result
- deleted_ids = []
+ _next_id, claimed = claimed_result
+ _deleted_ids = []
elif len(claimed_result) == 3:
- next_id, claimed, deleted_ids = claimed_result
+ _next_id, claimed, _deleted_ids = claimed_result
else:
raise ValueError(
f"Unexpected xautoclaim response length: {len(claimed_result)}"
diff --git a/src/memos/mem_scheduler/task_schedule_modules/registry.py b/src/memos/mem_scheduler/task_schedule_modules/registry.py
new file mode 100644
index 000000000..f47be933e
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/registry.py
@@ -0,0 +1,55 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+ from .context import SchedulerHandlerContext
+
+from memos.mem_scheduler.schemas.task_schemas import (
+ ADD_TASK_LABEL,
+ ANSWER_TASK_LABEL,
+ MEM_FEEDBACK_TASK_LABEL,
+ MEM_ORGANIZE_TASK_LABEL,
+ MEM_READ_TASK_LABEL,
+ MEM_UPDATE_TASK_LABEL,
+ PREF_ADD_TASK_LABEL,
+ QUERY_TASK_LABEL,
+ TaskPriorityLevel,
+)
+
+from .handlers.add_handler import AddMessageHandler
+from .handlers.answer_handler import AnswerMessageHandler
+from .handlers.feedback_handler import FeedbackMessageHandler
+from .handlers.mem_read_handler import MemReadMessageHandler
+from .handlers.mem_reorganize_handler import MemReorganizeMessageHandler
+from .handlers.memory_update_handler import MemoryUpdateHandler
+from .handlers.pref_add_handler import PrefAddMessageHandler
+from .handlers.query_handler import QueryMessageHandler
+
+
+class SchedulerHandlerRegistry:
+ def __init__(self, scheduler_context: SchedulerHandlerContext) -> None:
+ self.query = QueryMessageHandler(scheduler_context)
+ self.answer = AnswerMessageHandler(scheduler_context)
+ self.add = AddMessageHandler(scheduler_context)
+ self.memory_update = MemoryUpdateHandler(scheduler_context)
+ self.mem_feedback = FeedbackMessageHandler(scheduler_context)
+ self.mem_read = MemReadMessageHandler(scheduler_context)
+ self.mem_reorganize = MemReorganizeMessageHandler(scheduler_context)
+ self.pref_add = PrefAddMessageHandler(scheduler_context)
+
+ def build_dispatch_map(self) -> dict[str, Callable | tuple]:
+ predefined_handlers = {
+ QUERY_TASK_LABEL: (self.query, TaskPriorityLevel.LEVEL_1, None),
+ ANSWER_TASK_LABEL: (self.answer, TaskPriorityLevel.LEVEL_1, None),
+ MEM_UPDATE_TASK_LABEL: self.memory_update,
+ ADD_TASK_LABEL: (self.add, TaskPriorityLevel.LEVEL_1, None),
+ MEM_READ_TASK_LABEL: self.mem_read,
+ MEM_ORGANIZE_TASK_LABEL: self.mem_reorganize,
+ PREF_ADD_TASK_LABEL: (self.pref_add, None, 600_000),
+ MEM_FEEDBACK_TASK_LABEL: self.mem_feedback,
+ }
+ return predefined_handlers
diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py
index 46770758d..7e40f1d50 100644
--- a/src/memos/memories/textual/item.py
+++ b/src/memos/memories/textual/item.py
@@ -45,6 +45,43 @@ class SourceMessage(BaseModel):
model_config = ConfigDict(extra="allow")
+class ArchivedTextualMemory(BaseModel):
+ """
+ This is a light-weighted class for storing archived versions of memories.
+
+ When an existing memory item needs to be updated due to conflict/duplicate with new memory contents,
+ its previous contents will be preserved, in 2 places:
+ 1. ArchivedTextualMemory, which only contains minimal information, like memory content and create time,
+ stored in the 'history' field of the original node.
+ 2. A new memory node, storing full original information including sources and embedding,
+ and referenced by 'archived_memory_id'.
+ """
+
+ version: int = Field(
+ default=1,
+ description="The version of the archived memory content. Will be compared to the version of the active memory item(in Metadata)",
+ )
+ is_fast: bool = Field(
+ default=False,
+ description="Whether this archived memory was created in fast mode, thus raw.",
+ )
+ memory: str | None = Field(
+ default_factory=lambda: "", description="The content of the archived version of the memory."
+ )
+ update_type: Literal["conflict", "duplicate", "extract", "unrelated"] = Field(
+ default="unrelated",
+ description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`).",
+ )
+ archived_memory_id: str | None = Field(
+ default=None,
+ description="Link to a memory node with status='archived', storing full original information, including sources and embedding.",
+ )
+ created_at: str | None = Field(
+ default_factory=lambda: datetime.now().isoformat(),
+ description="The time the memory was created.",
+ )
+
+
class TextualMemoryMetadata(BaseModel):
"""Metadata for a memory item.
@@ -60,9 +97,29 @@ class TextualMemoryMetadata(BaseModel):
default=None,
description="The ID of the session during which the memory was created. Useful for tracking context in conversations.",
)
- status: Literal["activated", "archived", "deleted"] | None = Field(
+ status: Literal["activated", "resolving", "archived", "deleted"] | None = Field(
default="activated",
- description="The status of the memory, e.g., 'activated', 'archived', 'deleted'.",
+ description="The status of the memory, e.g., 'activated', 'resolving'(updating with conflicting/duplicating new memories), 'archived', 'deleted'.",
+ )
+ is_fast: bool | None = Field(
+ default=None,
+ description="Whether or not the memory was created in fast mode, carrying raw memory contents that haven't been edited by llms yet.",
+ )
+ evolve_to: list[str] | None = Field(
+ default_factory=list,
+ description="Only valid if a node was once a (raw)fast node. Recording which new memory nodes it 'evolves' to after llm extraction.",
+ )
+ version: int | None = Field(
+ default=None,
+ description="The version of the memory. Will be incremented when the memory is updated.",
+ )
+ history: list[ArchivedTextualMemory] | None = Field(
+ default_factory=list,
+ description="Storing the archived versions of the memory. Only preserving core information of each version.",
+ )
+ working_binding: str | None = Field(
+ default=None,
+ description="The working memory id binding of the (fast) memory.",
)
type: str | None = Field(default=None)
key: str | None = Field(default=None, description="Memory key or title.")
@@ -112,6 +169,7 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata):
"OuterMemory",
"ToolSchemaMemory",
"ToolTrajectoryMemory",
+ "RawFileMemory",
"SkillMemory",
] = Field(default="WorkingMemory", description="Memory lifecycle type.")
sources: list[SourceMessage] | None = Field(
diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py
index 5e58d23a5..68536da8d 100644
--- a/src/memos/memories/textual/prefer_text_memory/adder.py
+++ b/src/memos/memories/textual/prefer_text_memory/adder.py
@@ -64,7 +64,7 @@ def _judge_update_or_add_fast(self, old_msg: str, new_msg: str) -> bool:
response = result.get("is_same", False)
return response if isinstance(response, bool) else response.lower() == "true"
except Exception as e:
- logger.error(f"Error in judge_update_or_add: {e}")
+ logger.warning(f"Error in judge_update_or_add: {e}")
# Fallback to simple string comparison
return old_msg == new_msg
@@ -80,7 +80,7 @@ def _judge_update_or_add_fine(self, new_mem: str, retrieved_mems: str) -> dict[s
result = json.loads(response)
return result
except Exception as e:
- logger.error(f"Error in judge_update_or_add_fine: {e}")
+ logger.warning(f"Error in judge_update_or_add_fine: {e}")
return None
def _judge_dup_with_text_mem(self, new_pref: MilvusVecDBItem) -> bool:
@@ -118,7 +118,7 @@ def _judge_dup_with_text_mem(self, new_pref: MilvusVecDBItem) -> bool:
exists = result.get("exists", False)
return exists
except Exception as e:
- logger.error(f"Error in judge_dup_with_text_mem: {e}")
+ logger.warning(f"Error in judge_dup_with_text_mem: {e}")
return False
def _judge_update_or_add_trace_op(
@@ -135,7 +135,7 @@ def _judge_update_or_add_trace_op(
result = json.loads(response)
return result
except Exception as e:
- logger.error(f"Error in judge_update_or_add_trace_op: {e}")
+ logger.warning(f"Error in judge_update_or_add_trace_op: {e}")
return None
def _dedup_explicit_pref_by_textual(
@@ -156,7 +156,7 @@ def _dedup_explicit_pref_by_textual(
try:
is_dup_flags[idx] = future.result()
except Exception as e:
- logger.error(
+ logger.warning(
f"Error in _judge_dup_with_text_mem for pref {new_prefs[idx].id}: {e}"
)
is_dup_flags[idx] = False
@@ -407,7 +407,7 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str |
)
except Exception as e:
- logger.error(f"Error processing memory {memory.id}: {e}")
+ logger.warning(f"Error processing memory {memory.id}: {e}")
return None
def process_memory_batch(self, memories: list[TextualMemoryItem], *args, **kwargs) -> list[str]:
@@ -480,7 +480,7 @@ def process_memory_single(
added_ids.append(memory_id)
except Exception as e:
memory = future_to_memory[future]
- logger.error(f"Error processing memory {memory.id}: {e}")
+ logger.warning(f"Error processing memory {memory.id}: {e}")
continue
return added_ids
diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py
index aa4f3cb44..e696e82d4 100644
--- a/src/memos/memories/textual/prefer_text_memory/extractor.py
+++ b/src/memos/memories/textual/prefer_text_memory/extractor.py
@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from concurrent.futures import as_completed
from datetime import datetime
-from typing import Any
+from typing import TYPE_CHECKING, Any
from memos.context.context import ContextThreadPoolExecutor
from memos.log import get_logger
@@ -25,6 +25,10 @@
from memos.types import MessageList
+if TYPE_CHECKING:
+ from memos.types.general_types import UserContext
+
+
logger = get_logger(__name__)
@@ -177,6 +181,7 @@ def extract(
msg_type: str,
info: dict[str, Any],
max_workers: int = 10,
+ **kwargs,
) -> list[TextualMemoryItem]:
"""Extract preference memories based on the messages using thread pool for acceleration."""
chunks: list[MessageList] = []
@@ -186,6 +191,10 @@ def extract(
if not chunks:
return []
+ user_context: UserContext | None = kwargs.get("user_context")
+ user_context_dict = user_context.model_dump() if user_context else {}
+ info = {**info, **user_context_dict}
+
memories = []
with ContextThreadPoolExecutor(max_workers=min(max_workers, len(chunks))) as executor:
futures = {
diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py
index 78f4d6e28..dba321f55 100644
--- a/src/memos/memories/textual/preference.py
+++ b/src/memos/memories/textual/preference.py
@@ -67,7 +67,7 @@ def __init__(self, config: PreferenceTextMemoryConfig):
)
def get_memory(
- self, messages: list[MessageList], type: str, info: dict[str, Any]
+ self, messages: list[MessageList], type: str, info: dict[str, Any], **kwargs
) -> list[TextualMemoryItem]:
"""Get memory based on the messages.
Args:
@@ -75,7 +75,7 @@ def get_memory(
type (str): The type of memory to get.
info (dict[str, Any]): The info to get memory.
"""
- return self.extractor.extract(messages, type, info)
+ return self.extractor.extract(messages, type, info, **kwargs)
def search(
self, query: str, top_k: int, info=None, search_filter=None, **kwargs
diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py
index cc1781f06..db7101744 100644
--- a/src/memos/memories/textual/simple_preference.py
+++ b/src/memos/memories/textual/simple_preference.py
@@ -40,15 +40,16 @@ def __init__(
self.retriever = retriever
def get_memory(
- self, messages: list[MessageList], type: str, info: dict[str, Any]
+ self, messages: list[MessageList], type: str, info: dict[str, Any], **kwargs
) -> list[TextualMemoryItem]:
"""Get memory based on the messages.
Args:
messages (MessageList): The messages to get memory from.
type (str): The type of memory to get.
info (dict[str, Any]): The info to get memory.
+ **kwargs: Additional keyword arguments to pass to the extractor.
"""
- return self.extractor.extract(messages, type, info)
+ return self.extractor.extract(messages, type, info, **kwargs)
def search(
self, query: str, top_k: int, info=None, search_filter=None, **kwargs
diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py
index b556db5d7..5faf8aa09 100644
--- a/src/memos/memories/textual/tree.py
+++ b/src/memos/memories/textual/tree.py
@@ -1,18 +1,23 @@
+import concurrent.futures
import json
import os
import shutil
import tempfile
+import time
from datetime import datetime
from pathlib import Path
-from typing import Any
+from typing import Any, Literal
from memos.configs.memory import TreeTextMemoryConfig
from memos.configs.reranker import RerankerConfigFactory
+from memos.context.context import ContextThreadPoolExecutor
+from memos.dependency import require_python_package
from memos.embedders.factory import EmbedderFactory, OllamaEmbedder
from memos.graph_dbs.factory import GraphStoreFactory, Neo4jGraphDB
from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM
from memos.log import get_logger
+from memos.mem_reader.read_multi_modal.utils import detect_lang
from memos.memories.textual.base import BaseTextMemory
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
@@ -23,6 +28,7 @@
from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import (
InternetRetrieverFactory,
)
+from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager
from memos.reranker.factory import RerankerFactory
from memos.types import MessageList
@@ -164,6 +170,7 @@ def search(
include_skill_memory: bool = False,
skill_mem_top_k: int = 3,
dedup: str | None = None,
+ include_embedding: bool | None = None,
**kwargs,
) -> list[TextualMemoryItem]:
"""Search for memories based on a query.
@@ -187,6 +194,9 @@ def search(
Returns:
list[TextualMemoryItem]: List of matching memories.
"""
+ # Use parameter if provided, otherwise fall back to instance attribute
+ include_emb = include_embedding if include_embedding is not None else self.include_embedding
+
searcher = Searcher(
self.dispatcher_llm,
self.graph_store,
@@ -197,7 +207,7 @@ def search(
search_strategy=self.search_strategy,
manual_close_internet=manual_close_internet,
tokenizer=self.tokenizer,
- include_embedding=self.include_embedding,
+ include_embedding=include_emb,
)
return searcher.search(
query,
@@ -223,6 +233,7 @@ def get_relevant_subgraph(
depth: int = 2,
center_status: str = "activated",
user_name: str | None = None,
+ search_type: Literal["embedding", "fulltext"] = "fulltext",
) -> dict[str, Any]:
"""
Find and merge the local neighborhood sub-graphs of the top-k
@@ -249,13 +260,40 @@ def get_relevant_subgraph(
- 'nodes': List of unique nodes (core + neighbors) in the merged subgraph.
- 'edges': List of unique edges (as dicts with 'from', 'to', 'type') in the merged subgraph.
"""
- # Step 1: Embed query
- query_embedding = self.embedder.embed([query])[0]
+ if search_type == "embedding":
+ # Step 1: Embed query
+ query_embedding = self.embedder.embed([query])[0]
+
+ # Step 2: Get top-1 similar node
+ similar_nodes = self.graph_store.search_by_embedding(
+ query_embedding, top_k=top_k, user_name=user_name
+ )
+
+ elif search_type == "fulltext":
+
+ @require_python_package(
+ import_name="jieba",
+ install_command="pip install jieba",
+ install_link="https://github.com/fxsjy/jieba",
+ )
+ def _tokenize_chinese(text):
+ """split zh jieba"""
+ import jieba
+
+ stopword_manager = StopwordManager()
+ tokens = jieba.lcut(text)
+ tokens = [token.strip() for token in tokens if token.strip()]
+ return stopword_manager.filter_words(tokens)
+
+ lang = detect_lang(query)
+ queries = _tokenize_chinese(query) if lang == "zh" else query.split()
+
+ similar_nodes = self.graph_store.search_by_fulltext(
+ query_words=queries,
+ top_k=top_k,
+ user_name=user_name,
+ )
- # Step 2: Get top-1 similar node
- similar_nodes = self.graph_store.search_by_embedding(
- query_embedding, top_k=top_k, user_name=user_name
- )
if not similar_nodes:
logger.info("No similar nodes found for query embedding.")
return {"core_id": None, "nodes": [], "edges": []}
@@ -328,18 +366,24 @@ def get_by_ids(
def get_all(
self,
- user_name: str,
+ user_name: str | None = None,
user_id: str | None = None,
page: int | None = None,
page_size: int | None = None,
filter: dict | None = None,
+ memory_type: list[str] | None = None,
) -> dict:
"""Get all memories.
Returns:
list[TextualMemoryItem]: List of all memories.
"""
graph_output = self.graph_store.export_graph(
- user_name=user_name, user_id=user_id, page=page, page_size=page_size, filter=filter
+ user_name=user_name,
+ user_id=user_id,
+ page=page,
+ page_size=page_size,
+ filter=filter,
+ memory_type=memory_type,
)
return graph_output
@@ -462,3 +506,100 @@ def _cleanup_old_backups(root_dir: Path, keep_last_n: int) -> None:
logger.info(f"Deleted old backup directory: {old_dir}")
except Exception as e:
logger.warning(f"Failed to delete backup {old_dir}: {e}")
+
+ def add_rawfile_nodes_n_edges(
+ self,
+ raw_file_mem_group: list[TextualMemoryItem],
+ mem_ids: list[str],
+ user_id: str | None = None,
+ user_name: str | None = None,
+ ) -> None:
+ """
+ Add raw file nodes and edges to the graph. Edges are between raw file ids and mem_ids.
+ Args:
+ raw_file_mem_group: List of raw file memory items.
+ mem_ids: List of memory IDs.
+ user_name: cube id.
+ """
+ rawfile_ids_local: list[str] = self.add(
+ raw_file_mem_group,
+ user_name=user_name,
+ )
+
+ from_ids = []
+ to_ids = []
+ types = []
+
+ for raw_file_mem in raw_file_mem_group:
+ # Add SUMMARY edge: memory -> raw file; raw file -> memory
+ if hasattr(raw_file_mem.metadata, "summary_ids") and raw_file_mem.metadata.summary_ids:
+ summary_ids = raw_file_mem.metadata.summary_ids
+ for summary_id in summary_ids:
+ if summary_id in mem_ids:
+ from_ids.append(summary_id)
+ to_ids.append(raw_file_mem.id)
+ types.append("MATERIAL")
+
+ from_ids.append(raw_file_mem.id)
+ to_ids.append(summary_id)
+ types.append("SUMMARY")
+
+ # Add FOLLOWING edge: current chunk -> next chunk
+ if (
+ hasattr(raw_file_mem.metadata, "following_id")
+ and raw_file_mem.metadata.following_id
+ ):
+ following_id = raw_file_mem.metadata.following_id
+ if following_id in rawfile_ids_local:
+ from_ids.append(raw_file_mem.id)
+ to_ids.append(following_id)
+ types.append("FOLLOWING")
+
+ # Add PRECEDING edge: previous chunk -> current chunk
+ if (
+ hasattr(raw_file_mem.metadata, "preceding_id")
+ and raw_file_mem.metadata.preceding_id
+ ):
+ preceding_id = raw_file_mem.metadata.preceding_id
+ if preceding_id in rawfile_ids_local:
+ from_ids.append(raw_file_mem.id)
+ to_ids.append(preceding_id)
+ types.append("PRECEDING")
+
+ start_time = time.time()
+ self.add_graph_edges(
+ from_ids,
+ to_ids,
+ types,
+ user_name=user_name,
+ )
+ end_time = time.time()
+ logger.info(f"[RawFile] Added {len(rawfile_ids_local)} chunks for user {user_id}")
+ logger.info(
+ f"[RawFile] Time taken to add edges: {end_time - start_time} seconds for {len(from_ids)} edges"
+ )
+
+ def add_graph_edges(
+ self, from_ids: list[str], to_ids: list[str], types: list[str], user_name: str | None = None
+ ) -> None:
+ """
+ Add edges to the graph.
+ Args:
+ from_ids: List of source node IDs.
+ to_ids: List of target node IDs.
+ types: List of edge types.
+ user_name: Optional user name.
+ """
+ with ContextThreadPoolExecutor(max_workers=20) as executor:
+ futures = {
+ executor.submit(
+ self.graph_store.add_edge, from_id, to_id, edge_type, user_name=user_name
+ )
+ for from_id, to_id, edge_type in zip(from_ids, to_ids, types, strict=False)
+ }
+
+ for future in concurrent.futures.as_completed(futures):
+ try:
+ future.result()
+ except Exception as e:
+ logger.exception("Add edge error: ", exc_info=e)
diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py
new file mode 100644
index 000000000..1afdc9281
--- /dev/null
+++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py
@@ -0,0 +1,166 @@
+import logging
+
+from typing import Literal
+
+from memos.context.context import ContextThreadPoolExecutor
+from memos.extras.nli_model.client import NLIClient
+from memos.extras.nli_model.types import NLIResult
+from memos.graph_dbs.base import BaseGraphDB
+from memos.memories.textual.item import ArchivedTextualMemory, TextualMemoryItem
+
+
+logger = logging.getLogger(__name__)
+
+CONFLICT_MEMORY_TITLE = "[possibly conflicting memories]"
+DUPLICATE_MEMORY_TITLE = "[possibly duplicate memories]"
+
+
+def _append_related_content(
+ new_item: TextualMemoryItem, duplicates: list[str], conflicts: list[str]
+) -> None:
+ """
+ Append duplicate and conflict memory contents to the new item's memory text,
+ truncated to avoid excessive length.
+ """
+ max_per_item_len = 200
+ max_section_len = 1000
+
+ def _format_section(title: str, items: list[str]) -> str:
+ if not items:
+ return ""
+
+ section_content = ""
+ for mem in items:
+ # Truncate individual item
+ snippet = mem[:max_per_item_len] + "..." if len(mem) > max_per_item_len else mem
+ # Check total section length
+ if len(section_content) + len(snippet) + 5 > max_section_len:
+ section_content += "\n- ... (more items truncated)"
+ break
+ section_content += f"\n- {snippet}"
+
+ return f"\n\n{title}:{section_content}"
+
+ append_text = ""
+ append_text += _format_section(CONFLICT_MEMORY_TITLE, conflicts)
+ append_text += _format_section(DUPLICATE_MEMORY_TITLE, duplicates)
+
+ if append_text:
+ new_item.memory += append_text
+
+
+def _detach_related_content(new_item: TextualMemoryItem) -> None:
+ """
+ Detach duplicate and conflict memory contents from the new item's memory text.
+ """
+ markers = [f"\n\n{CONFLICT_MEMORY_TITLE}:", f"\n\n{DUPLICATE_MEMORY_TITLE}:"]
+
+ cut_index = -1
+ for marker in markers:
+ idx = new_item.memory.find(marker)
+ if idx != -1 and (cut_index == -1 or idx < cut_index):
+ cut_index = idx
+
+ if cut_index != -1:
+ new_item.memory = new_item.memory[:cut_index]
+
+ return
+
+
+class MemoryHistoryManager:
+ def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None:
+ """
+ Initialize the MemoryHistoryManager.
+
+ Args:
+ nli_client: NLIClient for conflict/duplicate detection.
+ graph_db: GraphDB instance for marking operations during history management.
+ """
+ self.nli_client = nli_client
+ self.graph_db = graph_db
+
+ def resolve_history_via_nli(
+ self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem]
+ ) -> list[TextualMemoryItem]:
+ """
+ Detect relationships (Duplicate/Conflict) between the new item and related items using NLI,
+ and attach them as history to the new fast item.
+
+ Args:
+ new_item: The new memory item being added.
+ related_items: Existing memory items that might be related.
+
+ Returns:
+ List of duplicate or conflicting memory items judged by the NLI service.
+ """
+ if not related_items:
+ return []
+
+ # 1. Call NLI
+ nli_results = self.nli_client.compare_one_to_many(
+ new_item.memory, [r.memory for r in related_items]
+ )
+
+ # 2. Process results and attach to history
+ duplicate_memories = []
+ conflict_memories = []
+
+ for r_item, nli_res in zip(related_items, nli_results, strict=False):
+ if nli_res == NLIResult.DUPLICATE:
+ update_type = "duplicate"
+ duplicate_memories.append(r_item.memory)
+ elif nli_res == NLIResult.CONTRADICTION:
+ update_type = "conflict"
+ conflict_memories.append(r_item.memory)
+ else:
+ update_type = "unrelated"
+
+ # Safely get created_at, fallback to updated_at
+ created_at = getattr(r_item.metadata, "created_at", None) or r_item.metadata.updated_at
+
+ archived = ArchivedTextualMemory(
+ version=r_item.metadata.version or 1,
+ is_fast=r_item.metadata.is_fast or False,
+ memory=r_item.memory,
+ update_type=update_type,
+ archived_memory_id=r_item.id,
+ created_at=created_at,
+ )
+ new_item.metadata.history.append(archived)
+ logger.info(
+ f"[MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}"
+ )
+
+ # 3. Concat duplicate/conflict memories to new_item.memory
+ # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss.
+ _append_related_content(new_item, duplicate_memories, conflict_memories)
+
+ return duplicate_memories + conflict_memories
+
+ def mark_memory_status(
+ self,
+ memory_items: list[TextualMemoryItem],
+ status: Literal["activated", "resolving", "archived", "deleted"],
+ ) -> None:
+ """
+ Support status marking operations during history management. Common usages are:
+ 1. Mark conflict/duplicate old memories' status as "resolving",
+ to make them invisible to /search api, but still visible for PreUpdateRetriever.
+ 2. Mark resolved memories' status as "activated", to restore their visibility.
+ """
+ # Execute the actual marking operation - in db.
+ with ContextThreadPoolExecutor() as executor:
+ futures = []
+ for mem in memory_items:
+ futures.append(
+ executor.submit(
+ self.graph_db.update_node,
+ id=mem.id,
+ fields={"status": status},
+ )
+ )
+
+ # Wait for all tasks to complete and raise any exceptions
+ for future in futures:
+ future.result()
+ return
diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py
index 5e9c74f61..cbc349d67 100644
--- a/src/memos/memories/textual/tree_text_memory/organize/manager.py
+++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py
@@ -68,12 +68,14 @@ def __init__(
self.current_memory_size = {
"WorkingMemory": 0,
"LongTermMemory": 0,
+ "RawFileMemory": 0,
"UserMemory": 0,
}
if not memory_size:
self.memory_size = {
"WorkingMemory": 20,
"LongTermMemory": 1500,
+ "RawFileMemory": 1500,
"UserMemory": 480,
}
logger.info(f"MemorySize is {self.memory_size}")
@@ -157,7 +159,7 @@ def _add_memories_batch(
graph_node_ids: list[str] = []
for memory in memories:
- working_id = str(uuid.uuid4())
+ working_id = memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4())
if memory.metadata.memory_type in (
"WorkingMemory",
@@ -181,11 +183,12 @@ def _add_memories_batch(
"UserMemory",
"ToolSchemaMemory",
"ToolTrajectoryMemory",
+ "RawFileMemory",
"SkillMemory",
):
- if not memory.id:
- logger.error("Memory ID is not set, generating a new one")
- graph_node_id = memory.id or str(uuid.uuid4())
+ graph_node_id = (
+ memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4())
+ )
metadata_dict = memory.metadata.model_dump(exclude_none=True)
metadata_dict["updated_at"] = datetime.now().isoformat()
@@ -315,7 +318,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non
ids: list[str] = []
futures = []
- working_id = str(uuid.uuid4())
+ working_id = memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4())
with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex:
if memory.metadata.memory_type in (
@@ -334,6 +337,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non
"UserMemory",
"ToolSchemaMemory",
"ToolTrajectoryMemory",
+ "RawFileMemory",
"SkillMemory",
):
f_graph = ex.submit(
@@ -386,9 +390,7 @@ def _add_to_graph_memory(
"""
Generalized method to add memory to a graph-based memory type (e.g., LongTermMemory, UserMemory).
"""
- if not memory.id:
- logger.error("Memory ID is not set, generating a new one")
- node_id = memory.id or str(uuid.uuid4())
+ node_id = memory.id if hasattr(memory, "id") else str(uuid.uuid4())
# Step 2: Add new node to graph
metadata_dict = memory.metadata.model_dump(exclude_none=True)
tags = metadata_dict.get("tags") or []
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py
index a5fc7e049..cb77d2243 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py
@@ -163,14 +163,14 @@ def keyword_search(
results = []
- # 2. Try seach_by_keywords_tfidf (PolarDB specific)
- if hasattr(self.graph_db, "seach_by_keywords_tfidf"):
+ # 2. Try search_by_keywords_tfidf (PolarDB specific)
+ if hasattr(self.graph_db, "search_by_keywords_tfidf"):
try:
- results = self.graph_db.seach_by_keywords_tfidf(
+ results = self.graph_db.search_by_keywords_tfidf(
query_words=keywords, user_name=user_name, filter=search_filter
)
except Exception as e:
- logger.warning(f"[PreUpdateRetriever] seach_by_keywords_tfidf failed: {e}")
+ logger.warning(f"[PreUpdateRetriever] search_by_keywords_tfidf failed: {e}")
# 3. Fallback to search_by_fulltext
if not results and hasattr(self.graph_db, "search_by_fulltext"):
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py
index c9f2ec156..e5e96dd58 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py
@@ -67,6 +67,7 @@ def retrieve(
"UserMemory",
"ToolSchemaMemory",
"ToolTrajectoryMemory",
+ "RawFileMemory",
"SkillMemory",
]:
raise ValueError(f"Unsupported memory scope: {memory_scope}")
@@ -391,18 +392,51 @@ def search_path_b():
if not all_hits:
return []
- # merge and deduplicate
- unique_ids = {r["id"] for r in all_hits if r.get("id")}
+ # merge and deduplicate, keeping highest score per ID
+ id_to_score = {}
+ for r in all_hits:
+ rid = r.get("id")
+ if rid:
+ rid = str(rid).strip("\"'")
+ score = r.get("score", 0.0)
+ if rid not in id_to_score or score > id_to_score[rid]:
+ id_to_score[rid] = score
+
+ # Sort IDs by score (descending) to preserve ranking
+ sorted_ids = sorted(id_to_score.keys(), key=lambda x: id_to_score[x], reverse=True)
+
node_dicts = (
self.graph_store.get_nodes(
- list(unique_ids),
+ sorted_ids,
include_embedding=self.include_embedding,
cube_name=cube_name,
user_name=user_name,
)
or []
)
- return [TextualMemoryItem.from_dict(n) for n in node_dicts]
+
+ # Restore score-based order and inject scores into metadata
+ id_to_node = {}
+ for n in node_dicts:
+ node_id = n.get("id")
+ if node_id:
+ # Ensure ID is a string and strip any surrounding quotes
+ node_id = str(node_id).strip("\"'")
+ id_to_node[node_id] = n
+
+ ordered_nodes = []
+ for rid in sorted_ids:
+ # Ensure rid is normalized for matching
+ rid_normalized = str(rid).strip("\"'")
+ if rid_normalized in id_to_node:
+ node = id_to_node[rid_normalized]
+ # Inject similarity score as relativity
+ if "metadata" not in node:
+ node["metadata"] = {}
+ node["metadata"]["relativity"] = id_to_score.get(rid, 0.0)
+ ordered_nodes.append(node)
+
+ return [TextualMemoryItem.from_dict(n) for n in ordered_nodes]
def _bm25_recall(
self,
@@ -484,15 +518,49 @@ def _fulltext_recall(
if not all_hits:
return []
- # merge and deduplicate
- unique_ids = {r["id"] for r in all_hits if r.get("id")}
+ # merge and deduplicate, keeping highest score per ID
+ id_to_score = {}
+ for r in all_hits:
+ rid = r.get("id")
+ if rid:
+ # Ensure ID is a string and strip any surrounding quotes
+ rid = str(rid).strip("\"'")
+ score = r.get("score", 0.0)
+ if rid not in id_to_score or score > id_to_score[rid]:
+ id_to_score[rid] = score
+
+ # Sort IDs by score (descending) to preserve ranking
+ sorted_ids = sorted(id_to_score.keys(), key=lambda x: id_to_score[x], reverse=True)
+
node_dicts = (
self.graph_store.get_nodes(
- list(unique_ids),
+ sorted_ids,
include_embedding=self.include_embedding,
cube_name=cube_name,
user_name=user_name,
)
or []
)
- return [TextualMemoryItem.from_dict(n) for n in node_dicts]
+
+ # Restore score-based order and inject scores into metadata
+ id_to_node = {}
+ for n in node_dicts:
+ node_id = n.get("id")
+ if node_id:
+ # Ensure ID is a string and strip any surrounding quotes
+ node_id = str(node_id).strip("\"'")
+ id_to_node[node_id] = n
+
+ ordered_nodes = []
+ for rid in sorted_ids:
+ # Ensure rid is normalized for matching
+ rid_normalized = str(rid).strip("\"'")
+ if rid_normalized in id_to_node:
+ node = id_to_node[rid_normalized]
+ # Inject similarity score as relativity
+ if "metadata" not in node:
+ node["metadata"] = {}
+ node["metadata"]["relativity"] = id_to_score.get(rid, 0.0)
+ ordered_nodes.append(node)
+
+ return [TextualMemoryItem.from_dict(n) for n in ordered_nodes]
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py b/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py
index 861343e20..b8ab813dc 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py
@@ -78,7 +78,11 @@ def rerank(
embeddings = [item.metadata.embedding for item in items_with_embeddings]
if not embeddings:
- return [(item, 0.5) for item in graph_results[:top_k]]
+ # Use relativity from recall stage if available, otherwise default to 0.5
+ return [
+ (item, getattr(item.metadata, "relativity", None) or 0.5)
+ for item in graph_results[:top_k]
+ ]
# Step 2: Compute cosine similarities
similarity_scores = batch_cosine_similarity(query_embedding, embeddings)
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
index dcd4e1fba..bc8d76517 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
@@ -1,5 +1,7 @@
import traceback
+from concurrent.futures import as_completed
+
from memos.context.context import ContextThreadPoolExecutor
from memos.embedders.factory import OllamaEmbedder
from memos.graph_dbs.factory import Neo4jGraphDB
@@ -88,7 +90,7 @@ def retrieve(
logger.info(
f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}"
)
- parsed_goal, query_embedding, context, query = self._parse_task(
+ parsed_goal, query_embedding, _context, query = self._parse_task(
query,
info,
mode,
@@ -483,8 +485,8 @@ def _retrieve_from_long_term_and_user(
else:
cot_embeddings = query_embedding
- with ContextThreadPoolExecutor(max_workers=2) as executor:
- if memory_type in ["All", "LongTermMemory"]:
+ with ContextThreadPoolExecutor(max_workers=3) as executor:
+ if memory_type in ["All", "AllSummaryMemory", "LongTermMemory"]:
tasks.append(
executor.submit(
self.graph_retriever.retrieve,
@@ -500,7 +502,7 @@ def _retrieve_from_long_term_and_user(
use_fast_graph=self.use_fast_graph,
)
)
- if memory_type in ["All", "UserMemory"]:
+ if memory_type in ["All", "AllSummaryMemory", "UserMemory"]:
tasks.append(
executor.submit(
self.graph_retriever.retrieve,
@@ -516,10 +518,28 @@ def _retrieve_from_long_term_and_user(
use_fast_graph=self.use_fast_graph,
)
)
+ if memory_type in ["RawFileMemory"]:
+ tasks.append(
+ executor.submit(
+ self.graph_retriever.retrieve,
+ query=query,
+ parsed_goal=parsed_goal,
+ query_embedding=cot_embeddings,
+ top_k=top_k * 2,
+ memory_scope="RawFileMemory",
+ search_filter=search_filter,
+ search_priority=search_priority,
+ user_name=user_name,
+ id_filter=id_filter,
+ use_fast_graph=self.use_fast_graph,
+ )
+ )
# Collect results from all tasks
for task in tasks:
results.extend(task.result())
+ results = self._deduplicate_rawfile_results(results, user_name=user_name)
+ results = self._filter_intermediate_content(results)
return self.reranker.rerank(
query=query,
@@ -872,7 +892,7 @@ def _sort_and_trim(
(item, score)
for item, score in results
if item.metadata.memory_type
- in ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"]
+ in ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory", "RawFileMemory"]
]
sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k]
@@ -891,6 +911,66 @@ def _sort_and_trim(
)
return final_items
+ @timed
+ def _deduplicate_rawfile_results(self, results, user_name: str | None = None):
+ """
+ Deduplicate rawfile related memories by edge
+ """
+ if not results:
+ return results
+
+ summary_ids_to_remove = set()
+ rawfile_items = [item for item in results if item.metadata.memory_type == "RawFileMemory"]
+ if not rawfile_items:
+ return results
+
+ with ContextThreadPoolExecutor(max_workers=min(len(rawfile_items), 10)) as executor:
+ futures = [
+ executor.submit(
+ self.graph_store.get_edges,
+ rawfile_item.id,
+ type="SUMMARY",
+ direction="OUTGOING",
+ user_name=user_name,
+ )
+ for rawfile_item in rawfile_items
+ ]
+ for future in as_completed(futures):
+ try:
+ edges = future.result()
+ for edge in edges:
+ summary_target_id = edge.get("to")
+ if summary_target_id:
+ summary_ids_to_remove.add(summary_target_id)
+ logger.debug(
+ f"[DEDUP] Marking summary node {summary_target_id} for removal (pointed by RawFileMemory)"
+ )
+ except Exception as e:
+ logger.warning(f"[DEDUP] Failed to get summary target ids: {e}")
+
+ filtered_results = []
+ for item in results:
+ if item.id in summary_ids_to_remove:
+ logger.debug(
+ f"[DEDUP] Removing summary node {item.id} because it is pointed by RawFileMemory"
+ )
+ continue
+ filtered_results.append(item)
+
+ return filtered_results
+
+ def _filter_intermediate_content(self, results):
+ """Filter intermediate content"""
+ filtered_results = []
+ for item in results:
+ if (
+ "File URL:" not in item.memory
+ and "File ID:" not in item.memory
+ and "Filename:" not in item.memory
+ ):
+ filtered_results.append(item)
+ return filtered_results
+
@timed
def _update_usage_history(self, items, info, user_name: str | None = None):
"""Update usage history in graph DB
diff --git a/src/memos/memos_tools/dinding_report_bot.py b/src/memos/memos_tools/dinding_report_bot.py
index d8b762855..5bbd1f4cd 100644
--- a/src/memos/memos_tools/dinding_report_bot.py
+++ b/src/memos/memos_tools/dinding_report_bot.py
@@ -146,7 +146,7 @@ def _text_wh(draw: ImageDraw.ImageDraw, text: str, font: ImageFont.ImageFont):
# center alignment
title_w, title_h = _text_wh(draw, title, font_title)
- sub_w, sub_h = _text_wh(draw, subtitle, font_sub)
+ sub_w, _sub_h = _text_wh(draw, subtitle, font_sub)
title_x = (w - title_w) // 2
title_y = h // 2 - title_h
diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py
index bd026a51d..6da55ce02 100644
--- a/src/memos/multi_mem_cube/single_cube.py
+++ b/src/memos/multi_mem_cube/single_cube.py
@@ -24,7 +24,9 @@
MEM_READ_TASK_LABEL,
PREF_ADD_TASK_LABEL,
)
+from memos.memories.textual.item import TextualMemoryItem
from memos.multi_mem_cube.views import MemCubeView
+from memos.search import search_text_memories
from memos.templates.mem_reader_prompts import PROMPT_MAPPING
from memos.types.general_types import (
FINE_STRATEGY,
@@ -44,7 +46,6 @@
from memos.mem_cube.navie import NaiveMemCube
from memos.mem_reader.simple_struct import SimpleStructMemReader
from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler
- from memos.memories.textual.item import TextualMemoryItem
@dataclass
@@ -72,6 +73,8 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]:
user_id=add_req.user_id,
mem_cube_id=self.cube_id,
session_id=add_req.session_id or "default_session",
+ manager_user_id=add_req.manager_user_id,
+ project_id=add_req.project_id,
)
target_session_id = add_req.session_id or "default_session"
@@ -266,11 +269,12 @@ def _deep_search(
search_filter=search_filter,
info=info,
)
- formatted_memories = [
- format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr"))
- for data in enhanced_memories
- ]
- return formatted_memories
+ return self._postformat_memories(
+ enhanced_memories,
+ user_context.mem_cube_id,
+ include_embedding=search_req.dedup == "sim",
+ neighbor_discovery=search_req.neighbor_discovery,
+ )
def _agentic_search(
self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int
@@ -278,11 +282,12 @@ def _agentic_search(
deepsearch_results = self.deepsearch_agent.run(
search_req.query, user_id=user_context.mem_cube_id
)
- formatted_memories = [
- format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr"))
- for data in deepsearch_results
- ]
- return formatted_memories
+ return self._postformat_memories(
+ deepsearch_results,
+ user_context.mem_cube_id,
+ include_embedding=search_req.dedup == "sim",
+ neighbor_discovery=search_req.neighbor_discovery,
+ )
def _fine_search(
self,
@@ -323,6 +328,7 @@ def _fine_search(
user_name=user_context.mem_cube_id,
top_k=search_req.top_k,
mode=SearchMode.FINE,
+ memory_type=search_req.search_memory_type,
manual_close_internet=not search_req.internet_search,
moscube=search_req.moscube,
search_filter=search_filter,
@@ -362,7 +368,7 @@ def _fine_search(
user_name=user_context.mem_cube_id,
top_k=retrieval_size,
mode=SearchMode.FAST,
- memory_type="All",
+ memory_type=search_req.search_memory_type,
search_priority=search_priority,
search_filter=search_filter,
info=info,
@@ -390,10 +396,12 @@ def _dedup_by_content(memories: list) -> list:
deduped_memories = (
enhanced_memories if search_req.dedup == "no" else _dedup_by_content(enhanced_memories)
)
- formatted_memories = [
- format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr"))
- for data in deduped_memories
- ]
+ formatted_memories = self._postformat_memories(
+ deduped_memories,
+ user_context.mem_cube_id,
+ include_embedding=search_req.dedup == "sim",
+ neighbor_discovery=search_req.neighbor_discovery,
+ )
logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}")
@@ -435,7 +443,7 @@ def _search_pref(
},
search_filter=search_req.filter,
)
- return [format_memory_item(data) for data in results]
+ return self._postformat_memories(results, user_context.mem_cube_id)
except Exception as e:
self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc())
return []
@@ -455,39 +463,73 @@ def _fast_search(
Returns:
List of search results
"""
- target_session_id = search_req.session_id or "default_session"
- search_priority = {"session_id": search_req.session_id} if search_req.session_id else None
- search_filter = search_req.filter or None
- plugin = bool(search_req.source is not None and search_req.source == "plugin")
-
- search_results = self.naive_mem_cube.text_mem.search(
- query=search_req.query,
- user_name=user_context.mem_cube_id,
- top_k=search_req.top_k,
+ search_results = search_text_memories(
+ text_mem=self.naive_mem_cube.text_mem,
+ search_req=search_req,
+ user_context=user_context,
mode=SearchMode.FAST,
- manual_close_internet=not search_req.internet_search,
- memory_type=search_req.search_memory_type,
- search_filter=search_filter,
- search_priority=search_priority,
- info={
- "user_id": search_req.user_id,
- "session_id": target_session_id,
- "chat_history": search_req.chat_history,
- },
- plugin=plugin,
- search_tool_memory=search_req.search_tool_memory,
- tool_mem_top_k=search_req.tool_mem_top_k,
- include_skill_memory=search_req.include_skill_memory,
- skill_mem_top_k=search_req.skill_mem_top_k,
- dedup=search_req.dedup,
+ include_embedding=(search_req.dedup == "mmr"),
)
- formatted_memories = [
- format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr"))
- for data in search_results
- ]
+ return self._postformat_memories(
+ search_results,
+ user_context.mem_cube_id,
+ include_embedding=search_req.dedup == "sim",
+ neighbor_discovery=search_req.neighbor_discovery,
+ )
- return formatted_memories
+ def _postformat_memories(
+ self,
+ search_results: list,
+ user_name: str,
+ include_embedding: bool = False,
+ neighbor_discovery: bool = False,
+ ) -> list:
+ """
+ Postprocess search results.
+ """
+
+ def extract_edge_info(edges_info: list[dict], neighbor_relativity: float):
+ edge_mems = []
+ for edge in edges_info:
+ chunk_target_id = edge.get("to")
+ edge_type = edge.get("type")
+ item_neighbor = self.searcher.graph_store.get_node(chunk_target_id)
+ if item_neighbor:
+ item_neighbor_mem = TextualMemoryItem(**item_neighbor)
+ item_neighbor_mem.metadata.relativity = neighbor_relativity
+ edge_mems.append(item_neighbor_mem)
+ item_neighbor_id = item_neighbor.get("id", "None")
+ self.logger.info(
+ f"Add neighbor chunk: {item_neighbor_id}, edge_type: {edge_type} for {item.id}"
+ )
+ return edge_mems
+
+ final_items = []
+ if neighbor_discovery:
+ for item in search_results:
+ if item.metadata.memory_type == "RawFileMemory":
+ neighbor_relativity = item.metadata.relativity * 0.8
+ preceding_info = self.searcher.graph_store.get_edges(
+ item.id, type="PRECEDING", direction="OUTGOING", user_name=user_name
+ )
+ final_items.extend(extract_edge_info(preceding_info, neighbor_relativity))
+
+ final_items.append(item)
+
+ following_info = self.searcher.graph_store.get_edges(
+ item.id, type="FOLLOWING", direction="OUTGOING", user_name=user_name
+ )
+ final_items.extend(extract_edge_info(following_info, neighbor_relativity))
+
+ else:
+ final_items.append(item)
+ else:
+ final_items = search_results
+
+ return [
+ format_memory_item(data, include_embedding=include_embedding) for data in final_items
+ ]
def _mix_search(
self,
@@ -554,6 +596,7 @@ def _schedule_memory_tasks(
user_name=self.cube_id,
info=add_req.info,
chat_history=add_req.chat_history,
+ user_context=user_context,
)
self.mem_scheduler.submit_messages(messages=[message_item_read])
self.logger.info(
@@ -624,6 +667,7 @@ def _process_pref_mem(
info=add_req.info,
user_name=self.cube_id,
task_id=add_req.task_id,
+ user_context=user_context,
)
self.mem_scheduler.submit_messages(messages=[message_item_pref])
self.logger.info(f"[SingleCubeView] cube={self.cube_id} Submitted PREF_ADD async")
@@ -643,6 +687,7 @@ def _process_pref_mem(
"session_id": target_session_id,
"mem_cube_id": user_context.mem_cube_id,
},
+ user_context=user_context,
)
pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local)
self.logger.info(
@@ -809,6 +854,7 @@ def _process_text_mem(
mode=extract_mode,
user_name=user_context.mem_cube_id,
chat_history=add_req.chat_history,
+ user_context=user_context,
)
self.logger.info(
f"Time for get_memory in extract mode {extract_mode}: {time.time() - init_time}"
@@ -824,16 +870,34 @@ def _process_text_mem(
self.logger.info(f"Memory extraction completed for user {add_req.user_id}")
# Add memories to text_mem
+ mem_group = [
+ memory for memory in flattened_local if memory.metadata.memory_type != "RawFileMemory"
+ ]
mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add(
- flattened_local,
+ mem_group,
user_name=user_context.mem_cube_id,
)
+
self.logger.info(
f"Added {len(mem_ids_local)} memories for user {add_req.user_id} "
f"in session {add_req.session_id}: {mem_ids_local}"
)
- # Schedule async/sync tasks
+ # Add raw file nodes and edges
+ if self.mem_reader.save_rawfile and extract_mode == "fine":
+ raw_file_mem_group = [
+ memory
+ for memory in flattened_local
+ if memory.metadata.memory_type == "RawFileMemory"
+ ]
+ self.naive_mem_cube.text_mem.add_rawfile_nodes_n_edges(
+ raw_file_mem_group,
+ mem_ids_local,
+ user_id=add_req.user_id,
+ user_name=user_context.mem_cube_id,
+ )
+
+ # Schedule async/sync tasks: async process raw chunk memory | sync only send messages
self._schedule_memory_tasks(
add_req=add_req,
user_context=user_context,
diff --git a/src/memos/search/__init__.py b/src/memos/search/__init__.py
new file mode 100644
index 000000000..71388c62b
--- /dev/null
+++ b/src/memos/search/__init__.py
@@ -0,0 +1,4 @@
+from .search_service import SearchContext, build_search_context, search_text_memories
+
+
+__all__ = ["SearchContext", "build_search_context", "search_text_memories"]
diff --git a/src/memos/search/search_service.py b/src/memos/search/search_service.py
new file mode 100644
index 000000000..6d57e3605
--- /dev/null
+++ b/src/memos/search/search_service.py
@@ -0,0 +1,67 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any
+
+
+if TYPE_CHECKING:
+ from memos.api.product_models import APISearchRequest
+ from memos.types import SearchMode, UserContext
+
+
+@dataclass(frozen=True)
+class SearchContext:
+ target_session_id: str
+ search_priority: dict[str, Any] | None
+ search_filter: dict[str, Any] | None
+ info: dict[str, Any]
+ plugin: bool
+
+
+def build_search_context(
+ search_req: APISearchRequest,
+) -> SearchContext:
+ target_session_id = search_req.session_id or "default_session"
+ search_priority = {"session_id": search_req.session_id} if search_req.session_id else None
+ return SearchContext(
+ target_session_id=target_session_id,
+ search_priority=search_priority,
+ search_filter=search_req.filter,
+ info={
+ "user_id": search_req.user_id,
+ "session_id": target_session_id,
+ "chat_history": search_req.chat_history,
+ },
+ plugin=bool(search_req.source is not None and search_req.source == "plugin"),
+ )
+
+
+def search_text_memories(
+ text_mem: Any,
+ search_req: APISearchRequest,
+ user_context: UserContext,
+ mode: SearchMode,
+ include_embedding: bool | None = None,
+) -> list[Any]:
+ """
+ Shared text-memory search logic for API and scheduler paths.
+ """
+ ctx = build_search_context(search_req=search_req)
+ return text_mem.search(
+ query=search_req.query,
+ user_name=user_context.mem_cube_id,
+ top_k=search_req.top_k,
+ mode=mode,
+ manual_close_internet=not search_req.internet_search,
+ memory_type=search_req.search_memory_type,
+ search_filter=ctx.search_filter,
+ search_priority=ctx.search_priority,
+ info=ctx.info,
+ plugin=ctx.plugin,
+ search_tool_memory=search_req.search_tool_memory,
+ tool_mem_top_k=search_req.tool_mem_top_k,
+ include_skill_memory=search_req.include_skill_memory,
+ skill_mem_top_k=search_req.skill_mem_top_k,
+ dedup=search_req.dedup,
+ include_embedding=include_embedding,
+ )
diff --git a/src/memos/templates/skill_mem_prompt.py b/src/memos/templates/skill_mem_prompt.py
index df64d736d..200f27c52 100644
--- a/src/memos/templates/skill_mem_prompt.py
+++ b/src/memos/templates/skill_mem_prompt.py
@@ -3,29 +3,29 @@
{{messages}}
# Role
-You are an expert in natural language processing (NLP) and dialogue logic analysis. You excel at organizing logical threads from complex long conversations and accurately extracting users' core intentions.
+You are an expert in natural language processing (NLP) and dialogue logic analysis. You excel at organizing logical threads from complex long conversations and accurately extracting users' core intentions to segment the dialogue into distinct tasks.
# Task
-Please analyze the provided conversation records, identify all independent "tasks" that the user has asked the AI to perform, and assign the corresponding dialogue message numbers to each task.
+Please analyze the provided conversation records, identify all independent "tasks" that the user has asked the AI to perform, and assign the corresponding dialogue message indices to each task.
-**Note**: Tasks should be high-level and general, typically divided by theme or topic. For example: "Travel Planning", "PDF Operations", "Code Review", "Data Analysis", etc. Avoid being too specific or granular.
+**Note**: Tasks should be high-level and general. Group similar activities under broad themes such as "Travel Planning", "Project Engineering & Implementation", "Code Review", "Data Analysis", etc. Avoid being overly specific or granular.
# Rules & Constraints
-1. **Task Independence**: If multiple unrelated topics are discussed in the conversation, identify them as different tasks.
-2. **Non-continuous Processing**: Pay attention to identifying "jumping" conversations. For example, if the user made travel plans in messages 8-11, switched to consulting about weather in messages 12-22, and then returned to making travel plans in messages 23-24, be sure to assign both 8-11 and 23-24 to the task "Making travel plans". However, if messages are continuous and belong to the same task, do not split them apart.
-3. **Filter Chit-chat**: Only extract tasks with clear goals, instructions, or knowledge-based discussions. Ignore meaningless greetings (such as "Hello", "Are you there?") or closing remarks unless they are part of the task context.
-4. **Main Task and Subtasks**: Carefully identify whether subtasks serve a main task. If a subtask supports the main task (e.g., "checking weather" serves "travel planning"), do NOT separate it as an independent task. Instead, include all related conversations in the main task. Only split tasks when they are truly independent and unrelated.
-5. **Output Format**: Please strictly follow the JSON format for output to facilitate my subsequent processing.
-6. **Language Consistency**: The language used in the task_name field must match the language used in the conversation records.
-7. **Generic Task Names**: Use generic, reusable task names, not specific descriptions. For example, use "Travel Planning" instead of "Planning a 5-day trip to Chengdu".
+1. **Task Independence**: If multiple completely unrelated topics are discussed, identify them as different tasks.
+2. **Main Task and Subtasks**: Carefully identify whether a subtask serves a primary objective. If a specific request supports a larger goal (e.g., "checking weather" within a "Travel Planning" thread), do NOT separate it. Include all supporting conversations within the main task. **Only split tasks when they are truly independent and unrelated.**
+3. **Non-continuous Processing**: Identify "jumping" or "interleaved" conversations. For example, if the user works on Travel Planning in messages 8-11, switches topics in 12-22, and returns to Travel Planning in 23-24, assign both [8, 11] and [23, 24] to the same "Travel Planning" task. Conversely, if messages are continuous and belong to the same task, keep them as a single range.
+4. **Filter Chit-chat**: Only extract tasks with clear goals, instructions, or knowledge-based discussions. Ignore meaningless greetings (e.g., "Hello", "Are you there?") or polite closings unless they contain necessary context for the task.
+5. **Output Format**: Strictly follow the JSON format below for automated processing.
+6. **Language Consistency**: The language used in the `task_name` field must match the primary language used in the conversation records.
+7. **Generic Task Names**: Use broad, reusable task categories. For example, use "Travel Planning" instead of "Planning a 5-day trip to Chengdu".
```json
[
{
"task_id": 1,
- "task_name": "Generic task name (e.g., Travel Planning, Code Review, Data Analysis)",
- "message_indices": [[0, 5],[16, 17]], # 0-5 and 16-17 are the message indices for this task
- "reasoning": "Briefly explain why these messages are grouped together"
+ "task_name": "Generic task name (e.g., Travel Planning, Code Review)",
+ "message_indices": [[0, 5], [16, 17]],
+ "reasoning": "Briefly explain the logic behind grouping these indices and how they relate to the core intent."
},
...
]
@@ -34,31 +34,30 @@
TASK_CHUNKING_PROMPT_ZH = """
-# 上下文(对话记录)
+# 上下文(历史对话记录)
{{messages}}
# 角色
-你是自然语言处理(NLP)和对话逻辑分析的专家。你擅长从复杂的长对话中整理逻辑线索,准确提取用户的核心意图。
+你是自然语言处理(NLP)和对话逻辑分析的专家。你擅长从复杂的长对话中整理逻辑线索,准确提取用户的不同意图,从而按照不同的意图对上述对话进行任务划分。
-# 任务
+# 目标
请分析提供的对话记录,识别所有用户要求 AI 执行的独立"任务",并为每个任务分配相应的对话消息编号。
-**注意**:任务应该是高层次和通用的,通常按主题或话题划分。例如:"旅行计划"、"PDF操作"、"代码审查"、"数据分析"等。避免过于具体或细化。
+**注意**:上述划分"任务"应该是高层次且通用的,通常按主题或任务类型划分,对同目标或相似的任务进行合并,例如:"旅行计划"、"项目工程设计与实现"、"代码审查" 等,避免过于具体或细化。
# 规则与约束
-1. **任务独立性**:如果对话中讨论了多个不相关的话题,请将它们识别为不同的任务。
-2. **非连续处理**:注意识别"跳跃式"对话。例如,如果用户在消息 8-11 中制定旅行计划,在消息 12-22 中切换到咨询天气,然后在消息 23-24 中返回到制定旅行计划,请务必将 8-11 和 23-24 都分配给"制定旅行计划"任务。但是,如果消息是连续的且属于同一任务,不能将其分开。
-3. **过滤闲聊**:仅提取具有明确目标、指令或基于知识的讨论的任务。忽略无意义的问候(例如"你好"、"在吗?")或结束语,除非它们是任务上下文的一部分。
-4. **主任务与子任务识别**:仔细识别子任务是否服务于主任务。如果子任务是为主任务服务的(例如"查天气"服务于"旅行规划"),不要将其作为独立任务分离出来,而是将所有相关对话都划分到主任务中。只有真正独立且无关联的任务才需要分开。
+1. **任务独立性**:如果对话中讨论了多个完全不相关的话题,请将它们识别为不同的任务。
+2. **主任务与子任务识别**:仔细识别划分的任务是否服务于主任务。如果某一个任务是为了完成主任务而服务的(例如"旅行规划"的对话中出现了"查天气"),不要将其作为独立任务分离出来,而是将所有相关对话都划分到主任务中。**只有真正独立且无关联的任务才需要分开。**
+3. **非连续处理**:注意识别"跳跃式"对话。例如,如果用户在消息 8-11 中制定旅行计划,在消息 12-22 中切换到其他任务,然后在消息 23-24 中返回到制定旅行计划,请务必将 8-11 和 23-24 都分配给"制定旅行计划"任务。按照规则2的描述,如果消息是连续的且属于同一任务,不能将其分开。
+4. **过滤闲聊**:仅提取具有明确目标、指令或基于知识的讨论的任务。忽略无意义的问候(例如"你好"、"在吗?")或结束语,除非它们是任务上下文的一部分。
5. **输出格式**:请严格遵循 JSON 格式输出,以便我后续处理。
-6. **语言一致性**:task_name 字段使用的语言必须与对话记录中使用的语言相匹配。
-7. **通用任务名称**:使用通用的、可复用的任务名称,而不是具体的描述。例如,使用"旅行规划"而不是"规划成都5日游"。
+6. **通用任务名称**:使用通用的、可复用的任务名称,而不是具体的描述。例如,使用"旅行规划"而不是"规划成都5日游"。
```json
[
{
"task_id": 1,
- "task_name": "通用任务名称(例如:旅行规划、代码审查、数据分析)",
+ "task_name": "通用任务名称",
"message_indices": [[0, 5],[16, 17]], # 0-5 和 16-17 是此任务的消息索引
"reasoning": "简要解释为什么这些消息被分组在一起"
},
@@ -67,7 +66,6 @@
```
"""
-
SKILL_MEMORY_EXTRACTION_PROMPT = """
# Role
You are an expert in skill abstraction and knowledge extraction. You excel at distilling general, reusable methodologies from specific conversations.
@@ -229,6 +227,152 @@
"""
+SKILL_MEMORY_EXTRACTION_PROMPT_MD = """
+# Role
+You are an expert in skill abstraction and knowledge extraction. You excel at distilling general, reusable methodologies and executable workflows from specific conversations to enable direct application in future similar scenarios.
+
+# Task
+Analyze the current messages and chat history to extract a universal, effective skill template. Compare the extracted methodology with existing skill memories (checking descriptions and triggers) to determine if this should be a new entry or an update to an existing one.
+
+# Prerequisites
+## Long Term Relevant Memories
+{old_memories}
+
+## Short Term Conversation
+{chat_history}
+
+## Conversation Messages
+{messages}
+
+# Skill Extraction Principles
+To define the content of a skill, comprehensively analyze the dialogue content to create a list of reusable resources, including scripts, reference materials, and resources. Please generate the skill according to the following principles:
+1. **Generalization**: Extract abstract methodologies that can be applied across scenarios. Avoid specific details (e.g., 'travel planning' rather than 'Beijing travel planning'). Moreover, the skills acquired should be durable and effective, rather than tied to a specific time.
+2. **Similarity Check**: If the skill list in 'existing skill memory' is not empty and there are skills with the **same topic**, you need to set "update": true and "old_memory_id". Otherwise, set "update": false and leave "old_memory_id" empty.
+3. **Language Consistency**: Keep consistent with the language of the dialogue.
+4. **Historical Usage Constraint**: Use 'historically related dialogues' as auxiliary context. If the current historical messages are insufficient to form a complete skill, and the historically related dialogue can provide missing information in the messages that is related to the current task objectives, execution methods, or constraints, it may be considered.
+Note: If the similarity check result shows that an existing **skill** description covers the same topic, be sure to use the update operation and set old_memory_id to the ID of the existing skill. Do not create a new methodology; make sure to reasonably add it to the existing skill memory, ensuring smoothness while preserving the information of the existing methodology.
+
+# Output Format and Field Specifications
+## Output Format
+```json
+{
+ "name": "General skill name (e.g., 'Travel Itinerary Planning', 'Code Review Workflow')",
+ "description": "Universal description of what this skill accomplishes and its scope",
+ "trigger": ["keyword1", "keyword2"],
+ "procedure": "Generic step-by-step process: 1. Step one 2. Step two...",
+ "experience": ["General principles or lessons learned", "Error handling strategies", "Best practices..."],
+ "preference": ["User's general preference patterns", "Preferred approaches or constraints..."],
+ "update": false,
+ "old_memory_id": "",
+ "content_of_current_message": "Summary of core content from current messages",
+ "whether_use_chat_history": false,
+ "content_of_related_chat_history": "",
+ "examples": ["Complete formatted output example in markdown format showing the final deliverable structure, content can be abbreviated with '...' but should demonstrate the format and structure"],
+ "scripts": a TODO list of code and requirements. Use null if no specific code are required.
+ "tool": List of specific external tools required (for example, if links or API information appear in the context, a websearch or external API may be needed), not product names or system tools (e.g., Python, Redis, or MySQL). If no specific tools are needed, please use null.
+ "others": {"reference.md": "A concise summary of other reference need to be provided (e.g., examples, tutorials, or best practices) "}. Only need to give the writing requirements, no need to provide the full documentation content.
+}
+```
+
+## Field Specifications
+- **name**: Generic skill identifier without specific instances.
+- **description**: Universal purpose and applicability.
+- **trigger**: List of keywords that should activate this skill.
+- **procedure**: Abstract, reusable process steps without specific details. Should be generalizable to similar tasks.
+- **experience**: General lessons, principles, or insights.
+- **preference**: User's overarching preference patterns.
+- **update**: true if updating existing skill, false if new.
+- **old_memory_id**: ID of skill being updated, or empty string if new.
+- **whether_use_chat_history**: Indicates whether information from chat_history that does not appear in messages was incorporated into the skill.
+- **content_of_related_chat_history**: If whether_use_chat_history is true, provide a high-level summary of the type of historical information used (e.g., “long-term preference: prioritizes cultural attractions”); do not quote the original dialogue verbatim. If not used, leave this field as an empty string.
+- **examples**: Complete output templates showing the final deliverable format and structure. Should demonstrate how the task result looks when this skill is applied, including format, sections, and content organization. Content can be abbreviated but must show the complete structure. Use markdown format for better readability
+- **scripts**: If the skill examples requires an implementation involving code, you must provide a TODO list that clearly enumerates: (1) The components or steps that need to be implemented, (2) The expected inputs, (3)The expected outputs. Detailed code or full implementations are not required. Use null if no specific code is required.
+- **tool**: If links or interface information appear in the context, it indicates that the skill needs to rely on specific tools (such as websearch, external APIs, or system tools) during the answering process. Please list the tool names. If no specific tools are detected, please use null.
+- **others**: If must have additional supporting sections for the skill or other dependencies, structured as key–value pairs. For example: {"reference.md": "A concise summary of the reference content"}. Only need to give the writing requirements, no need to provide the full documentation content.
+
+# Key Guidelines
+- Return null if a skill cannot be extracted.
+- Only create a new methodology when necessary. In the same scenario, try to merge them ("update": true).
+For example, merge dietary planning into one entry. Do not add a new "Keto Diet Planning" if "Dietary Planning" already exists, because skills are a universal template. You can choose to add preferences and triggers to update "Dietary Planning".
+
+# Output Format
+Output the JSON object only.
+"""
+
+
+SKILL_MEMORY_EXTRACTION_PROMPT_MD_ZH = """
+# 角色
+你是技能抽象和知识提取的专家。你擅长从上下文的具体对话中提炼通用的、可复用的方法流程,从而可以在后续遇到相似任务中允许直接执行该工作流程及脚本。
+
+# 任务
+通过分析历史相关对话和**给定当前对话消息**中提取可应用于类似场景的**有效且通用**的技能模板,同时还需要分析现有的技能的描述和触发关键字(trigger),判断与当前对话是否相关,从而决定技能是需要新建还是更新。
+
+# 先决条件
+## 长期相关记忆
+{old_memories}
+
+## 短期对话
+{chat_history}
+
+## 当前对话消息
+{messages}
+
+# 技能提取原则
+为了确定技能的内容,综合分析对话内容以创建可重复使用资源的清单,包括脚本、参考资料和资源,请你按照下面的原则来生成技能:
+1. **通用化**:提取可跨场景应用的抽象方法论。避免具体细节(如"旅行规划"而非"北京旅行规划")。 而且提取的技能应该是持久有效的,而非与特定时间绑定。
+2. **相似性检查**:如果‘现有技能记忆’中的技能列表不为空,且存在**相同主题**的技能,则需要设置"update": true 及"old_memory_id"。否则设置"update": false 并将"old_memory_id"留空。
+3. **语言一致性**:与对话语言保持一致。
+4. **历史使用约束**:“历史相关对话”作为辅助上下文,若当前历史消息不足以形成完整的技能,且历史相关对话能提供 messages 中缺失、且与当前任务目标、执行方式或约束相关的信息增量时,可以纳入考虑。
+注意:如果相似性检查结果是存在已有的**一个**技能描述的是同一个主题,请务必使用更新操作,并将old_memory_id设置为该历史技能的id,不要新建一个方法论,注意合理的追加到已有的技能记忆上,保证通顺的同时不丢失已有方法论的信息。
+
+# 输出格式的模版和字段规范描述
+## 输出格式
+```json
+{
+ "name": "通用技能名称(如:'旅行行程规划'、'代码审查流程')",
+ "description": "技能作用的通用描述",
+ "trigger": ["关键词1", "关键词2"],
+ "procedure": "通用的分步流程:1. 步骤一 2. 步骤二...",
+ "experience": ["通用原则或经验教训", "对于可能出现错误的处理情况", "可应用于类似场景的最佳实践..."],
+ "preference": ["用户的通用偏好模式", "偏好的方法或约束..."],
+ "update": false,
+ "old_memory_id": "",
+ "content_of_current_message": "",
+ "whether_use_chat_history": false,
+ "content_of_related_chat_history": "",
+ "examples": ["展示最终交付成果的完整格式范本(使用 markdown 格式), 内容可用'...'省略,但需展示完整格式和结构"],
+ "scripts": "一个代码待办列表和需求说明。如果不需要特定代码,请使用 null.",
+ "tool": "所需特定外部工具列表(例如,如果上下文中出现了链接或接口信息,则需要使用websearch或外部 API)。",
+ "others": {"reference.md": "其他对于执行技能必须的参考内容(例如,示例、教程或最佳实践)"}。只需要给出撰写要求,无需完整的文档内容。
+}
+```
+
+## 字段规范
+- **name**:通用技能标识符,不含具体实例
+- **description**:通用用途和适用范围
+- **trigger**:触发技能执行的关键字列表,用于自动识别任务场景
+- **procedure**:抽象的、可复用的流程步骤,不含具体细节。应当能够推广到类似任务
+- **experience**:通用经验、原则或见解
+- **preference**:用户的整体偏好模式
+- **update**:更新现有技能为true,新建为false
+- **old_memory_id**:被更新技能的ID,新建则为空字符串
+- **content_of_current_message**: 从当前对话消息中提取的核心内容(简写但必填),
+- **whether_use_chat_history**:是否从 chat_history 中引用了 messages 中没有的内容并提取到skill中
+- **content_of_related_chat_history**:若 whether_use_chat_history 为 true,仅需概括性说明所使用的历史信息类型(如“长期偏好:文化类景点优先”),不要求逐字引用原始对话内容;若未使用,则置为空字符串。
+- **examples**:展示最终任务成果的输出模板,包括格式、章节和内容组织结构。应展示应用此技能后任务结果的样子,包含所有必要的部分。内容可以省略但必须展示完整结构。使用 markdown 格式以提高可读性
+- **scripts**:如果技能examples需要实现代码,必须提供一个待办列表,清晰枚举:(1) 需实现的组件或步骤,(2) 预期输入,(3) 预期输出。详细代码或完整实现不是必须的。如果不需要特定代码,请使用 null.
+- **tool**:如果上下文中出现了链接或接口信息,则表明在回答过程中技能需要依赖特定工具(如websearch或外部 API),请列出工具名称。
+- **others**:如果必须要其他支持性章节或其他依赖项,格式为键值对,例如:{"reference.md": "参考内容的简要总结"}。只需要给出撰写要求,无需完整的文档内容。
+
+# 关键指导
+- 无法提取技能时返回null
+- 一定仅在必要时才新建方法论,同样的场景尽量合并("update": true),
+如饮食规划合并为一条,不要已有“饮食规划”的情况下,再新增一个“生酮饮食规划”,因为技能是一个通用的模版,可以选择添加preference和trigger来更新“饮食规划”。
+
+请生成技能模版,返回上述JSON对象
+"""
+
+
TASK_QUERY_REWRITE_PROMPT = """
# Role
You are an expert in understanding user intentions and task requirements. You excel at analyzing conversations and extracting the core task description.
@@ -284,3 +428,107 @@
SKILLS_AUTHORING_PROMPT = """
"""
+
+
+SCRIPT_GENERATION_PROMPT = """
+# Role
+You are a Senior Python Developer and Architect.
+
+# Task
+Generate production-ready, executable Python scripts based on the provided requirements and context.
+The scripts will be part of a skill package used by an AI agent or a developer.
+
+# Requirements
+{requirements}
+
+# Context
+{context}
+
+# Instructions
+1. **Completeness**: The code must be fully functional and self-contained. DO NOT use placeholders like `# ...`, `pass` (unless necessary), or `TODO`.
+2. **Robustness**: Include comprehensive error handling (try-except blocks) and input validation.
+3. **Style**: Follow PEP 8 guidelines. Use type hints for all function signatures.
+4. **Dependencies**: Use standard libraries whenever possible. If external libraries are needed, list them in a comment at the top.
+5. **Main Guard**: Include `if __name__ == "__main__":` blocks with example usage or test cases.
+
+# Output Format
+Return ONLY a valid JSON object where keys are filenames (e.g., "utils.py", "main_task.py") and values are the raw code strings.
+```json
+{{
+ "filename.py": "import os\\n\\ndef func():\\n ..."
+}}
+```
+"""
+
+TOOL_GENERATION_PROMPT = """
+# Task
+Analyze the `Requirements` and `Context` to identify the relevant tools from the provided `Available Tools`. Return a list of the **names** of the matching tools.
+
+# Constraints
+1. **Selection Criteria**: Include a tool name only if the tool's schema directly addresses the user's requirements.
+2. **Empty Set Logic**: If `Available Tools` is empty or no relevant tools are found, you **must** return an empty JSON array: `[]`.
+3. **Format Purity**: Return ONLY the JSON array of strings. Do not provide commentary, justifications, or any text outside the JSON block.
+
+# Available Tools
+{tool_schemas}
+
+# Requirements
+{requirements}
+
+# Context
+{context}
+
+# Output
+```json
+[
+ "tool_name_1",
+ "tool_name_2"
+]
+```
+"""
+
+OTHERS_GENERATION_PROMPT = """
+# Task
+Create detailed, well-structured documentation for the file '{filename}' based on the provided summary and context.
+
+# Summary
+{summary}
+
+# Context
+{context}
+
+# Instructions
+1. **Structure**:
+ - **Introduction**: Brief overview of the topic.
+ - **Detailed Content**: The main body of the documentation, organized with headers (##, ###).
+ - **Key Concepts/Reference**: Definitions or reference tables if applicable.
+ - **Conclusion/Next Steps**: Wrap up or point to related resources.
+2. **Formatting**: Use Markdown effectively (lists, tables, code blocks, bold text) to enhance readability.
+3. **Language Consistency**: Keep consistent with **the language of the context**.
+
+# Output Format
+Return the content directly in Markdown format.
+"""
+
+OTHERS_GENERATION_PROMPT_ZH = """
+# 任务
+根据提供的摘要和上下文,为文件 '{filename}' 创建详细且结构良好的文档。
+
+# 摘要
+{summary}
+
+# 上下文
+{context}
+
+# 指南
+1. **结构**:
+- **简介**:对主题进行简要概述。
+- **详细内容**:文档的主体内容,使用标题(##, ###)进行组织。
+- **关键概念/参考**:如果适用,提供定义或参考表格。
+- **结论/下一步**:总结或指向相关资源。
+2. **格式**:有效使用 Markdown(列表、表格、代码块、加粗文本)以增强可读性。
+3. **语言一致性**:保持与**上下文语言**一致。
+
+# 输出格式
+以 Markdown 格式直接返回内容。
+"""
diff --git a/src/memos/types/general_types.py b/src/memos/types/general_types.py
index 44c75ec02..8234caf8b 100644
--- a/src/memos/types/general_types.py
+++ b/src/memos/types/general_types.py
@@ -10,7 +10,7 @@
from enum import Enum
from typing import Literal, NewType, TypeAlias
-from pydantic import BaseModel
+from pydantic import BaseModel, ConfigDict
from typing_extensions import TypedDict
from memos.memories.activation.item import ActivationMemoryItem
@@ -149,3 +149,7 @@ class UserContext(BaseModel):
mem_cube_id: str | None = None
session_id: str | None = None
operation: list[PermissionDict] | None = None
+ manager_user_id: str | None = None
+ project_id: str | None = None
+
+ model_config = ConfigDict(extra="allow")
diff --git a/src/memos/types/openai_chat_completion_types/__init__.py b/src/memos/types/openai_chat_completion_types/__init__.py
index 4a08a9f24..025e75360 100644
--- a/src/memos/types/openai_chat_completion_types/__init__.py
+++ b/src/memos/types/openai_chat_completion_types/__init__.py
@@ -1,4 +1,4 @@
-# ruff: noqa: F403, F401
+# ruff: noqa: F403
from .chat_completion_assistant_message_param import *
from .chat_completion_content_part_image_param import *
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py
index 3c5638788..f28796c2d 100644
--- a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py
@@ -1,4 +1,4 @@
-# ruff: noqa: TC001, TC003
+# ruff: noqa: TC001
from __future__ import annotations
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py
index ea2101229..13a9a89af 100644
--- a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py
@@ -1,4 +1,4 @@
-# ruff: noqa: TC001, TC003
+# ruff: noqa: TC001
from __future__ import annotations
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py
index 99c845d11..f76f2b862 100644
--- a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py
@@ -1,4 +1,4 @@
-# ruff: noqa: TC001, TC003
+# ruff: noqa: TC001
from __future__ import annotations
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py
index 8c004f340..b5bee9842 100644
--- a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py
@@ -1,4 +1,4 @@
-# ruff: noqa: TC001, TC003
+# ruff: noqa: TC001
from __future__ import annotations
diff --git a/tests/mem_reader/test_coarse_memory_type.py b/tests/mem_reader/test_coarse_memory_type.py
index bd90d6a69..1cbb6b2eb 100644
--- a/tests/mem_reader/test_coarse_memory_type.py
+++ b/tests/mem_reader/test_coarse_memory_type.py
@@ -64,7 +64,7 @@ def test_chat_passthrough():
def test_doc_local_file():
- local_path, content = create_temp_file("test local file content")
+ local_path, _content = create_temp_file("test local file content")
result = coerce_scene_data([local_path], "doc")
filename = os.path.basename(local_path)
@@ -108,7 +108,7 @@ def test_doc_plain_text():
def test_doc_mixed():
- local_path, content = create_temp_file("local file content")
+ local_path, _content = create_temp_file("local file content")
url = "https://example.com/x.pdf"
plain = "hello world"
diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py
new file mode 100644
index 000000000..46cf3a1f6
--- /dev/null
+++ b/tests/memories/textual/test_history_manager.py
@@ -0,0 +1,137 @@
+import uuid
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from memos.extras.nli_model.client import NLIClient
+from memos.extras.nli_model.types import NLIResult
+from memos.graph_dbs.base import BaseGraphDB
+from memos.memories.textual.item import (
+ TextualMemoryItem,
+ TextualMemoryMetadata,
+)
+from memos.memories.textual.tree_text_memory.organize.history_manager import (
+ MemoryHistoryManager,
+ _append_related_content,
+ _detach_related_content,
+)
+
+
+@pytest.fixture
+def mock_nli_client():
+ client = MagicMock(spec=NLIClient)
+ return client
+
+
+@pytest.fixture
+def mock_graph_db():
+ return MagicMock(spec=BaseGraphDB)
+
+
+@pytest.fixture
+def history_manager(mock_nli_client, mock_graph_db):
+ return MemoryHistoryManager(nli_client=mock_nli_client, graph_db=mock_graph_db)
+
+
+def test_detach_related_content():
+ original_memory = "This is the original memory content."
+ item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata())
+
+ duplicates = ["Duplicate 1", "Duplicate 2"]
+ conflicts = ["Conflict 1", "Conflict 2"]
+
+ # 1. Append content
+ _append_related_content(item, duplicates, conflicts)
+
+ # Verify content was appended
+ assert item.memory != original_memory
+ assert "[possibly conflicting memories]" in item.memory
+ assert "[possibly duplicate memories]" in item.memory
+ assert "Duplicate 1" in item.memory
+ assert "Conflict 1" in item.memory
+
+ # 2. Detach content
+ _detach_related_content(item)
+
+ # 3. Verify content is restored
+ assert item.memory == original_memory
+
+
+def test_detach_only_conflicts():
+ original_memory = "Original memory."
+ item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata())
+
+ duplicates = []
+ conflicts = ["Conflict A"]
+
+ _append_related_content(item, duplicates, conflicts)
+ assert "Conflict A" in item.memory
+ assert "Duplicate" not in item.memory
+
+ _detach_related_content(item)
+ assert item.memory == original_memory
+
+
+def test_detach_only_duplicates():
+ original_memory = "Original memory."
+ item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata())
+
+ duplicates = ["Duplicate A"]
+ conflicts = []
+
+ _append_related_content(item, duplicates, conflicts)
+ assert "Duplicate A" in item.memory
+ assert "Conflict" not in item.memory
+
+ _detach_related_content(item)
+ assert item.memory == original_memory
+
+
+def test_truncation(history_manager, mock_nli_client):
+ # Setup
+ new_item = TextualMemoryItem(memory="Test")
+ long_memory = "A" * 300
+ related_item = TextualMemoryItem(memory=long_memory)
+
+ mock_nli_client.compare_one_to_many.return_value = [NLIResult.DUPLICATE]
+
+ # Action
+ history_manager.resolve_history_via_nli(new_item, [related_item])
+
+ # Assert
+ assert "possibly duplicate memories" in new_item.memory
+ assert "..." in new_item.memory # Should be truncated
+ assert len(new_item.memory) < 1000 # Ensure reasonable length
+
+
+def test_empty_related_items(history_manager, mock_nli_client):
+ new_item = TextualMemoryItem(memory="Test")
+ history_manager.resolve_history_via_nli(new_item, [])
+
+ mock_nli_client.compare_one_to_many.assert_not_called()
+ assert new_item.metadata.history is None or len(new_item.metadata.history) == 0
+
+
+def test_mark_memory_status(history_manager, mock_graph_db):
+ # Setup
+ id1 = uuid.uuid4().hex
+ id2 = uuid.uuid4().hex
+ id3 = uuid.uuid4().hex
+ items = [
+ TextualMemoryItem(memory="M1", id=id1),
+ TextualMemoryItem(memory="M2", id=id2),
+ TextualMemoryItem(memory="M3", id=id3),
+ ]
+ status = "resolving"
+
+ # Action
+ history_manager.mark_memory_status(items, status)
+
+ # Assert
+ assert mock_graph_db.update_node.call_count == 3
+
+ # Verify we called it correctly
+ mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status})
+ mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status})
+ mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status})
diff --git a/tests/memories/textual/test_tree_searcher.py b/tests/memories/textual/test_tree_searcher.py
index 2a5536cf8..3d1469d00 100644
--- a/tests/memories/textual/test_tree_searcher.py
+++ b/tests/memories/textual/test_tree_searcher.py
@@ -48,11 +48,23 @@ def test_searcher_fast_path(mock_searcher):
mock_searcher.embedder.embed.return_value = [[0.1] * 5, [0.2] * 5]
# working path mock
- mock_searcher.graph_retriever.retrieve.side_effect = [
- [make_item("wm1", 0.9)[0]], # working memory
- [make_item("lt1", 0.8)[0]], # long-term
- [make_item("um1", 0.7)[0]], # user
- ]
+ # For "All", _retrieve_from_working_memory calls once (WorkingMemory),
+ # and _retrieve_from_long_term_and_user calls 3 times (LongTermMemory, UserMemory, RawFileMemory)
+ # Use a function to handle concurrent calls with different memory_scope
+ def retrieve_side_effect(*args, **kwargs):
+ memory_scope = kwargs.get("memory_scope", "")
+ if memory_scope == "WorkingMemory":
+ return [make_item("wm1", 0.9)[0]]
+ elif memory_scope == "LongTermMemory":
+ return [make_item("lt1", 0.8)[0]]
+ elif memory_scope == "UserMemory":
+ return [make_item("um1", 0.7)[0]]
+ elif memory_scope == "RawFileMemory":
+ return [make_item("rm1", 0.6)[0]]
+ else:
+ return []
+
+ mock_searcher.graph_retriever.retrieve.side_effect = retrieve_side_effect
mock_searcher.reranker.rerank.return_value = [
make_item("wm1", 0.9),
make_item("lt1", 0.8),