diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 8292e027b..fe250bbd6 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -7,6 +7,7 @@ import asyncio import json +import os import re import time import traceback @@ -23,6 +24,7 @@ APIADDRequest, APIChatCompleteRequest, APISearchRequest, + ChatBusinessRequest, ChatPlaygroundRequest, ChatRequest, ) @@ -759,6 +761,195 @@ def generate_chat_response() -> Generator[str, None, None]: ) raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + def handle_chat_stream_for_business_user( + self, chat_req: ChatBusinessRequest + ) -> StreamingResponse: + """Chat API for business user.""" + self.logger.info(f"[ChatBusinessHandler] Chat Req is: {chat_req}") + + # Validate business_key permission + business_chat_keys = os.environ.get("BUSINESS_CHAT_KEYS", "[]") + allowed_keys = json.loads(business_chat_keys) + + if not allowed_keys or chat_req.business_key not in allowed_keys: + self.logger.warning( + f"[ChatBusinessHandler] Unauthorized access attempt with business_key: {chat_req.business_key}" + ) + raise HTTPException( + status_code=403, + detail="Access denied: Invalid business_key. You do not have permission to use this service.", + ) + + try: + + def generate_chat_response() -> Generator[str, None, None]: + """Generate chat stream response as SSE stream.""" + try: + if chat_req.need_search: + # Resolve readable cube IDs (for search) + readable_cube_ids = chat_req.readable_cube_ids or ( + [chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id] + ) + + search_req = APISearchRequest( + query=chat_req.query, + user_id=chat_req.user_id, + readable_cube_ids=readable_cube_ids, + mode=chat_req.mode, + internet_search=chat_req.internet_search, + top_k=chat_req.top_k, + chat_history=chat_req.history, + session_id=chat_req.session_id, + include_preference=chat_req.include_preference, + pref_top_k=chat_req.pref_top_k, + filter=chat_req.filter, + ) + + search_response = self.search_handler.handle_search_memories(search_req) + + # Extract memories from search results + memories_list = [] + if search_response.data and search_response.data.get("text_mem"): + text_mem_results = search_response.data["text_mem"] + if text_mem_results and text_mem_results[0].get("memories"): + memories_list = text_mem_results[0]["memories"] + + # Drop internet memories forced + memories_list = [ + mem + for mem in memories_list + if mem.get("metadata", {}).get("memory_type") != "OuterMemory" + ] + + # Filter memories by threshold + filtered_memories = self._filter_memories_by_threshold(memories_list) + + # Step 2: Build system prompt with memories + system_prompt = self._build_system_prompt( + query=chat_req.query, + memories=filtered_memories, + pref_string=search_response.data.get("pref_string", ""), + base_prompt=chat_req.system_prompt, + ) + + self.logger.info( + f"[ChatBusinessHandler] chat stream user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, " + f"current_system_prompt: {system_prompt}" + ) + else: + system_prompt = self._build_system_prompt( + query=chat_req.query, + memories=None, + pref_string=None, + base_prompt=chat_req.system_prompt, + ) + + # Prepare messages + history_info = chat_req.history[-20:] if chat_req.history else [] + current_messages = [ + {"role": "system", "content": system_prompt}, + *history_info, + {"role": "user", "content": chat_req.query}, + ] + + # Step 3: Generate streaming response from LLM + if ( + chat_req.model_name_or_path + and chat_req.model_name_or_path not in self.chat_llms + ): + raise HTTPException( + status_code=400, + detail=f"Model {chat_req.model_name_or_path} not suport, choose from {list(self.chat_llms.keys())}", + ) + + model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys())) + self.logger.info(f"[ChatBusinessHandler] Chat Stream Model: {model}") + + start = time.time() + response_stream = self.chat_llms[model].generate_stream( + current_messages, model_name_or_path=model + ) + + # Stream the response + buffer = "" + full_response = "" + in_think = False + + for chunk in response_stream: + if chunk == "": + in_think = True + continue + if chunk == "": + in_think = False + continue + + if in_think: + chunk_data = f"data: {json.dumps({'type': 'reasoning', 'data': chunk}, ensure_ascii=False)}\n\n" + yield chunk_data + continue + + buffer += chunk + full_response += chunk + + chunk_data = f"data: {json.dumps({'type': 'text', 'data': chunk}, ensure_ascii=False)}\n\n" + yield chunk_data + + end = time.time() + self.logger.info( + f"[ChatBusinessHandler] Chat Stream Time: {end - start} seconds" + ) + + self.logger.info( + f"[ChatBusinessHandler] Chat Stream LLM Input: {json.dumps(current_messages, ensure_ascii=False)} Chat Stream LLM Response: {full_response}" + ) + + current_messages.append({"role": "assistant", "content": full_response}) + if chat_req.add_message_on_answer: + # Resolve writable cube IDs (for add) + writable_cube_ids = chat_req.writable_cube_ids or ( + [chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id] + ) + start = time.time() + self._start_add_to_memory( + user_id=chat_req.user_id, + writable_cube_ids=writable_cube_ids, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=full_response, + async_mode="async", + ) + end = time.time() + self.logger.info( + f"[ChatBusinessHandler] Chat Stream Add Time: {end - start} seconds" + ) + except Exception as e: + self.logger.error( + f"[ChatBusinessHandler] Error in chat stream: {e}", exc_info=True + ) + error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" + yield error_data + + return StreamingResponse( + generate_chat_response(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "*", + "Access-Control-Allow-Methods": "*", + }, + ) + + except ValueError as err: + raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err + except Exception as err: + self.logger.error( + f"[ChatBusinessHandler] Failed to start chat stream: {traceback.format_exc()}" + ) + raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + def _dedup_and_supplement_memories( self, first_filtered_memories: list, second_filtered_memories: list ) -> list: diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index e8bc5b640..2b9c137ca 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -109,6 +109,7 @@ def handle_get_subgraph( query: str, top_k: int, naive_mem_cube: Any, + search_type: Literal["embedding", "fulltext"], ) -> MemoryResponse: """ Main handler for getting memory subgraph based on query. @@ -128,7 +129,7 @@ def handle_get_subgraph( try: # Get relevant subgraph from text memory memories = naive_mem_cube.text_mem.get_relevant_subgraph( - query, top_k=top_k, user_name=mem_cube_id + query, top_k=top_k, user_name=mem_cube_id, search_type=search_type ) # Format and convert to tree structure @@ -139,7 +140,7 @@ def handle_get_subgraph( "UserMemory": 0.40, } tree_result, node_type_count = convert_graph_to_tree_forworkmem( - memories_cleaned, target_node_count=150, type_ratios=custom_type_ratios + memories_cleaned, target_node_count=200, type_ratios=custom_type_ratios ) # Ensure all node IDs are unique in the tree structure tree_result = ensure_unique_tree_ids(tree_result) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index d8fa784a3..d11573610 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -46,6 +46,7 @@ class GetMemoryPlaygroundRequest(BaseRequest): ) mem_cube_ids: list[str] | None = Field(None, description="Cube IDs") search_query: str | None = Field(None, description="Search query") + search_type: Literal["embedding", "fulltext"] = Field("fulltext", description="Search type") # Start API Models @@ -167,6 +168,13 @@ class ChatPlaygroundRequest(ChatRequest): ) +class ChatBusinessRequest(ChatRequest): + """Request model for chat operations for business user.""" + + business_key: str = Field(..., description="Business User Key") + need_search: bool = Field(False, description="Whether to need search before chat") + + class ChatCompleteRequest(BaseRequest): """Request model for chat operations. will (Deprecated), instead use APIChatCompleteRequest.""" @@ -1217,3 +1225,26 @@ class ExistMemCubeIdRequest(BaseRequest): class ExistMemCubeIdResponse(BaseResponse[dict[str, bool]]): """Response model for checking if mem cube id exists.""" + + +class DeleteMemoryByRecordIdRequest(BaseRequest): + """Request model for deleting memory by record id.""" + + mem_cube_id: str = Field(..., description="Mem cube ID") + record_id: str = Field(..., description="Record ID") + hard_delete: bool = Field(False, description="Hard delete") + + +class DeleteMemoryByRecordIdResponse(BaseResponse[dict]): + """Response model for deleting memory by record id.""" + + +class RecoverMemoryByRecordIdRequest(BaseRequest): + """Request model for recovering memory by record id.""" + + mem_cube_id: str = Field(..., description="Mem cube ID") + delete_record_id: str = Field(..., description="Delete record ID") + + +class RecoverMemoryByRecordIdResponse(BaseResponse[dict]): + """Response model for recovering memory by record id.""" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 736c328ac..243fb36cd 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -29,8 +29,11 @@ APIChatCompleteRequest, APIFeedbackRequest, APISearchRequest, + ChatBusinessRequest, ChatPlaygroundRequest, ChatRequest, + DeleteMemoryByRecordIdRequest, + DeleteMemoryByRecordIdResponse, DeleteMemoryRequest, DeleteMemoryResponse, ExistMemCubeIdRequest, @@ -41,6 +44,8 @@ GetUserNamesByMemoryIdsRequest, GetUserNamesByMemoryIdsResponse, MemoryResponse, + RecoverMemoryByRecordIdRequest, + RecoverMemoryByRecordIdResponse, SearchResponse, StatusResponse, SuggestionRequest, @@ -290,8 +295,9 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest): memory_req.mem_cube_ids[0] if memory_req.mem_cube_ids else memory_req.user_id ), query=memory_req.search_query, - top_k=20, + top_k=200, naive_mem_cube=naive_mem_cube, + search_type=memory_req.search_type, ) else: return handlers.memory_handler.handle_get_all_memories( @@ -394,9 +400,59 @@ def get_user_names_by_memory_ids(request: GetUserNamesByMemoryIdsRequest): response_model=ExistMemCubeIdResponse, ) def exist_mem_cube_id(request: ExistMemCubeIdRequest): - """Check if mem cube id exists.""" + """(inner) Check if mem cube id exists.""" return ExistMemCubeIdResponse( code=200, message="Successfully", data=graph_db.exist_user_name(user_name=request.mem_cube_id), ) + + +@router.post("/chat/stream/business_user", summary="Chat with MemOS for business user") +def chat_stream_business_user(chat_req: ChatBusinessRequest): + """(inner) Chat with MemOS for a specific business user. Returns SSE stream.""" + if chat_handler is None: + raise HTTPException( + status_code=503, detail="Chat service is not available. Chat handler not initialized." + ) + + return chat_handler.handle_chat_stream_for_business_user(chat_req) + + +@router.post( + "/delete_memory_by_record_id", + summary="Delete memory by record id", + response_model=DeleteMemoryByRecordIdResponse, +) +def delete_memory_by_record_id(memory_req: DeleteMemoryByRecordIdRequest): + """(inner) Delete memory nodes by mem_cube_id (user_name) and delete_record_id. Record id is inner field, just for delete and recover memory, not for user to set.""" + graph_db.delete_node_by_mem_cube_id( + mem_cube_id=memory_req.mem_cube_id, + delete_record_id=memory_req.record_id, + hard_delete=memory_req.hard_delete, + ) + + return DeleteMemoryByRecordIdResponse( + code=200, + message="Called Successfully", + data={"status": "success"}, + ) + + +@router.post( + "/recover_memory_by_record_id", + summary="Recover memory by record id", + response_model=RecoverMemoryByRecordIdResponse, +) +def recover_memory_by_record_id(memory_req: RecoverMemoryByRecordIdRequest): + """(inner) Recover memory nodes by mem_cube_id (user_name) and delete_record_id. Record id is inner field, just for delete and recover memory, not for user to set.""" + graph_db.recover_memory_by_mem_cube_id( + mem_cube_id=memory_req.mem_cube_id, + delete_record_id=memory_req.delete_record_id, + ) + + return RecoverMemoryByRecordIdResponse( + code=200, + message="Called Successfully", + data={"status": "success"}, + ) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 2bd2e5a46..8f1955d16 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1953,35 +1953,35 @@ def exist_user_name(self, user_name: str) -> dict[str, bool]: def delete_node_by_mem_cube_id( self, - mem_kube_id: dict | None = None, + mem_cube_id: dict | None = None, delete_record_id: dict | None = None, - deleted_type: bool = False, + hard_delete: bool = False, ) -> int: """ - Delete nodes by mem_kube_id (user_name) and delete_record_id. + Delete nodes by mem_cube_id (user_name) and delete_record_id. Args: - mem_kube_id: The mem_kube_id which corresponds to user_name in the table. + mem_cube_id: The mem_cube_id which corresponds to user_name in the table. Can be dict or str. If dict, will extract the value. delete_record_id: The delete_record_id to match. Can be dict or str. If dict, will extract the value. - deleted_type: If True, performs hard delete (directly deletes records). + hard_delete: If True, performs hard delete (directly deletes records). If False, performs soft delete (updates status to 'deleted' and sets delete_record_id and delete_time). Returns: int: Number of nodes deleted or updated. """ # Handle dict type parameters (extract value if dict) - if isinstance(mem_kube_id, dict): + if isinstance(mem_cube_id, dict): # Try to get a value from dict, use first value if multiple - mem_kube_id = next(iter(mem_kube_id.values())) if mem_kube_id else None + mem_cube_id = next(iter(mem_cube_id.values())) if mem_cube_id else None if isinstance(delete_record_id, dict): delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None # Validate required parameters - if not mem_kube_id: - logger.warning("[delete_node_by_mem_cube_id] mem_kube_id is required but not provided") + if not mem_cube_id: + logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided") return 0 if not delete_record_id: @@ -1991,27 +1991,27 @@ def delete_node_by_mem_cube_id( return 0 # Convert to string if needed - mem_kube_id = str(mem_kube_id) if mem_kube_id else None + mem_cube_id = str(mem_cube_id) if mem_cube_id else None delete_record_id = str(delete_record_id) if delete_record_id else None logger.info( - f"[delete_node_by_mem_cube_id] mem_kube_id={mem_kube_id}, " - f"delete_record_id={delete_record_id}, deleted_type={deleted_type}" + f"[delete_node_by_mem_cube_id] mem_cube_id={mem_cube_id}, " + f"delete_record_id={delete_record_id}, hard_delete={hard_delete}" ) try: with self.driver.session(database=self.db_name) as session: - if deleted_type: - # Hard delete: WHERE user_name = mem_kube_id AND delete_record_id = $delete_record_id + if hard_delete: + # Hard delete: WHERE user_name = mem_cube_id AND delete_record_id = $delete_record_id query = """ MATCH (n:Memory) - WHERE n.user_name = $mem_kube_id AND n.delete_record_id = $delete_record_id + WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id DETACH DELETE n """ logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {query}") result = session.run( - query, mem_kube_id=mem_kube_id, delete_record_id=delete_record_id + query, mem_cube_id=mem_cube_id, delete_record_id=delete_record_id ) summary = result.consume() deleted_count = summary.counters.nodes_deleted if summary.counters else 0 @@ -2019,12 +2019,12 @@ def delete_node_by_mem_cube_id( logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") return deleted_count else: - # Soft delete: WHERE user_name = mem_kube_id (only user_name condition) + # Soft delete: WHERE user_name = mem_cube_id (only user_name condition) current_time = datetime.utcnow().isoformat() query = """ MATCH (n:Memory) - WHERE n.user_name = $mem_kube_id + WHERE n.user_name = $mem_cube_id SET n.status = $status, n.delete_record_id = $delete_record_id, n.delete_time = $delete_time @@ -2034,7 +2034,7 @@ def delete_node_by_mem_cube_id( result = session.run( query, - mem_kube_id=mem_kube_id, + mem_cube_id=mem_cube_id, status="deleted", delete_record_id=delete_record_id, delete_time=current_time, @@ -2053,38 +2053,38 @@ def delete_node_by_mem_cube_id( ) raise - def recover_memory_by_mem_kube_id( + def recover_memory_by_mem_cube_id( self, - mem_kube_id: str | None = None, + mem_cube_id: str | None = None, delete_record_id: str | None = None, ) -> int: """ - Recover memory nodes by mem_kube_id (user_name) and delete_record_id. + Recover memory nodes by mem_cube_id (user_name) and delete_record_id. This function updates the status to 'activated', and clears delete_record_id and delete_time. Args: - mem_kube_id: The mem_kube_id which corresponds to user_name in the table. + mem_cube_id: The mem_cube_id which corresponds to user_name in the table. delete_record_id: The delete_record_id to match. Returns: int: Number of nodes recovered (updated). """ # Validate required parameters - if not mem_kube_id: + if not mem_cube_id: logger.warning( - "[recover_memory_by_mem_kube_id] mem_kube_id is required but not provided" + "[recover_memory_by_mem_cube_id] mem_cube_id is required but not provided" ) return 0 if not delete_record_id: logger.warning( - "[recover_memory_by_mem_kube_id] delete_record_id is required but not provided" + "[recover_memory_by_mem_cube_id] delete_record_id is required but not provided" ) return 0 logger.info( - f"[recover_memory_by_mem_kube_id] mem_kube_id={mem_kube_id}, " + f"[recover_memory_by_mem_cube_id] mem_cube_id={mem_cube_id}, " f"delete_record_id={delete_record_id}" ) @@ -2092,17 +2092,17 @@ def recover_memory_by_mem_kube_id( with self.driver.session(database=self.db_name) as session: query = """ MATCH (n:Memory) - WHERE n.user_name = $mem_kube_id AND n.delete_record_id = $delete_record_id + WHERE n.user_name = $mem_cube_id AND n.delete_record_id = $delete_record_id SET n.status = $status, n.delete_record_id = $delete_record_id_empty, n.delete_time = $delete_time_empty RETURN count(n) AS updated_count """ - logger.info(f"[recover_memory_by_mem_kube_id] Update query: {query}") + logger.info(f"[recover_memory_by_mem_cube_id] Update query: {query}") result = session.run( query, - mem_kube_id=mem_kube_id, + mem_cube_id=mem_cube_id, delete_record_id=delete_record_id, status="activated", delete_record_id_empty="", @@ -2112,12 +2112,12 @@ def recover_memory_by_mem_kube_id( updated_count = record["updated_count"] if record else 0 logger.info( - f"[recover_memory_by_mem_kube_id] Recovered (updated) {updated_count} nodes" + f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes" ) return updated_count except Exception as e: logger.error( - f"[recover_memory_by_mem_kube_id] Failed to recover nodes: {e}", exc_info=True + f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True ) raise diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 0cdd089e4..18778532f 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -5470,15 +5470,27 @@ def escape_user_name(un: str) -> str: @timed def delete_node_by_mem_cube_id( self, - mem_cube_id: str | None = None, - delete_record_id: str | None = None, + mem_cube_id: dict | None = None, + delete_record_id: dict | None = None, hard_delete: bool = False, ) -> int: - logger.info( - f"delete_node_by_mem_cube_id mem_cube_id:{mem_cube_id}, " - f"delete_record_id:{delete_record_id}, hard_delete:{hard_delete}" - ) + """ + (inner) Delete memory nodes by mem_cube_id (user_name) and delete_record_id. Record id is inner field, just for delete and recover memory, not for user to set. + + Args: + mem_cube_id: The mem_cube_id which corresponds to user_name in the table. + delete_record_id: The delete_record_id to match. + hard_delete: Whether to hard delete the nodes. + """ + # Handle dict type parameters (extract value if dict) + if isinstance(mem_cube_id, dict): + # Try to get a value from dict, use first value if multiple + mem_cube_id = next(iter(mem_cube_id.values())) if mem_cube_id else None + + if isinstance(delete_record_id, dict): + delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None + # Validate required parameters if not mem_cube_id: logger.warning("[delete_node_by_mem_cube_id] mem_cube_id is required but not provided") return 0 @@ -5489,18 +5501,32 @@ def delete_node_by_mem_cube_id( ) return 0 + # Convert to string if needed + mem_cube_id = str(mem_cube_id) if mem_cube_id else None + delete_record_id = str(delete_record_id) if delete_record_id else None + + logger.info( + f"[delete_node_by_mem_cube_id] mem_cube_id={mem_cube_id}, " + f"delete_record_id={delete_record_id}, hard_delete={hard_delete}" + ) + conn = None try: conn = self._get_connection() with conn.cursor() as cursor: + # Build WHERE clause for user_name using parameter binding + # user_name must match mem_cube_id user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + # Prepare parameter for user_name user_name_param = self.format_param_value(mem_cube_id) if hard_delete: + # Hard delete: WHERE user_name = mem_cube_id AND delete_record_id = $delete_record_id delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" where_clause = f"{user_name_condition} AND {delete_record_id_condition}" + # Prepare parameters for WHERE clause (user_name and delete_record_id) where_params = [user_name_param, self.format_param_value(delete_record_id)] delete_query = f""" @@ -5515,39 +5541,40 @@ def delete_node_by_mem_cube_id( logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") return deleted_count else: - delete_time_empty_condition = ( - "(ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) IS NULL " - "OR ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) = '\"\"'::agtype)" - ) - delete_record_id_empty_condition = ( - "(ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) IS NULL " - "OR ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = '\"\"'::agtype)" - ) - where_clause = f"{user_name_condition} AND {delete_time_empty_condition} AND {delete_record_id_empty_condition}" + # Soft delete: WHERE user_name = mem_cube_id (only user_name condition) + where_clause = user_name_condition current_time = datetime.utcnow().isoformat() + # Build update properties JSON with status, delete_time, and delete_record_id + # Use PostgreSQL JSONB merge operator (||) to update properties + # Convert agtype to jsonb, merge with new values, then convert back to agtype update_query = f""" UPDATE "{self.db_name}_graph"."Memory" SET properties = ( properties::jsonb || %s::jsonb - )::text::agtype, - deletetime = %s + )::text::agtype WHERE {where_clause} """ + # Create update JSON with the three fields to update update_properties = { "status": "deleted", "delete_time": current_time, "delete_record_id": delete_record_id, } logger.info( - f"delete_node_by_mem_cube_id Soft delete update_query:{update_query},update_properties:{update_properties},deletetime:{current_time}" + f"[delete_node_by_mem_cube_id] Soft delete update_query: {update_query}" ) - update_params = [json.dumps(update_properties), current_time, user_name_param] + logger.info( + f"[delete_node_by_mem_cube_id] update_properties: {update_properties}" + ) + + # Combine update_properties JSON with user_name parameter (only user_name, no delete_record_id) + update_params = [json.dumps(update_properties), user_name_param] cursor.execute(update_query, update_params) updated_count = cursor.rowcount logger.info( - f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes" + f"[delete_node_by_mem_cube_id] Soft deleted (updated) {updated_count} nodes" ) return updated_count @@ -5566,7 +5593,7 @@ def recover_memory_by_mem_cube_id( delete_record_id: str | None = None, ) -> int: """ - Recover memory nodes by mem_cube_id (user_name) and delete_record_id. + (inner) Recover memory nodes by mem_cube_id (user_name) and delete_record_id. Record id is inner field, just for delete and recover memory, not for user to set. This function updates the status to 'activated', and clears delete_record_id and delete_time. @@ -5582,17 +5609,19 @@ def recover_memory_by_mem_cube_id( ) # Validate required parameters if not mem_cube_id: - logger.warning("recover_memory_by_mem_cube_id mem_cube_id is required but not provided") + logger.warning( + "[recover_memory_by_mem_cube_id] mem_cube_id is required but not provided" + ) return 0 if not delete_record_id: logger.warning( - "recover_memory_by_mem_cube_id delete_record_id is required but not provided" + "[recover_memory_by_mem_cube_id] delete_record_id is required but not provided" ) return 0 logger.info( - f"recover_memory_by_mem_cube_id mem_cube_id={mem_cube_id}, " + f"[recover_memory_by_mem_cube_id] mem_cube_id={mem_cube_id}, " f"delete_record_id={delete_record_id}" ) @@ -5600,15 +5629,19 @@ def recover_memory_by_mem_cube_id( try: conn = self._get_connection() with conn.cursor() as cursor: + # Build WHERE clause for user_name and delete_record_id using parameter binding user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" where_clause = f"{user_name_condition} AND {delete_record_id_condition}" + # Prepare parameters for WHERE clause where_params = [ self.format_param_value(mem_cube_id), self.format_param_value(delete_record_id), ] + # Build update properties: status='activated', delete_record_id='', delete_time='' + # Use PostgreSQL JSONB merge operator (||) to update properties update_properties = { "status": "activated", "delete_record_id": "", @@ -5619,8 +5652,7 @@ def recover_memory_by_mem_cube_id( UPDATE "{self.db_name}_graph"."Memory" SET properties = ( properties::jsonb || %s::jsonb - )::text::agtype, - deletetime = NULL + )::text::agtype WHERE {where_clause} """ @@ -5629,6 +5661,7 @@ def recover_memory_by_mem_cube_id( f"[recover_memory_by_mem_cube_id] update_properties: {update_properties}" ) + # Combine update_properties JSON with where_params update_params = [json.dumps(update_properties), *where_params] cursor.execute(update_query, update_params) updated_count = cursor.rowcount diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index 5e58d23a5..68536da8d 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -64,7 +64,7 @@ def _judge_update_or_add_fast(self, old_msg: str, new_msg: str) -> bool: response = result.get("is_same", False) return response if isinstance(response, bool) else response.lower() == "true" except Exception as e: - logger.error(f"Error in judge_update_or_add: {e}") + logger.warning(f"Error in judge_update_or_add: {e}") # Fallback to simple string comparison return old_msg == new_msg @@ -80,7 +80,7 @@ def _judge_update_or_add_fine(self, new_mem: str, retrieved_mems: str) -> dict[s result = json.loads(response) return result except Exception as e: - logger.error(f"Error in judge_update_or_add_fine: {e}") + logger.warning(f"Error in judge_update_or_add_fine: {e}") return None def _judge_dup_with_text_mem(self, new_pref: MilvusVecDBItem) -> bool: @@ -118,7 +118,7 @@ def _judge_dup_with_text_mem(self, new_pref: MilvusVecDBItem) -> bool: exists = result.get("exists", False) return exists except Exception as e: - logger.error(f"Error in judge_dup_with_text_mem: {e}") + logger.warning(f"Error in judge_dup_with_text_mem: {e}") return False def _judge_update_or_add_trace_op( @@ -135,7 +135,7 @@ def _judge_update_or_add_trace_op( result = json.loads(response) return result except Exception as e: - logger.error(f"Error in judge_update_or_add_trace_op: {e}") + logger.warning(f"Error in judge_update_or_add_trace_op: {e}") return None def _dedup_explicit_pref_by_textual( @@ -156,7 +156,7 @@ def _dedup_explicit_pref_by_textual( try: is_dup_flags[idx] = future.result() except Exception as e: - logger.error( + logger.warning( f"Error in _judge_dup_with_text_mem for pref {new_prefs[idx].id}: {e}" ) is_dup_flags[idx] = False @@ -407,7 +407,7 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | ) except Exception as e: - logger.error(f"Error processing memory {memory.id}: {e}") + logger.warning(f"Error processing memory {memory.id}: {e}") return None def process_memory_batch(self, memories: list[TextualMemoryItem], *args, **kwargs) -> list[str]: @@ -480,7 +480,7 @@ def process_memory_single( added_ids.append(memory_id) except Exception as e: memory = future_to_memory[future] - logger.error(f"Error processing memory {memory.id}: {e}") + logger.warning(f"Error processing memory {memory.id}: {e}") continue return added_ids diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index b556db5d7..ea3d536c4 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -5,14 +5,16 @@ from datetime import datetime from pathlib import Path -from typing import Any +from typing import Any, Literal from memos.configs.memory import TreeTextMemoryConfig from memos.configs.reranker import RerankerConfigFactory +from memos.dependency import require_python_package from memos.embedders.factory import EmbedderFactory, OllamaEmbedder from memos.graph_dbs.factory import GraphStoreFactory, Neo4jGraphDB from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM from memos.log import get_logger +from memos.mem_reader.read_multi_modal.utils import detect_lang from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager @@ -23,6 +25,7 @@ from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager from memos.reranker.factory import RerankerFactory from memos.types import MessageList @@ -223,6 +226,7 @@ def get_relevant_subgraph( depth: int = 2, center_status: str = "activated", user_name: str | None = None, + search_type: Literal["embedding", "fulltext"] = "fulltext", ) -> dict[str, Any]: """ Find and merge the local neighborhood sub-graphs of the top-k @@ -249,13 +253,40 @@ def get_relevant_subgraph( - 'nodes': List of unique nodes (core + neighbors) in the merged subgraph. - 'edges': List of unique edges (as dicts with 'from', 'to', 'type') in the merged subgraph. """ - # Step 1: Embed query - query_embedding = self.embedder.embed([query])[0] + if search_type == "embedding": + # Step 1: Embed query + query_embedding = self.embedder.embed([query])[0] + + # Step 2: Get top-1 similar node + similar_nodes = self.graph_store.search_by_embedding( + query_embedding, top_k=top_k, user_name=user_name + ) + + elif search_type == "fulltext": + + @require_python_package( + import_name="jieba", + install_command="pip install jieba", + install_link="https://github.com/fxsjy/jieba", + ) + def _tokenize_chinese(text): + """split zh jieba""" + import jieba + + stopword_manager = StopwordManager() + tokens = jieba.lcut(text) + tokens = [token.strip() for token in tokens if token.strip()] + return stopword_manager.filter_words(tokens) + + lang = detect_lang(query) + queries = _tokenize_chinese(query) if lang == "zh" else query.split() + + similar_nodes = self.graph_store.search_by_fulltext( + query_words=queries, + top_k=top_k, + user_name=user_name, + ) - # Step 2: Get top-1 similar node - similar_nodes = self.graph_store.search_by_embedding( - query_embedding, top_k=top_k, user_name=user_name - ) if not similar_nodes: logger.info("No similar nodes found for query embedding.") return {"core_id": None, "nodes": [], "edges": []}