diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 2b9c137ca..a35caf4d9 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -14,6 +14,7 @@ from memos.api.product_models import ( DeleteMemoryRequest, DeleteMemoryResponse, + GetMemoryDashboardRequest, GetMemoryRequest, GetMemoryResponse, MemoryResponse, @@ -353,3 +354,76 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: message="Memories deleted successfully", data={"status": "success"}, ) + + +# ============================================================================= +# Other handler functions Endpoints (for internal use) +# ============================================================================= + + +def handle_get_memories_dashboard( + get_mem_req: GetMemoryDashboardRequest, naive_mem_cube: NaiveMemCube +) -> GetMemoryResponse: + results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": [], "skill_mem": []} + memories = naive_mem_cube.text_mem.get_all( + user_name=get_mem_req.mem_cube_id, + user_id=get_mem_req.user_id, + page=get_mem_req.page, + page_size=get_mem_req.page_size, + filter=get_mem_req.filter, + )["nodes"] + + results = post_process_textual_mem(results, memories, get_mem_req.mem_cube_id) + + if not get_mem_req.include_tool_memory: + results["tool_mem"] = [] + if not get_mem_req.include_skill_memory: + results["skill_mem"] = [] + + preferences: list[TextualMemoryItem] = [] + + format_preferences = [] + if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: + filter_params: dict[str, Any] = {} + if get_mem_req.user_id is not None: + filter_params["user_id"] = get_mem_req.user_id + if get_mem_req.mem_cube_id is not None: + filter_params["mem_cube_id"] = get_mem_req.mem_cube_id + if get_mem_req.filter is not None: + # Check and remove user_id/mem_cube_id from filter if present + filter_copy = get_mem_req.filter.copy() + removed_fields = [] + + if "user_id" in filter_copy: + filter_copy.pop("user_id") + removed_fields.append("user_id") + if "mem_cube_id" in filter_copy: + filter_copy.pop("mem_cube_id") + removed_fields.append("mem_cube_id") + + if removed_fields: + logger.warning( + f"Fields {removed_fields} found in filter will be ignored. " + f"Use request-level user_id/mem_cube_id parameters instead." + ) + + filter_params.update(filter_copy) + + preferences, _ = naive_mem_cube.pref_mem.get_memory_by_filter( + filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size + ) + format_preferences = [format_memory_item(item, save_sources=False) for item in preferences] + + results = post_process_pref_mem( + results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference + ) + + # Filter to only keep text_mem, pref_mem, tool_mem + filtered_results = { + "text_mem": results.get("text_mem", []), + "pref_mem": results.get("pref_mem", []), + "tool_mem": results.get("tool_mem", []), + "skill_mem": results.get("skill_mem", []), + } + + return GetMemoryResponse(message="Memories retrieved successfully", data=filtered_results) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 2b6c8b420..7ba31669c 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -507,6 +507,8 @@ class APIADDRequest(BaseRequest): description="Session ID. If not provided, a default session will be used.", ) task_id: str | None = Field(None, description="Task ID for monitering async tasks") + manager_user_id: str | None = Field(None, description="Manager User ID") + project_id: str | None = Field(None, description="Project ID") # ==== Multi-cube writing ==== writable_cube_ids: list[str] | None = Field( @@ -814,6 +816,12 @@ class GetMemoryRequest(BaseRequest): ) +class GetMemoryDashboardRequest(GetMemoryRequest): + """Request model for getting memories for dashboard.""" + + mem_cube_id: str | None = Field(None, description="Cube ID") + + class DeleteMemoryRequest(BaseRequest): """Request model for deleting memories.""" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 243fb36cd..83079239f 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -38,6 +38,7 @@ DeleteMemoryResponse, ExistMemCubeIdRequest, ExistMemCubeIdResponse, + GetMemoryDashboardRequest, GetMemoryPlaygroundRequest, GetMemoryRequest, GetMemoryResponse, @@ -456,3 +457,13 @@ def recover_memory_by_record_id(memory_req: RecoverMemoryByRecordIdRequest): message="Called Successfully", data={"status": "success"}, ) + + +@router.post( + "/get_memory_dashboard", summary="Get memories for dashboard", response_model=GetMemoryResponse +) +def get_memories_dashboard(memory_req: GetMemoryDashboardRequest): + return handlers.memory_handler.handle_get_memories_dashboard( + get_mem_req=memory_req, + naive_mem_cube=naive_mem_cube, + ) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 409b3a967..1bedfdc40 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1691,7 +1691,7 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: raise NotImplementedError @timed - def seach_by_keywords_like( + def search_by_keywords_like( self, query_word: str, scope: str | None = None, @@ -1761,7 +1761,7 @@ def seach_by_keywords_like( params = (query_word,) logger.info( - f"[seach_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}" + f"[search_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}" ) conn = None try: @@ -1773,16 +1773,18 @@ def seach_by_keywords_like( for row in results: oldid = row[0] id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] output.append({"id": id_val}) logger.info( - f"[seach_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" ) return output finally: self._return_connection(conn) @timed - def seach_by_keywords_tfidf( + def search_by_keywords_tfidf( self, query_words: list[str], scope: str | None = None, @@ -1858,7 +1860,7 @@ def seach_by_keywords_tfidf( params = (tsquery_string,) logger.info( - f"[seach_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}" + f"[search_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}" ) conn = None try: @@ -1870,10 +1872,12 @@ def seach_by_keywords_tfidf( for row in results: oldid = row[0] id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] output.append({"id": id_val}) logger.info( - f"[seach_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" ) return output finally: @@ -2003,6 +2007,8 @@ def search_by_fulltext( rank = row[2] # rank score id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] score_val = float(rank) # Apply threshold filter if specified @@ -2167,6 +2173,8 @@ def search_by_embedding( oldid = row[3] # old_id score = row[4] # scope id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] score_val = float(score) score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score if threshold is None or score_val >= threshold: diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index e38318a64..6e24ca7a5 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -924,7 +924,7 @@ def process_keyword_replace( ) must_part = f"{' & '.join(queries)}" if len(queries) > 1 else queries[0] - retrieved_ids = self.graph_store.seach_by_keywords_tfidf( + retrieved_ids = self.graph_store.search_by_keywords_tfidf( [must_part], user_name=user_name, filter=filter_dict ) if len(retrieved_ids) < 1: @@ -932,7 +932,7 @@ def process_keyword_replace( queries, top_k=100, user_name=user_name, filter=filter_dict ) else: - retrieved_ids = self.graph_store.seach_by_keywords_like( + retrieved_ids = self.graph_store.search_by_keywords_like( f"%{original_word}%", user_name=user_name, filter=filter_dict ) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index ce75f6dc5..1a312868a 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -3,7 +3,7 @@ import re import traceback -from typing import Any +from typing import TYPE_CHECKING, Any from memos import log from memos.configs.mem_reader import MultiModalStructMemReaderConfig @@ -20,6 +20,10 @@ from memos.utils import timed +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + logger = log.get_logger(__name__) @@ -667,6 +671,12 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]: if file_ids: extra_kwargs["file_ids"] = file_ids + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + if user_context: + extra_kwargs["manager_user_id"] = user_context.manager_user_id + extra_kwargs["project_id"] = user_context.project_id + # Determine prompt type based on sources prompt_type = self._determine_prompt_type(sources) @@ -782,6 +792,11 @@ def _process_tool_trajectory_fine( fine_memory_items = [] + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + for fast_item in fast_memory_items: # Extract memory text (string content) mem_str = fast_item.memory or "" @@ -808,6 +823,8 @@ def _process_tool_trajectory_fine( correctness=m.get("correctness", ""), experience=m.get("experience", ""), tool_used_status=m.get("tool_used_status", []), + manager_user_id=manager_user_id, + project_id=project_id, ) fine_memory_items.append(node) except Exception as e: diff --git a/src/memos/mem_reader/read_multi_modal/assistant_parser.py b/src/memos/mem_reader/read_multi_modal/assistant_parser.py index 89d4fec7f..bac9deaad 100644 --- a/src/memos/mem_reader/read_multi_modal/assistant_parser.py +++ b/src/memos/mem_reader/read_multi_modal/assistant_parser.py @@ -2,7 +2,7 @@ import json -from typing import Any +from typing import TYPE_CHECKING, Any from memos.embedders.base import BaseEmbedder from memos.llms.base import BaseLLM @@ -18,6 +18,10 @@ from .utils import detect_lang +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + logger = get_logger(__name__) @@ -281,6 +285,11 @@ def parse_fast( user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + # Create memory item (equivalent to _make_memory_item) memory_item = TextualMemoryItem( memory=line, @@ -298,6 +307,8 @@ def parse_fast( confidence=0.99, type="fact", info=info_, + manager_user_id=manager_user_id, + project_id=project_id, ), ) diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 9f4ab94c2..0f3f3ef01 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -5,7 +5,7 @@ import re import tempfile -from typing import Any +from typing import TYPE_CHECKING, Any from tqdm import tqdm @@ -34,6 +34,10 @@ from memos.types.openai_chat_completion_types import File +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + logger = get_logger(__name__) # Prompt dictionary for doc processing (shared by simple_struct and file_content_parser) @@ -451,6 +455,11 @@ def parse_fast( user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + # For file content parts, default to LongTermMemory # (since we don't have role information at this level) memory_type = "LongTermMemory" @@ -495,6 +504,8 @@ def parse_fast( type="fact", info=info_, file_ids=file_ids, + manager_user_id=manager_user_id, + project_id=project_id, ), ) memory_items.append(memory_item) @@ -527,6 +538,8 @@ def parse_fast( type="fact", info=info_, file_ids=file_ids, + manager_user_id=manager_user_id, + project_id=project_id, ), ) memory_items.append(memory_item) @@ -644,6 +657,12 @@ def parse_fine( info_ = info.copy() user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") + + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + if file_id: info_["file_id"] = file_id file_ids = [file_id] if file_id else [] @@ -702,6 +721,8 @@ def _make_memory_item( type="fact", info=info_, file_ids=file_ids, + manager_user_id=manager_user_id, + project_id=project_id, ), ) diff --git a/src/memos/mem_reader/read_multi_modal/image_parser.py b/src/memos/mem_reader/read_multi_modal/image_parser.py index 9322b9bc9..97400ca26 100644 --- a/src/memos/mem_reader/read_multi_modal/image_parser.py +++ b/src/memos/mem_reader/read_multi_modal/image_parser.py @@ -3,7 +3,7 @@ import json import re -from typing import Any +from typing import TYPE_CHECKING, Any from memos.embedders.base import BaseEmbedder from memos.llms.base import BaseLLM @@ -20,6 +20,10 @@ from .utils import detect_lang +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + logger = get_logger(__name__) @@ -212,6 +216,7 @@ def parse_fine( key=_derive_key(summary), sources=[source], background=summary, + **kwargs, ) ) return memory_items @@ -252,6 +257,7 @@ def parse_fine( key=key if key else _derive_key(value), sources=[source], background=background, + **kwargs, ) memory_items.append(memory_item) except Exception as e: @@ -273,6 +279,7 @@ def parse_fine( key=_derive_key(fallback_value), sources=[source], background="Image processing encountered an error.", + **kwargs, ) ] @@ -333,12 +340,18 @@ def _create_memory_item( key: str, sources: list[SourceMessage], background: str = "", + **kwargs, ) -> TextualMemoryItem: """Create a TextualMemoryItem with the given parameters.""" info_ = info.copy() user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + return TextualMemoryItem( memory=value, metadata=TreeNodeTextualMemoryMetadata( @@ -355,5 +368,7 @@ def _create_memory_item( confidence=0.99, type="fact", info=info_, + manager_user_id=manager_user_id, + project_id=project_id, ), ) diff --git a/src/memos/mem_reader/read_multi_modal/string_parser.py b/src/memos/mem_reader/read_multi_modal/string_parser.py index b6e18fda3..220cf6e58 100644 --- a/src/memos/mem_reader/read_multi_modal/string_parser.py +++ b/src/memos/mem_reader/read_multi_modal/string_parser.py @@ -3,7 +3,7 @@ Handles simple string messages that need to be converted to memory items. """ -from typing import Any +from typing import TYPE_CHECKING, Any from memos.embedders.base import BaseEmbedder from memos.llms.base import BaseLLM @@ -17,6 +17,10 @@ from .base import BaseMessageParser, _add_lang_to_source, _derive_key +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + logger = get_logger(__name__) @@ -92,6 +96,11 @@ def parse_fast( user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + # For string messages, default to LongTermMemory memory_type = "LongTermMemory" @@ -120,6 +129,8 @@ def parse_fast( confidence=0.99, type="fact", info=info_, + manager_user_id=manager_user_id, + project_id=project_id, ), ) memory_items.append(memory_item) diff --git a/src/memos/mem_reader/read_multi_modal/system_parser.py b/src/memos/mem_reader/read_multi_modal/system_parser.py index 03a49afd8..74545ceee 100644 --- a/src/memos/mem_reader/read_multi_modal/system_parser.py +++ b/src/memos/mem_reader/read_multi_modal/system_parser.py @@ -6,7 +6,7 @@ import re import uuid -from typing import Any +from typing import TYPE_CHECKING, Any from memos.embedders.base import BaseEmbedder from memos.llms.base import BaseLLM @@ -21,6 +21,10 @@ from .base import BaseMessageParser, _add_lang_to_source +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + logger = get_logger(__name__) @@ -242,6 +246,11 @@ def format_tool_schema_readable(tool_schema): user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + # Split parsed text into chunks content_chunks = self._split_text(msg_line) @@ -260,6 +269,8 @@ def format_tool_schema_readable(tool_schema): tags=["mode:fast"], sources=[source], info=info_, + manager_user_id=manager_user_id, + project_id=project_id, ), ) memory_items.append(memory_item) @@ -294,6 +305,11 @@ def parse_fine( user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + # Deduplicate tool schemas based on memory content # Use hash as key for efficiency, but store original string to handle collisions seen_memories = {} # hash -> memory_str mapping @@ -321,6 +337,8 @@ def parse_fine( status="activated", embedding=self.embedder.embed([json.dumps(schema, ensure_ascii=False)])[0], info=info_, + manager_user_id=manager_user_id, + project_id=project_id, ), ) for schema in unique_schemas diff --git a/src/memos/mem_reader/read_multi_modal/text_content_parser.py b/src/memos/mem_reader/read_multi_modal/text_content_parser.py index 549f74852..9fdcf8c58 100644 --- a/src/memos/mem_reader/read_multi_modal/text_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/text_content_parser.py @@ -4,7 +4,7 @@ Text content parts are typically used in user/assistant messages with multimodal content. """ -from typing import Any +from typing import TYPE_CHECKING, Any from memos.embedders.base import BaseEmbedder from memos.llms.base import BaseLLM @@ -19,6 +19,10 @@ from .base import BaseMessageParser, _add_lang_to_source, _derive_key +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + logger = get_logger(__name__) @@ -92,6 +96,11 @@ def parse_fast( user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + # For text content parts, default to LongTermMemory # (since we don't have role information at this level) memory_type = "LongTermMemory" @@ -113,6 +122,8 @@ def parse_fast( confidence=0.99, type="fact", info=info_, + manager_user_id=manager_user_id, + project_id=project_id, ), ) diff --git a/src/memos/mem_reader/read_multi_modal/tool_parser.py b/src/memos/mem_reader/read_multi_modal/tool_parser.py index caf5ffaa6..4718f87ba 100644 --- a/src/memos/mem_reader/read_multi_modal/tool_parser.py +++ b/src/memos/mem_reader/read_multi_modal/tool_parser.py @@ -2,7 +2,7 @@ import json -from typing import Any +from typing import TYPE_CHECKING, Any from memos.embedders.base import BaseEmbedder from memos.llms.base import BaseLLM @@ -18,6 +18,10 @@ from .utils import detect_lang +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + logger = get_logger(__name__) @@ -179,6 +183,11 @@ def parse_fast( user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + content_chunks = self._split_text(line) memory_items = [] for _chunk_idx, chunk_text in enumerate(content_chunks): @@ -195,6 +204,8 @@ def parse_fast( tags=["mode:fast"], sources=sources, info=info_, + manager_user_id=manager_user_id, + project_id=project_id, ), ) memory_items.append(memory_item) diff --git a/src/memos/mem_reader/read_multi_modal/user_parser.py b/src/memos/mem_reader/read_multi_modal/user_parser.py index 1ab48c82e..abfebc5db 100644 --- a/src/memos/mem_reader/read_multi_modal/user_parser.py +++ b/src/memos/mem_reader/read_multi_modal/user_parser.py @@ -1,6 +1,6 @@ """Parser for user messages.""" -from typing import Any +from typing import TYPE_CHECKING, Any from memos.embedders.base import BaseEmbedder from memos.llms.base import BaseLLM @@ -16,6 +16,10 @@ from .utils import detect_lang +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + logger = get_logger(__name__) @@ -183,6 +187,11 @@ def parse_fast( user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + # Create memory item (equivalent to _make_memory_item) memory_item = TextualMemoryItem( memory=line, @@ -200,6 +209,8 @@ def parse_fast( confidence=0.99, type="fact", info=info_, + manager_user_id=manager_user_id, + project_id=project_id, ), ) diff --git a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py index 6bd18808d..7bd3f3ebb 100644 --- a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py +++ b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py @@ -7,7 +7,7 @@ from concurrent.futures import as_completed from datetime import datetime from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from memos.context.context import ContextThreadPoolExecutor from memos.dependency import require_python_package @@ -29,6 +29,10 @@ from memos.types import MessageList +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + logger = get_logger(__name__) @@ -494,12 +498,20 @@ def _write_skills_to_file( def create_skill_memory_item( - skill_memory: dict[str, Any], info: dict[str, Any], embedder: BaseEmbedder | None = None + skill_memory: dict[str, Any], + info: dict[str, Any], + embedder: BaseEmbedder | None = None, + **kwargs: Any, ) -> TextualMemoryItem: info_ = info.copy() user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") + # Extract manager_user_id and project_id from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + # Use description as the memory content memory_content = skill_memory.get("description", "") @@ -530,6 +542,8 @@ def create_skill_memory_item( scripts=skill_memory.get("scripts"), others=skill_memory.get("others"), url=skill_memory.get("url", ""), + manager_user_id=manager_user_id, + project_id=project_id, ) # If this is an update, use the old memory ID @@ -748,7 +762,7 @@ def process_skill_memory_fine( skill_memory_items = [] for skill_memory in skill_memories: try: - memory_item = create_skill_memory_item(skill_memory, info, embedder) + memory_item = create_skill_memory_item(skill_memory, info, embedder, **kwargs) skill_memory_items.append(memory_item) except Exception as e: logger.warning(f"[PROCESS_SKILLS] Error creating skill memory item: {e}") diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 74e50a514..44bea3dec 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -41,6 +41,7 @@ MemCubeID, UserID, ) +from memos.types.general_types import UserContext logger = get_logger(__name__) @@ -765,6 +766,7 @@ def process_message(message: ScheduleMessageItem): user_name = message.user_name info = message.info or {} chat_history = message.chat_history + user_context = message.user_context # Parse the memory IDs from content mem_ids = json.loads(content) if isinstance(content, str) else content @@ -792,6 +794,7 @@ def process_message(message: ScheduleMessageItem): task_id=message.task_id, info=info, chat_history=chat_history, + user_context=user_context, ) logger.info( @@ -820,6 +823,7 @@ def _process_memories_with_reader( task_id: str | None = None, info: dict | None = None, chat_history: list | None = None, + user_context: UserContext | None = None, ) -> None: logger.info( f"[DIAGNOSTIC] general_scheduler._process_memories_with_reader called. mem_ids: {mem_ids}, user_id: {user_id}, mem_cube_id: {mem_cube_id}, task_id: {task_id}" @@ -882,6 +886,7 @@ def _process_memories_with_reader( custom_tags=custom_tags, user_name=user_name, chat_history=chat_history, + user_context=user_context, ) except Exception as e: logger.warning(f"{e}: Fail to transfer mem: {memory_items}") @@ -1340,6 +1345,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id = message.mem_cube_id content = message.content messages_list = json.loads(content) + user_context = message.user_context info = message.info or {} logger.info(f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}") @@ -1369,6 +1375,7 @@ def process_message(message: ScheduleMessageItem): "session_id": session_id, "mem_cube_id": mem_cube_id, }, + user_context=user_context, ) # Add pref_mem to vector db pref_ids = pref_mem.add(pref_memories) diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index c7f270f19..d7ef0ea24 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -9,6 +9,7 @@ from memos.log import get_logger from memos.mem_scheduler.general_modules.misc import DictConversionMixin from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.types.general_types import UserContext from .general_schemas import NOT_INITIALIZED @@ -55,6 +56,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): description="Optional business-level task ID. Multiple items can share the same task_id.", ) chat_history: list | None = Field(default=None, description="user chat history") + user_context: UserContext | None = Field(default=None, description="user context") # Pydantic V2 model configuration model_config = ConfigDict( @@ -91,6 +93,9 @@ def to_dict(self) -> dict: "user_name": self.user_name, "task_id": self.task_id if self.task_id is not None else "", "chat_history": self.chat_history if self.chat_history is not None else [], + "user_context": self.user_context.model_dump(exclude_none=True) + if self.user_context + else None, } @classmethod @@ -107,6 +112,9 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": user_name=data.get("user_name"), task_id=data.get("task_id"), chat_history=data.get("chat_history"), + user_context=UserContext.model_validate(data.get("user_context")) + if data.get("user_context") + else None, ) diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index aa4f3cb44..e696e82d4 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from concurrent.futures import as_completed from datetime import datetime -from typing import Any +from typing import TYPE_CHECKING, Any from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger @@ -25,6 +25,10 @@ from memos.types import MessageList +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + logger = get_logger(__name__) @@ -177,6 +181,7 @@ def extract( msg_type: str, info: dict[str, Any], max_workers: int = 10, + **kwargs, ) -> list[TextualMemoryItem]: """Extract preference memories based on the messages using thread pool for acceleration.""" chunks: list[MessageList] = [] @@ -186,6 +191,10 @@ def extract( if not chunks: return [] + user_context: UserContext | None = kwargs.get("user_context") + user_context_dict = user_context.model_dump() if user_context else {} + info = {**info, **user_context_dict} + memories = [] with ContextThreadPoolExecutor(max_workers=min(max_workers, len(chunks))) as executor: futures = { diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 78f4d6e28..dba321f55 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -67,7 +67,7 @@ def __init__(self, config: PreferenceTextMemoryConfig): ) def get_memory( - self, messages: list[MessageList], type: str, info: dict[str, Any] + self, messages: list[MessageList], type: str, info: dict[str, Any], **kwargs ) -> list[TextualMemoryItem]: """Get memory based on the messages. Args: @@ -75,7 +75,7 @@ def get_memory( type (str): The type of memory to get. info (dict[str, Any]): The info to get memory. """ - return self.extractor.extract(messages, type, info) + return self.extractor.extract(messages, type, info, **kwargs) def search( self, query: str, top_k: int, info=None, search_filter=None, **kwargs diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py index cc1781f06..db7101744 100644 --- a/src/memos/memories/textual/simple_preference.py +++ b/src/memos/memories/textual/simple_preference.py @@ -40,15 +40,16 @@ def __init__( self.retriever = retriever def get_memory( - self, messages: list[MessageList], type: str, info: dict[str, Any] + self, messages: list[MessageList], type: str, info: dict[str, Any], **kwargs ) -> list[TextualMemoryItem]: """Get memory based on the messages. Args: messages (MessageList): The messages to get memory from. type (str): The type of memory to get. info (dict[str, Any]): The info to get memory. + **kwargs: Additional keyword arguments to pass to the extractor. """ - return self.extractor.extract(messages, type, info) + return self.extractor.extract(messages, type, info, **kwargs) def search( self, query: str, top_k: int, info=None, search_filter=None, **kwargs diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index dbe6912c9..72e43d15f 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -363,7 +363,7 @@ def get_by_ids( def get_all( self, - user_name: str, + user_name: str | None = None, user_id: str | None = None, page: int | None = None, page_size: int | None = None, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py index a5fc7e049..cb77d2243 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py @@ -163,14 +163,14 @@ def keyword_search( results = [] - # 2. Try seach_by_keywords_tfidf (PolarDB specific) - if hasattr(self.graph_db, "seach_by_keywords_tfidf"): + # 2. Try search_by_keywords_tfidf (PolarDB specific) + if hasattr(self.graph_db, "search_by_keywords_tfidf"): try: - results = self.graph_db.seach_by_keywords_tfidf( + results = self.graph_db.search_by_keywords_tfidf( query_words=keywords, user_name=user_name, filter=search_filter ) except Exception as e: - logger.warning(f"[PreUpdateRetriever] seach_by_keywords_tfidf failed: {e}") + logger.warning(f"[PreUpdateRetriever] search_by_keywords_tfidf failed: {e}") # 3. Fallback to search_by_fulltext if not results and hasattr(self.graph_db, "search_by_fulltext"): diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 15e3b1bb9..4550c4f60 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -72,6 +72,8 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: user_id=add_req.user_id, mem_cube_id=self.cube_id, session_id=add_req.session_id or "default_session", + manager_user_id=add_req.manager_user_id, + project_id=add_req.project_id, ) target_session_id = add_req.session_id or "default_session" @@ -555,6 +557,7 @@ def _schedule_memory_tasks( user_name=self.cube_id, info=add_req.info, chat_history=add_req.chat_history, + user_context=user_context, ) self.mem_scheduler.submit_messages(messages=[message_item_read]) self.logger.info( @@ -625,6 +628,7 @@ def _process_pref_mem( info=add_req.info, user_name=self.cube_id, task_id=add_req.task_id, + user_context=user_context, ) self.mem_scheduler.submit_messages(messages=[message_item_pref]) self.logger.info(f"[SingleCubeView] cube={self.cube_id} Submitted PREF_ADD async") @@ -644,6 +648,7 @@ def _process_pref_mem( "session_id": target_session_id, "mem_cube_id": user_context.mem_cube_id, }, + user_context=user_context, ) pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local) self.logger.info( @@ -810,6 +815,7 @@ def _process_text_mem( mode=extract_mode, user_name=user_context.mem_cube_id, chat_history=add_req.chat_history, + user_context=user_context, ) self.logger.info( f"Time for get_memory in extract mode {extract_mode}: {time.time() - init_time}" diff --git a/src/memos/types/general_types.py b/src/memos/types/general_types.py index 44c75ec02..8234caf8b 100644 --- a/src/memos/types/general_types.py +++ b/src/memos/types/general_types.py @@ -10,7 +10,7 @@ from enum import Enum from typing import Literal, NewType, TypeAlias -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from typing_extensions import TypedDict from memos.memories.activation.item import ActivationMemoryItem @@ -149,3 +149,7 @@ class UserContext(BaseModel): mem_cube_id: str | None = None session_id: str | None = None operation: list[PermissionDict] | None = None + manager_user_id: str | None = None + project_id: str | None = None + + model_config = ConfigDict(extra="allow")