From 695e597da794712018807ad640a2b71a59803e8d Mon Sep 17 00:00:00 2001 From: fancy Date: Tue, 3 Feb 2026 19:30:27 +0800 Subject: [PATCH 01/14] refactor(scheduler): modularize handlers and search - extract scheduler handlers into dedicated modules - split retriever into pipelines (search/enhance/rerank/filter) - centralize text search logic for API and scheduler Refs #1003 --- .../mem_scheduler/base_mixins/__init__.py | 9 + .../mem_scheduler/base_mixins/memory_ops.py | 344 ++++ .../mem_scheduler/base_mixins/queue_ops.py | 405 +++++ .../mem_scheduler/base_mixins/web_log_ops.py | 108 ++ src/memos/mem_scheduler/base_scheduler.py | 979 +---------- src/memos/mem_scheduler/general_scheduler.py | 1517 +---------------- src/memos/mem_scheduler/handlers/__init__.py | 8 + .../mem_scheduler/handlers/add_handler.py | 302 ++++ .../mem_scheduler/handlers/answer_handler.py | 52 + src/memos/mem_scheduler/handlers/base.py | 8 + src/memos/mem_scheduler/handlers/context.py | 39 + .../handlers/feedback_handler.py | 178 ++ .../handlers/mem_read_handler.py | 353 ++++ .../handlers/mem_reorganize_handler.py | 245 +++ .../handlers/memory_update_handler.py | 278 +++ .../handlers/pref_add_handler.py | 84 + .../mem_scheduler/handlers/query_handler.py | 70 + src/memos/mem_scheduler/handlers/registry.py | 47 + .../enhancement_pipeline.py | 275 +++ .../memory_manage_modules/filter_pipeline.py | 24 + .../memory_manage_modules/rerank_pipeline.py | 110 ++ .../memory_manage_modules/retriever.py | 546 +----- .../memory_manage_modules/search_pipeline.py | 94 + .../mem_scheduler/optimized_scheduler.py | 41 +- src/memos/multi_mem_cube/single_cube.py | 29 +- src/memos/search/__init__.py | 3 + src/memos/search/search_service.py | 64 + 27 files changed, 3228 insertions(+), 2984 deletions(-) create mode 100644 src/memos/mem_scheduler/base_mixins/__init__.py create mode 100644 src/memos/mem_scheduler/base_mixins/memory_ops.py create mode 100644 src/memos/mem_scheduler/base_mixins/queue_ops.py create mode 100644 src/memos/mem_scheduler/base_mixins/web_log_ops.py create mode 100644 src/memos/mem_scheduler/handlers/__init__.py create mode 100644 src/memos/mem_scheduler/handlers/add_handler.py create mode 100644 src/memos/mem_scheduler/handlers/answer_handler.py create mode 100644 src/memos/mem_scheduler/handlers/base.py create mode 100644 src/memos/mem_scheduler/handlers/context.py create mode 100644 src/memos/mem_scheduler/handlers/feedback_handler.py create mode 100644 src/memos/mem_scheduler/handlers/mem_read_handler.py create mode 100644 src/memos/mem_scheduler/handlers/mem_reorganize_handler.py create mode 100644 src/memos/mem_scheduler/handlers/memory_update_handler.py create mode 100644 src/memos/mem_scheduler/handlers/pref_add_handler.py create mode 100644 src/memos/mem_scheduler/handlers/query_handler.py create mode 100644 src/memos/mem_scheduler/handlers/registry.py create mode 100644 src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py create mode 100644 src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py create mode 100644 src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py create mode 100644 src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py create mode 100644 src/memos/search/__init__.py create mode 100644 src/memos/search/search_service.py diff --git a/src/memos/mem_scheduler/base_mixins/__init__.py b/src/memos/mem_scheduler/base_mixins/__init__.py new file mode 100644 index 000000000..471d30f06 --- /dev/null +++ b/src/memos/mem_scheduler/base_mixins/__init__.py @@ -0,0 +1,9 @@ +from .memory_ops import BaseSchedulerMemoryMixin +from .queue_ops import BaseSchedulerQueueMixin +from .web_log_ops import BaseSchedulerWebLogMixin + +__all__ = [ + "BaseSchedulerMemoryMixin", + "BaseSchedulerQueueMixin", + "BaseSchedulerWebLogMixin", +] diff --git a/src/memos/mem_scheduler/base_mixins/memory_ops.py b/src/memos/mem_scheduler/base_mixins/memory_ops.py new file mode 100644 index 000000000..5ad197a9e --- /dev/null +++ b/src/memos/mem_scheduler/base_mixins/memory_ops.py @@ -0,0 +1,344 @@ +from __future__ import annotations + +from datetime import datetime + +from memos.log import get_logger +from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.mem_scheduler.utils.filter_utils import transform_name_to_key +from memos.memories.activation.kv import KVCacheMemory +from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory +from memos.memories.textual.naive import NaiveTextMemory +from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE +from memos.types.general_types import MemCubeID, UserID + + +logger = get_logger(__name__) + + +class BaseSchedulerMemoryMixin: + def transform_working_memories_to_monitors( + self, query_keywords, memories: list[TextualMemoryItem] + ) -> list[MemoryMonitorItem]: + result = [] + mem_length = len(memories) + for idx, mem in enumerate(memories): + text_mem = mem.memory + mem_key = transform_name_to_key(name=text_mem) + + keywords_score = 0 + if query_keywords and text_mem: + for keyword, count in query_keywords.items(): + keyword_count = text_mem.count(keyword) + if keyword_count > 0: + keywords_score += keyword_count * count + logger.debug( + "Matched keyword '%s' %s times, added %s to keywords_score", + keyword, + keyword_count, + keywords_score, + ) + + sorting_score = mem_length - idx + + mem_monitor = MemoryMonitorItem( + memory_text=text_mem, + tree_memory_item=mem, + tree_memory_item_mapping_key=mem_key, + sorting_score=sorting_score, + keywords_score=keywords_score, + recording_count=1, + ) + result.append(mem_monitor) + + logger.info("Transformed %s memories to monitors", len(result)) + return result + + def replace_working_memory( + self, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube, + original_memory: list[TextualMemoryItem], + new_memory: list[TextualMemoryItem], + ) -> None | list[TextualMemoryItem]: + text_mem_base = mem_cube.text_mem + if isinstance(text_mem_base, TreeTextMemory): + text_mem_base = text_mem_base + + query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] + query_db_manager.sync_with_orm() + + query_history = query_db_manager.obj.get_queries_with_timesort() + + original_count = len(original_memory) + filtered_original_memory = [] + for origin_mem in original_memory: + if "mode:fast" not in origin_mem.metadata.tags: + filtered_original_memory.append(origin_mem) + else: + logger.debug( + "Filtered out memory - ID: %s, Tags: %s", + getattr(origin_mem, "id", "unknown"), + origin_mem.metadata.tags, + ) + filtered_count = original_count - len(filtered_original_memory) + remaining_count = len(filtered_original_memory) + + logger.info( + "Filtering complete. Removed %s memories with tag 'mode:fast'. Remaining memories: %s", + filtered_count, + remaining_count, + ) + original_memory = filtered_original_memory + + memories_with_new_order, rerank_success_flag = ( + self.retriever.process_and_rerank_memories( + queries=query_history, + original_memory=original_memory, + new_memory=new_memory, + top_k=self.top_k, + ) + ) + + logger.info("Filtering memories based on query history: %s queries", len(query_history)) + filtered_memories, filter_success_flag = self.retriever.filter_unrelated_memories( + query_history=query_history, + memories=memories_with_new_order, + ) + + if filter_success_flag: + logger.info( + "Memory filtering completed successfully. Filtered from %s to %s memories", + len(memories_with_new_order), + len(filtered_memories), + ) + memories_with_new_order = filtered_memories + else: + logger.warning( + "Memory filtering failed - keeping all memories as fallback. Original count: %s", + len(memories_with_new_order), + ) + + query_keywords = query_db_manager.obj.get_keywords_collections() + logger.info( + "Processing %s memories with %s query keywords", + len(memories_with_new_order), + len(query_keywords), + ) + new_working_memory_monitors = self.transform_working_memories_to_monitors( + query_keywords=query_keywords, + memories=memories_with_new_order, + ) + + if not rerank_success_flag: + for one in new_working_memory_monitors: + one.sorting_score = 0 + + logger.info("update %s working_memory_monitors", len(new_working_memory_monitors)) + self.monitor.update_working_memory_monitors( + new_working_memory_monitors=new_working_memory_monitors, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + + mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][ + mem_cube_id + ].obj.get_sorted_mem_monitors(reverse=True) + new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors] + + text_mem_base.replace_working_memory(memories=new_working_memories) + + logger.info( + "The working memory has been replaced with %s new memories.", + len(memories_with_new_order), + ) + self.log_working_memory_replacement( + original_memory=original_memory, + new_memory=new_working_memories, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + log_func_callback=self._submit_web_logs, + ) + elif isinstance(text_mem_base, NaiveTextMemory): + logger.info( + "NaiveTextMemory: Updating working memory monitors with %s candidates.", + len(new_memory), + ) + + query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] + query_db_manager.sync_with_orm() + query_keywords = query_db_manager.obj.get_keywords_collections() + + new_working_memory_monitors = self.transform_working_memories_to_monitors( + query_keywords=query_keywords, + memories=new_memory, + ) + + self.monitor.update_working_memory_monitors( + new_working_memory_monitors=new_working_memory_monitors, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + memories_with_new_order = new_memory + else: + logger.error("memory_base is not supported") + memories_with_new_order = new_memory + + return memories_with_new_order + + def update_activation_memory( + self, + new_memories: list[str | TextualMemoryItem], + label: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube, + ) -> None: + if len(new_memories) == 0: + logger.error("update_activation_memory: new_memory is empty.") + return + if isinstance(new_memories[0], TextualMemoryItem): + new_text_memories = [mem.memory for mem in new_memories] + elif isinstance(new_memories[0], str): + new_text_memories = new_memories + else: + logger.error("Not Implemented.") + return + + try: + if isinstance(mem_cube.act_mem, VLLMKVCacheMemory): + act_mem: VLLMKVCacheMemory = mem_cube.act_mem + elif isinstance(mem_cube.act_mem, KVCacheMemory): + act_mem = mem_cube.act_mem + else: + logger.error("Not Implemented.") + return + + new_text_memory = MEMORY_ASSEMBLY_TEMPLATE.format( + memory_text="".join( + [ + f"{i + 1}. {sentence.strip()}\n" + for i, sentence in enumerate(new_text_memories) + if sentence.strip() + ] + ) + ) + + original_cache_items: list[VLLMKVCacheItem] = act_mem.get_all() + original_text_memories = [] + if len(original_cache_items) > 0: + pre_cache_item: VLLMKVCacheItem = original_cache_items[-1] + original_text_memories = pre_cache_item.records.text_memories + original_composed_text_memory = pre_cache_item.records.composed_text_memory + if original_composed_text_memory == new_text_memory: + logger.warning( + "Skipping memory update - new composition matches existing cache: %s", + new_text_memory[:50] + "..." if len(new_text_memory) > 50 else new_text_memory, + ) + return + act_mem.delete_all() + + cache_item = act_mem.extract(new_text_memory) + cache_item.records.text_memories = new_text_memories + cache_item.records.timestamp = get_utc_now() + + act_mem.add([cache_item]) + act_mem.dump(self.act_mem_dump_path) + + self.log_activation_memory_update( + original_text_memories=original_text_memories, + new_text_memories=new_text_memories, + label=label, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + log_func_callback=self._submit_web_logs, + ) + + except Exception as e: + logger.error("MOS-based activation memory update failed: %s", e, exc_info=True) + + def update_activation_memory_periodically( + self, + interval_seconds: int, + label: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube, + ): + try: + if ( + self.monitor.last_activation_mem_update_time == datetime.min + or self.monitor.timed_trigger( + last_time=self.monitor.last_activation_mem_update_time, + interval_seconds=interval_seconds, + ) + ): + logger.info( + "Updating activation memory for user %s and mem_cube %s", + user_id, + mem_cube_id, + ) + + if ( + user_id not in self.monitor.working_memory_monitors + or mem_cube_id not in self.monitor.working_memory_monitors[user_id] + or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories) + == 0 + ): + logger.warning( + "No memories found in working_memory_monitors, activation memory update is skipped" + ) + return + + self.monitor.update_activation_memory_monitors( + user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube + ) + + activation_db_manager = self.monitor.activation_memory_monitors[user_id][mem_cube_id] + activation_db_manager.sync_with_orm() + new_activation_memories = [ + m.memory_text for m in activation_db_manager.obj.memories + ] + + logger.info( + "Collected %s new memory entries for processing", + len(new_activation_memories), + ) + for i, memory in enumerate(new_activation_memories[:5], 1): + logger.info( + "Part of New Activation Memorires | %s/%s: %s", + i, + len(new_activation_memories), + memory[:20], + ) + + self.update_activation_memory( + new_memories=new_activation_memories, + label=label, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + + self.monitor.last_activation_mem_update_time = get_utc_now() + + logger.debug( + "Activation memory update completed at %s", + self.monitor.last_activation_mem_update_time, + ) + + else: + logger.info( + "Skipping update - %s second interval not yet reached. Last update time is %s and now is %s", + interval_seconds, + self.monitor.last_activation_mem_update_time, + get_utc_now(), + ) + except Exception as e: + logger.error("Error in update_activation_memory_periodically: %s", e, exc_info=True) diff --git a/src/memos/mem_scheduler/base_mixins/queue_ops.py b/src/memos/mem_scheduler/base_mixins/queue_ops.py new file mode 100644 index 000000000..ffe230c84 --- /dev/null +++ b/src/memos/mem_scheduler/base_mixins/queue_ops.py @@ -0,0 +1,405 @@ +from __future__ import annotations + +import multiprocessing +import time + +from collections.abc import Callable +from contextlib import suppress +from datetime import datetime, timezone + +from memos.context.context import ContextThread, RequestContext, get_current_context, get_current_trace_id, set_request_context +from memos.log import get_logger +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.general_schemas import STARTUP_BY_PROCESS +from memos.mem_scheduler.schemas.task_schemas import TaskPriorityLevel +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube +from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso + + +logger = get_logger(__name__) + + +class BaseSchedulerQueueMixin: + def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): + if isinstance(messages, ScheduleMessageItem): + messages = [messages] + + if not messages: + return + + current_trace_id = get_current_trace_id() + + immediate_msgs: list[ScheduleMessageItem] = [] + queued_msgs: list[ScheduleMessageItem] = [] + + for msg in messages: + if current_trace_id: + msg.trace_id = current_trace_id + + with suppress(Exception): + self.metrics.task_enqueued(user_id=msg.user_id, task_type=msg.label) + + if getattr(msg, "timestamp", None) is None: + msg.timestamp = get_utc_now() + + if self.status_tracker: + try: + self.status_tracker.task_submitted( + task_id=msg.item_id, + user_id=msg.user_id, + task_type=msg.label, + mem_cube_id=msg.mem_cube_id, + business_task_id=msg.task_id, + ) + except Exception: + logger.warning("status_tracker.task_submitted failed", exc_info=True) + + if self.disabled_handlers and msg.label in self.disabled_handlers: + logger.info("Skipping disabled handler: %s - %s", msg.label, msg.content) + continue + + task_priority = self.orchestrator.get_task_priority(task_label=msg.label) + if task_priority == TaskPriorityLevel.LEVEL_1: + immediate_msgs.append(msg) + else: + queued_msgs.append(msg) + + if immediate_msgs: + for m in immediate_msgs: + emit_monitor_event( + "enqueue", + m, + { + "enqueue_ts": to_iso(getattr(m, "timestamp", None)), + "event_duration_ms": 0, + "total_duration_ms": 0, + }, + ) + + for m in immediate_msgs: + try: + now = time.time() + enqueue_ts_obj = getattr(m, "timestamp", None) + enqueue_epoch = None + if isinstance(enqueue_ts_obj, int | float): + enqueue_epoch = float(enqueue_ts_obj) + elif hasattr(enqueue_ts_obj, "timestamp"): + dt = enqueue_ts_obj + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + enqueue_epoch = dt.timestamp() + + queue_wait_ms = None + if enqueue_epoch is not None: + queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 + + object.__setattr__(m, "_dequeue_ts", now) + emit_monitor_event( + "dequeue", + m, + { + "enqueue_ts": to_iso(enqueue_ts_obj), + "dequeue_ts": datetime.fromtimestamp(now, tz=timezone.utc).isoformat(), + "queue_wait_ms": queue_wait_ms, + "event_duration_ms": queue_wait_ms, + "total_duration_ms": queue_wait_ms, + }, + ) + self.metrics.task_dequeued(user_id=m.user_id, task_type=m.label) + except Exception: + logger.debug("Failed to emit dequeue for immediate task", exc_info=True) + + user_cube_groups = group_messages_by_user_and_mem_cube(immediate_msgs) + for user_id, cube_groups in user_cube_groups.items(): + for mem_cube_id, user_cube_msgs in cube_groups.items(): + label_groups: dict[str, list[ScheduleMessageItem]] = {} + for m in user_cube_msgs: + label_groups.setdefault(m.label, []).append(m) + + for label, msgs_by_label in label_groups.items(): + handler = self.dispatcher.handlers.get( + label, self.dispatcher._default_message_handler + ) + self.dispatcher.execute_task( + user_id=user_id, + mem_cube_id=mem_cube_id, + task_label=label, + msgs=msgs_by_label, + handler_call_back=handler, + ) + + if queued_msgs: + self.memos_message_queue.submit_messages(messages=queued_msgs) + + def _message_consumer(self) -> None: + while self._running: + try: + if self.enable_parallel_dispatch and self.dispatcher: + running_tasks = self.dispatcher.get_running_task_count() + if running_tasks >= self.dispatcher.max_workers: + time.sleep(self._consume_interval) + continue + + messages = self.memos_message_queue.get_messages(batch_size=self.consume_batch) + + if messages: + now = time.time() + for msg in messages: + prev_context = get_current_context() + try: + msg_context = RequestContext( + trace_id=msg.trace_id, + user_name=msg.user_name, + ) + set_request_context(msg_context) + + enqueue_ts_obj = getattr(msg, "timestamp", None) + enqueue_epoch = None + if isinstance(enqueue_ts_obj, int | float): + enqueue_epoch = float(enqueue_ts_obj) + elif hasattr(enqueue_ts_obj, "timestamp"): + dt = enqueue_ts_obj + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + enqueue_epoch = dt.timestamp() + + queue_wait_ms = None + if enqueue_epoch is not None: + queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 + + object.__setattr__(msg, "_dequeue_ts", now) + emit_monitor_event( + "dequeue", + msg, + { + "enqueue_ts": to_iso(enqueue_ts_obj), + "dequeue_ts": datetime.fromtimestamp(now, tz=timezone.utc).isoformat(), + "queue_wait_ms": queue_wait_ms, + "event_duration_ms": queue_wait_ms, + "total_duration_ms": queue_wait_ms, + }, + ) + self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label) + finally: + set_request_context(prev_context) + try: + with suppress(Exception): + if messages: + self.dispatcher.on_messages_enqueued(messages) + + self.dispatcher.dispatch(messages) + except Exception as e: + logger.error("Error dispatching messages: %s", e) + + time.sleep(self._consume_interval) + + except Exception as e: + if "No messages available in Redis queue" not in str(e): + logger.error("Unexpected error in message consumer: %s", e, exc_info=True) + time.sleep(self._consume_interval) + + def _monitor_loop(self): + while self._running: + try: + q_sizes = self.memos_message_queue.qsize() + + if not isinstance(q_sizes, dict): + continue + + for stream_key, queue_length in q_sizes.items(): + if stream_key == "total_size": + continue + + parts = stream_key.split(":") + if len(parts) >= 3: + user_id = parts[-3] + self.metrics.update_queue_length(queue_length, user_id) + else: + if ":" not in stream_key: + self.metrics.update_queue_length(queue_length, stream_key) + + except Exception as e: + logger.error("Error in metrics monitor loop: %s", e, exc_info=True) + + time.sleep(15) + + def start(self) -> None: + if self.enable_parallel_dispatch: + logger.info( + "Initializing dispatcher thread pool with %s workers", + self.thread_pool_max_workers, + ) + + self.start_consumer() + self.start_background_monitor() + + def start_background_monitor(self): + if self._monitor_thread and self._monitor_thread.is_alive(): + return + self._monitor_thread = ContextThread( + target=self._monitor_loop, daemon=True, name="SchedulerMetricsMonitor" + ) + self._monitor_thread.start() + logger.info("Scheduler metrics monitor thread started.") + + def start_consumer(self) -> None: + if self._running: + logger.warning("Memory Scheduler consumer is already running") + return + + self._running = True + + if self.scheduler_startup_mode == STARTUP_BY_PROCESS: + self._consumer_process = multiprocessing.Process( + target=self._message_consumer, + daemon=True, + name="MessageConsumerProcess", + ) + self._consumer_process.start() + logger.info("Message consumer process started") + else: + self._consumer_thread = ContextThread( + target=self._message_consumer, + daemon=True, + name="MessageConsumerThread", + ) + self._consumer_thread.start() + logger.info("Message consumer thread started") + + def stop_consumer(self) -> None: + if not self._running: + logger.warning("Memory Scheduler consumer is not running") + return + + self._running = False + + if self.scheduler_startup_mode == STARTUP_BY_PROCESS and self._consumer_process: + if self._consumer_process.is_alive(): + self._consumer_process.join(timeout=5.0) + if self._consumer_process.is_alive(): + logger.warning("Consumer process did not stop gracefully, terminating...") + self._consumer_process.terminate() + self._consumer_process.join(timeout=2.0) + if self._consumer_process.is_alive(): + logger.error("Consumer process could not be terminated") + else: + logger.info("Consumer process terminated") + else: + logger.info("Consumer process stopped") + self._consumer_process = None + elif self._consumer_thread and self._consumer_thread.is_alive(): + self._consumer_thread.join(timeout=5.0) + if self._consumer_thread.is_alive(): + logger.warning("Consumer thread did not stop gracefully") + else: + logger.info("Consumer thread stopped") + self._consumer_thread = None + + logger.info("Memory Scheduler consumer stopped") + + def stop(self) -> None: + if not self._running: + logger.warning("Memory Scheduler is not running") + return + + self.stop_consumer() + + if self._monitor_thread: + self._monitor_thread.join(timeout=2.0) + + if self.dispatcher: + logger.info("Shutting down dispatcher...") + self.dispatcher.shutdown() + + if self.dispatcher_monitor: + logger.info("Shutting down monitor...") + self.dispatcher_monitor.stop() + + @property + def handlers(self) -> dict[str, Callable]: + if not self.dispatcher: + logger.warning("Dispatcher is not initialized, returning empty handlers dict") + return {} + + return self.dispatcher.handlers + + def register_handlers(self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]]) -> None: + if not self.dispatcher: + logger.warning("Dispatcher is not initialized, cannot register handlers") + return + + self.dispatcher.register_handlers(handlers) + + def unregister_handlers(self, labels: list[str]) -> dict[str, bool]: + if not self.dispatcher: + logger.warning("Dispatcher is not initialized, cannot unregister handlers") + return dict.fromkeys(labels, False) + + return self.dispatcher.unregister_handlers(labels) + + def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, dict]: + if not self.dispatcher: + logger.warning("Dispatcher is not initialized, returning empty tasks dict") + return {} + + running_tasks = self.dispatcher.get_running_tasks(filter_func=filter_func) + + result = {} + for task_id, task_item in running_tasks.items(): + result[task_id] = { + "item_id": task_item.item_id, + "user_id": task_item.user_id, + "mem_cube_id": task_item.mem_cube_id, + "task_info": task_item.task_info, + "task_name": task_item.task_name, + "start_time": task_item.start_time, + "end_time": task_item.end_time, + "status": task_item.status, + "result": task_item.result, + "error_message": task_item.error_message, + "messages": task_item.messages, + } + + return result + + def get_tasks_status(self): + return self.task_schedule_monitor.get_tasks_status() + + def print_tasks_status(self, tasks_status: dict | None = None) -> None: + self.task_schedule_monitor.print_tasks_status(tasks_status=tasks_status) + + def _gather_queue_stats(self) -> dict: + memos_message_queue = self.memos_message_queue.memos_message_queue + stats: dict[str, int | float | str] = {} + stats["use_redis_queue"] = bool(self.use_redis_queue) + if not self.use_redis_queue: + try: + stats["qsize"] = int(memos_message_queue.qsize()) + except Exception: + stats["qsize"] = -1 + try: + stats["unfinished_tasks"] = int( + getattr(memos_message_queue, "unfinished_tasks", 0) or 0 + ) + except Exception: + stats["unfinished_tasks"] = -1 + stats["maxsize"] = int(self.max_internal_message_queue_size) + try: + maxsize = int(self.max_internal_message_queue_size) or 1 + qsize = int(stats.get("qsize", 0)) + stats["utilization"] = min(1.0, max(0.0, qsize / maxsize)) + except Exception: + stats["utilization"] = 0.0 + try: + d_stats = self.dispatcher.stats() + stats.update( + { + "running": int(d_stats.get("running", 0)), + "inflight": int(d_stats.get("inflight", 0)), + "handlers": int(d_stats.get("handlers", 0)), + } + ) + except Exception: + stats.update({"running": 0, "inflight": 0, "handlers": 0}) + return stats diff --git a/src/memos/mem_scheduler/base_mixins/web_log_ops.py b/src/memos/mem_scheduler/base_mixins/web_log_ops.py new file mode 100644 index 000000000..beac47500 --- /dev/null +++ b/src/memos/mem_scheduler/base_mixins/web_log_ops.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from memos.log import get_logger +from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + ANSWER_TASK_LABEL, + MEM_ARCHIVE_TASK_LABEL, + MEM_ORGANIZE_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, + QUERY_TASK_LABEL, +) + + +logger = get_logger(__name__) + + +class BaseSchedulerWebLogMixin: + def _submit_web_logs( + self, + messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem], + additional_log_info: str | None = None, + ) -> None: + if isinstance(messages, ScheduleLogForWebItem): + messages = [messages] + + for message in messages: + if self.rabbitmq_config is None: + return + try: + logger.info( + "[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish %s", + message.model_dump_json(indent=2), + ) + self.rabbitmq_publish_message(message=message.to_dict()) + logger.info( + "[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched item_id=%s task_id=%s label=%s", + message.item_id, + message.task_id, + message.label, + ) + except Exception as e: + logger.error( + "[DIAGNOSTIC] base_scheduler._submit_web_logs failed: %s", + e, + exc_info=True, + ) + + logger.debug( + "%s submitted. %s in queue. additional_log_info: %s", + len(messages), + self._web_log_message_queue.qsize(), + additional_log_info, + ) + + def get_web_log_messages(self) -> list[dict]: + raw_items: list[ScheduleLogForWebItem] = [] + while True: + try: + raw_items.append(self._web_log_message_queue.get_nowait()) + except Exception: + break + + def _map_label(label: str) -> str: + mapping = { + QUERY_TASK_LABEL: "addMessage", + ANSWER_TASK_LABEL: "addMessage", + ADD_TASK_LABEL: "addMemory", + MEM_UPDATE_TASK_LABEL: "updateMemory", + MEM_ORGANIZE_TASK_LABEL: "mergeMemory", + MEM_ARCHIVE_TASK_LABEL: "archiveMemory", + } + return mapping.get(label, label) + + def _normalize_item(item: ScheduleLogForWebItem) -> dict: + data = item.to_dict() + data["label"] = _map_label(data.get("label")) + memcube_content = getattr(item, "memcube_log_content", None) or [] + metadata = getattr(item, "metadata", None) or [] + + memcube_name = getattr(item, "memcube_name", None) + if not memcube_name and hasattr(self, "_map_memcube_name"): + memcube_name = self._map_memcube_name(item.mem_cube_id) + data["memcube_name"] = memcube_name + + memory_len = getattr(item, "memory_len", None) + if memory_len is None: + if data["label"] == "mergeMemory": + memory_len = len([c for c in memcube_content if c.get("type") != "postMerge"]) + elif memcube_content: + memory_len = len(memcube_content) + else: + memory_len = 1 if item.log_content else 0 + + data["memcube_log_content"] = memcube_content + data["memory_len"] = memory_len + + def _with_memory_time(meta: dict) -> dict: + enriched = dict(meta) + if "memory_time" not in enriched: + enriched["memory_time"] = enriched.get("updated_at") or enriched.get("update_at") + return enriched + + data["metadata"] = [_with_memory_time(m) for m in metadata] + data["log_title"] = "" + return data + + return [_normalize_item(it) for it in raw_items] diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 5ab524128..d14248f02 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1,28 +1,20 @@ -import multiprocessing import os import threading -import time -from collections.abc import Callable -from contextlib import suppress -from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Union from sqlalchemy.engine import Engine from memos.configs.mem_scheduler import AuthConfig, BaseSchedulerConfig -from memos.context.context import ( - ContextThread, - RequestContext, - get_current_context, - get_current_trace_id, - set_request_context, +from memos.mem_scheduler.base_mixins import ( + BaseSchedulerMemoryMixin, + BaseSchedulerQueueMixin, + BaseSchedulerWebLogMixin, ) from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.base import BaseMemCube -from memos.mem_cube.general import GeneralMemCube from memos.mem_feedback.simple_feedback import SimpleMemFeedback from memos.mem_scheduler.general_modules.init_components_for_scheduler import init_components from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue @@ -42,42 +34,18 @@ DEFAULT_THREAD_POOL_MAX_WORKERS, DEFAULT_TOP_K, DEFAULT_USE_REDIS_QUEUE, - STARTUP_BY_PROCESS, TreeTextMemory_SEARCH_METHOD, ) -from memos.mem_scheduler.schemas.message_schemas import ( - ScheduleLogForWebItem, - ScheduleMessageItem, -) -from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem -from memos.mem_scheduler.schemas.task_schemas import ( - ADD_TASK_LABEL, - ANSWER_TASK_LABEL, - MEM_ARCHIVE_TASK_LABEL, - MEM_ORGANIZE_TASK_LABEL, - MEM_UPDATE_TASK_LABEL, - QUERY_TASK_LABEL, - TaskPriorityLevel, -) +from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue from memos.mem_scheduler.utils import metrics -from memos.mem_scheduler.utils.db_utils import get_utc_now -from memos.mem_scheduler.utils.filter_utils import ( - transform_name_to_key, -) -from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube -from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule -from memos.memories.activation.kv import KVCacheMemory -from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory -from memos.memories.textual.naive import NaiveTextMemory -from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher -from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE from memos.types.general_types import ( MemCubeID, UserID, @@ -93,7 +61,14 @@ logger = get_logger(__name__) -class BaseScheduler(RabbitMQSchedulerModule, RedisSchedulerModule, SchedulerLoggerModule): +class BaseScheduler( + RabbitMQSchedulerModule, + RedisSchedulerModule, + SchedulerLoggerModule, + BaseSchedulerWebLogMixin, + BaseSchedulerMemoryMixin, + BaseSchedulerQueueMixin, +): """Base class for all mem_scheduler.""" def __init__(self, config: BaseSchedulerConfig): @@ -391,929 +366,5 @@ def mem_cubes(self, value: dict[str, BaseMemCube]) -> None: f"Failed to initialize current_mem_cube from mem_cubes: {e}", exc_info=True ) - def transform_working_memories_to_monitors( - self, query_keywords, memories: list[TextualMemoryItem] - ) -> list[MemoryMonitorItem]: - """ - Convert a list of TextualMemoryItem objects into MemoryMonitorItem objects - with importance scores based on keyword matching. - - Args: - memories: List of TextualMemoryItem objects to be transformed. - - Returns: - List of MemoryMonitorItem objects with computed importance scores. - """ - - result = [] - mem_length = len(memories) - for idx, mem in enumerate(memories): - text_mem = mem.memory - mem_key = transform_name_to_key(name=text_mem) - - # Calculate importance score based on keyword matches - keywords_score = 0 - if query_keywords and text_mem: - for keyword, count in query_keywords.items(): - keyword_count = text_mem.count(keyword) - if keyword_count > 0: - keywords_score += keyword_count * count - logger.debug( - f"Matched keyword '{keyword}' {keyword_count} times, added {keywords_score} to keywords_score" - ) - - # rank score - sorting_score = mem_length - idx - - mem_monitor = MemoryMonitorItem( - memory_text=text_mem, - tree_memory_item=mem, - tree_memory_item_mapping_key=mem_key, - sorting_score=sorting_score, - keywords_score=keywords_score, - recording_count=1, - ) - result.append(mem_monitor) - - logger.info(f"Transformed {len(result)} memories to monitors") - return result - - def replace_working_memory( - self, - user_id: UserID | str, - mem_cube_id: MemCubeID | str, - mem_cube: GeneralMemCube, - original_memory: list[TextualMemoryItem], - new_memory: list[TextualMemoryItem], - ) -> None | list[TextualMemoryItem]: - """Replace working memory with new memories after reranking.""" - text_mem_base = mem_cube.text_mem - if isinstance(text_mem_base, TreeTextMemory): - text_mem_base: TreeTextMemory = text_mem_base - - # process rerank memories with llm - query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] - # Sync with database to get latest query history - query_db_manager.sync_with_orm() - - query_history = query_db_manager.obj.get_queries_with_timesort() - - original_count = len(original_memory) - # Filter out memories tagged with "mode:fast" - filtered_original_memory = [] - for origin_mem in original_memory: - if "mode:fast" not in origin_mem.metadata.tags: - filtered_original_memory.append(origin_mem) - else: - logger.debug( - f"Filtered out memory - ID: {getattr(origin_mem, 'id', 'unknown')}, Tags: {origin_mem.metadata.tags}" - ) - # Calculate statistics - filtered_count = original_count - len(filtered_original_memory) - remaining_count = len(filtered_original_memory) - - logger.info( - f"Filtering complete. Removed {filtered_count} memories with tag 'mode:fast'. Remaining memories: {remaining_count}" - ) - original_memory = filtered_original_memory - - memories_with_new_order, rerank_success_flag = ( - self.retriever.process_and_rerank_memories( - queries=query_history, - original_memory=original_memory, - new_memory=new_memory, - top_k=self.top_k, - ) - ) - - # Filter completely unrelated memories according to query_history - logger.info(f"Filtering memories based on query history: {len(query_history)} queries") - filtered_memories, filter_success_flag = self.retriever.filter_unrelated_memories( - query_history=query_history, - memories=memories_with_new_order, - ) - - if filter_success_flag: - logger.info( - f"Memory filtering completed successfully. " - f"Filtered from {len(memories_with_new_order)} to {len(filtered_memories)} memories" - ) - memories_with_new_order = filtered_memories - else: - logger.warning( - "Memory filtering failed - keeping all memories as fallback. " - f"Original count: {len(memories_with_new_order)}" - ) - - # Update working memory monitors - query_keywords = query_db_manager.obj.get_keywords_collections() - logger.info( - f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords" - ) - new_working_memory_monitors = self.transform_working_memories_to_monitors( - query_keywords=query_keywords, - memories=memories_with_new_order, - ) - - if not rerank_success_flag: - for one in new_working_memory_monitors: - one.sorting_score = 0 - - logger.info(f"update {len(new_working_memory_monitors)} working_memory_monitors") - self.monitor.update_working_memory_monitors( - new_working_memory_monitors=new_working_memory_monitors, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - - mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][ - mem_cube_id - ].obj.get_sorted_mem_monitors(reverse=True) - new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors] - - text_mem_base.replace_working_memory(memories=new_working_memories) - - logger.info( - f"The working memory has been replaced with {len(memories_with_new_order)} new memories." - ) - self.log_working_memory_replacement( - original_memory=original_memory, - new_memory=new_working_memories, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - log_func_callback=self._submit_web_logs, - ) - elif isinstance(text_mem_base, NaiveTextMemory): - # For NaiveTextMemory, we populate the monitors with the new candidates so activation memory can pick them up - logger.info( - f"NaiveTextMemory: Updating working memory monitors with {len(new_memory)} candidates." - ) - - # Use query keywords if available, otherwise just basic monitoring - query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] - query_db_manager.sync_with_orm() - query_keywords = query_db_manager.obj.get_keywords_collections() - - new_working_memory_monitors = self.transform_working_memories_to_monitors( - query_keywords=query_keywords, - memories=new_memory, - ) - - self.monitor.update_working_memory_monitors( - new_working_memory_monitors=new_working_memory_monitors, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - memories_with_new_order = new_memory - else: - logger.error("memory_base is not supported") - memories_with_new_order = new_memory - - return memories_with_new_order - - def update_activation_memory( - self, - new_memories: list[str | TextualMemoryItem], - label: str, - user_id: UserID | str, - mem_cube_id: MemCubeID | str, - mem_cube: GeneralMemCube, - ) -> None: - """ - Update activation memory by extracting KVCacheItems from new_memory (list of str), - add them to a KVCacheMemory instance, and dump to disk. - """ - if len(new_memories) == 0: - logger.error("update_activation_memory: new_memory is empty.") - return - if isinstance(new_memories[0], TextualMemoryItem): - new_text_memories = [mem.memory for mem in new_memories] - elif isinstance(new_memories[0], str): - new_text_memories = new_memories - else: - logger.error("Not Implemented.") - return - - try: - if isinstance(mem_cube.act_mem, VLLMKVCacheMemory): - act_mem: VLLMKVCacheMemory = mem_cube.act_mem - elif isinstance(mem_cube.act_mem, KVCacheMemory): - act_mem: KVCacheMemory = mem_cube.act_mem - else: - logger.error("Not Implemented.") - return - - new_text_memory = MEMORY_ASSEMBLY_TEMPLATE.format( - memory_text="".join( - [ - f"{i + 1}. {sentence.strip()}\n" - for i, sentence in enumerate(new_text_memories) - if sentence.strip() # Skip empty strings - ] - ) - ) - - # huggingface or vllm kv cache - original_cache_items: list[VLLMKVCacheItem] = act_mem.get_all() - original_text_memories = [] - if len(original_cache_items) > 0: - pre_cache_item: VLLMKVCacheItem = original_cache_items[-1] - original_text_memories = pre_cache_item.records.text_memories - original_composed_text_memory = pre_cache_item.records.composed_text_memory - if original_composed_text_memory == new_text_memory: - logger.warning( - "Skipping memory update - new composition matches existing cache: %s", - new_text_memory[:50] + "..." - if len(new_text_memory) > 50 - else new_text_memory, - ) - return - act_mem.delete_all() - - cache_item = act_mem.extract(new_text_memory) - cache_item.records.text_memories = new_text_memories - cache_item.records.timestamp = get_utc_now() - - act_mem.add([cache_item]) - act_mem.dump(self.act_mem_dump_path) - - self.log_activation_memory_update( - original_text_memories=original_text_memories, - new_text_memories=new_text_memories, - label=label, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - log_func_callback=self._submit_web_logs, - ) - - except Exception as e: - logger.error(f"MOS-based activation memory update failed: {e}", exc_info=True) - # Re-raise the exception if it's critical for the operation - # For now, we'll continue execution but this should be reviewed - - def update_activation_memory_periodically( - self, - interval_seconds: int, - label: str, - user_id: UserID | str, - mem_cube_id: MemCubeID | str, - mem_cube: GeneralMemCube, - ): - try: - if ( - self.monitor.last_activation_mem_update_time == datetime.min - or self.monitor.timed_trigger( - last_time=self.monitor.last_activation_mem_update_time, - interval_seconds=interval_seconds, - ) - ): - logger.info( - f"Updating activation memory for user {user_id} and mem_cube {mem_cube_id}" - ) - - if ( - user_id not in self.monitor.working_memory_monitors - or mem_cube_id not in self.monitor.working_memory_monitors[user_id] - or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories) - == 0 - ): - logger.warning( - "No memories found in working_memory_monitors, activation memory update is skipped" - ) - return - - self.monitor.update_activation_memory_monitors( - user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube - ) - - # Sync with database to get latest activation memories - activation_db_manager = self.monitor.activation_memory_monitors[user_id][ - mem_cube_id - ] - activation_db_manager.sync_with_orm() - new_activation_memories = [ - m.memory_text for m in activation_db_manager.obj.memories - ] - - logger.info( - f"Collected {len(new_activation_memories)} new memory entries for processing" - ) - # Print the content of each new activation memory - for i, memory in enumerate(new_activation_memories[:5], 1): - logger.info( - f"Part of New Activation Memorires | {i}/{len(new_activation_memories)}: {memory[:20]}" - ) - - self.update_activation_memory( - new_memories=new_activation_memories, - label=label, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - - self.monitor.last_activation_mem_update_time = get_utc_now() - - logger.debug( - f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}" - ) - - else: - logger.info( - f"Skipping update - {interval_seconds} second interval not yet reached. " - f"Last update time is {self.monitor.last_activation_mem_update_time} and now is " - f"{get_utc_now()}" - ) - except Exception as e: - logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True) - - def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): - """Submit messages for processing, with priority-aware dispatch. - - - LEVEL_1 tasks dispatch immediately to the appropriate handler. - - Lower-priority tasks are enqueued via the configured message queue. - """ - if isinstance(messages, ScheduleMessageItem): - messages = [messages] - - if not messages: - return - - current_trace_id = get_current_trace_id() - - immediate_msgs: list[ScheduleMessageItem] = [] - queued_msgs: list[ScheduleMessageItem] = [] - - for msg in messages: - # propagate request trace_id when available so monitor logs align with request logs - if current_trace_id: - msg.trace_id = current_trace_id - - # basic metrics and status tracking - with suppress(Exception): - self.metrics.task_enqueued(user_id=msg.user_id, task_type=msg.label) - - # ensure timestamp exists for monitoring - if getattr(msg, "timestamp", None) is None: - msg.timestamp = get_utc_now() - - if self.status_tracker: - try: - self.status_tracker.task_submitted( - task_id=msg.item_id, - user_id=msg.user_id, - task_type=msg.label, - mem_cube_id=msg.mem_cube_id, - business_task_id=msg.task_id, - ) - except Exception: - logger.warning("status_tracker.task_submitted failed", exc_info=True) - - # honor disabled handlers - if self.disabled_handlers and msg.label in self.disabled_handlers: - logger.info(f"Skipping disabled handler: {msg.label} - {msg.content}") - continue - - # decide priority path - task_priority = self.orchestrator.get_task_priority(task_label=msg.label) - if task_priority == TaskPriorityLevel.LEVEL_1: - immediate_msgs.append(msg) - else: - queued_msgs.append(msg) - - # Dispatch high-priority tasks immediately - if immediate_msgs: - # emit enqueue events for consistency - for m in immediate_msgs: - emit_monitor_event( - "enqueue", - m, - { - "enqueue_ts": to_iso(getattr(m, "timestamp", None)), - "event_duration_ms": 0, - "total_duration_ms": 0, - }, - ) - - # simulate dequeue for immediately dispatched messages so monitor logs stay complete - for m in immediate_msgs: - try: - now = time.time() - enqueue_ts_obj = getattr(m, "timestamp", None) - enqueue_epoch = None - if isinstance(enqueue_ts_obj, int | float): - enqueue_epoch = float(enqueue_ts_obj) - elif hasattr(enqueue_ts_obj, "timestamp"): - dt = enqueue_ts_obj - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - enqueue_epoch = dt.timestamp() - - queue_wait_ms = None - if enqueue_epoch is not None: - queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 - - object.__setattr__(m, "_dequeue_ts", now) - emit_monitor_event( - "dequeue", - m, - { - "enqueue_ts": to_iso(enqueue_ts_obj), - "dequeue_ts": datetime.fromtimestamp(now, tz=timezone.utc).isoformat(), - "queue_wait_ms": queue_wait_ms, - "event_duration_ms": queue_wait_ms, - "total_duration_ms": queue_wait_ms, - }, - ) - self.metrics.task_dequeued(user_id=m.user_id, task_type=m.label) - except Exception: - logger.debug("Failed to emit dequeue for immediate task", exc_info=True) - - user_cube_groups = group_messages_by_user_and_mem_cube(immediate_msgs) - for user_id, cube_groups in user_cube_groups.items(): - for mem_cube_id, user_cube_msgs in cube_groups.items(): - label_groups: dict[str, list[ScheduleMessageItem]] = {} - for m in user_cube_msgs: - label_groups.setdefault(m.label, []).append(m) - - for label, msgs_by_label in label_groups.items(): - handler = self.dispatcher.handlers.get( - label, self.dispatcher._default_message_handler - ) - self.dispatcher.execute_task( - user_id=user_id, - mem_cube_id=mem_cube_id, - task_label=label, - msgs=msgs_by_label, - handler_call_back=handler, - ) - - # Enqueue lower-priority tasks - if queued_msgs: - self.memos_message_queue.submit_messages(messages=queued_msgs) - - def _submit_web_logs( - self, - messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem], - additional_log_info: str | None = None, - ) -> None: - """Submit log messages to the web log queue and optionally to RabbitMQ. - - Args: - messages: Single log message or list of log messages - """ - if isinstance(messages, ScheduleLogForWebItem): - messages = [messages] # transform single message to list - - for message in messages: - if self.rabbitmq_config is None: - return - try: - # Always call publish; the publisher now caches when offline and flushes after reconnect - logger.info( - f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message.model_dump_json(indent=2)}" - ) - self.rabbitmq_publish_message(message=message.to_dict()) - logger.info( - "[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched " - "item_id=%s task_id=%s label=%s", - message.item_id, - message.task_id, - message.label, - ) - except Exception as e: - logger.error( - f"[DIAGNOSTIC] base_scheduler._submit_web_logs failed: {e}", exc_info=True - ) - - logger.debug( - f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue. additional_log_info: {additional_log_info}" - ) - - def get_web_log_messages(self) -> list[dict]: - """ - Retrieve structured log messages from the queue and return JSON-serializable dicts. - """ - raw_items: list[ScheduleLogForWebItem] = [] - while True: - try: - raw_items.append(self._web_log_message_queue.get_nowait()) - except Exception: - break - - def _map_label(label: str) -> str: - mapping = { - QUERY_TASK_LABEL: "addMessage", - ANSWER_TASK_LABEL: "addMessage", - ADD_TASK_LABEL: "addMemory", - MEM_UPDATE_TASK_LABEL: "updateMemory", - MEM_ORGANIZE_TASK_LABEL: "mergeMemory", - MEM_ARCHIVE_TASK_LABEL: "archiveMemory", - } - return mapping.get(label, label) - - def _normalize_item(item: ScheduleLogForWebItem) -> dict: - data = item.to_dict() - data["label"] = _map_label(data.get("label")) - memcube_content = getattr(item, "memcube_log_content", None) or [] - metadata = getattr(item, "metadata", None) or [] - - memcube_name = getattr(item, "memcube_name", None) - if not memcube_name and hasattr(self, "_map_memcube_name"): - memcube_name = self._map_memcube_name(item.mem_cube_id) - data["memcube_name"] = memcube_name - - memory_len = getattr(item, "memory_len", None) - if memory_len is None: - if data["label"] == "mergeMemory": - memory_len = len([c for c in memcube_content if c.get("type") != "postMerge"]) - elif memcube_content: - memory_len = len(memcube_content) - else: - memory_len = 1 if item.log_content else 0 - - data["memcube_log_content"] = memcube_content - data["memory_len"] = memory_len - - def _with_memory_time(meta: dict) -> dict: - enriched = dict(meta) - if "memory_time" not in enriched: - enriched["memory_time"] = enriched.get("updated_at") or enriched.get( - "update_at" - ) - return enriched - - data["metadata"] = [_with_memory_time(m) for m in metadata] - data["log_title"] = "" - return data - - return [_normalize_item(it) for it in raw_items] - - def _message_consumer(self) -> None: - """ - Continuously checks the queue for messages and dispatches them. - - Runs in a dedicated thread to process messages at regular intervals. - For Redis queue, this method starts the Redis listener. - """ - - # Original local queue logic - while self._running: # Use a running flag for graceful shutdown - try: - # Check dispatcher thread pool status to avoid overloading - if self.enable_parallel_dispatch and self.dispatcher: - running_tasks = self.dispatcher.get_running_task_count() - if running_tasks >= self.dispatcher.max_workers: - # Thread pool is full, wait and retry - time.sleep(self._consume_interval) - continue - - # Get messages in batches based on consume_batch setting - - messages = self.memos_message_queue.get_messages(batch_size=self.consume_batch) - - if messages: - now = time.time() - for msg in messages: - prev_context = get_current_context() - try: - # Set context for this message - msg_context = RequestContext( - trace_id=msg.trace_id, - user_name=msg.user_name, - ) - set_request_context(msg_context) - - enqueue_ts_obj = getattr(msg, "timestamp", None) - enqueue_epoch = None - if isinstance(enqueue_ts_obj, int | float): - enqueue_epoch = float(enqueue_ts_obj) - elif hasattr(enqueue_ts_obj, "timestamp"): - dt = enqueue_ts_obj - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - enqueue_epoch = dt.timestamp() - - queue_wait_ms = None - if enqueue_epoch is not None: - queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 - - # Avoid pydantic field enforcement by using object.__setattr__ - object.__setattr__(msg, "_dequeue_ts", now) - emit_monitor_event( - "dequeue", - msg, - { - "enqueue_ts": to_iso(enqueue_ts_obj), - "dequeue_ts": datetime.fromtimestamp( - now, tz=timezone.utc - ).isoformat(), - "queue_wait_ms": queue_wait_ms, - "event_duration_ms": queue_wait_ms, - "total_duration_ms": queue_wait_ms, - }, - ) - self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label) - finally: - # Restore the prior context of the consumer thread - set_request_context(prev_context) - try: - import contextlib - - with contextlib.suppress(Exception): - if messages: - self.dispatcher.on_messages_enqueued(messages) - - self.dispatcher.dispatch(messages) - except Exception as e: - logger.error(f"Error dispatching messages: {e!s}") - - # Sleep briefly to prevent busy waiting - time.sleep(self._consume_interval) # Adjust interval as needed - - except Exception as e: - # Don't log error for "No messages available in Redis queue" as it's expected - if "No messages available in Redis queue" not in str(e): - logger.error(f"Unexpected error in message consumer: {e!s}", exc_info=True) - time.sleep(self._consume_interval) # Prevent tight error loops - - def _monitor_loop(self): - while self._running: - try: - q_sizes = self.memos_message_queue.qsize() - - if not isinstance(q_sizes, dict): - continue - - for stream_key, queue_length in q_sizes.items(): - # Skip aggregate keys like 'total_size' - if stream_key == "total_size": - continue - - # Key format: ...:{user_id}:{mem_cube_id}:{task_label} - # We want to extract user_id, which is the 3rd component from the end. - parts = stream_key.split(":") - if len(parts) >= 3: - user_id = parts[-3] - self.metrics.update_queue_length(queue_length, user_id) - else: - # Fallback for unexpected key formats (e.g. legacy or testing) - # Try to use the key itself if it looks like a user_id (no colons) - # or just log a warning? - # For now, let's assume if it's not total_size and short, it might be a direct user_id key - # (though that shouldn't happen with current queue implementations) - if ":" not in stream_key: - self.metrics.update_queue_length(queue_length, stream_key) - - except Exception as e: - logger.error(f"Error in metrics monitor loop: {e}", exc_info=True) - - time.sleep(15) # 每 15 秒采样一次 - - def start(self) -> None: - """ - Start the message consumer thread/process and initialize dispatcher resources. - - Initializes and starts: - 1. Message consumer thread or process (based on startup_mode) - 2. Dispatcher thread pool (if parallel dispatch enabled) - """ - # Initialize dispatcher resources - if self.enable_parallel_dispatch: - logger.info( - f"Initializing dispatcher thread pool with {self.thread_pool_max_workers} workers" - ) - - self.start_consumer() - self.start_background_monitor() - def start_background_monitor(self): - if self._monitor_thread and self._monitor_thread.is_alive(): - return - self._monitor_thread = ContextThread( - target=self._monitor_loop, daemon=True, name="SchedulerMetricsMonitor" - ) - self._monitor_thread.start() - logger.info("Scheduler metrics monitor thread started.") - - def start_consumer(self) -> None: - """ - Start only the message consumer thread/process. - - This method can be used to restart the consumer after it has been stopped - with stop_consumer(), without affecting other scheduler components. - """ - if self._running: - logger.warning("Memory Scheduler consumer is already running") - return - - # Start consumer based on startup mode - self._running = True - - if self.scheduler_startup_mode == STARTUP_BY_PROCESS: - # Start consumer process - self._consumer_process = multiprocessing.Process( - target=self._message_consumer, - daemon=True, - name="MessageConsumerProcess", - ) - self._consumer_process.start() - logger.info("Message consumer process started") - else: - # Default to thread mode - self._consumer_thread = ContextThread( - target=self._message_consumer, - daemon=True, - name="MessageConsumerThread", - ) - self._consumer_thread.start() - logger.info("Message consumer thread started") - - def stop_consumer(self) -> None: - """Stop only the message consumer thread/process gracefully. - - This method stops the consumer without affecting other components like - dispatcher or monitors. Useful when you want to pause message processing - while keeping other scheduler components running. - """ - if not self._running: - logger.warning("Memory Scheduler consumer is not running") - return - - # Signal consumer thread/process to stop - self._running = False - - # Wait for consumer thread or process - if self.scheduler_startup_mode == STARTUP_BY_PROCESS and self._consumer_process: - if self._consumer_process.is_alive(): - self._consumer_process.join(timeout=5.0) - if self._consumer_process.is_alive(): - logger.warning("Consumer process did not stop gracefully, terminating...") - self._consumer_process.terminate() - self._consumer_process.join(timeout=2.0) - if self._consumer_process.is_alive(): - logger.error("Consumer process could not be terminated") - else: - logger.info("Consumer process terminated") - else: - logger.info("Consumer process stopped") - self._consumer_process = None - elif self._consumer_thread and self._consumer_thread.is_alive(): - self._consumer_thread.join(timeout=5.0) - if self._consumer_thread.is_alive(): - logger.warning("Consumer thread did not stop gracefully") - else: - logger.info("Consumer thread stopped") - self._consumer_thread = None - - logger.info("Memory Scheduler consumer stopped") - - def stop(self) -> None: - """Stop all scheduler components gracefully. - - 1. Stops message consumer thread/process - 2. Shuts down dispatcher thread pool - 3. Cleans up resources - """ - if not self._running: - logger.warning("Memory Scheduler is not running") - return - - # Stop consumer first - self.stop_consumer() - - if self._monitor_thread: - self._monitor_thread.join(timeout=2.0) - - # Shutdown dispatcher - if self.dispatcher: - logger.info("Shutting down dispatcher...") - self.dispatcher.shutdown() - - # Shutdown dispatcher_monitor - if self.dispatcher_monitor: - logger.info("Shutting down monitor...") - self.dispatcher_monitor.stop() - - @property - def handlers(self) -> dict[str, Callable]: - """ - Access the dispatcher's handlers dictionary. - - Returns: - dict[str, Callable]: Dictionary mapping labels to handler functions - """ - if not self.dispatcher: - logger.warning("Dispatcher is not initialized, returning empty handlers dict") - return {} - - return self.dispatcher.handlers - - def register_handlers( - self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]] - ) -> None: - """ - Bulk register multiple handlers from a dictionary. - - Args: - handlers: Dictionary mapping labels to handler functions - Format: {label: handler_callable} - """ - if not self.dispatcher: - logger.warning("Dispatcher is not initialized, cannot register handlers") - return - - self.dispatcher.register_handlers(handlers) - - def unregister_handlers(self, labels: list[str]) -> dict[str, bool]: - """ - Unregister handlers from the dispatcher by their labels. - - Args: - labels: List of labels to unregister handlers for - - Returns: - dict[str, bool]: Dictionary mapping each label to whether it was successfully unregistered - """ - if not self.dispatcher: - logger.warning("Dispatcher is not initialized, cannot unregister handlers") - return dict.fromkeys(labels, False) - - return self.dispatcher.unregister_handlers(labels) - - def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, dict]: - if not self.dispatcher: - logger.warning("Dispatcher is not initialized, returning empty tasks dict") - return {} - - running_tasks = self.dispatcher.get_running_tasks(filter_func=filter_func) - - # Convert RunningTaskItem objects to dictionaries for easier consumption - result = {} - for task_id, task_item in running_tasks.items(): - result[task_id] = { - "item_id": task_item.item_id, - "user_id": task_item.user_id, - "mem_cube_id": task_item.mem_cube_id, - "task_info": task_item.task_info, - "task_name": task_item.task_name, - "start_time": task_item.start_time, - "end_time": task_item.end_time, - "status": task_item.status, - "result": task_item.result, - "error_message": task_item.error_message, - "messages": task_item.messages, - } - - return result - - def get_tasks_status(self): - """Delegate status collection to TaskScheduleMonitor.""" - return self.task_schedule_monitor.get_tasks_status() - - def print_tasks_status(self, tasks_status: dict | None = None) -> None: - """Delegate pretty printing to TaskScheduleMonitor.""" - self.task_schedule_monitor.print_tasks_status(tasks_status=tasks_status) - - def _gather_queue_stats(self) -> dict: - """Collect queue/dispatcher stats for reporting.""" - memos_message_queue = self.memos_message_queue.memos_message_queue - stats: dict[str, int | float | str] = {} - stats["use_redis_queue"] = bool(self.use_redis_queue) - # local queue metrics - if not self.use_redis_queue: - try: - stats["qsize"] = int(memos_message_queue.qsize()) - except Exception: - stats["qsize"] = -1 - # unfinished_tasks if available - try: - stats["unfinished_tasks"] = int( - getattr(memos_message_queue, "unfinished_tasks", 0) or 0 - ) - except Exception: - stats["unfinished_tasks"] = -1 - stats["maxsize"] = int(self.max_internal_message_queue_size) - try: - maxsize = int(self.max_internal_message_queue_size) or 1 - qsize = int(stats.get("qsize", 0)) - stats["utilization"] = min(1.0, max(0.0, qsize / maxsize)) - except Exception: - stats["utilization"] = 0.0 - # dispatcher stats - try: - d_stats = self.dispatcher.stats() - stats.update( - { - "running": int(d_stats.get("running", 0)), - "inflight": int(d_stats.get("inflight", 0)), - "handlers": int(d_stats.get("handlers", 0)), - } - ) - except Exception: - stats.update({"running": 0, "inflight": 0, "handlers": 0}) - return stats + # Methods moved to mixins in mem_scheduler.base_mixins. diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 74e50a514..66801def6 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -1,49 +1,12 @@ -import concurrent.futures -import contextlib -import json -import traceback +from __future__ import annotations from memos.configs.mem_scheduler import GeneralSchedulerConfig -from memos.context.context import ContextThreadPoolExecutor -from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.base_scheduler import BaseScheduler -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem -from memos.mem_scheduler.schemas.task_schemas import ( - ADD_TASK_LABEL, - ANSWER_TASK_LABEL, - DEFAULT_MAX_QUERY_KEY_WORDS, - LONG_TERM_MEMORY_TYPE, - MEM_FEEDBACK_TASK_LABEL, - MEM_ORGANIZE_TASK_LABEL, - MEM_READ_TASK_LABEL, - MEM_UPDATE_TASK_LABEL, - NOT_APPLICABLE_TYPE, - PREF_ADD_TASK_LABEL, - QUERY_TASK_LABEL, - USER_INPUT_TYPE, +from memos.mem_scheduler.handlers import ( + SchedulerHandlerContext, + SchedulerHandlerRegistry, + SchedulerHandlerServices, ) -from memos.mem_scheduler.utils.filter_utils import ( - is_all_chinese, - is_all_english, - transform_name_to_key, -) -from memos.mem_scheduler.utils.misc_utils import ( - group_messages_by_user_and_mem_cube, - is_cloud_env, -) -from memos.memories.textual.item import TextualMemoryItem -from memos.memories.textual.naive import NaiveTextMemory -from memos.memories.textual.preference import PreferenceTextMemory -from memos.memories.textual.tree import TreeTextMemory -from memos.types import ( - MemCubeID, - UserID, -) - - -logger = get_logger(__name__) class GeneralScheduler(BaseScheduler): @@ -53,1447 +16,29 @@ def __init__(self, config: GeneralSchedulerConfig): self.query_key_words_limit = self.config.get("query_key_words_limit", 20) - # register handlers - handlers = { - QUERY_TASK_LABEL: self._query_message_consumer, - ANSWER_TASK_LABEL: self._answer_message_consumer, - MEM_UPDATE_TASK_LABEL: self._memory_update_consumer, - ADD_TASK_LABEL: self._add_message_consumer, - MEM_READ_TASK_LABEL: self._mem_read_message_consumer, - MEM_ORGANIZE_TASK_LABEL: self._mem_reorganize_message_consumer, - PREF_ADD_TASK_LABEL: self._pref_add_message_consumer, - MEM_FEEDBACK_TASK_LABEL: self._mem_feedback_message_consumer, - } - self.dispatcher.register_handlers(handlers) - - def long_memory_update_process( - self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem] - ): - mem_cube = self.mem_cube - - # update query monitors - for msg in messages: - self.monitor.register_query_monitor_if_not_exists( - user_id=user_id, mem_cube_id=mem_cube_id - ) - - query = msg.content - query_keywords = self.monitor.extract_query_keywords(query=query) - logger.info( - f'Extracted keywords "{query_keywords}" from query "{query}" for user_id={user_id}' - ) - - if len(query_keywords) == 0: - stripped_query = query.strip() - # Determine measurement method based on language - if is_all_english(stripped_query): - words = stripped_query.split() # Word count for English - elif is_all_chinese(stripped_query): - words = stripped_query # Character count for Chinese - else: - logger.debug( - f"Mixed-language memory, using character count: {stripped_query[:50]}..." - ) - words = stripped_query # Default to character count - - query_keywords = list(set(words[: self.query_key_words_limit])) - logger.error( - f"Keyword extraction failed for query '{query}' (user_id={user_id}). Using fallback keywords: {query_keywords[:10]}... (truncated)", - exc_info=True, - ) - - item = QueryMonitorItem( - user_id=user_id, - mem_cube_id=mem_cube_id, - query_text=query, - keywords=query_keywords, - max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS, - ) - - query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] - query_db_manager.obj.put(item=item) - # Sync with database after adding new item - query_db_manager.sync_with_orm() - logger.debug( - f"Queries in monitor for user_id={user_id}, mem_cube_id={mem_cube_id}: {query_db_manager.obj.get_queries_with_timesort()}" - ) - - queries = [msg.content for msg in messages] - - # recall - cur_working_memory, new_candidates = self.process_session_turn( - queries=queries, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - top_k=self.top_k, - ) - logger.info( - # Build the candidate preview string outside the f-string to avoid backslashes in expression - f"[long_memory_update_process] Processed {len(queries)} queries {queries} and retrieved {len(new_candidates)} " - f"new candidate memories for user_id={user_id}: " - + ("\n- " + "\n- ".join([f"{one.id}: {one.memory}" for one in new_candidates])) - ) - - # rerank - new_order_working_memory = self.replace_working_memory( - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - original_memory=cur_working_memory, - new_memory=new_candidates, - ) - logger.debug( - f"[long_memory_update_process] Final working memory size: {len(new_order_working_memory)} memories for user_id={user_id}" - ) - - old_memory_texts = "\n- " + "\n- ".join( - [f"{one.id}: {one.memory}" for one in cur_working_memory] - ) - new_memory_texts = "\n- " + "\n- ".join( - [f"{one.id}: {one.memory}" for one in new_order_working_memory] - ) - - logger.info( - f"[long_memory_update_process] For user_id='{user_id}', mem_cube_id='{mem_cube_id}': " - f"Scheduler replaced working memory based on query history {queries}. " - f"Old working memory ({len(cur_working_memory)} items): {old_memory_texts}. " - f"New working memory ({len(new_order_working_memory)} items): {new_memory_texts}." - ) - - # update activation memories - logger.debug( - f"Activation memory update {'enabled' if self.enable_activation_memory else 'disabled'} " - f"(interval: {self.monitor.act_mem_update_interval}s)" - ) - if self.enable_activation_memory: - self.update_activation_memory_periodically( - interval_seconds=self.monitor.act_mem_update_interval, - label=QUERY_TASK_LABEL, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=self.mem_cube, - ) - - def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {ADD_TASK_LABEL} handler.") - # Process the query in a session turn - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - - self.validate_schedule_messages(messages=messages, label=ADD_TASK_LABEL) - try: - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - - # Process each message in the batch - for msg in batch: - prepared_add_items, prepared_update_items_with_original = ( - self.log_add_messages(msg=msg) - ) - logger.info( - f"prepared_add_items: {prepared_add_items};\n prepared_update_items_with_original: {prepared_update_items_with_original}" - ) - # Conditional Logging: Knowledge Base (Cloud Service) vs. Playground/Default - cloud_env = is_cloud_env() - - if cloud_env: - self.send_add_log_messages_to_cloud_env( - msg, prepared_add_items, prepared_update_items_with_original - ) - else: - self.send_add_log_messages_to_local_env( - msg, prepared_add_items, prepared_update_items_with_original - ) - - except Exception as e: - logger.error(f"Error: {e}", exc_info=True) - - def _memory_update_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {MEM_UPDATE_TASK_LABEL} handler.") - - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - - self.validate_schedule_messages(messages=messages, label=MEM_UPDATE_TASK_LABEL) - - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - # Process the whole batch once; no need to iterate per message - self.long_memory_update_process( - user_id=user_id, mem_cube_id=mem_cube_id, messages=batch - ) - - def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - """ - Process and handle query trigger messages from the queue. - - Args: - messages: List of query messages to process - """ - logger.info(f"Messages {messages} assigned to {QUERY_TASK_LABEL} handler.") - - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - - self.validate_schedule_messages(messages=messages, label=QUERY_TASK_LABEL) - - mem_update_messages = [] - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - - for msg in batch: - try: - event = self.create_event_log( - label="addMessage", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=NOT_APPLICABLE_TYPE, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=[ - { - "content": f"[User] {msg.content}", - "ref_id": msg.item_id, - "role": "user", - } - ], - metadata=[], - memory_len=1, - memcube_name=self._map_memcube_name(msg.mem_cube_id), - ) - event.task_id = msg.task_id - self._submit_web_logs([event]) - except Exception: - logger.exception("Failed to record addMessage log for query") - # Re-submit the message with label changed to mem_update - update_msg = ScheduleMessageItem( - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - label=MEM_UPDATE_TASK_LABEL, - content=msg.content, - session_id=msg.session_id, - user_name=msg.user_name, - info=msg.info, - task_id=msg.task_id, - ) - mem_update_messages.append(update_msg) - - self.submit_messages(messages=mem_update_messages) - - def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - """ - Process and handle answer trigger messages from the queue. - - Args: - messages: List of answer messages to process - """ - logger.info(f"Messages {messages} assigned to {ANSWER_TASK_LABEL} handler.") - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - - self.validate_schedule_messages(messages=messages, label=ANSWER_TASK_LABEL) - - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - try: - for msg in batch: - event = self.create_event_log( - label="addMessage", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=NOT_APPLICABLE_TYPE, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=[ - { - "content": f"[Assistant] {msg.content}", - "ref_id": msg.item_id, - "role": "assistant", - } - ], - metadata=[], - memory_len=1, - memcube_name=self._map_memcube_name(msg.mem_cube_id), - ) - event.task_id = msg.task_id - self._submit_web_logs([event]) - except Exception: - logger.exception("Failed to record addMessage log for answer") - - def log_add_messages(self, msg: ScheduleMessageItem): - try: - userinput_memory_ids = json.loads(msg.content) - except Exception as e: - logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True) - userinput_memory_ids = [] - - # Prepare data for both logging paths, fetching original content for updates - prepared_add_items = [] - prepared_update_items_with_original = [] - missing_ids: list[str] = [] - - for memory_id in userinput_memory_ids: - try: - # This mem_item represents the NEW content that was just added/processed - mem_item: TextualMemoryItem | None = None - mem_item = self.mem_cube.text_mem.get( - memory_id=memory_id, user_name=msg.mem_cube_id - ) - if mem_item is None: - raise ValueError(f"Memory {memory_id} not found after retries") - # Check if a memory with the same key already exists (determining if it's an update) - key = getattr(mem_item.metadata, "key", None) or transform_name_to_key( - name=mem_item.memory - ) - exists = False - original_content = None - original_item_id = None - - # Only check graph_store if a key exists and the text_mem has a graph_store - if key and hasattr(self.mem_cube.text_mem, "graph_store"): - candidates = self.mem_cube.text_mem.graph_store.get_by_metadata( - [ - {"field": "key", "op": "=", "value": key}, - { - "field": "memory_type", - "op": "=", - "value": mem_item.metadata.memory_type, - }, - ] - ) - if candidates: - exists = True - original_item_id = candidates[0] - # Crucial step: Fetch the original content for updates - # This `get` is for the *existing* memory that will be updated - original_mem_item = self.mem_cube.text_mem.get( - memory_id=original_item_id, user_name=msg.mem_cube_id - ) - original_content = original_mem_item.memory - - if exists: - prepared_update_items_with_original.append( - { - "new_item": mem_item, - "original_content": original_content, - "original_item_id": original_item_id, - } - ) - else: - prepared_add_items.append(mem_item) - - except Exception: - missing_ids.append(memory_id) - logger.debug( - f"This MemoryItem {memory_id} has already been deleted or an error occurred during preparation." - ) - - if missing_ids: - content_preview = ( - msg.content[:200] + "..." - if isinstance(msg.content, str) and len(msg.content) > 200 - else msg.content - ) - logger.warning( - "Missing TextualMemoryItem(s) during add log preparation. " - "memory_ids=%s user_id=%s mem_cube_id=%s task_id=%s item_id=%s redis_msg_id=%s label=%s stream_key=%s content_preview=%s", - missing_ids, - msg.user_id, - msg.mem_cube_id, - msg.task_id, - msg.item_id, - getattr(msg, "redis_message_id", ""), - msg.label, - getattr(msg, "stream_key", ""), - content_preview, - ) - - if not prepared_add_items and not prepared_update_items_with_original: - logger.warning( - "No add/update items prepared; skipping addMemory/knowledgeBaseUpdate logs. " - "user_id=%s mem_cube_id=%s task_id=%s item_id=%s redis_msg_id=%s label=%s stream_key=%s missing_ids=%s", - msg.user_id, - msg.mem_cube_id, - msg.task_id, - msg.item_id, - getattr(msg, "redis_message_id", ""), - msg.label, - getattr(msg, "stream_key", ""), - missing_ids, - ) - return prepared_add_items, prepared_update_items_with_original - - def send_add_log_messages_to_local_env( - self, msg: ScheduleMessageItem, prepared_add_items, prepared_update_items_with_original - ): - # Existing: Playground/Default Logging - # Reconstruct add_content/add_meta/update_content/update_meta from prepared_items - # This ensures existing logging path continues to work with pre-existing data structures - add_content_legacy: list[dict] = [] - add_meta_legacy: list[dict] = [] - update_content_legacy: list[dict] = [] - update_meta_legacy: list[dict] = [] - - for item in prepared_add_items: - key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory) - add_content_legacy.append({"content": f"{key}: {item.memory}", "ref_id": item.id}) - add_meta_legacy.append( - { - "ref_id": item.id, - "id": item.id, - "key": item.metadata.key, - "memory": item.memory, - "memory_type": item.metadata.memory_type, - "status": item.metadata.status, - "confidence": item.metadata.confidence, - "tags": item.metadata.tags, - "updated_at": getattr(item.metadata, "updated_at", None) - or getattr(item.metadata, "update_at", None), - } - ) - - for item_data in prepared_update_items_with_original: - item = item_data["new_item"] - key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory) - update_content_legacy.append({"content": f"{key}: {item.memory}", "ref_id": item.id}) - update_meta_legacy.append( - { - "ref_id": item.id, - "id": item.id, - "key": item.metadata.key, - "memory": item.memory, - "memory_type": item.metadata.memory_type, - "status": item.metadata.status, - "confidence": item.metadata.confidence, - "tags": item.metadata.tags, - "updated_at": getattr(item.metadata, "updated_at", None) - or getattr(item.metadata, "update_at", None), - } - ) - - events = [] - if add_content_legacy: - event = self.create_event_log( - label="addMemory", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=add_content_legacy, - metadata=add_meta_legacy, - memory_len=len(add_content_legacy), - memcube_name=self._map_memcube_name(msg.mem_cube_id), - ) - event.task_id = msg.task_id - events.append(event) - if update_content_legacy: - event = self.create_event_log( - label="updateMemory", - from_memory_type=LONG_TERM_MEMORY_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=update_content_legacy, - metadata=update_meta_legacy, - memory_len=len(update_content_legacy), - memcube_name=self._map_memcube_name(msg.mem_cube_id), - ) - event.task_id = msg.task_id - events.append(event) - logger.info(f"send_add_log_messages_to_local_env: {len(events)}") - if events: - self._submit_web_logs(events, additional_log_info="send_add_log_messages_to_cloud_env") - - def send_add_log_messages_to_cloud_env( - self, msg: ScheduleMessageItem, prepared_add_items, prepared_update_items_with_original - ): - """ - Cloud logging path for add/update events. - """ - kb_log_content: list[dict] = [] - info = msg.info or {} - - # Process added items - for item in prepared_add_items: - metadata = getattr(item, "metadata", None) - file_ids = getattr(metadata, "file_ids", None) if metadata else None - source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None - kb_log_content.append( - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": info.get("trigger_source", "Messages"), - "operation": "ADD", - "memory_id": item.id, - "content": item.memory, - "original_content": None, - "source_doc_id": source_doc_id, - } - ) - - # Process updated items - for item_data in prepared_update_items_with_original: - item = item_data["new_item"] - metadata = getattr(item, "metadata", None) - file_ids = getattr(metadata, "file_ids", None) if metadata else None - source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None - kb_log_content.append( - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": info.get("trigger_source", "Messages"), - "operation": "UPDATE", - "memory_id": item.id, - "content": item.memory, - "original_content": item_data.get("original_content"), - "source_doc_id": source_doc_id, - } - ) - - if kb_log_content: - logger.info( - f"[DIAGNOSTIC] general_scheduler.send_add_log_messages_to_cloud_env: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: {msg.user_id}, mem_cube_id: {msg.mem_cube_id}, task_id: {msg.task_id}. KB content: {json.dumps(kb_log_content, indent=2)}" - ) - event = self.create_event_log( - label="knowledgeBaseUpdate", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=kb_log_content, - metadata=None, - memory_len=len(kb_log_content), - memcube_name=self._map_memcube_name(msg.mem_cube_id), - ) - event.log_content = f"Knowledge Base Memory Update: {len(kb_log_content)} changes." - event.task_id = msg.task_id - self._submit_web_logs([event]) - - def _mem_feedback_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - try: - if not messages: - return - message = messages[0] - mem_cube = self.mem_cube - - user_id = message.user_id - mem_cube_id = message.mem_cube_id - content = message.content - - try: - feedback_data = json.loads(content) if isinstance(content, str) else content - if not isinstance(feedback_data, dict): - logger.error( - f"Failed to decode feedback_data or it is not a dict: {feedback_data}" - ) - return - except json.JSONDecodeError: - logger.error(f"Invalid JSON content for feedback message: {content}", exc_info=True) - return - - task_id = feedback_data.get("task_id") or message.task_id - feedback_result = self.feedback_server.process_feedback( - user_id=user_id, - user_name=mem_cube_id, - session_id=feedback_data.get("session_id"), - chat_history=feedback_data.get("history", []), - retrieved_memory_ids=feedback_data.get("retrieved_memory_ids", []), - feedback_content=feedback_data.get("feedback_content"), - feedback_time=feedback_data.get("feedback_time"), - task_id=task_id, - info=feedback_data.get("info", None), - ) - - logger.info( - f"Successfully processed feedback for user_id={user_id}, mem_cube_id={mem_cube_id}" - ) - - cloud_env = is_cloud_env() - if cloud_env: - record = feedback_result.get("record") if isinstance(feedback_result, dict) else {} - add_records = record.get("add") if isinstance(record, dict) else [] - update_records = record.get("update") if isinstance(record, dict) else [] - - def _extract_fields(mem_item): - mem_id = ( - getattr(mem_item, "id", None) - if not isinstance(mem_item, dict) - else mem_item.get("id") - ) - mem_memory = ( - getattr(mem_item, "memory", None) - if not isinstance(mem_item, dict) - else mem_item.get("memory") or mem_item.get("text") - ) - if mem_memory is None and isinstance(mem_item, dict): - mem_memory = mem_item.get("text") - original_content = ( - getattr(mem_item, "origin_memory", None) - if not isinstance(mem_item, dict) - else mem_item.get("origin_memory") - or mem_item.get("old_memory") - or mem_item.get("original_content") - ) - source_doc_id = None - if isinstance(mem_item, dict): - source_doc_id = mem_item.get("source_doc_id", None) - - return mem_id, mem_memory, original_content, source_doc_id - - kb_log_content: list[dict] = [] - - for mem_item in add_records or []: - mem_id, mem_memory, _, source_doc_id = _extract_fields(mem_item) - if mem_id and mem_memory: - kb_log_content.append( - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": "Feedback", - "operation": "ADD", - "memory_id": mem_id, - "content": mem_memory, - "original_content": None, - "source_doc_id": source_doc_id, - } - ) - else: - logger.warning( - "Skipping malformed feedback add item. user_id=%s mem_cube_id=%s task_id=%s item=%s", - user_id, - mem_cube_id, - task_id, - mem_item, - stack_info=True, - ) - - for mem_item in update_records or []: - mem_id, mem_memory, original_content, source_doc_id = _extract_fields(mem_item) - if mem_id and mem_memory: - kb_log_content.append( - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": "Feedback", - "operation": "UPDATE", - "memory_id": mem_id, - "content": mem_memory, - "original_content": original_content, - "source_doc_id": source_doc_id, - } - ) - else: - logger.warning( - "Skipping malformed feedback update item. user_id=%s mem_cube_id=%s task_id=%s item=%s", - user_id, - mem_cube_id, - task_id, - mem_item, - stack_info=True, - ) - - logger.info(f"[Feedback Scheduler] kb_log_content: {kb_log_content!s}") - if kb_log_content: - logger.info( - "[DIAGNOSTIC] general_scheduler._mem_feedback_message_consumer: Creating knowledgeBaseUpdate event for feedback. user_id=%s mem_cube_id=%s task_id=%s items=%s", - user_id, - mem_cube_id, - task_id, - len(kb_log_content), - ) - event = self.create_event_log( - label="knowledgeBaseUpdate", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - memcube_log_content=kb_log_content, - metadata=None, - memory_len=len(kb_log_content), - memcube_name=self._map_memcube_name(mem_cube_id), - ) - event.log_content = ( - f"Knowledge Base Memory Update: {len(kb_log_content)} changes." - ) - event.task_id = task_id - self._submit_web_logs([event]) - else: - logger.warning( - "No valid feedback content generated for web log. user_id=%s mem_cube_id=%s task_id=%s", - user_id, - mem_cube_id, - task_id, - stack_info=True, - ) - else: - logger.info( - "Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)", - cloud_env, - ) - - except Exception as e: - logger.error(f"Error processing feedbackMemory message: {e}", exc_info=True) - - def _mem_read_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info( - f"[DIAGNOSTIC] general_scheduler._mem_read_message_consumer called. Received messages: {[msg.model_dump_json(indent=2) for msg in messages]}" - ) - logger.info(f"Messages {messages} assigned to {MEM_READ_TASK_LABEL} handler.") - - def process_message(message: ScheduleMessageItem): - try: - user_id = message.user_id - mem_cube_id = message.mem_cube_id - mem_cube = self.mem_cube - if mem_cube is None: - logger.error( - f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing", - stack_info=True, - ) - return - - content = message.content - user_name = message.user_name - info = message.info or {} - chat_history = message.chat_history - - # Parse the memory IDs from content - mem_ids = json.loads(content) if isinstance(content, str) else content - if not mem_ids: - return - - logger.info( - f"Processing mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}" - ) - - # Get the text memory from the mem_cube - text_mem = mem_cube.text_mem - if not isinstance(text_mem, TreeTextMemory): - logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}") - return - - # Use mem_reader to process the memories - self._process_memories_with_reader( - mem_ids=mem_ids, - user_id=user_id, - mem_cube_id=mem_cube_id, - text_mem=text_mem, - user_name=user_name, - custom_tags=info.get("custom_tags", None), - task_id=message.task_id, - info=info, - chat_history=chat_history, - ) - - logger.info( - f"Successfully processed mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}" - ) - - except Exception as e: - logger.error(f"Error processing mem_read message: {e}", stack_info=True) - - with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: - futures = [executor.submit(process_message, msg) for msg in messages] - for future in concurrent.futures.as_completed(futures): - try: - future.result() - except Exception as e: - logger.error(f"Thread task failed: {e}", stack_info=True) - - def _process_memories_with_reader( - self, - mem_ids: list[str], - user_id: str, - mem_cube_id: str, - text_mem: TreeTextMemory, - user_name: str, - custom_tags: list[str] | None = None, - task_id: str | None = None, - info: dict | None = None, - chat_history: list | None = None, - ) -> None: - logger.info( - f"[DIAGNOSTIC] general_scheduler._process_memories_with_reader called. mem_ids: {mem_ids}, user_id: {user_id}, mem_cube_id: {mem_cube_id}, task_id: {task_id}" - ) - """ - Process memories using mem_reader for enhanced memory processing. - - Args: - mem_ids: List of memory IDs to process - user_id: User ID - mem_cube_id: Memory cube ID - text_mem: Text memory instance - custom_tags: Optional list of custom tags for memory processing - """ - kb_log_content: list[dict] = [] - try: - # Get the mem_reader from the parent MOSCore - if not hasattr(self, "mem_reader") or self.mem_reader is None: - logger.warning( - "mem_reader not available in scheduler, skipping enhanced processing" - ) - return - - # Get the original memory items - memory_items = [] - for mem_id in mem_ids: - try: - memory_item = text_mem.get(mem_id, user_name=user_name) - memory_items.append(memory_item) - except Exception as e: - logger.warning( - f"[_process_memories_with_reader] Failed to get memory {mem_id}: {e}" - ) - continue - - if not memory_items: - logger.warning("No valid memory items found for processing") - return - - # parse working_binding ids from the *original* memory_items (the raw items created in /add) - # these still carry metadata.background with "[working_binding:...]" so we can know - # which WorkingMemory clones should be cleaned up later. - from memos.memories.textual.tree_text_memory.organize.manager import ( - extract_working_binding_ids, - ) - - bindings_to_delete = extract_working_binding_ids(memory_items) - logger.info( - f"Extracted {len(bindings_to_delete)} working_binding ids to cleanup: {list(bindings_to_delete)}" - ) - - # Use mem_reader to process the memories - logger.info(f"Processing {len(memory_items)} memories with mem_reader") - - # Extract memories using mem_reader - try: - processed_memories = self.mem_reader.fine_transfer_simple_mem( - memory_items, - type="chat", - custom_tags=custom_tags, - user_name=user_name, - chat_history=chat_history, - ) - except Exception as e: - logger.warning(f"{e}: Fail to transfer mem: {memory_items}") - processed_memories = [] - - if processed_memories and len(processed_memories) > 0: - # Flatten the results (mem_reader returns list of lists) - flattened_memories = [] - for memory_list in processed_memories: - flattened_memories.extend(memory_list) - - logger.info(f"mem_reader processed {len(flattened_memories)} enhanced memories") - - # Add the enhanced memories back to the memory system - if flattened_memories: - enhanced_mem_ids = text_mem.add(flattened_memories, user_name=user_name) - logger.info( - f"Added {len(enhanced_mem_ids)} enhanced memories: {enhanced_mem_ids}" - ) - - # Mark merged_from memories as archived when provided in memory metadata - if self.mem_reader.graph_db: - for memory in flattened_memories: - merged_from = (memory.metadata.info or {}).get("merged_from") - if merged_from: - old_ids = ( - merged_from - if isinstance(merged_from, (list | tuple | set)) - else [merged_from] - ) - for old_id in old_ids: - try: - self.mem_reader.graph_db.update_node( - str(old_id), {"status": "archived"}, user_name=user_name - ) - logger.info( - f"[Scheduler] Archived merged_from memory: {old_id}" - ) - except Exception as e: - logger.warning( - f"[Scheduler] Failed to archive merged_from memory {old_id}: {e}" - ) - else: - # Check if any memory has merged_from but graph_db is unavailable - has_merged_from = any( - (m.metadata.info or {}).get("merged_from") for m in flattened_memories - ) - if has_merged_from: - logger.warning( - "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." - ) - - # LOGGING BLOCK START - # This block is replicated from _add_message_consumer to ensure consistent logging - cloud_env = is_cloud_env() - if cloud_env: - # New: Knowledge Base Logging (Cloud Service) - kb_log_content = [] - for item in flattened_memories: - metadata = getattr(item, "metadata", None) - file_ids = getattr(metadata, "file_ids", None) if metadata else None - source_doc_id = ( - file_ids[0] if isinstance(file_ids, list) and file_ids else None - ) - kb_log_content.append( - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": info.get("trigger_source", "Messages") - if info - else "Messages", - "operation": "ADD", - "memory_id": item.id, - "content": item.memory, - "original_content": None, - "source_doc_id": source_doc_id, - } - ) - if kb_log_content: - logger.info( - f"[DIAGNOSTIC] general_scheduler._process_memories_with_reader: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: {user_id}, mem_cube_id: {mem_cube_id}, task_id: {task_id}. KB content: {json.dumps(kb_log_content, indent=2)}" - ) - event = self.create_event_log( - label="knowledgeBaseUpdate", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=kb_log_content, - metadata=None, - memory_len=len(kb_log_content), - memcube_name=self._map_memcube_name(mem_cube_id), - ) - event.log_content = ( - f"Knowledge Base Memory Update: {len(kb_log_content)} changes." - ) - event.task_id = task_id - self._submit_web_logs([event]) - else: - # Existing: Playground/Default Logging - add_content_legacy: list[dict] = [] - add_meta_legacy: list[dict] = [] - for item_id, item in zip( - enhanced_mem_ids, flattened_memories, strict=False - ): - key = getattr(item.metadata, "key", None) or transform_name_to_key( - name=item.memory - ) - add_content_legacy.append( - {"content": f"{key}: {item.memory}", "ref_id": item_id} - ) - add_meta_legacy.append( - { - "ref_id": item_id, - "id": item_id, - "key": item.metadata.key, - "memory": item.memory, - "memory_type": item.metadata.memory_type, - "status": item.metadata.status, - "confidence": item.metadata.confidence, - "tags": item.metadata.tags, - "updated_at": getattr(item.metadata, "updated_at", None) - or getattr(item.metadata, "update_at", None), - } - ) - if add_content_legacy: - event = self.create_event_log( - label="addMemory", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=add_content_legacy, - metadata=add_meta_legacy, - memory_len=len(add_content_legacy), - memcube_name=self._map_memcube_name(mem_cube_id), - ) - event.task_id = task_id - self._submit_web_logs([event]) - # LOGGING BLOCK END - else: - logger.info("No enhanced memories generated by mem_reader") - else: - logger.info("mem_reader returned no processed memories") - - # build full delete list: - # - original raw mem_ids (temporary fast memories) - # - any bound working memories referenced by the enhanced memories - delete_ids = list(mem_ids) - if bindings_to_delete: - delete_ids.extend(list(bindings_to_delete)) - # deduplicate - delete_ids = list(dict.fromkeys(delete_ids)) - if delete_ids: - try: - text_mem.delete(delete_ids, user_name=user_name) - logger.info( - f"Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" - ) - except Exception as e: - logger.warning(f"Failed to delete some mem_ids {delete_ids}: {e}") - else: - logger.info("No mem_ids to delete (nothing to cleanup)") - - text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) - logger.info("Remove and Refresh Memories") - logger.debug(f"Finished add {user_id} memory: {mem_ids}") - - except Exception as exc: - logger.error( - f"Error in _process_memories_with_reader: {traceback.format_exc()}", exc_info=True - ) - with contextlib.suppress(Exception): - cloud_env = is_cloud_env() - if cloud_env: - if not kb_log_content: - trigger_source = ( - info.get("trigger_source", "Messages") if info else "Messages" - ) - kb_log_content = [ - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": trigger_source, - "operation": "ADD", - "memory_id": mem_id, - "content": None, - "original_content": None, - "source_doc_id": None, - } - for mem_id in mem_ids - ] - event = self.create_event_log( - label="knowledgeBaseUpdate", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=kb_log_content, - metadata=None, - memory_len=len(kb_log_content), - memcube_name=self._map_memcube_name(mem_cube_id), - ) - event.log_content = f"Knowledge Base Memory Update failed: {exc!s}" - event.task_id = task_id - event.status = "failed" - self._submit_web_logs([event]) - - def _mem_reorganize_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {MEM_ORGANIZE_TASK_LABEL} handler.") - - def process_message(message: ScheduleMessageItem): - try: - user_id = message.user_id - mem_cube_id = message.mem_cube_id - mem_cube = self.mem_cube - if mem_cube is None: - logger.warning( - f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing" - ) - return - content = message.content - user_name = message.user_name - - # Parse the memory IDs from content - mem_ids = json.loads(content) if isinstance(content, str) else content - if not mem_ids: - return - - logger.info( - f"Processing mem_reorganize for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}" - ) - - # Get the text memory from the mem_cube - text_mem = mem_cube.text_mem - if not isinstance(text_mem, TreeTextMemory): - logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}") - return - - # Use mem_reader to process the memories - self._process_memories_with_reorganize( - mem_ids=mem_ids, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - text_mem=text_mem, - user_name=user_name, - ) - - with contextlib.suppress(Exception): - mem_items: list[TextualMemoryItem] = [] - for mid in mem_ids: - with contextlib.suppress(Exception): - mem_items.append(text_mem.get(mid, user_name=user_name)) - if len(mem_items) > 1: - keys: list[str] = [] - memcube_content: list[dict] = [] - meta: list[dict] = [] - merged_target_ids: set[str] = set() - with contextlib.suppress(Exception): - if hasattr(text_mem, "graph_store"): - for mid in mem_ids: - edges = text_mem.graph_store.get_edges( - mid, type="MERGED_TO", direction="OUT" - ) - for edge in edges: - target = ( - edge.get("to") or edge.get("dst") or edge.get("target") - ) - if target: - merged_target_ids.add(target) - for item in mem_items: - key = getattr( - getattr(item, "metadata", {}), "key", None - ) or transform_name_to_key(getattr(item, "memory", "")) - keys.append(key) - memcube_content.append( - {"content": key or "(no key)", "ref_id": item.id, "type": "merged"} - ) - meta.append( - { - "ref_id": item.id, - "id": item.id, - "key": key, - "memory": item.memory, - "memory_type": item.metadata.memory_type, - "status": item.metadata.status, - "confidence": item.metadata.confidence, - "tags": item.metadata.tags, - "updated_at": getattr(item.metadata, "updated_at", None) - or getattr(item.metadata, "update_at", None), - } - ) - combined_key = keys[0] if keys else "" - post_ref_id = None - post_meta = { - "ref_id": None, - "id": None, - "key": None, - "memory": None, - "memory_type": None, - "status": None, - "confidence": None, - "tags": None, - "updated_at": None, - } - if merged_target_ids: - post_ref_id = next(iter(merged_target_ids)) - with contextlib.suppress(Exception): - merged_item = text_mem.get(post_ref_id, user_name=user_name) - combined_key = ( - getattr(getattr(merged_item, "metadata", {}), "key", None) - or combined_key - ) - post_meta = { - "ref_id": post_ref_id, - "id": post_ref_id, - "key": getattr( - getattr(merged_item, "metadata", {}), "key", None - ), - "memory": getattr(merged_item, "memory", None), - "memory_type": getattr( - getattr(merged_item, "metadata", {}), "memory_type", None - ), - "status": getattr( - getattr(merged_item, "metadata", {}), "status", None - ), - "confidence": getattr( - getattr(merged_item, "metadata", {}), "confidence", None - ), - "tags": getattr( - getattr(merged_item, "metadata", {}), "tags", None - ), - "updated_at": getattr( - getattr(merged_item, "metadata", {}), "updated_at", None - ) - or getattr( - getattr(merged_item, "metadata", {}), "update_at", None - ), - } - if not post_ref_id: - import hashlib - - post_ref_id = f"merge-{hashlib.md5(''.join(sorted(mem_ids)).encode()).hexdigest()}" - post_meta["ref_id"] = post_ref_id - post_meta["id"] = post_ref_id - if not post_meta.get("key"): - post_meta["key"] = combined_key - if not keys: - keys = [item.id for item in mem_items] - memcube_content.append( - { - "content": combined_key if combined_key else "(no key)", - "ref_id": post_ref_id, - "type": "postMerge", - } - ) - meta.append(post_meta) - event = self.create_event_log( - label="mergeMemory", - from_memory_type=LONG_TERM_MEMORY_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - memcube_log_content=memcube_content, - metadata=meta, - memory_len=len(keys), - memcube_name=self._map_memcube_name(mem_cube_id), - ) - self._submit_web_logs([event]) - - logger.info( - f"Successfully processed mem_reorganize for user_id={user_id}, mem_cube_id={mem_cube_id}" - ) - - except Exception as e: - logger.error(f"Error processing mem_reorganize message: {e}", exc_info=True) - - with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: - futures = [executor.submit(process_message, msg) for msg in messages] - for future in concurrent.futures.as_completed(futures): - try: - future.result() - except Exception as e: - logger.error(f"Thread task failed: {e}", exc_info=True) - - def _process_memories_with_reorganize( - self, - mem_ids: list[str], - user_id: str, - mem_cube_id: str, - mem_cube: GeneralMemCube, - text_mem: TreeTextMemory, - user_name: str, - ) -> None: - """ - Process memories using mem_reorganize for enhanced memory processing. - - Args: - mem_ids: List of memory IDs to process - user_id: User ID - mem_cube_id: Memory cube ID - mem_cube: Memory cube instance - text_mem: Text memory instance - """ - try: - # Get the mem_reader from the parent MOSCore - if not hasattr(self, "mem_reader") or self.mem_reader is None: - logger.warning( - "mem_reader not available in scheduler, skipping enhanced processing" - ) - return - - # Get the original memory items - memory_items = [] - for mem_id in mem_ids: - try: - memory_item = text_mem.get(mem_id, user_name=user_name) - memory_items.append(memory_item) - except Exception as e: - logger.warning(f"Failed to get memory {mem_id}: {e}|{traceback.format_exc()}") - continue - - if not memory_items: - logger.warning("No valid memory items found for processing") - return - - # Use mem_reader to process the memories - logger.info(f"Processing {len(memory_items)} memories with mem_reader") - text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) - logger.info("Remove and Refresh Memories") - logger.debug(f"Finished add {user_id} memory: {mem_ids}") - - except Exception: - logger.error( - f"Error in _process_memories_with_reorganize: {traceback.format_exc()}", - exc_info=True, - ) - - def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {PREF_ADD_TASK_LABEL} handler.") - - def process_message(message: ScheduleMessageItem): - try: - mem_cube = self.mem_cube - if mem_cube is None: - logger.warning( - f"mem_cube is None for user_id={message.user_id}, mem_cube_id={message.mem_cube_id}, skipping processing" - ) - return - - user_id = message.user_id - session_id = message.session_id - mem_cube_id = message.mem_cube_id - content = message.content - messages_list = json.loads(content) - info = message.info or {} - - logger.info(f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}") - - # Get the preference memory from the mem_cube - pref_mem = mem_cube.pref_mem - if pref_mem is None: - logger.warning( - f"Preference memory not initialized for mem_cube_id={mem_cube_id}, " - f"skipping pref_add processing" - ) - return - if not isinstance(pref_mem, PreferenceTextMemory): - logger.error( - f"Expected PreferenceTextMemory but got {type(pref_mem).__name__} " - f"for mem_cube_id={mem_cube_id}" - ) - return - - # Use pref_mem.get_memory to process the memories - pref_memories = pref_mem.get_memory( - messages_list, - type="chat", - info={ - **info, - "user_id": user_id, - "session_id": session_id, - "mem_cube_id": mem_cube_id, - }, - ) - # Add pref_mem to vector db - pref_ids = pref_mem.add(pref_memories) - - logger.info( - f"Successfully processed and add preferences for user_id={user_id}, mem_cube_id={mem_cube_id}, pref_ids={pref_ids}" - ) - - except Exception as e: - logger.error(f"Error processing pref_add message: {e}", exc_info=True) - - with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: - futures = [executor.submit(process_message, msg) for msg in messages] - for future in concurrent.futures.as_completed(futures): - try: - future.result() - except Exception as e: - logger.error(f"Thread task failed: {e}", exc_info=True) - - def process_session_turn( - self, - queries: str | list[str], - user_id: UserID | str, - mem_cube_id: MemCubeID | str, - mem_cube: GeneralMemCube, - top_k: int = 10, - ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]] | None: - """ - Process a dialog turn: - - If q_list reaches window size, trigger retrieval; - - Immediately switch to the new memory if retrieval is triggered. - """ - - text_mem_base = mem_cube.text_mem - if not isinstance(text_mem_base, TreeTextMemory): - if isinstance(text_mem_base, NaiveTextMemory): - logger.debug( - f"NaiveTextMemory used for mem_cube_id={mem_cube_id}, processing session turn with simple search." - ) - # Treat NaiveTextMemory similar to TreeTextMemory but with simpler logic - # We will perform retrieval to get "working memory" candidates for activation memory - # But we won't have a distinct "current working memory" - cur_working_memory = [] - else: - logger.warning( - f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} " - f"for mem_cube_id={mem_cube_id}, user_id={user_id}. " - f"text_mem_base value: {text_mem_base}" - ) - return [], [] - else: - cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory( - user_name=mem_cube_id - ) - cur_working_memory = cur_working_memory[:top_k] - - logger.info( - f"[process_session_turn] Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}" - ) - - text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] - intent_result = self.monitor.detect_intent( - q_list=queries, text_working_memory=text_working_memory - ) - - time_trigger_flag = False - if self.monitor.timed_trigger( - last_time=self.monitor.last_query_consume_time, - interval_seconds=self.monitor.query_trigger_interval, - ): - time_trigger_flag = True - - if (not intent_result["trigger_retrieval"]) and (not time_trigger_flag): - logger.info( - f"[process_session_turn] Query schedule not triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. Intent_result: {intent_result}" - ) - return - elif (not intent_result["trigger_retrieval"]) and time_trigger_flag: - logger.info( - f"[process_session_turn] Query schedule forced to trigger due to time ticker for user_id={user_id}, mem_cube_id={mem_cube_id}" - ) - intent_result["trigger_retrieval"] = True - intent_result["missing_evidences"] = queries - else: - logger.info( - f"[process_session_turn] Query schedule triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. " - f"Missing evidences: {intent_result['missing_evidences']}" - ) - - missing_evidences = intent_result["missing_evidences"] - num_evidence = len(missing_evidences) - k_per_evidence = max(1, top_k // max(1, num_evidence)) - new_candidates = [] - for item in missing_evidences: - logger.info( - f"[process_session_turn] Searching for missing evidence: '{item}' with top_k={k_per_evidence} for user_id={user_id}" - ) - - search_args = {} - if isinstance(text_mem_base, NaiveTextMemory): - # NaiveTextMemory doesn't support complex search args usually, but let's see - # self.retriever.search calls mem_cube.text_mem.search - # NaiveTextMemory.search takes query and top_k - # SchedulerRetriever.search handles method dispatch - # For NaiveTextMemory, we might need to bypass retriever or extend it - # But let's try calling naive memory directly if retriever fails or doesn't support it - try: - results = text_mem_base.search(query=item, top_k=k_per_evidence) - except Exception as e: - logger.warning(f"NaiveTextMemory search failed: {e}") - results = [] - else: - results: list[TextualMemoryItem] = self.retriever.search( - query=item, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - top_k=k_per_evidence, - method=self.search_method, - search_args=search_args, - ) - - logger.info( - f"[process_session_turn] Search results for missing evidence '{item}': " - + ("\n- " + "\n- ".join([f"{one.id}: {one.memory}" for one in results])) - ) - new_candidates.extend(results) - return cur_working_memory, new_candidates + services = SchedulerHandlerServices( + validate_messages=self.validate_schedule_messages, + submit_messages=self.submit_messages, + create_event_log=self.create_event_log, + submit_web_logs=self._submit_web_logs, + map_memcube_name=self._map_memcube_name, + update_activation_memory_periodically=self.update_activation_memory_periodically, + replace_working_memory=self.replace_working_memory, + transform_working_memories_to_monitors=self.transform_working_memories_to_monitors, + log_working_memory_replacement=self.log_working_memory_replacement, + ) + ctx = SchedulerHandlerContext( + get_mem_cube=lambda: self.mem_cube, + get_monitor=lambda: self.monitor, + get_retriever=lambda: self.retriever, + get_mem_reader=lambda: self.mem_reader, + get_feedback_server=lambda: self.feedback_server, + get_search_method=lambda: self.search_method, + get_top_k=lambda: self.top_k, + get_enable_activation_memory=lambda: self.enable_activation_memory, + get_query_key_words_limit=lambda: self.query_key_words_limit, + services=services, + ) + + self._handler_registry = SchedulerHandlerRegistry(ctx) + self.register_handlers(self._handler_registry.build_dispatch_map()) diff --git a/src/memos/mem_scheduler/handlers/__init__.py b/src/memos/mem_scheduler/handlers/__init__.py new file mode 100644 index 000000000..283740d1c --- /dev/null +++ b/src/memos/mem_scheduler/handlers/__init__.py @@ -0,0 +1,8 @@ +from .context import SchedulerHandlerContext, SchedulerHandlerServices +from .registry import SchedulerHandlerRegistry + +__all__ = [ + "SchedulerHandlerContext", + "SchedulerHandlerRegistry", + "SchedulerHandlerServices", +] diff --git a/src/memos/mem_scheduler/handlers/add_handler.py b/src/memos/mem_scheduler/handlers/add_handler.py new file mode 100644 index 000000000..900550952 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/add_handler.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +import json + +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + LONG_TERM_MEMORY_TYPE, + USER_INPUT_TYPE, +) +from memos.mem_scheduler.utils.filter_utils import transform_name_to_key +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube, is_cloud_env +from memos.memories.textual.item import TextualMemoryItem + + +logger = get_logger(__name__) + + +class AddMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {ADD_TASK_LABEL} handler.") + grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) + + self.ctx.services.validate_messages(messages=messages, label=ADD_TASK_LABEL) + try: + for user_id in grouped_messages: + for mem_cube_id in grouped_messages[user_id]: + batch = grouped_messages[user_id][mem_cube_id] + if not batch: + continue + + for msg in batch: + prepared_add_items, prepared_update_items_with_original = ( + self.log_add_messages(msg=msg) + ) + logger.info( + "prepared_add_items: %s;\n prepared_update_items_with_original: %s", + prepared_add_items, + prepared_update_items_with_original, + ) + cloud_env = is_cloud_env() + + if cloud_env: + self.send_add_log_messages_to_cloud_env( + msg, prepared_add_items, prepared_update_items_with_original + ) + else: + self.send_add_log_messages_to_local_env( + msg, prepared_add_items, prepared_update_items_with_original + ) + + except Exception as e: + logger.error(f"Error: {e}", exc_info=True) + + def log_add_messages(self, msg: ScheduleMessageItem): + try: + userinput_memory_ids = json.loads(msg.content) + except Exception as e: + logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True) + userinput_memory_ids = [] + + prepared_add_items = [] + prepared_update_items_with_original = [] + missing_ids: list[str] = [] + + mem_cube = self.ctx.get_mem_cube() + + for memory_id in userinput_memory_ids: + try: + mem_item: TextualMemoryItem | None = None + mem_item = mem_cube.text_mem.get(memory_id=memory_id, user_name=msg.mem_cube_id) + if mem_item is None: + raise ValueError(f"Memory {memory_id} not found after retries") + key = getattr(mem_item.metadata, "key", None) or transform_name_to_key( + name=mem_item.memory + ) + exists = False + original_content = None + original_item_id = None + + if key and hasattr(mem_cube.text_mem, "graph_store"): + candidates = mem_cube.text_mem.graph_store.get_by_metadata( + [ + {"field": "key", "op": "=", "value": key}, + { + "field": "memory_type", + "op": "=", + "value": mem_item.metadata.memory_type, + }, + ] + ) + if candidates: + exists = True + original_item_id = candidates[0] + original_mem_item = mem_cube.text_mem.get( + memory_id=original_item_id, user_name=msg.mem_cube_id + ) + original_content = original_mem_item.memory + + if exists: + prepared_update_items_with_original.append( + { + "new_item": mem_item, + "original_content": original_content, + "original_item_id": original_item_id, + } + ) + else: + prepared_add_items.append(mem_item) + + except Exception: + missing_ids.append(memory_id) + logger.debug( + "This MemoryItem %s has already been deleted or an error occurred during preparation.", + memory_id, + ) + + if missing_ids: + content_preview = ( + msg.content[:200] + "..." + if isinstance(msg.content, str) and len(msg.content) > 200 + else msg.content + ) + logger.warning( + "Missing TextualMemoryItem(s) during add log preparation. " + "memory_ids=%s user_id=%s mem_cube_id=%s task_id=%s item_id=%s redis_msg_id=%s label=%s stream_key=%s content_preview=%s", + missing_ids, + msg.user_id, + msg.mem_cube_id, + msg.task_id, + msg.item_id, + getattr(msg, "redis_message_id", ""), + msg.label, + getattr(msg, "stream_key", ""), + content_preview, + ) + + if not prepared_add_items and not prepared_update_items_with_original: + logger.warning( + "No add/update items prepared; skipping addMemory/knowledgeBaseUpdate logs. " + "user_id=%s mem_cube_id=%s task_id=%s item_id=%s redis_msg_id=%s label=%s stream_key=%s missing_ids=%s", + msg.user_id, + msg.mem_cube_id, + msg.task_id, + msg.item_id, + getattr(msg, "redis_message_id", ""), + msg.label, + getattr(msg, "stream_key", ""), + missing_ids, + ) + return prepared_add_items, prepared_update_items_with_original + + def send_add_log_messages_to_local_env( + self, + msg: ScheduleMessageItem, + prepared_add_items, + prepared_update_items_with_original, + ) -> None: + add_content_legacy: list[dict] = [] + add_meta_legacy: list[dict] = [] + update_content_legacy: list[dict] = [] + update_meta_legacy: list[dict] = [] + + for item in prepared_add_items: + key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory) + add_content_legacy.append({"content": f"{key}: {item.memory}", "ref_id": item.id}) + add_meta_legacy.append( + { + "ref_id": item.id, + "id": item.id, + "key": item.metadata.key, + "memory": item.memory, + "memory_type": item.metadata.memory_type, + "status": item.metadata.status, + "confidence": item.metadata.confidence, + "tags": item.metadata.tags, + "updated_at": getattr(item.metadata, "updated_at", None) + or getattr(item.metadata, "update_at", None), + } + ) + + for item_data in prepared_update_items_with_original: + item = item_data["new_item"] + key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory) + update_content_legacy.append({"content": f"{key}: {item.memory}", "ref_id": item.id}) + update_meta_legacy.append( + { + "ref_id": item.id, + "id": item.id, + "key": item.metadata.key, + "memory": item.memory, + "memory_type": item.metadata.memory_type, + "status": item.metadata.status, + "confidence": item.metadata.confidence, + "tags": item.metadata.tags, + "updated_at": getattr(item.metadata, "updated_at", None) + or getattr(item.metadata, "update_at", None), + } + ) + + events = [] + if add_content_legacy: + event = self.ctx.services.create_event_log( + label="addMemory", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=add_content_legacy, + metadata=add_meta_legacy, + memory_len=len(add_content_legacy), + memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + ) + event.task_id = msg.task_id + events.append(event) + if update_content_legacy: + event = self.ctx.services.create_event_log( + label="updateMemory", + from_memory_type=LONG_TERM_MEMORY_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=update_content_legacy, + metadata=update_meta_legacy, + memory_len=len(update_content_legacy), + memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + ) + event.task_id = msg.task_id + events.append(event) + logger.info("send_add_log_messages_to_local_env: %s", len(events)) + if events: + self.ctx.services.submit_web_logs(events, additional_log_info="send_add_log_messages_to_cloud_env") + + def send_add_log_messages_to_cloud_env( + self, + msg: ScheduleMessageItem, + prepared_add_items, + prepared_update_items_with_original, + ) -> None: + kb_log_content: list[dict] = [] + info = msg.info or {} + + for item in prepared_add_items: + metadata = getattr(item, "metadata", None) + file_ids = getattr(metadata, "file_ids", None) if metadata else None + source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None + kb_log_content.append( + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": info.get("trigger_source", "Messages"), + "operation": "ADD", + "memory_id": item.id, + "content": item.memory, + "original_content": None, + "source_doc_id": source_doc_id, + } + ) + + for item_data in prepared_update_items_with_original: + item = item_data["new_item"] + metadata = getattr(item, "metadata", None) + file_ids = getattr(metadata, "file_ids", None) if metadata else None + source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None + kb_log_content.append( + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": info.get("trigger_source", "Messages"), + "operation": "UPDATE", + "memory_id": item.id, + "content": item.memory, + "original_content": item_data.get("original_content"), + "source_doc_id": source_doc_id, + } + ) + + if kb_log_content: + logger.info( + "[DIAGNOSTIC] add_handler.send_add_log_messages_to_cloud_env: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: %s, mem_cube_id: %s, task_id: %s. KB content: %s", + msg.user_id, + msg.mem_cube_id, + msg.task_id, + json.dumps(kb_log_content, indent=2), + ) + event = self.ctx.services.create_event_log( + label="knowledgeBaseUpdate", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=kb_log_content, + metadata=None, + memory_len=len(kb_log_content), + memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + ) + event.log_content = f"Knowledge Base Memory Update: {len(kb_log_content)} changes." + event.task_id = msg.task_id + self.ctx.services.submit_web_logs([event]) diff --git a/src/memos/mem_scheduler/handlers/answer_handler.py b/src/memos/mem_scheduler/handlers/answer_handler.py new file mode 100644 index 000000000..aa6bf708e --- /dev/null +++ b/src/memos/mem_scheduler/handlers/answer_handler.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + ANSWER_TASK_LABEL, + NOT_APPLICABLE_TYPE, + USER_INPUT_TYPE, +) +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube + + +logger = get_logger(__name__) + + +class AnswerMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {ANSWER_TASK_LABEL} handler.") + grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) + + self.ctx.services.validate_messages(messages=messages, label=ANSWER_TASK_LABEL) + + for user_id in grouped_messages: + for mem_cube_id in grouped_messages[user_id]: + batch = grouped_messages[user_id][mem_cube_id] + if not batch: + continue + try: + for msg in batch: + event = self.ctx.services.create_event_log( + label="addMessage", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=NOT_APPLICABLE_TYPE, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=[ + { + "content": f"[Assistant] {msg.content}", + "ref_id": msg.item_id, + "role": "assistant", + } + ], + metadata=[], + memory_len=1, + memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + ) + event.task_id = msg.task_id + self.ctx.services.submit_web_logs([event]) + except Exception: + logger.exception("Failed to record addMessage log for answer") diff --git a/src/memos/mem_scheduler/handlers/base.py b/src/memos/mem_scheduler/handlers/base.py new file mode 100644 index 000000000..f0d8246a3 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/base.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from memos.mem_scheduler.handlers.context import SchedulerHandlerContext + + +class BaseSchedulerHandler: + def __init__(self, ctx: SchedulerHandlerContext) -> None: + self.ctx = ctx diff --git a/src/memos/mem_scheduler/handlers/context.py b/src/memos/mem_scheduler/handlers/context.py new file mode 100644 index 000000000..848739fa3 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/context.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from collections.abc import Callable + +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem +from memos.memories.textual.item import TextualMemoryItem + + +@dataclass(frozen=True) +class SchedulerHandlerServices: + validate_messages: Callable[[list[ScheduleMessageItem], str], None] + submit_messages: Callable[[list[ScheduleMessageItem]], None] + create_event_log: Callable[..., Any] + submit_web_logs: Callable[..., None] + map_memcube_name: Callable[[str], str] + update_activation_memory_periodically: Callable[..., None] + replace_working_memory: Callable[ + [str, str, Any, list[TextualMemoryItem], list[TextualMemoryItem]], + list[TextualMemoryItem] | None, + ] + transform_working_memories_to_monitors: Callable[..., list[MemoryMonitorItem]] + log_working_memory_replacement: Callable[..., None] + + +@dataclass(frozen=True) +class SchedulerHandlerContext: + get_mem_cube: Callable[[], Any] + get_monitor: Callable[[], Any] + get_retriever: Callable[[], Any] + get_mem_reader: Callable[[], Any] + get_feedback_server: Callable[[], Any] + get_search_method: Callable[[], str] + get_top_k: Callable[[], int] + get_enable_activation_memory: Callable[[], bool] + get_query_key_words_limit: Callable[[], int] + services: SchedulerHandlerServices diff --git a/src/memos/mem_scheduler/handlers/feedback_handler.py b/src/memos/mem_scheduler/handlers/feedback_handler.py new file mode 100644 index 000000000..55dbd6add --- /dev/null +++ b/src/memos/mem_scheduler/handlers/feedback_handler.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import json + +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import LONG_TERM_MEMORY_TYPE, USER_INPUT_TYPE +from memos.mem_scheduler.utils.misc_utils import is_cloud_env + + +logger = get_logger(__name__) + + +class FeedbackMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + try: + if not messages: + return + message = messages[0] + mem_cube = self.ctx.get_mem_cube() + + user_id = message.user_id + mem_cube_id = message.mem_cube_id + content = message.content + + try: + feedback_data = json.loads(content) if isinstance(content, str) else content + if not isinstance(feedback_data, dict): + logger.error("Failed to decode feedback_data or it is not a dict: %s", feedback_data) + return + except json.JSONDecodeError: + logger.error("Invalid JSON content for feedback message: %s", content, exc_info=True) + return + + task_id = feedback_data.get("task_id") or message.task_id + feedback_result = self.ctx.get_feedback_server().process_feedback( + user_id=user_id, + user_name=mem_cube_id, + session_id=feedback_data.get("session_id"), + chat_history=feedback_data.get("history", []), + retrieved_memory_ids=feedback_data.get("retrieved_memory_ids", []), + feedback_content=feedback_data.get("feedback_content"), + feedback_time=feedback_data.get("feedback_time"), + task_id=task_id, + info=feedback_data.get("info", None), + ) + + logger.info( + "Successfully processed feedback for user_id=%s, mem_cube_id=%s", + user_id, + mem_cube_id, + ) + + cloud_env = is_cloud_env() + if cloud_env: + record = feedback_result.get("record") if isinstance(feedback_result, dict) else {} + add_records = record.get("add") if isinstance(record, dict) else [] + update_records = record.get("update") if isinstance(record, dict) else [] + + def _extract_fields(mem_item): + mem_id = ( + getattr(mem_item, "id", None) + if not isinstance(mem_item, dict) + else mem_item.get("id") + ) + mem_memory = ( + getattr(mem_item, "memory", None) + if not isinstance(mem_item, dict) + else mem_item.get("memory") or mem_item.get("text") + ) + if mem_memory is None and isinstance(mem_item, dict): + mem_memory = mem_item.get("text") + original_content = ( + getattr(mem_item, "origin_memory", None) + if not isinstance(mem_item, dict) + else mem_item.get("origin_memory") + or mem_item.get("old_memory") + or mem_item.get("original_content") + ) + source_doc_id = None + if isinstance(mem_item, dict): + source_doc_id = mem_item.get("source_doc_id", None) + + return mem_id, mem_memory, original_content, source_doc_id + + kb_log_content: list[dict] = [] + + for mem_item in add_records or []: + mem_id, mem_memory, _, source_doc_id = _extract_fields(mem_item) + if mem_id and mem_memory: + kb_log_content.append( + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": "Feedback", + "operation": "ADD", + "memory_id": mem_id, + "content": mem_memory, + "original_content": None, + "source_doc_id": source_doc_id, + } + ) + else: + logger.warning( + "Skipping malformed feedback add item. user_id=%s mem_cube_id=%s task_id=%s item=%s", + user_id, + mem_cube_id, + task_id, + mem_item, + stack_info=True, + ) + + for mem_item in update_records or []: + mem_id, mem_memory, original_content, source_doc_id = _extract_fields(mem_item) + if mem_id and mem_memory: + kb_log_content.append( + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": "Feedback", + "operation": "UPDATE", + "memory_id": mem_id, + "content": mem_memory, + "original_content": original_content, + "source_doc_id": source_doc_id, + } + ) + else: + logger.warning( + "Skipping malformed feedback update item. user_id=%s mem_cube_id=%s task_id=%s item=%s", + user_id, + mem_cube_id, + task_id, + mem_item, + stack_info=True, + ) + + logger.info("[Feedback Scheduler] kb_log_content: %s", kb_log_content) + if kb_log_content: + logger.info( + "[DIAGNOSTIC] feedback_handler: Creating knowledgeBaseUpdate event for feedback. user_id=%s mem_cube_id=%s task_id=%s items=%s", + user_id, + mem_cube_id, + task_id, + len(kb_log_content), + ) + event = self.ctx.services.create_event_log( + label="knowledgeBaseUpdate", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + memcube_log_content=kb_log_content, + metadata=None, + memory_len=len(kb_log_content), + memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + ) + event.log_content = ( + f"Knowledge Base Memory Update: {len(kb_log_content)} changes." + ) + event.task_id = task_id + self.ctx.services.submit_web_logs([event]) + else: + logger.warning( + "No valid feedback content generated for web log. user_id=%s mem_cube_id=%s task_id=%s", + user_id, + mem_cube_id, + task_id, + stack_info=True, + ) + else: + logger.info( + "Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)", + cloud_env, + ) + + except Exception as e: + logger.error("Error processing feedbackMemory message: %s", e, exc_info=True) diff --git a/src/memos/mem_scheduler/handlers/mem_read_handler.py b/src/memos/mem_scheduler/handlers/mem_read_handler.py new file mode 100644 index 000000000..cb1c7631c --- /dev/null +++ b/src/memos/mem_scheduler/handlers/mem_read_handler.py @@ -0,0 +1,353 @@ +from __future__ import annotations + +import concurrent.futures +import contextlib +import json +import traceback + +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + LONG_TERM_MEMORY_TYPE, + MEM_READ_TASK_LABEL, + USER_INPUT_TYPE, +) +from memos.mem_scheduler.utils.filter_utils import transform_name_to_key +from memos.mem_scheduler.utils.misc_utils import is_cloud_env +from memos.memories.textual.tree import TreeTextMemory + + +logger = get_logger(__name__) + + +class MemReadMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info( + "[DIAGNOSTIC] mem_read_handler called. Received messages: %s", + [msg.model_dump_json(indent=2) for msg in messages], + ) + logger.info(f"Messages {messages} assigned to {MEM_READ_TASK_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + user_id = message.user_id + mem_cube_id = message.mem_cube_id + mem_cube = self.ctx.get_mem_cube() + if mem_cube is None: + logger.error( + "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing", + user_id, + mem_cube_id, + stack_info=True, + ) + return + + content = message.content + user_name = message.user_name + info = message.info or {} + chat_history = message.chat_history + + mem_ids = json.loads(content) if isinstance(content, str) else content + if not mem_ids: + return + + logger.info( + "Processing mem_read for user_id=%s, mem_cube_id=%s, mem_ids=%s", + user_id, + mem_cube_id, + mem_ids, + ) + + text_mem = mem_cube.text_mem + if not isinstance(text_mem, TreeTextMemory): + logger.error("Expected TreeTextMemory but got %s", type(text_mem).__name__) + return + + self._process_memories_with_reader( + mem_ids=mem_ids, + user_id=user_id, + mem_cube_id=mem_cube_id, + text_mem=text_mem, + user_name=user_name, + custom_tags=info.get("custom_tags", None), + task_id=message.task_id, + info=info, + chat_history=chat_history, + ) + + logger.info( + "Successfully processed mem_read for user_id=%s, mem_cube_id=%s", + user_id, + mem_cube_id, + ) + + except Exception as e: + logger.error("Error processing mem_read message: %s", e, stack_info=True) + + with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error("Thread task failed: %s", e, stack_info=True) + + def _process_memories_with_reader( + self, + mem_ids: list[str], + user_id: str, + mem_cube_id: str, + text_mem: TreeTextMemory, + user_name: str, + custom_tags: list[str] | None = None, + task_id: str | None = None, + info: dict | None = None, + chat_history: list | None = None, + ) -> None: + logger.info( + "[DIAGNOSTIC] mem_read_handler._process_memories_with_reader called. mem_ids: %s, user_id: %s, mem_cube_id: %s, task_id: %s", + mem_ids, + user_id, + mem_cube_id, + task_id, + ) + kb_log_content: list[dict] = [] + try: + mem_reader = self.ctx.get_mem_reader() + if mem_reader is None: + logger.warning("mem_reader not available in scheduler, skipping enhanced processing") + return + + memory_items = [] + for mem_id in mem_ids: + try: + memory_item = text_mem.get(mem_id, user_name=user_name) + memory_items.append(memory_item) + except Exception as e: + logger.warning("[_process_memories_with_reader] Failed to get memory %s: %s", mem_id, e) + continue + + if not memory_items: + logger.warning("No valid memory items found for processing") + return + + from memos.memories.textual.tree_text_memory.organize.manager import ( + extract_working_binding_ids, + ) + + bindings_to_delete = extract_working_binding_ids(memory_items) + logger.info( + "Extracted %s working_binding ids to cleanup: %s", + len(bindings_to_delete), + list(bindings_to_delete), + ) + + logger.info("Processing %s memories with mem_reader", len(memory_items)) + + try: + processed_memories = mem_reader.fine_transfer_simple_mem( + memory_items, + type="chat", + custom_tags=custom_tags, + user_name=user_name, + chat_history=chat_history, + ) + except Exception as e: + logger.warning("%s: Fail to transfer mem: %s", e, memory_items) + processed_memories = [] + + if processed_memories and len(processed_memories) > 0: + flattened_memories = [] + for memory_list in processed_memories: + flattened_memories.extend(memory_list) + + logger.info("mem_reader processed %s enhanced memories", len(flattened_memories)) + + if flattened_memories: + enhanced_mem_ids = text_mem.add(flattened_memories, user_name=user_name) + logger.info( + "Added %s enhanced memories: %s", + len(enhanced_mem_ids), + enhanced_mem_ids, + ) + + if mem_reader.graph_db: + for memory in flattened_memories: + merged_from = (memory.metadata.info or {}).get("merged_from") + if merged_from: + old_ids = ( + merged_from + if isinstance(merged_from, (list | tuple | set)) + else [merged_from] + ) + for old_id in old_ids: + try: + mem_reader.graph_db.update_node( + str(old_id), {"status": "archived"}, user_name=user_name + ) + logger.info( + "[Scheduler] Archived merged_from memory: %s", + old_id, + ) + except Exception as e: + logger.warning( + "[Scheduler] Failed to archive merged_from memory %s: %s", + old_id, + e, + ) + else: + has_merged_from = any( + (m.metadata.info or {}).get("merged_from") for m in flattened_memories + ) + if has_merged_from: + logger.warning( + "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + ) + + cloud_env = is_cloud_env() + if cloud_env: + kb_log_content = [] + for item in flattened_memories: + metadata = getattr(item, "metadata", None) + file_ids = getattr(metadata, "file_ids", None) if metadata else None + source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None + kb_log_content.append( + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": info.get("trigger_source", "Messages") + if info + else "Messages", + "operation": "ADD", + "memory_id": item.id, + "content": item.memory, + "original_content": None, + "source_doc_id": source_doc_id, + } + ) + if kb_log_content: + logger.info( + "[DIAGNOSTIC] mem_read_handler: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: %s, mem_cube_id: %s, task_id: %s. KB content: %s", + user_id, + mem_cube_id, + task_id, + json.dumps(kb_log_content, indent=2), + ) + event = self.ctx.services.create_event_log( + label="knowledgeBaseUpdate", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=kb_log_content, + metadata=None, + memory_len=len(kb_log_content), + memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + ) + event.log_content = ( + f"Knowledge Base Memory Update: {len(kb_log_content)} changes." + ) + event.task_id = task_id + self.ctx.services.submit_web_logs([event]) + else: + add_content_legacy: list[dict] = [] + add_meta_legacy: list[dict] = [] + for item_id, item in zip(enhanced_mem_ids, flattened_memories, strict=False): + key = getattr(item.metadata, "key", None) or transform_name_to_key( + name=item.memory + ) + add_content_legacy.append( + {"content": f"{key}: {item.memory}", "ref_id": item_id} + ) + add_meta_legacy.append( + { + "ref_id": item_id, + "id": item_id, + "key": item.metadata.key, + "memory": item.memory, + "memory_type": item.metadata.memory_type, + "status": item.metadata.status, + "confidence": item.metadata.confidence, + "tags": item.metadata.tags, + "updated_at": getattr(item.metadata, "updated_at", None) + or getattr(item.metadata, "update_at", None), + } + ) + if add_content_legacy: + event = self.ctx.services.create_event_log( + label="addMemory", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=add_content_legacy, + metadata=add_meta_legacy, + memory_len=len(add_content_legacy), + memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + ) + event.task_id = task_id + self.ctx.services.submit_web_logs([event]) + else: + logger.info("No enhanced memories generated by mem_reader") + else: + logger.info("mem_reader returned no processed memories") + + delete_ids = list(mem_ids) + if bindings_to_delete: + delete_ids.extend(list(bindings_to_delete)) + delete_ids = list(dict.fromkeys(delete_ids)) + if delete_ids: + try: + text_mem.delete(delete_ids, user_name=user_name) + logger.info("Delete raw/working mem_ids: %s for user_name: %s", delete_ids, user_name) + except Exception as e: + logger.warning("Failed to delete some mem_ids %s: %s", delete_ids, e) + else: + logger.info("No mem_ids to delete (nothing to cleanup)") + + text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) + logger.info("Remove and Refresh Memories") + logger.debug("Finished add %s memory: %s", user_id, mem_ids) + + except Exception as exc: + logger.error( + "Error in _process_memories_with_reader: %s", + traceback.format_exc(), + exc_info=True, + ) + with contextlib.suppress(Exception): + cloud_env = is_cloud_env() + if cloud_env: + if not kb_log_content: + trigger_source = info.get("trigger_source", "Messages") if info else "Messages" + kb_log_content = [ + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": trigger_source, + "operation": "ADD", + "memory_id": mem_id, + "content": None, + "original_content": None, + "source_doc_id": None, + } + for mem_id in mem_ids + ] + event = self.ctx.services.create_event_log( + label="knowledgeBaseUpdate", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=kb_log_content, + metadata=None, + memory_len=len(kb_log_content), + memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + ) + event.log_content = f"Knowledge Base Memory Update failed: {exc!s}" + event.task_id = task_id + event.status = "failed" + self.ctx.services.submit_web_logs([event]) diff --git a/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py b/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py new file mode 100644 index 000000000..ce320fc8d --- /dev/null +++ b/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +import concurrent.futures +import contextlib +import json +import traceback + +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import LONG_TERM_MEMORY_TYPE, MEM_ORGANIZE_TASK_LABEL +from memos.mem_scheduler.utils.filter_utils import transform_name_to_key +from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree import TreeTextMemory + + +logger = get_logger(__name__) + + +class MemReorganizeMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {MEM_ORGANIZE_TASK_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + user_id = message.user_id + mem_cube_id = message.mem_cube_id + mem_cube = self.ctx.get_mem_cube() + if mem_cube is None: + logger.warning( + "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing", + user_id, + mem_cube_id, + ) + return + content = message.content + user_name = message.user_name + + mem_ids = json.loads(content) if isinstance(content, str) else content + if not mem_ids: + return + + logger.info( + "Processing mem_reorganize for user_id=%s, mem_cube_id=%s, mem_ids=%s", + user_id, + mem_cube_id, + mem_ids, + ) + + text_mem = mem_cube.text_mem + if not isinstance(text_mem, TreeTextMemory): + logger.error("Expected TreeTextMemory but got %s", type(text_mem).__name__) + return + + self._process_memories_with_reorganize( + mem_ids=mem_ids, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + text_mem=text_mem, + user_name=user_name, + ) + + with contextlib.suppress(Exception): + mem_items: list[TextualMemoryItem] = [] + for mid in mem_ids: + with contextlib.suppress(Exception): + mem_items.append(text_mem.get(mid, user_name=user_name)) + if len(mem_items) > 1: + keys: list[str] = [] + memcube_content: list[dict] = [] + meta: list[dict] = [] + merged_target_ids: set[str] = set() + with contextlib.suppress(Exception): + if hasattr(text_mem, "graph_store"): + for mid in mem_ids: + edges = text_mem.graph_store.get_edges( + mid, type="MERGED_TO", direction="OUT" + ) + for edge in edges: + target = ( + edge.get("to") + or edge.get("dst") + or edge.get("target") + ) + if target: + merged_target_ids.add(target) + for item in mem_items: + key = getattr(getattr(item, "metadata", {}), "key", None) or transform_name_to_key( + getattr(item, "memory", "") + ) + keys.append(key) + memcube_content.append( + {"content": key or "(no key)", "ref_id": item.id, "type": "merged"} + ) + meta.append( + { + "ref_id": item.id, + "id": item.id, + "key": key, + "memory": item.memory, + "memory_type": item.metadata.memory_type, + "status": item.metadata.status, + "confidence": item.metadata.confidence, + "tags": item.metadata.tags, + "updated_at": getattr(item.metadata, "updated_at", None) + or getattr(item.metadata, "update_at", None), + } + ) + combined_key = keys[0] if keys else "" + post_ref_id = None + post_meta = { + "ref_id": None, + "id": None, + "key": None, + "memory": None, + "memory_type": None, + "status": None, + "confidence": None, + "tags": None, + "updated_at": None, + } + if merged_target_ids: + post_ref_id = next(iter(merged_target_ids)) + with contextlib.suppress(Exception): + merged_item = text_mem.get(post_ref_id, user_name=user_name) + combined_key = ( + getattr(getattr(merged_item, "metadata", {}), "key", None) + or combined_key + ) + post_meta = { + "ref_id": post_ref_id, + "id": post_ref_id, + "key": getattr(getattr(merged_item, "metadata", {}), "key", None), + "memory": getattr(merged_item, "memory", None), + "memory_type": getattr( + getattr(merged_item, "metadata", {}), "memory_type", None + ), + "status": getattr( + getattr(merged_item, "metadata", {}), "status", None + ), + "confidence": getattr( + getattr(merged_item, "metadata", {}), "confidence", None + ), + "tags": getattr( + getattr(merged_item, "metadata", {}), "tags", None + ), + "updated_at": getattr( + getattr(merged_item, "metadata", {}), "updated_at", None + ) + or getattr( + getattr(merged_item, "metadata", {}), "update_at", None + ), + } + if not post_ref_id: + import hashlib + + post_ref_id = ( + "merge-" + hashlib.md5("".join(sorted(mem_ids)).encode()).hexdigest() + ) + post_meta["ref_id"] = post_ref_id + post_meta["id"] = post_ref_id + if not post_meta.get("key"): + post_meta["key"] = combined_key + if not keys: + keys = [item.id for item in mem_items] + memcube_content.append( + { + "content": combined_key if combined_key else "(no key)", + "ref_id": post_ref_id, + "type": "postMerge", + } + ) + meta.append(post_meta) + event = self.ctx.services.create_event_log( + label="mergeMemory", + from_memory_type=LONG_TERM_MEMORY_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + memcube_log_content=memcube_content, + metadata=meta, + memory_len=len(keys), + memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + ) + self.ctx.services.submit_web_logs([event]) + + logger.info( + "Successfully processed mem_reorganize for user_id=%s, mem_cube_id=%s", + user_id, + mem_cube_id, + ) + + except Exception as e: + logger.error("Error processing mem_reorganize message: %s", e, exc_info=True) + + with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error("Thread task failed: %s", e, exc_info=True) + + def _process_memories_with_reorganize( + self, + mem_ids: list[str], + user_id: str, + mem_cube_id: str, + mem_cube, + text_mem: TreeTextMemory, + user_name: str, + ) -> None: + try: + mem_reader = self.ctx.get_mem_reader() + if mem_reader is None: + logger.warning("mem_reader not available in scheduler, skipping enhanced processing") + return + + memory_items = [] + for mem_id in mem_ids: + try: + memory_item = text_mem.get(mem_id, user_name=user_name) + memory_items.append(memory_item) + except Exception as e: + logger.warning("Failed to get memory %s: %s|%s", mem_id, e, traceback.format_exc()) + continue + + if not memory_items: + logger.warning("No valid memory items found for processing") + return + + logger.info("Processing %s memories with mem_reader", len(memory_items)) + text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) + logger.info("Remove and Refresh Memories") + logger.debug("Finished add %s memory: %s", user_id, mem_ids) + + except Exception: + logger.error( + "Error in _process_memories_with_reorganize: %s", + traceback.format_exc(), + exc_info=True, + ) diff --git a/src/memos/mem_scheduler/handlers/memory_update_handler.py b/src/memos/mem_scheduler/handlers/memory_update_handler.py new file mode 100644 index 000000000..0775fc91a --- /dev/null +++ b/src/memos/mem_scheduler/handlers/memory_update_handler.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem +from memos.mem_scheduler.schemas.task_schemas import ( + DEFAULT_MAX_QUERY_KEY_WORDS, + MEM_UPDATE_TASK_LABEL, +) +from memos.mem_scheduler.utils.filter_utils import is_all_chinese, is_all_english +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube +from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.naive import NaiveTextMemory +from memos.memories.textual.tree import TreeTextMemory +from memos.types import MemCubeID, UserID + + +logger = get_logger(__name__) + + +class MemoryUpdateHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {MEM_UPDATE_TASK_LABEL} handler.") + + grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) + + self.ctx.services.validate_messages(messages=messages, label=MEM_UPDATE_TASK_LABEL) + + for user_id in grouped_messages: + for mem_cube_id in grouped_messages[user_id]: + batch = grouped_messages[user_id][mem_cube_id] + if not batch: + continue + self.long_memory_update_process( + user_id=user_id, mem_cube_id=mem_cube_id, messages=batch + ) + + def long_memory_update_process( + self, + user_id: str, + mem_cube_id: str, + messages: list[ScheduleMessageItem], + ) -> None: + mem_cube = self.ctx.get_mem_cube() + monitor = self.ctx.get_monitor() + + query_key_words_limit = self.ctx.get_query_key_words_limit() + + for msg in messages: + monitor.register_query_monitor_if_not_exists(user_id=user_id, mem_cube_id=mem_cube_id) + + query = msg.content + query_keywords = monitor.extract_query_keywords(query=query) + logger.info( + 'Extracted keywords "%s" from query "%s" for user_id=%s', + query_keywords, + query, + user_id, + ) + + if len(query_keywords) == 0: + stripped_query = query.strip() + if is_all_english(stripped_query): + words = stripped_query.split() + elif is_all_chinese(stripped_query): + words = stripped_query + else: + logger.debug( + "Mixed-language memory, using character count: %s...", + stripped_query[:50], + ) + words = stripped_query + + query_keywords = list(set(words[:query_key_words_limit])) + logger.error( + "Keyword extraction failed for query '%s' (user_id=%s). Using fallback keywords: %s... (truncated)", + query, + user_id, + query_keywords[:10], + exc_info=True, + ) + + item = QueryMonitorItem( + user_id=user_id, + mem_cube_id=mem_cube_id, + query_text=query, + keywords=query_keywords, + max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS, + ) + + query_db_manager = monitor.query_monitors[user_id][mem_cube_id] + query_db_manager.obj.put(item=item) + query_db_manager.sync_with_orm() + logger.debug( + "Queries in monitor for user_id=%s, mem_cube_id=%s: %s", + user_id, + mem_cube_id, + query_db_manager.obj.get_queries_with_timesort(), + ) + + queries = [msg.content for msg in messages] + + cur_working_memory, new_candidates = self.process_session_turn( + queries=queries, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + top_k=self.ctx.get_top_k(), + ) + logger.info( + "[long_memory_update_process] Processed %s queries %s and retrieved %s new candidate memories for user_id=%s: " + + ("\n- " + "\n- ".join([f"{one.id}: {one.memory}" for one in new_candidates])), + len(queries), + queries, + len(new_candidates), + user_id, + ) + + new_order_working_memory = self.ctx.services.replace_working_memory( + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + original_memory=cur_working_memory, + new_memory=new_candidates, + ) + logger.debug( + "[long_memory_update_process] Final working memory size: %s memories for user_id=%s", + len(new_order_working_memory), + user_id, + ) + + old_memory_texts = "\n- " + "\n- ".join( + [f"{one.id}: {one.memory}" for one in cur_working_memory] + ) + new_memory_texts = "\n- " + "\n- ".join( + [f"{one.id}: {one.memory}" for one in new_order_working_memory] + ) + + logger.info( + "[long_memory_update_process] For user_id='%s', mem_cube_id='%s': " + "Scheduler replaced working memory based on query history %s. " + "Old working memory (%s items): %s. " + "New working memory (%s items): %s.", + user_id, + mem_cube_id, + queries, + len(cur_working_memory), + old_memory_texts, + len(new_order_working_memory), + new_memory_texts, + ) + + logger.debug( + "Activation memory update %s (interval: %ss)", + "enabled" if self.ctx.get_enable_activation_memory() else "disabled", + monitor.act_mem_update_interval, + ) + if self.ctx.get_enable_activation_memory(): + self.ctx.services.update_activation_memory_periodically( + interval_seconds=monitor.act_mem_update_interval, + label=MEM_UPDATE_TASK_LABEL, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + + def process_session_turn( + self, + queries: str | list[str], + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube, + top_k: int = 10, + ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]] | None: + text_mem_base = mem_cube.text_mem + if not isinstance(text_mem_base, TreeTextMemory): + if isinstance(text_mem_base, NaiveTextMemory): + logger.debug( + "NaiveTextMemory used for mem_cube_id=%s, processing session turn with simple search.", + mem_cube_id, + ) + cur_working_memory = [] + else: + logger.warning( + "Not implemented! Expected TreeTextMemory but got %s for mem_cube_id=%s, user_id=%s. text_mem_base value: %s", + type(text_mem_base).__name__, + mem_cube_id, + user_id, + text_mem_base, + ) + return [], [] + else: + cur_working_memory = text_mem_base.get_working_memory(user_name=mem_cube_id) + cur_working_memory = cur_working_memory[:top_k] + + logger.info( + "[process_session_turn] Processing %s queries for user_id=%s, mem_cube_id=%s", + len(queries), + user_id, + mem_cube_id, + ) + + text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] + monitor = self.ctx.get_monitor() + intent_result = monitor.detect_intent(q_list=queries, text_working_memory=text_working_memory) + + time_trigger_flag = False + if monitor.timed_trigger( + last_time=monitor.last_query_consume_time, + interval_seconds=monitor.query_trigger_interval, + ): + time_trigger_flag = True + + if (not intent_result["trigger_retrieval"]) and (not time_trigger_flag): + logger.info( + "[process_session_turn] Query schedule not triggered for user_id=%s, mem_cube_id=%s. Intent_result: %s", + user_id, + mem_cube_id, + intent_result, + ) + return + if (not intent_result["trigger_retrieval"]) and time_trigger_flag: + logger.info( + "[process_session_turn] Query schedule forced to trigger due to time ticker for user_id=%s, mem_cube_id=%s", + user_id, + mem_cube_id, + ) + intent_result["trigger_retrieval"] = True + intent_result["missing_evidences"] = queries + else: + logger.info( + "[process_session_turn] Query schedule triggered for user_id=%s, mem_cube_id=%s. Missing evidences: %s", + user_id, + mem_cube_id, + intent_result["missing_evidences"], + ) + + missing_evidences = intent_result["missing_evidences"] + num_evidence = len(missing_evidences) + k_per_evidence = max(1, top_k // max(1, num_evidence)) + new_candidates: list[TextualMemoryItem] = [] + retriever = self.ctx.get_retriever() + search_method = self.ctx.get_search_method() + + for item in missing_evidences: + logger.info( + "[process_session_turn] Searching for missing evidence: '%s' with top_k=%s for user_id=%s", + item, + k_per_evidence, + user_id, + ) + + search_args = {} + if isinstance(text_mem_base, NaiveTextMemory): + try: + results = text_mem_base.search(query=item, top_k=k_per_evidence) + except Exception as e: + logger.warning("NaiveTextMemory search failed: %s", e) + results = [] + else: + results = retriever.search( + query=item, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + top_k=k_per_evidence, + method=search_method, + search_args=search_args, + ) + + logger.info( + "[process_session_turn] Search results for missing evidence '%s': \n- %s", + item, + "\n- ".join([f"{one.id}: {one.memory}" for one in results]), + ) + new_candidates.extend(results) + return cur_working_memory, new_candidates diff --git a/src/memos/mem_scheduler/handlers/pref_add_handler.py b/src/memos/mem_scheduler/handlers/pref_add_handler.py new file mode 100644 index 000000000..195b35385 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/pref_add_handler.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import concurrent.futures +import json + +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import PREF_ADD_TASK_LABEL +from memos.memories.textual.preference import PreferenceTextMemory + + +logger = get_logger(__name__) + + +class PrefAddMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {PREF_ADD_TASK_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + mem_cube = self.ctx.get_mem_cube() + if mem_cube is None: + logger.warning( + "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing", + message.user_id, + message.mem_cube_id, + ) + return + + user_id = message.user_id + session_id = message.session_id + mem_cube_id = message.mem_cube_id + content = message.content + messages_list = json.loads(content) + info = message.info or {} + + logger.info("Processing pref_add for user_id=%s, mem_cube_id=%s", user_id, mem_cube_id) + + pref_mem = mem_cube.pref_mem + if pref_mem is None: + logger.warning( + "Preference memory not initialized for mem_cube_id=%s, skipping pref_add processing", + mem_cube_id, + ) + return + if not isinstance(pref_mem, PreferenceTextMemory): + logger.error( + "Expected PreferenceTextMemory but got %s for mem_cube_id=%s", + type(pref_mem).__name__, + mem_cube_id, + ) + return + + pref_memories = pref_mem.get_memory( + messages_list, + type="chat", + info={ + **info, + "user_id": user_id, + "session_id": session_id, + "mem_cube_id": mem_cube_id, + }, + ) + pref_ids = pref_mem.add(pref_memories) + + logger.info( + "Successfully processed and add preferences for user_id=%s, mem_cube_id=%s, pref_ids=%s", + user_id, + mem_cube_id, + pref_ids, + ) + + except Exception as e: + logger.error("Error processing pref_add message: %s", e, exc_info=True) + + with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error("Thread task failed: %s", e, exc_info=True) diff --git a/src/memos/mem_scheduler/handlers/query_handler.py b/src/memos/mem_scheduler/handlers/query_handler.py new file mode 100644 index 000000000..4d3a09368 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/query_handler.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + MEM_UPDATE_TASK_LABEL, + NOT_APPLICABLE_TYPE, + QUERY_TASK_LABEL, + USER_INPUT_TYPE, +) +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube + + +logger = get_logger(__name__) + + +class QueryMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {QUERY_TASK_LABEL} handler.") + + grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) + + self.ctx.services.validate_messages(messages=messages, label=QUERY_TASK_LABEL) + + mem_update_messages: list[ScheduleMessageItem] = [] + for user_id in grouped_messages: + for mem_cube_id in grouped_messages[user_id]: + batch = grouped_messages[user_id][mem_cube_id] + if not batch: + continue + + for msg in batch: + try: + event = self.ctx.services.create_event_log( + label="addMessage", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=NOT_APPLICABLE_TYPE, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=[ + { + "content": f"[User] {msg.content}", + "ref_id": msg.item_id, + "role": "user", + } + ], + metadata=[], + memory_len=1, + memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + ) + event.task_id = msg.task_id + self.ctx.services.submit_web_logs([event]) + except Exception: + logger.exception("Failed to record addMessage log for query") + + update_msg = ScheduleMessageItem( + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + label=MEM_UPDATE_TASK_LABEL, + content=msg.content, + session_id=msg.session_id, + user_name=msg.user_name, + info=msg.info, + task_id=msg.task_id, + ) + mem_update_messages.append(update_msg) + + self.ctx.services.submit_messages(messages=mem_update_messages) diff --git a/src/memos/mem_scheduler/handlers/registry.py b/src/memos/mem_scheduler/handlers/registry.py new file mode 100644 index 000000000..1e1db0404 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/registry.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from collections.abc import Callable + +from memos.mem_scheduler.handlers.add_handler import AddMessageHandler +from memos.mem_scheduler.handlers.answer_handler import AnswerMessageHandler +from memos.mem_scheduler.handlers.context import SchedulerHandlerContext +from memos.mem_scheduler.handlers.feedback_handler import FeedbackMessageHandler +from memos.mem_scheduler.handlers.mem_read_handler import MemReadMessageHandler +from memos.mem_scheduler.handlers.mem_reorganize_handler import MemReorganizeMessageHandler +from memos.mem_scheduler.handlers.memory_update_handler import MemoryUpdateHandler +from memos.mem_scheduler.handlers.pref_add_handler import PrefAddMessageHandler +from memos.mem_scheduler.handlers.query_handler import QueryMessageHandler +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + ANSWER_TASK_LABEL, + MEM_FEEDBACK_TASK_LABEL, + MEM_ORGANIZE_TASK_LABEL, + MEM_READ_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, + PREF_ADD_TASK_LABEL, + QUERY_TASK_LABEL, +) + + +class SchedulerHandlerRegistry: + def __init__(self, ctx: SchedulerHandlerContext) -> None: + self.query = QueryMessageHandler(ctx) + self.answer = AnswerMessageHandler(ctx) + self.add = AddMessageHandler(ctx) + self.memory_update = MemoryUpdateHandler(ctx) + self.mem_feedback = FeedbackMessageHandler(ctx) + self.mem_read = MemReadMessageHandler(ctx) + self.mem_reorganize = MemReorganizeMessageHandler(ctx) + self.pref_add = PrefAddMessageHandler(ctx) + + def build_dispatch_map(self) -> dict[str, Callable]: + return { + QUERY_TASK_LABEL: self.query.handle, + ANSWER_TASK_LABEL: self.answer.handle, + MEM_UPDATE_TASK_LABEL: self.memory_update.handle, + ADD_TASK_LABEL: self.add.handle, + MEM_READ_TASK_LABEL: self.mem_read.handle, + MEM_ORGANIZE_TASK_LABEL: self.mem_reorganize.handle, + PREF_ADD_TASK_LABEL: self.pref_add.handle, + MEM_FEEDBACK_TASK_LABEL: self.mem_feedback.handle, + } diff --git a/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py new file mode 100644 index 000000000..5dd5e95d3 --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import time +from collections.abc import Callable + +from memos.log import get_logger +from memos.mem_scheduler.schemas.general_schemas import ( + DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, + DEFAULT_SCHEDULER_RETRIEVER_RETRIES, +) +from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer +from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata +from memos.types.general_types import FINE_STRATEGY, FineStrategy + + +logger = get_logger(__name__) + + +class EnhancementPipeline: + def __init__(self, process_llm, config, build_prompt: Callable[..., str]): + self.process_llm = process_llm + self.config = config + self.build_prompt = build_prompt + self.batch_size: int | None = getattr( + config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE + ) + self.retries: int = getattr( + config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES + ) + + def evaluate_memory_answer_ability( + self, query: str, memory_texts: list[str], top_k: int | None = None + ) -> bool: + limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts + prompt = self.build_prompt( + template_name="memory_answer_ability_evaluation", + query=query, + memory_list="\n".join([f"- {memory}" for memory in limited_memories]) + if limited_memories + else "No memories available", + ) + + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + + try: + result = extract_json_obj(response) + + if "result" in result: + logger.info( + "Answerability: result=%s; reason=%s; evaluated=%s", + result["result"], + result.get("reason", "n/a"), + len(limited_memories), + ) + return result["result"] + logger.warning("Answerability: invalid LLM JSON structure; payload=%s", result) + return False + + except Exception as e: + logger.error("Answerability: parse failed; err=%s; raw=%s...", e, str(response)[:200]) + return False + + def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[str]) -> str: + if len(query_history) == 1: + query_history = query_history[0] + else: + query_history = ( + [f"[{i}] {query}" for i, query in enumerate(query_history)] + if len(query_history) > 1 + else query_history[0] + ) + if FINE_STRATEGY == FineStrategy.REWRITE: + text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(batch_texts)]) + prompt_name = "memory_rewrite_enhancement" + else: + text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)]) + prompt_name = "memory_recreate_enhancement" + return self.build_prompt( + prompt_name, + query_history=query_history, + memories=text_memories, + ) + + def _process_enhancement_batch( + self, + batch_index: int, + query_history: list[str], + memories: list[TextualMemoryItem], + retries: int, + ) -> tuple[list[TextualMemoryItem], bool]: + attempt = 0 + text_memories = [one.memory for one in memories] + + prompt = self._build_enhancement_prompt( + query_history=query_history, batch_texts=text_memories + ) + + llm_response = None + while attempt <= max(0, retries) + 1: + try: + llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) + processed_text_memories = extract_list_items_in_answer(llm_response) + if len(processed_text_memories) > 0: + enhanced_memories = [] + user_id = memories[0].metadata.user_id + if FINE_STRATEGY == FineStrategy.RECREATE: + for new_mem in processed_text_memories: + enhanced_memories.append( + TextualMemoryItem( + memory=new_mem, + metadata=TextualMemoryMetadata( + user_id=user_id, memory_type="LongTermMemory" + ), + ) + ) + elif FINE_STRATEGY == FineStrategy.REWRITE: + def _parse_index_and_text(s: str) -> tuple[int | None, str]: + import re + + s = (s or "").strip() + m = re.match(r"^\s*\[(\d+)\]\s*(.+)$", s) + if m: + return int(m.group(1)), m.group(2).strip() + m = re.match(r"^\s*(\d+)\s*[:\-\)]\s*(.+)$", s) + if m: + return int(m.group(1)), m.group(2).strip() + return None, s + + idx_to_original = dict(enumerate(memories)) + for j, item in enumerate(processed_text_memories): + idx, new_text = _parse_index_and_text(item) + if idx is not None and idx in idx_to_original: + orig = idx_to_original[idx] + else: + orig = memories[j] if j < len(memories) else None + if not orig: + continue + enhanced_memories.append( + TextualMemoryItem( + id=orig.id, + memory=new_text, + metadata=orig.metadata, + ) + ) + else: + logger.error("Fine search strategy %s not exists", FINE_STRATEGY) + + logger.info( + "[enhance_memories_with_query] done | Strategy=%s | prompt=%s | llm_response=%s", + FINE_STRATEGY, + prompt, + llm_response, + ) + return enhanced_memories, True + raise ValueError( + "Fail to run memory enhancement; retry %s/%s; processed_text_memories: %s" + % (attempt, max(1, retries) + 1, processed_text_memories) + ) + except Exception as e: + attempt += 1 + time.sleep(1) + logger.debug( + "[enhance_memories_with_query][batch=%s] retry %s/%s failed: %s", + batch_index, + attempt, + max(1, retries) + 1, + e, + ) + logger.error( + "Fail to run memory enhancement; prompt: %s;\n llm_response: %s", + prompt, + llm_response, + exc_info=True, + ) + return memories, False + + @staticmethod + def _split_batches( + memories: list[TextualMemoryItem], batch_size: int + ) -> list[tuple[int, int, list[TextualMemoryItem]]]: + batches: list[tuple[int, int, list[TextualMemoryItem]]] = [] + start = 0 + n = len(memories) + while start < n: + end = min(start + batch_size, n) + batches.append((start, end, memories[start:end])) + start = end + return batches + + def recall_for_missing_memories(self, query: str, memories: list[str]) -> tuple[str, bool]: + text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(memories)]) + + prompt = self.build_prompt( + template_name="enlarge_recall", + query=query, + memories_inline=text_memories, + ) + llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) + + json_result: dict = extract_json_obj(llm_response) + + logger.info( + "[recall_for_missing_memories] done | prompt=%s | llm_response=%s", + prompt, + llm_response, + ) + + hint = json_result.get("hint", "") + if len(hint) == 0: + return hint, False + return hint, json_result.get("trigger_recall", False) + + def enhance_memories_with_query( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> tuple[list[TextualMemoryItem], bool]: + if not memories: + logger.warning("[Enhance] skipped (no memories to process)") + return memories, True + + batch_size = self.batch_size + retries = self.retries + num_of_memories = len(memories) + try: + if batch_size is None or num_of_memories <= batch_size: + enhanced_memories, success_flag = self._process_enhancement_batch( + batch_index=0, + query_history=query_history, + memories=memories, + retries=retries, + ) + + all_success = success_flag + else: + batches = self._split_batches(memories=memories, batch_size=batch_size) + + all_success = True + failed_batches = 0 + from memos.context.context import ContextThreadPoolExecutor + from concurrent.futures import as_completed + + with ContextThreadPoolExecutor(max_workers=len(batches)) as executor: + future_map = { + executor.submit( + self._process_enhancement_batch, bi, query_history, texts, retries + ): (bi, s, e) + for bi, (s, e, texts) in enumerate(batches) + } + enhanced_memories = [] + for fut in as_completed(future_map): + bi, s, e = future_map[fut] + + batch_memories, ok = fut.result() + enhanced_memories.extend(batch_memories) + if not ok: + all_success = False + failed_batches += 1 + logger.info( + "[Enhance] multi-batch done | batches=%s | enhanced=%s | failed_batches=%s | success=%s", + len(batches), + len(enhanced_memories), + failed_batches, + all_success, + ) + + except Exception as e: + logger.error("[Enhance] fatal error: %s", e, exc_info=True) + all_success = False + enhanced_memories = memories + + if len(enhanced_memories) == 0: + enhanced_memories = [] + logger.error("[Enhance] fatal error: enhanced_memories is empty", exc_info=True) + return enhanced_memories, all_success diff --git a/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py new file mode 100644 index 000000000..8bebe2456 --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from memos.mem_scheduler.memory_manage_modules.memory_filter import MemoryFilter +from memos.memories.textual.tree import TextualMemoryItem + + +class FilterPipeline: + def __init__(self, process_llm, config): + self.memory_filter = MemoryFilter(process_llm=process_llm, config=config) + + def filter_unrelated_memories( + self, query_history: list[str], memories: list[TextualMemoryItem] + ) -> tuple[list[TextualMemoryItem], bool]: + return self.memory_filter.filter_unrelated_memories(query_history, memories) + + def filter_redundant_memories( + self, query_history: list[str], memories: list[TextualMemoryItem] + ) -> tuple[list[TextualMemoryItem], bool]: + return self.memory_filter.filter_redundant_memories(query_history, memories) + + def filter_unrelated_and_redundant_memories( + self, query_history: list[str], memories: list[TextualMemoryItem] + ) -> tuple[list[TextualMemoryItem], bool]: + return self.memory_filter.filter_unrelated_and_redundant_memories(query_history, memories) diff --git a/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py new file mode 100644 index 000000000..21dabedd9 --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from memos.log import get_logger +from memos.mem_scheduler.utils.filter_utils import ( + filter_too_short_memories, + filter_vector_based_similar_memories, + transform_name_to_key, +) +from memos.mem_scheduler.utils.misc_utils import extract_json_obj +from memos.memories.textual.item import TextualMemoryItem + + +logger = get_logger(__name__) + + +class RerankPipeline: + def __init__( + self, + process_llm, + similarity_threshold: float, + min_length_threshold: int, + build_prompt, + ): + self.process_llm = process_llm + self.filter_similarity_threshold = similarity_threshold + self.filter_min_length_threshold = min_length_threshold + self.build_prompt = build_prompt + + def rerank_memories( + self, queries: list[str], original_memories: list[str], top_k: int + ) -> tuple[list[str], bool]: + logger.info("Starting memory reranking for %s memories", len(original_memories)) + + prompt = self.build_prompt( + "memory_reranking", + queries=[f"[0] {queries[0]}"], + current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)], + ) + logger.debug("Generated reranking prompt: %s...", prompt[:200]) + + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + logger.debug("Received LLM response: %s...", response[:200]) + + try: + response = extract_json_obj(response) + new_order = response["new_order"][:top_k] + text_memories_with_new_order = [original_memories[idx] for idx in new_order] + logger.info( + "Successfully reranked memories. Returning top %s items; Ranking reasoning: %s", + len(text_memories_with_new_order), + response["reasoning"], + ) + success_flag = True + except Exception as e: + logger.error( + "Failed to rerank memories with LLM. Exception: %s. Raw response: %s ", + e, + response, + exc_info=True, + ) + text_memories_with_new_order = original_memories[:top_k] + success_flag = False + return text_memories_with_new_order, success_flag + + def process_and_rerank_memories( + self, + queries: list[str], + original_memory: list[TextualMemoryItem], + new_memory: list[TextualMemoryItem], + top_k: int = 10, + ) -> tuple[list[TextualMemoryItem], bool]: + combined_memory = original_memory + new_memory + + memory_map = { + transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory + } + + combined_text_memory = [m.memory for m in combined_memory] + + filtered_combined_text_memory = filter_vector_based_similar_memories( + text_memories=combined_text_memory, + similarity_threshold=self.filter_similarity_threshold, + ) + + filtered_combined_text_memory = filter_too_short_memories( + text_memories=filtered_combined_text_memory, + min_length_threshold=self.filter_min_length_threshold, + ) + + unique_memory = list(dict.fromkeys(filtered_combined_text_memory)) + + text_memories_with_new_order, success_flag = self.rerank_memories( + queries=queries, + original_memories=unique_memory, + top_k=top_k, + ) + + memories_with_new_order = [] + for text in text_memories_with_new_order: + normalized_text = transform_name_to_key(name=text) + if normalized_text in memory_map: + memories_with_new_order.append(memory_map[normalized_text]) + else: + logger.warning( + "Memory text not found in memory map. text: %s;\nKeys of memory_map: %s", + text, + memory_map.keys(), + ) + + return memories_with_new_order, success_flag diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py index f205766f0..41e268cef 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py +++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py @@ -1,457 +1,94 @@ -import time +from __future__ import annotations -from concurrent.futures import as_completed - -from memos.configs.mem_scheduler import BaseSchedulerConfig -from memos.context.context import ContextThreadPoolExecutor -from memos.llms.base import BaseLLM from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.general_modules.base import BaseSchedulerModule -from memos.mem_scheduler.schemas.general_schemas import ( - DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, - DEFAULT_SCHEDULER_RETRIEVER_RETRIES, - TreeTextMemory_FINE_SEARCH_METHOD, - TreeTextMemory_SEARCH_METHOD, -) -from memos.mem_scheduler.utils.filter_utils import ( - filter_too_short_memories, - filter_vector_based_similar_memories, - transform_name_to_key, -) -from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer -from memos.memories.textual.item import TextualMemoryMetadata -from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory -from memos.types.general_types import ( - FINE_STRATEGY, - FineStrategy, - SearchMode, -) - -# Extract JSON response -from .memory_filter import MemoryFilter +from memos.mem_scheduler.memory_manage_modules.enhancement_pipeline import EnhancementPipeline +from memos.mem_scheduler.memory_manage_modules.filter_pipeline import FilterPipeline +from memos.mem_scheduler.memory_manage_modules.rerank_pipeline import RerankPipeline +from memos.mem_scheduler.memory_manage_modules.search_pipeline import SearchPipeline +from memos.memories.textual.item import TextualMemoryItem logger = get_logger(__name__) class SchedulerRetriever(BaseSchedulerModule): - def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): + def __init__(self, process_llm, config): super().__init__() - # hyper-parameters self.filter_similarity_threshold = 0.75 self.filter_min_length_threshold = 6 - self.memory_filter = MemoryFilter(process_llm=process_llm, config=config) self.process_llm = process_llm self.config = config - # Configure enhancement batching & retries from config with safe defaults - self.batch_size: int | None = getattr( - config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE + self.search_pipeline = SearchPipeline() + self.enhancement_pipeline = EnhancementPipeline( + process_llm=process_llm, + config=config, + build_prompt=self.build_prompt, ) - self.retries: int = getattr( - config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES + self.rerank_pipeline = RerankPipeline( + process_llm=process_llm, + similarity_threshold=self.filter_similarity_threshold, + min_length_threshold=self.filter_min_length_threshold, + build_prompt=self.build_prompt, ) + self.filter_pipeline = FilterPipeline(process_llm=process_llm, config=config) + self.memory_filter = self.filter_pipeline.memory_filter def evaluate_memory_answer_ability( self, query: str, memory_texts: list[str], top_k: int | None = None ) -> bool: - limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts - # Build prompt using the template - prompt = self.build_prompt( - template_name="memory_answer_ability_evaluation", + return self.enhancement_pipeline.evaluate_memory_answer_ability( query=query, - memory_list="\n".join([f"- {memory}" for memory in limited_memories]) - if limited_memories - else "No memories available", - ) - - # Use the process LLM to generate response - response = self.process_llm.generate([{"role": "user", "content": prompt}]) - - try: - result = extract_json_obj(response) - - # Validate response structure - if "result" in result: - logger.info( - f"Answerability: result={result['result']}; reason={result.get('reason', 'n/a')}; evaluated={len(limited_memories)}" - ) - return result["result"] - else: - logger.warning(f"Answerability: invalid LLM JSON structure; payload={result}") - return False - - except Exception as e: - logger.error(f"Answerability: parse failed; err={e}; raw={str(response)[:200]}...") - # Fallback: return False if we can't determine answer ability - return False - - # ---------------------- Enhancement helpers ---------------------- - def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[str]) -> str: - if len(query_history) == 1: - query_history = query_history[0] - else: - query_history = ( - [f"[{i}] {query}" for i, query in enumerate(query_history)] - if len(query_history) > 1 - else query_history[0] - ) - # Include numbering for rewrite mode to help LLM reference original memory IDs - if FINE_STRATEGY == FineStrategy.REWRITE: - text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(batch_texts)]) - prompt_name = "memory_rewrite_enhancement" - else: - text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)]) - prompt_name = "memory_recreate_enhancement" - return self.build_prompt( - prompt_name, - query_history=query_history, - memories=text_memories, - ) - - def _process_enhancement_batch( - self, - batch_index: int, - query_history: list[str], - memories: list[TextualMemoryItem], - retries: int, - ) -> tuple[list[TextualMemoryItem], bool]: - attempt = 0 - text_memories = [one.memory for one in memories] - - prompt = self._build_enhancement_prompt( - query_history=query_history, batch_texts=text_memories - ) - - llm_response = None - while attempt <= max(0, retries) + 1: - try: - llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) - processed_text_memories = extract_list_items_in_answer(llm_response) - if len(processed_text_memories) > 0: - # create new - enhanced_memories = [] - user_id = memories[0].metadata.user_id - if FINE_STRATEGY == FineStrategy.RECREATE: - for new_mem in processed_text_memories: - enhanced_memories.append( - TextualMemoryItem( - memory=new_mem, - metadata=TextualMemoryMetadata( - user_id=user_id, memory_type="LongTermMemory" - ), # TODO add memory_type - ) - ) - elif FINE_STRATEGY == FineStrategy.REWRITE: - # Parse index from each processed line and rewrite corresponding original memory - def _parse_index_and_text(s: str) -> tuple[int | None, str]: - import re - - s = (s or "").strip() - # Preferred: [index] text - m = re.match(r"^\s*\[(\d+)\]\s*(.+)$", s) - if m: - return int(m.group(1)), m.group(2).strip() - # Fallback: index: text or index - text - m = re.match(r"^\s*(\d+)\s*[:\-\)]\s*(.+)$", s) - if m: - return int(m.group(1)), m.group(2).strip() - return None, s - - idx_to_original = dict(enumerate(memories)) - for j, item in enumerate(processed_text_memories): - idx, new_text = _parse_index_and_text(item) - if idx is not None and idx in idx_to_original: - orig = idx_to_original[idx] - else: - # Fallback: align by order if index missing/invalid - orig = memories[j] if j < len(memories) else None - if not orig: - continue - enhanced_memories.append( - TextualMemoryItem( - id=orig.id, - memory=new_text, - metadata=orig.metadata, - ) - ) - else: - logger.error(f"Fine search strategy {FINE_STRATEGY} not exists") - - logger.info( - f"[enhance_memories_with_query] ✅ done | Strategy={FINE_STRATEGY} | prompt={prompt} | llm_response={llm_response}" - ) - return enhanced_memories, True - else: - raise ValueError( - f"Fail to run memory enhancement; retry {attempt}/{max(1, retries) + 1}; processed_text_memories: {processed_text_memories}" - ) - except Exception as e: - attempt += 1 - time.sleep(1) - logger.debug( - f"[enhance_memories_with_query][batch={batch_index}] 🔁 retry {attempt}/{max(1, retries) + 1} failed: {e}" - ) - logger.error( - f"Fail to run memory enhancement; prompt: {prompt};\n llm_response: {llm_response}", - exc_info=True, - ) - return memories, False - - @staticmethod - def _split_batches( - memories: list[TextualMemoryItem], batch_size: int - ) -> list[tuple[int, int, list[TextualMemoryItem]]]: - batches: list[tuple[int, int, list[TextualMemoryItem]]] = [] - start = 0 - n = len(memories) - while start < n: - end = min(start + batch_size, n) - batches.append((start, end, memories[start:end])) - start = end - return batches - - def recall_for_missing_memories( - self, - query: str, - memories: list[str], - ) -> tuple[str, bool]: - text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(memories)]) - - prompt = self.build_prompt( - template_name="enlarge_recall", - query=query, - memories_inline=text_memories, - ) - llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) - - json_result: dict = extract_json_obj(llm_response) - - logger.info( - f"[recall_for_missing_memories] ✅ done | prompt={prompt} | llm_response={llm_response}" + memory_texts=memory_texts, + top_k=top_k, ) - hint = json_result.get("hint", "") - if len(hint) == 0: - return hint, False - return hint, json_result.get("trigger_recall", False) - def search( self, query: str, user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, + mem_cube, top_k: int, - method: str = TreeTextMemory_SEARCH_METHOD, + method: str, search_args: dict | None = None, ) -> list[TextualMemoryItem]: - """Search in text memory with the given query. - - Args: - query: The search query string - top_k: Number of top results to return - method: Search method to use - - Returns: - Search results or None if not implemented - """ - text_mem_base = mem_cube.text_mem - # Normalize default for mutable argument - search_args = search_args or {} - try: - if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]: - assert isinstance(text_mem_base, TreeTextMemory) - session_id = search_args.get("session_id", "default_session") - target_session_id = session_id - search_priority = ( - {"session_id": target_session_id} if "session_id" in search_args else None - ) - search_filter = search_args.get("filter") - search_source = search_args.get("source") - plugin = bool(search_source is not None and search_source == "plugin") - user_name = search_args.get("user_name", mem_cube_id) - internet_search = search_args.get("internet_search", False) - chat_history = search_args.get("chat_history") - search_tool_memory = search_args.get("search_tool_memory", False) - tool_mem_top_k = search_args.get("tool_mem_top_k", 6) - playground_search_goal_parser = search_args.get( - "playground_search_goal_parser", False - ) - - info = search_args.get( - "info", - { - "user_id": user_id, - "session_id": target_session_id, - "chat_history": chat_history, - }, - ) - - results_long_term = mem_cube.text_mem.search( - query=query, - user_name=user_name, - top_k=top_k, - mode=SearchMode.FAST, - manual_close_internet=not internet_search, - memory_type="LongTermMemory", - search_filter=search_filter, - search_priority=search_priority, - info=info, - plugin=plugin, - search_tool_memory=search_tool_memory, - tool_mem_top_k=tool_mem_top_k, - playground_search_goal_parser=playground_search_goal_parser, - ) - - results_user = mem_cube.text_mem.search( - query=query, - user_name=user_name, - top_k=top_k, - mode=SearchMode.FAST, - manual_close_internet=not internet_search, - memory_type="UserMemory", - search_filter=search_filter, - search_priority=search_priority, - info=info, - plugin=plugin, - search_tool_memory=search_tool_memory, - tool_mem_top_k=tool_mem_top_k, - playground_search_goal_parser=playground_search_goal_parser, - ) - results = results_long_term + results_user - else: - raise NotImplementedError(str(type(text_mem_base))) - except Exception as e: - logger.error(f"Fail to search. The exeption is {e}.", exc_info=True) - results = [] - return results + return self.search_pipeline.search( + query=query, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + top_k=top_k, + method=method, + search_args=search_args, + ) def enhance_memories_with_query( self, query_history: list[str], memories: list[TextualMemoryItem], - ) -> (list[TextualMemoryItem], bool): - """ - Enhance memories by adding context and making connections to better answer queries. - - Args: - query_history: List of user queries in chronological order - memories: List of memory items to enhance - - Returns: - Tuple of (enhanced_memories, success_flag) - """ - if not memories: - logger.warning("[Enhance] ⚠️ skipped (no memories to process)") - return memories, True - - batch_size = self.batch_size - retries = self.retries - num_of_memories = len(memories) - try: - # no parallel - if batch_size is None or num_of_memories <= batch_size: - # Single batch path with retry - enhanced_memories, success_flag = self._process_enhancement_batch( - batch_index=0, - query_history=query_history, - memories=memories, - retries=retries, - ) - - all_success = success_flag - else: - # parallel running batches - # Split into batches preserving order - batches = self._split_batches(memories=memories, batch_size=batch_size) - - # Process batches concurrently - all_success = True - failed_batches = 0 - with ContextThreadPoolExecutor(max_workers=len(batches)) as executor: - future_map = { - executor.submit( - self._process_enhancement_batch, bi, query_history, texts, retries - ): (bi, s, e) - for bi, (s, e, texts) in enumerate(batches) - } - enhanced_memories = [] - for fut in as_completed(future_map): - bi, s, e = future_map[fut] - - batch_memories, ok = fut.result() - enhanced_memories.extend(batch_memories) - if not ok: - all_success = False - failed_batches += 1 - logger.info( - f"[Enhance] ✅ multi-batch done | batches={len(batches)} | enhanced={len(enhanced_memories)} |" - f" failed_batches={failed_batches} | success={all_success}" - ) - - except Exception as e: - logger.error(f"[Enhance] ❌ fatal error: {e}", exc_info=True) - all_success = False - enhanced_memories = memories + ) -> tuple[list[TextualMemoryItem], bool]: + return self.enhancement_pipeline.enhance_memories_with_query( + query_history=query_history, + memories=memories, + ) - if len(enhanced_memories) == 0: - enhanced_memories = [] - logger.error("[Enhance] ❌ fatal error: enhanced_memories is empty", exc_info=True) - return enhanced_memories, all_success + def recall_for_missing_memories(self, query: str, memories: list[str]) -> tuple[str, bool]: + return self.enhancement_pipeline.recall_for_missing_memories( + query=query, + memories=memories, + ) def rerank_memories( self, queries: list[str], original_memories: list[str], top_k: int - ) -> (list[str], bool): - """ - Rerank memories based on relevance to given queries using LLM. - - Args: - queries: List of query strings to determine relevance - original_memories: List of memory strings to be reranked - top_k: Number of top memories to return after reranking - - Returns: - List of reranked memory strings (length <= top_k) - - Note: - If LLM reranking fails, falls back to original order (truncated to top_k) - """ - - logger.info(f"Starting memory reranking for {len(original_memories)} memories") - - # Build LLM prompt for memory reranking - prompt = self.build_prompt( - "memory_reranking", - queries=[f"[0] {queries[0]}"], - current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)], + ) -> tuple[list[str], bool]: + return self.rerank_pipeline.rerank_memories( + queries=queries, + original_memories=original_memories, + top_k=top_k, ) - logger.debug(f"Generated reranking prompt: {prompt[:200]}...") # Log first 200 chars - - # Get LLM response - response = self.process_llm.generate([{"role": "user", "content": prompt}]) - logger.debug(f"Received LLM response: {response[:200]}...") # Log first 200 chars - - try: - # Parse JSON response - response = extract_json_obj(response) - new_order = response["new_order"][:top_k] - text_memories_with_new_order = [original_memories[idx] for idx in new_order] - logger.info( - f"Successfully reranked memories. Returning top {len(text_memories_with_new_order)} items;" - f"Ranking reasoning: {response['reasoning']}" - ) - success_flag = True - except Exception as e: - logger.error( - f"Failed to rerank memories with LLM. Exception: {e}. Raw response: {response} ", - exc_info=True, - ) - text_memories_with_new_order = original_memories[:top_k] - success_flag = False - return text_memories_with_new_order, success_flag def process_and_rerank_memories( self, @@ -459,89 +96,40 @@ def process_and_rerank_memories( original_memory: list[TextualMemoryItem], new_memory: list[TextualMemoryItem], top_k: int = 10, - ) -> list[TextualMemoryItem] | None: - """ - Process and rerank memory items by combining original and new memories, - applying filters, and then reranking based on relevance to queries. - - Args: - queries: List of query strings to rerank memories against - original_memory: List of original TextualMemoryItem objects - new_memory: List of new TextualMemoryItem objects to merge - top_k: Maximum number of memories to return after reranking - - Returns: - List of reranked TextualMemoryItem objects, or None if processing fails - """ - # Combine original and new memories into a single list - combined_memory = original_memory + new_memory - - # Create a mapping from normalized text to memory objects - memory_map = { - transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory - } - - # Extract normalized text representations from all memory items - combined_text_memory = [m.memory for m in combined_memory] - - # Apply similarity filter to remove overly similar memories - filtered_combined_text_memory = filter_vector_based_similar_memories( - text_memories=combined_text_memory, - similarity_threshold=self.filter_similarity_threshold, - ) - - # Apply length filter to remove memories that are too short - filtered_combined_text_memory = filter_too_short_memories( - text_memories=filtered_combined_text_memory, - min_length_threshold=self.filter_min_length_threshold, - ) - - # Ensure uniqueness of memory texts using dictionary keys (preserves order) - unique_memory = list(dict.fromkeys(filtered_combined_text_memory)) - - # Rerank the filtered memories based on relevance to the queries - text_memories_with_new_order, success_flag = self.rerank_memories( + ) -> tuple[list[TextualMemoryItem], bool]: + return self.rerank_pipeline.process_and_rerank_memories( queries=queries, - original_memories=unique_memory, + original_memory=original_memory, + new_memory=new_memory, top_k=top_k, ) - # Map reranked text entries back to their original memory objects - memories_with_new_order = [] - for text in text_memories_with_new_order: - normalized_text = transform_name_to_key(name=text) - if normalized_text in memory_map: # Ensure correct key matching - memories_with_new_order.append(memory_map[normalized_text]) - else: - logger.warning( - f"Memory text not found in memory map. text: {text};\n" - f"Keys of memory_map: {memory_map.keys()}" - ) - - return memories_with_new_order, success_flag - def filter_unrelated_memories( self, query_history: list[str], memories: list[TextualMemoryItem], - ) -> (list[TextualMemoryItem], bool): - return self.memory_filter.filter_unrelated_memories(query_history, memories) + ) -> tuple[list[TextualMemoryItem], bool]: + return self.filter_pipeline.filter_unrelated_memories( + query_history=query_history, + memories=memories, + ) def filter_redundant_memories( self, query_history: list[str], memories: list[TextualMemoryItem], - ) -> (list[TextualMemoryItem], bool): - return self.memory_filter.filter_redundant_memories(query_history, memories) + ) -> tuple[list[TextualMemoryItem], bool]: + return self.filter_pipeline.filter_redundant_memories( + query_history=query_history, + memories=memories, + ) def filter_unrelated_and_redundant_memories( self, query_history: list[str], memories: list[TextualMemoryItem], - ) -> (list[TextualMemoryItem], bool): - """ - Filter out both unrelated and redundant memories using LLM analysis. - - This method delegates to the MemoryFilter class. - """ - return self.memory_filter.filter_unrelated_and_redundant_memories(query_history, memories) + ) -> tuple[list[TextualMemoryItem], bool]: + return self.filter_pipeline.filter_unrelated_and_redundant_memories( + query_history=query_history, + memories=memories, + ) diff --git a/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py new file mode 100644 index 000000000..65496b478 --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from memos.log import get_logger +from memos.mem_scheduler.schemas.general_schemas import ( + TreeTextMemory_FINE_SEARCH_METHOD, + TreeTextMemory_SEARCH_METHOD, +) +from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.types.general_types import SearchMode + + +logger = get_logger(__name__) + + +class SearchPipeline: + def search( + self, + query: str, + user_id: str, + mem_cube_id: str, + mem_cube, + top_k: int, + method: str = TreeTextMemory_SEARCH_METHOD, + search_args: dict | None = None, + ) -> list[TextualMemoryItem]: + text_mem_base = mem_cube.text_mem + search_args = search_args or {} + try: + if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]: + assert isinstance(text_mem_base, TreeTextMemory) + session_id = search_args.get("session_id", "default_session") + target_session_id = session_id + search_priority = ( + {"session_id": target_session_id} if "session_id" in search_args else None + ) + search_filter = search_args.get("filter") + search_source = search_args.get("source") + plugin = bool(search_source is not None and search_source == "plugin") + user_name = search_args.get("user_name", mem_cube_id) + internet_search = search_args.get("internet_search", False) + chat_history = search_args.get("chat_history") + search_tool_memory = search_args.get("search_tool_memory", False) + tool_mem_top_k = search_args.get("tool_mem_top_k", 6) + playground_search_goal_parser = search_args.get( + "playground_search_goal_parser", False + ) + + info = search_args.get( + "info", + { + "user_id": user_id, + "session_id": target_session_id, + "chat_history": chat_history, + }, + ) + + results_long_term = mem_cube.text_mem.search( + query=query, + user_name=user_name, + top_k=top_k, + mode=SearchMode.FAST, + manual_close_internet=not internet_search, + memory_type="LongTermMemory", + search_filter=search_filter, + search_priority=search_priority, + info=info, + plugin=plugin, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, + playground_search_goal_parser=playground_search_goal_parser, + ) + + results_user = mem_cube.text_mem.search( + query=query, + user_name=user_name, + top_k=top_k, + mode=SearchMode.FAST, + manual_close_internet=not internet_search, + memory_type="UserMemory", + search_filter=search_filter, + search_priority=search_priority, + info=info, + plugin=plugin, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, + playground_search_goal_parser=playground_search_goal_parser, + ) + results = results_long_term + results_user + else: + raise NotImplementedError(str(type(text_mem_base))) + except Exception as e: + logger.error("Fail to search. The exeption is %s.", e, exc_info=True) + results = [] + return results diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index 497d19ac6..0e7390ed5 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -19,6 +19,7 @@ from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.search import build_search_context, search_text_memories from memos.types import ( MemCubeID, SearchMode, @@ -104,29 +105,13 @@ def search_memories( mem_cube: NaiveMemCube, mode: SearchMode, ): - """Fine search memories function copied from server_router to avoid circular import""" - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_priority = {"session_id": search_req.session_id} if search_req.session_id else None - search_filter = search_req.filter - - # Create MemCube and perform search - search_results = mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, + """Shared text-memory search via centralized search service.""" + return search_text_memories( + text_mem=mem_cube.text_mem, + search_req=search_req, + user_context=user_context, mode=mode, - manual_close_internet=not search_req.internet_search, - search_filter=search_filter, - search_priority=search_priority, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, ) - return search_results def mix_search_memories( self, @@ -157,19 +142,13 @@ def mix_search_memories( ] # Get mem_cube for fast search - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_priority = {"session_id": search_req.session_id} if search_req.session_id else None - search_filter = search_req.filter + search_ctx = build_search_context(search_req=search_req, user_context=user_context) + search_priority = search_ctx.search_priority + search_filter = search_ctx.search_filter # Rerank Memories - reranker expects TextualMemoryItem objects - info = { - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - } + info = search_ctx.info raw_retrieved_memories = self.searcher.retrieve( query=search_req.query, diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index bd026a51d..5058ca805 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -25,6 +25,7 @@ PREF_ADD_TASK_LABEL, ) from memos.multi_mem_cube.views import MemCubeView +from memos.search import search_text_memories from memos.templates.mem_reader_prompts import PROMPT_MAPPING from memos.types.general_types import ( FINE_STRATEGY, @@ -455,31 +456,11 @@ def _fast_search( Returns: List of search results """ - target_session_id = search_req.session_id or "default_session" - search_priority = {"session_id": search_req.session_id} if search_req.session_id else None - search_filter = search_req.filter or None - plugin = bool(search_req.source is not None and search_req.source == "plugin") - - search_results = self.naive_mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, + search_results = search_text_memories( + text_mem=self.naive_mem_cube.text_mem, + search_req=search_req, + user_context=user_context, mode=SearchMode.FAST, - manual_close_internet=not search_req.internet_search, - memory_type=search_req.search_memory_type, - search_filter=search_filter, - search_priority=search_priority, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, - plugin=plugin, - search_tool_memory=search_req.search_tool_memory, - tool_mem_top_k=search_req.tool_mem_top_k, - include_skill_memory=search_req.include_skill_memory, - skill_mem_top_k=search_req.skill_mem_top_k, - dedup=search_req.dedup, ) formatted_memories = [ diff --git a/src/memos/search/__init__.py b/src/memos/search/__init__.py new file mode 100644 index 000000000..1fa4e6819 --- /dev/null +++ b/src/memos/search/__init__.py @@ -0,0 +1,3 @@ +from .search_service import SearchContext, build_search_context, search_text_memories + +__all__ = ["SearchContext", "build_search_context", "search_text_memories"] diff --git a/src/memos/search/search_service.py b/src/memos/search/search_service.py new file mode 100644 index 000000000..9f6280355 --- /dev/null +++ b/src/memos/search/search_service.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from memos.api.product_models import APISearchRequest +from memos.types import SearchMode, UserContext + + +@dataclass(frozen=True) +class SearchContext: + target_session_id: str + search_priority: dict[str, Any] | None + search_filter: dict[str, Any] | None + info: dict[str, Any] + plugin: bool + + +def build_search_context( + search_req: APISearchRequest, + user_context: UserContext, +) -> SearchContext: + target_session_id = search_req.session_id or "default_session" + search_priority = {"session_id": search_req.session_id} if search_req.session_id else None + return SearchContext( + target_session_id=target_session_id, + search_priority=search_priority, + search_filter=search_req.filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + plugin=bool(search_req.source is not None and search_req.source == "plugin"), + ) + + +def search_text_memories( + text_mem: Any, + search_req: APISearchRequest, + user_context: UserContext, + mode: SearchMode, +) -> list[Any]: + """ + Shared text-memory search logic for API and scheduler paths. + """ + ctx = build_search_context(search_req=search_req, user_context=user_context) + return text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=mode, + manual_close_internet=not search_req.internet_search, + memory_type=search_req.search_memory_type, + search_filter=ctx.search_filter, + search_priority=ctx.search_priority, + info=ctx.info, + plugin=ctx.plugin, + search_tool_memory=search_req.search_tool_memory, + tool_mem_top_k=search_req.tool_mem_top_k, + include_skill_memory=search_req.include_skill_memory, + skill_mem_top_k=search_req.skill_mem_top_k, + dedup=search_req.dedup, + ) From 83b92faf5948cc97b6b4dbfce537e2b4550195dc Mon Sep 17 00:00:00 2001 From: fancy Date: Tue, 3 Feb 2026 20:00:10 +0800 Subject: [PATCH 02/14] fix: resolve ruff lint errors and address Copilot review feedback - Fix TC001/TC002/TC003: move type-only imports into TYPE_CHECKING blocks - Fix RUF059: prefix unused variables with underscore - Fix typos: "Memorires" -> "Memories", "exeption" -> "exception" - Remove self-assignment: `text_mem_base = text_mem_base` - Remove unused `user_context` param from `build_search_context` - Restore original `QUERY_TASK_LABEL` in activation memory update - Apply ruff format to all modified files --- .../mem_scheduler/base_mixins/__init__.py | 1 + .../mem_scheduler/base_mixins/memory_ops.py | 18 ++++++---- .../mem_scheduler/base_mixins/queue_ops.py | 25 ++++++++++---- .../mem_scheduler/base_mixins/web_log_ops.py | 4 ++- src/memos/mem_scheduler/base_scheduler.py | 30 ++++++++--------- src/memos/mem_scheduler/general_scheduler.py | 6 +++- src/memos/mem_scheduler/handlers/__init__.py | 1 + .../mem_scheduler/handlers/add_handler.py | 13 ++++++-- .../mem_scheduler/handlers/answer_handler.py | 6 +++- src/memos/mem_scheduler/handlers/base.py | 6 +++- src/memos/mem_scheduler/handlers/context.py | 13 +++++--- .../handlers/feedback_handler.py | 14 ++++++-- .../handlers/mem_read_handler.py | 30 +++++++++++++---- .../handlers/mem_reorganize_handler.py | 33 ++++++++++++------- .../handlers/memory_update_handler.py | 17 +++++++--- .../handlers/pref_add_handler.py | 10 ++++-- src/memos/mem_scheduler/handlers/registry.py | 9 +++-- .../enhancement_pipeline.py | 17 +++++++--- .../memory_manage_modules/filter_pipeline.py | 7 +++- .../memory_manage_modules/rerank_pipeline.py | 7 +++- .../memory_manage_modules/retriever.py | 7 +++- .../memory_manage_modules/search_pipeline.py | 2 +- .../mem_scheduler/optimized_scheduler.py | 2 +- .../task_schedule_modules/redis_queue.py | 12 +++---- src/memos/search/__init__.py | 1 + src/memos/search/search_service.py | 11 ++++--- 26 files changed, 210 insertions(+), 92 deletions(-) diff --git a/src/memos/mem_scheduler/base_mixins/__init__.py b/src/memos/mem_scheduler/base_mixins/__init__.py index 471d30f06..7e01cffc0 100644 --- a/src/memos/mem_scheduler/base_mixins/__init__.py +++ b/src/memos/mem_scheduler/base_mixins/__init__.py @@ -2,6 +2,7 @@ from .queue_ops import BaseSchedulerQueueMixin from .web_log_ops import BaseSchedulerWebLogMixin + __all__ = [ "BaseSchedulerMemoryMixin", "BaseSchedulerQueueMixin", diff --git a/src/memos/mem_scheduler/base_mixins/memory_ops.py b/src/memos/mem_scheduler/base_mixins/memory_ops.py index 5ad197a9e..87f284898 100644 --- a/src/memos/mem_scheduler/base_mixins/memory_ops.py +++ b/src/memos/mem_scheduler/base_mixins/memory_ops.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import datetime +from typing import TYPE_CHECKING from memos.log import get_logger from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem @@ -11,7 +12,10 @@ from memos.memories.textual.naive import NaiveTextMemory from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE -from memos.types.general_types import MemCubeID, UserID + + +if TYPE_CHECKING: + from memos.types.general_types import MemCubeID, UserID logger = get_logger(__name__) @@ -65,8 +69,6 @@ def replace_working_memory( ) -> None | list[TextualMemoryItem]: text_mem_base = mem_cube.text_mem if isinstance(text_mem_base, TreeTextMemory): - text_mem_base = text_mem_base - query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] query_db_manager.sync_with_orm() @@ -238,7 +240,9 @@ def update_activation_memory( if original_composed_text_memory == new_text_memory: logger.warning( "Skipping memory update - new composition matches existing cache: %s", - new_text_memory[:50] + "..." if len(new_text_memory) > 50 else new_text_memory, + new_text_memory[:50] + "..." + if len(new_text_memory) > 50 + else new_text_memory, ) return act_mem.delete_all() @@ -300,7 +304,9 @@ def update_activation_memory_periodically( user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube ) - activation_db_manager = self.monitor.activation_memory_monitors[user_id][mem_cube_id] + activation_db_manager = self.monitor.activation_memory_monitors[user_id][ + mem_cube_id + ] activation_db_manager.sync_with_orm() new_activation_memories = [ m.memory_text for m in activation_db_manager.obj.memories @@ -312,7 +318,7 @@ def update_activation_memory_periodically( ) for i, memory in enumerate(new_activation_memories[:5], 1): logger.info( - "Part of New Activation Memorires | %s/%s: %s", + "Part of New Activation Memories | %s/%s: %s", i, len(new_activation_memories), memory[:20], diff --git a/src/memos/mem_scheduler/base_mixins/queue_ops.py b/src/memos/mem_scheduler/base_mixins/queue_ops.py index ffe230c84..e5709ff36 100644 --- a/src/memos/mem_scheduler/base_mixins/queue_ops.py +++ b/src/memos/mem_scheduler/base_mixins/queue_ops.py @@ -3,14 +3,20 @@ import multiprocessing import time -from collections.abc import Callable from contextlib import suppress from datetime import datetime, timezone - -from memos.context.context import ContextThread, RequestContext, get_current_context, get_current_trace_id, set_request_context +from typing import TYPE_CHECKING + +from memos.context.context import ( + ContextThread, + RequestContext, + get_current_context, + get_current_trace_id, + set_request_context, +) from memos.log import get_logger -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.general_schemas import STARTUP_BY_PROCESS +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import TaskPriorityLevel from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube @@ -19,6 +25,9 @@ logger = get_logger(__name__) +if TYPE_CHECKING: + from collections.abc import Callable + class BaseSchedulerQueueMixin: def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): @@ -174,7 +183,9 @@ def _message_consumer(self) -> None: msg, { "enqueue_ts": to_iso(enqueue_ts_obj), - "dequeue_ts": datetime.fromtimestamp(now, tz=timezone.utc).isoformat(), + "dequeue_ts": datetime.fromtimestamp( + now, tz=timezone.utc + ).isoformat(), "queue_wait_ms": queue_wait_ms, "event_duration_ms": queue_wait_ms, "total_duration_ms": queue_wait_ms, @@ -324,7 +335,9 @@ def handlers(self) -> dict[str, Callable]: return self.dispatcher.handlers - def register_handlers(self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]]) -> None: + def register_handlers( + self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]] + ) -> None: if not self.dispatcher: logger.warning("Dispatcher is not initialized, cannot register handlers") return diff --git a/src/memos/mem_scheduler/base_mixins/web_log_ops.py b/src/memos/mem_scheduler/base_mixins/web_log_ops.py index beac47500..64b5348d3 100644 --- a/src/memos/mem_scheduler/base_mixins/web_log_ops.py +++ b/src/memos/mem_scheduler/base_mixins/web_log_ops.py @@ -98,7 +98,9 @@ def _normalize_item(item: ScheduleLogForWebItem) -> dict: def _with_memory_time(meta: dict) -> dict: enriched = dict(meta) if "memory_time" not in enriched: - enriched["memory_time"] = enriched.get("updated_at") or enriched.get("update_at") + enriched["memory_time"] = enriched.get("updated_at") or enriched.get( + "update_at" + ) return enriched data["metadata"] = [_with_memory_time(m) for m in metadata] diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index d14248f02..2cb104343 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1,21 +1,18 @@ +from __future__ import annotations + import os import threading from pathlib import Path -from typing import TYPE_CHECKING, Union - -from sqlalchemy.engine import Engine +from typing import TYPE_CHECKING from memos.configs.mem_scheduler import AuthConfig, BaseSchedulerConfig +from memos.log import get_logger from memos.mem_scheduler.base_mixins import ( BaseSchedulerMemoryMixin, BaseSchedulerQueueMixin, BaseSchedulerWebLogMixin, ) -from memos.llms.base import BaseLLM -from memos.log import get_logger -from memos.mem_cube.base import BaseMemCube -from memos.mem_feedback.simple_feedback import SimpleMemFeedback from memos.mem_scheduler.general_modules.init_components_for_scheduler import init_components from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule @@ -36,7 +33,6 @@ DEFAULT_USE_REDIS_QUEUE, TreeTextMemory_SEARCH_METHOD, ) -from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue @@ -44,18 +40,21 @@ from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule -from memos.memories.textual.tree import TreeTextMemory -from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher -from memos.types.general_types import ( - MemCubeID, - UserID, -) if TYPE_CHECKING: import redis + from sqlalchemy.engine import Engine + + from memos.llms.base import BaseLLM + from memos.mem_cube.base import BaseMemCube + from memos.mem_feedback.simple_feedback import SimpleMemFeedback + from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem + from memos.memories.textual.tree import TreeTextMemory + from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.http_bge import HTTPBGEReranker + from memos.types.general_types import MemCubeID, UserID logger = get_logger(__name__) @@ -194,7 +193,7 @@ def initialize_modules( process_llm: BaseLLM | None = None, db_engine: Engine | None = None, mem_reader=None, - redis_client: Union["redis.Redis", None] = None, + redis_client: redis.Redis | None = None, ): if process_llm is None: process_llm = chat_llm @@ -366,5 +365,4 @@ def mem_cubes(self, value: dict[str, BaseMemCube]) -> None: f"Failed to initialize current_mem_cube from mem_cubes: {e}", exc_info=True ) - # Methods moved to mixins in mem_scheduler.base_mixins. diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 66801def6..1fc3317d8 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -1,6 +1,10 @@ from __future__ import annotations -from memos.configs.mem_scheduler import GeneralSchedulerConfig +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.mem_scheduler.base_scheduler import BaseScheduler from memos.mem_scheduler.handlers import ( SchedulerHandlerContext, diff --git a/src/memos/mem_scheduler/handlers/__init__.py b/src/memos/mem_scheduler/handlers/__init__.py index 283740d1c..75c56791a 100644 --- a/src/memos/mem_scheduler/handlers/__init__.py +++ b/src/memos/mem_scheduler/handlers/__init__.py @@ -1,6 +1,7 @@ from .context import SchedulerHandlerContext, SchedulerHandlerServices from .registry import SchedulerHandlerRegistry + __all__ = [ "SchedulerHandlerContext", "SchedulerHandlerRegistry", diff --git a/src/memos/mem_scheduler/handlers/add_handler.py b/src/memos/mem_scheduler/handlers/add_handler.py index 900550952..5d1a8d3e0 100644 --- a/src/memos/mem_scheduler/handlers/add_handler.py +++ b/src/memos/mem_scheduler/handlers/add_handler.py @@ -2,9 +2,10 @@ import json +from typing import TYPE_CHECKING + from memos.log import get_logger from memos.mem_scheduler.handlers.base import BaseSchedulerHandler -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import ( ADD_TASK_LABEL, LONG_TERM_MEMORY_TYPE, @@ -12,7 +13,11 @@ ) from memos.mem_scheduler.utils.filter_utils import transform_name_to_key from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube, is_cloud_env -from memos.memories.textual.item import TextualMemoryItem + + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + from memos.memories.textual.item import TextualMemoryItem logger = get_logger(__name__) @@ -233,7 +238,9 @@ def send_add_log_messages_to_local_env( events.append(event) logger.info("send_add_log_messages_to_local_env: %s", len(events)) if events: - self.ctx.services.submit_web_logs(events, additional_log_info="send_add_log_messages_to_cloud_env") + self.ctx.services.submit_web_logs( + events, additional_log_info="send_add_log_messages_to_cloud_env" + ) def send_add_log_messages_to_cloud_env( self, diff --git a/src/memos/mem_scheduler/handlers/answer_handler.py b/src/memos/mem_scheduler/handlers/answer_handler.py index aa6bf708e..9ec4086a4 100644 --- a/src/memos/mem_scheduler/handlers/answer_handler.py +++ b/src/memos/mem_scheduler/handlers/answer_handler.py @@ -1,8 +1,9 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from memos.log import get_logger from memos.mem_scheduler.handlers.base import BaseSchedulerHandler -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import ( ANSWER_TASK_LABEL, NOT_APPLICABLE_TYPE, @@ -13,6 +14,9 @@ logger = get_logger(__name__) +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + class AnswerMessageHandler(BaseSchedulerHandler): def handle(self, messages: list[ScheduleMessageItem]) -> None: diff --git a/src/memos/mem_scheduler/handlers/base.py b/src/memos/mem_scheduler/handlers/base.py index f0d8246a3..e04add7d7 100644 --- a/src/memos/mem_scheduler/handlers/base.py +++ b/src/memos/mem_scheduler/handlers/base.py @@ -1,6 +1,10 @@ from __future__ import annotations -from memos.mem_scheduler.handlers.context import SchedulerHandlerContext +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from memos.mem_scheduler.handlers.context import SchedulerHandlerContext class BaseSchedulerHandler: diff --git a/src/memos/mem_scheduler/handlers/context.py b/src/memos/mem_scheduler/handlers/context.py index 848739fa3..d5c1ea9af 100644 --- a/src/memos/mem_scheduler/handlers/context.py +++ b/src/memos/mem_scheduler/handlers/context.py @@ -1,12 +1,15 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any -from collections.abc import Callable +from typing import TYPE_CHECKING, Any -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem -from memos.memories.textual.item import TextualMemoryItem + +if TYPE_CHECKING: + from collections.abc import Callable + + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem + from memos.memories.textual.item import TextualMemoryItem @dataclass(frozen=True) diff --git a/src/memos/mem_scheduler/handlers/feedback_handler.py b/src/memos/mem_scheduler/handlers/feedback_handler.py index 55dbd6add..cf52470dd 100644 --- a/src/memos/mem_scheduler/handlers/feedback_handler.py +++ b/src/memos/mem_scheduler/handlers/feedback_handler.py @@ -2,15 +2,19 @@ import json +from typing import TYPE_CHECKING + from memos.log import get_logger from memos.mem_scheduler.handlers.base import BaseSchedulerHandler -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import LONG_TERM_MEMORY_TYPE, USER_INPUT_TYPE from memos.mem_scheduler.utils.misc_utils import is_cloud_env logger = get_logger(__name__) +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + class FeedbackMessageHandler(BaseSchedulerHandler): def handle(self, messages: list[ScheduleMessageItem]) -> None: @@ -27,10 +31,14 @@ def handle(self, messages: list[ScheduleMessageItem]) -> None: try: feedback_data = json.loads(content) if isinstance(content, str) else content if not isinstance(feedback_data, dict): - logger.error("Failed to decode feedback_data or it is not a dict: %s", feedback_data) + logger.error( + "Failed to decode feedback_data or it is not a dict: %s", feedback_data + ) return except json.JSONDecodeError: - logger.error("Invalid JSON content for feedback message: %s", content, exc_info=True) + logger.error( + "Invalid JSON content for feedback message: %s", content, exc_info=True + ) return task_id = feedback_data.get("task_id") or message.task_id diff --git a/src/memos/mem_scheduler/handlers/mem_read_handler.py b/src/memos/mem_scheduler/handlers/mem_read_handler.py index cb1c7631c..76789f113 100644 --- a/src/memos/mem_scheduler/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/handlers/mem_read_handler.py @@ -5,10 +5,11 @@ import json import traceback +from typing import TYPE_CHECKING + from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.handlers.base import BaseSchedulerHandler -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import ( LONG_TERM_MEMORY_TYPE, MEM_READ_TASK_LABEL, @@ -21,6 +22,9 @@ logger = get_logger(__name__) +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + class MemReadMessageHandler(BaseSchedulerHandler): def handle(self, messages: list[ScheduleMessageItem]) -> None: @@ -117,7 +121,9 @@ def _process_memories_with_reader( try: mem_reader = self.ctx.get_mem_reader() if mem_reader is None: - logger.warning("mem_reader not available in scheduler, skipping enhanced processing") + logger.warning( + "mem_reader not available in scheduler, skipping enhanced processing" + ) return memory_items = [] @@ -126,7 +132,9 @@ def _process_memories_with_reader( memory_item = text_mem.get(mem_id, user_name=user_name) memory_items.append(memory_item) except Exception as e: - logger.warning("[_process_memories_with_reader] Failed to get memory %s: %s", mem_id, e) + logger.warning( + "[_process_memories_with_reader] Failed to get memory %s: %s", mem_id, e + ) continue if not memory_items: @@ -212,7 +220,9 @@ def _process_memories_with_reader( for item in flattened_memories: metadata = getattr(item, "metadata", None) file_ids = getattr(metadata, "file_ids", None) if metadata else None - source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None + source_doc_id = ( + file_ids[0] if isinstance(file_ids, list) and file_ids else None + ) kb_log_content.append( { "log_source": "KNOWLEDGE_BASE_LOG", @@ -254,7 +264,9 @@ def _process_memories_with_reader( else: add_content_legacy: list[dict] = [] add_meta_legacy: list[dict] = [] - for item_id, item in zip(enhanced_mem_ids, flattened_memories, strict=False): + for item_id, item in zip( + enhanced_mem_ids, flattened_memories, strict=False + ): key = getattr(item.metadata, "key", None) or transform_name_to_key( name=item.memory ) @@ -302,7 +314,9 @@ def _process_memories_with_reader( if delete_ids: try: text_mem.delete(delete_ids, user_name=user_name) - logger.info("Delete raw/working mem_ids: %s for user_name: %s", delete_ids, user_name) + logger.info( + "Delete raw/working mem_ids: %s for user_name: %s", delete_ids, user_name + ) except Exception as e: logger.warning("Failed to delete some mem_ids %s: %s", delete_ids, e) else: @@ -322,7 +336,9 @@ def _process_memories_with_reader( cloud_env = is_cloud_env() if cloud_env: if not kb_log_content: - trigger_source = info.get("trigger_source", "Messages") if info else "Messages" + trigger_source = ( + info.get("trigger_source", "Messages") if info else "Messages" + ) kb_log_content = [ { "log_source": "KNOWLEDGE_BASE_LOG", diff --git a/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py b/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py index ce320fc8d..d437ebbd6 100644 --- a/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py +++ b/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py @@ -5,18 +5,22 @@ import json import traceback +from typing import TYPE_CHECKING + from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.handlers.base import BaseSchedulerHandler -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import LONG_TERM_MEMORY_TYPE, MEM_ORGANIZE_TASK_LABEL from memos.mem_scheduler.utils.filter_utils import transform_name_to_key -from memos.memories.textual.item import TextualMemoryItem from memos.memories.textual.tree import TreeTextMemory logger = get_logger(__name__) +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + from memos.memories.textual.item import TextualMemoryItem + class MemReorganizeMessageHandler(BaseSchedulerHandler): def handle(self, messages: list[ScheduleMessageItem]) -> None: @@ -80,16 +84,14 @@ def process_message(message: ScheduleMessageItem): ) for edge in edges: target = ( - edge.get("to") - or edge.get("dst") - or edge.get("target") + edge.get("to") or edge.get("dst") or edge.get("target") ) if target: merged_target_ids.add(target) for item in mem_items: - key = getattr(getattr(item, "metadata", {}), "key", None) or transform_name_to_key( - getattr(item, "memory", "") - ) + key = getattr( + getattr(item, "metadata", {}), "key", None + ) or transform_name_to_key(getattr(item, "memory", "")) keys.append(key) memcube_content.append( {"content": key or "(no key)", "ref_id": item.id, "type": "merged"} @@ -132,7 +134,9 @@ def process_message(message: ScheduleMessageItem): post_meta = { "ref_id": post_ref_id, "id": post_ref_id, - "key": getattr(getattr(merged_item, "metadata", {}), "key", None), + "key": getattr( + getattr(merged_item, "metadata", {}), "key", None + ), "memory": getattr(merged_item, "memory", None), "memory_type": getattr( getattr(merged_item, "metadata", {}), "memory_type", None @@ -157,7 +161,8 @@ def process_message(message: ScheduleMessageItem): import hashlib post_ref_id = ( - "merge-" + hashlib.md5("".join(sorted(mem_ids)).encode()).hexdigest() + "merge-" + + hashlib.md5("".join(sorted(mem_ids)).encode()).hexdigest() ) post_meta["ref_id"] = post_ref_id post_meta["id"] = post_ref_id @@ -216,7 +221,9 @@ def _process_memories_with_reorganize( try: mem_reader = self.ctx.get_mem_reader() if mem_reader is None: - logger.warning("mem_reader not available in scheduler, skipping enhanced processing") + logger.warning( + "mem_reader not available in scheduler, skipping enhanced processing" + ) return memory_items = [] @@ -225,7 +232,9 @@ def _process_memories_with_reorganize( memory_item = text_mem.get(mem_id, user_name=user_name) memory_items.append(memory_item) except Exception as e: - logger.warning("Failed to get memory %s: %s|%s", mem_id, e, traceback.format_exc()) + logger.warning( + "Failed to get memory %s: %s|%s", mem_id, e, traceback.format_exc() + ) continue if not memory_items: diff --git a/src/memos/mem_scheduler/handlers/memory_update_handler.py b/src/memos/mem_scheduler/handlers/memory_update_handler.py index 0775fc91a..0d3d1719e 100644 --- a/src/memos/mem_scheduler/handlers/memory_update_handler.py +++ b/src/memos/mem_scheduler/handlers/memory_update_handler.py @@ -1,23 +1,28 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from memos.log import get_logger from memos.mem_scheduler.handlers.base import BaseSchedulerHandler -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem from memos.mem_scheduler.schemas.task_schemas import ( DEFAULT_MAX_QUERY_KEY_WORDS, MEM_UPDATE_TASK_LABEL, + QUERY_TASK_LABEL, ) from memos.mem_scheduler.utils.filter_utils import is_all_chinese, is_all_english from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube -from memos.memories.textual.item import TextualMemoryItem from memos.memories.textual.naive import NaiveTextMemory from memos.memories.textual.tree import TreeTextMemory -from memos.types import MemCubeID, UserID logger = get_logger(__name__) +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + from memos.memories.textual.item import TextualMemoryItem + from memos.types import MemCubeID, UserID + class MemoryUpdateHandler(BaseSchedulerHandler): def handle(self, messages: list[ScheduleMessageItem]) -> None: @@ -159,7 +164,7 @@ def long_memory_update_process( if self.ctx.get_enable_activation_memory(): self.ctx.services.update_activation_memory_periodically( interval_seconds=monitor.act_mem_update_interval, - label=MEM_UPDATE_TASK_LABEL, + label=QUERY_TASK_LABEL, user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, @@ -203,7 +208,9 @@ def process_session_turn( text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] monitor = self.ctx.get_monitor() - intent_result = monitor.detect_intent(q_list=queries, text_working_memory=text_working_memory) + intent_result = monitor.detect_intent( + q_list=queries, text_working_memory=text_working_memory + ) time_trigger_flag = False if monitor.timed_trigger( diff --git a/src/memos/mem_scheduler/handlers/pref_add_handler.py b/src/memos/mem_scheduler/handlers/pref_add_handler.py index 195b35385..4d17b0847 100644 --- a/src/memos/mem_scheduler/handlers/pref_add_handler.py +++ b/src/memos/mem_scheduler/handlers/pref_add_handler.py @@ -3,16 +3,20 @@ import concurrent.futures import json +from typing import TYPE_CHECKING + from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.handlers.base import BaseSchedulerHandler -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import PREF_ADD_TASK_LABEL from memos.memories.textual.preference import PreferenceTextMemory logger = get_logger(__name__) +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + class PrefAddMessageHandler(BaseSchedulerHandler): def handle(self, messages: list[ScheduleMessageItem]) -> None: @@ -36,7 +40,9 @@ def process_message(message: ScheduleMessageItem): messages_list = json.loads(content) info = message.info or {} - logger.info("Processing pref_add for user_id=%s, mem_cube_id=%s", user_id, mem_cube_id) + logger.info( + "Processing pref_add for user_id=%s, mem_cube_id=%s", user_id, mem_cube_id + ) pref_mem = mem_cube.pref_mem if pref_mem is None: diff --git a/src/memos/mem_scheduler/handlers/registry.py b/src/memos/mem_scheduler/handlers/registry.py index 1e1db0404..2a62aa57f 100644 --- a/src/memos/mem_scheduler/handlers/registry.py +++ b/src/memos/mem_scheduler/handlers/registry.py @@ -1,10 +1,15 @@ from __future__ import annotations -from collections.abc import Callable +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + + from memos.mem_scheduler.handlers.context import SchedulerHandlerContext from memos.mem_scheduler.handlers.add_handler import AddMessageHandler from memos.mem_scheduler.handlers.answer_handler import AnswerMessageHandler -from memos.mem_scheduler.handlers.context import SchedulerHandlerContext from memos.mem_scheduler.handlers.feedback_handler import FeedbackMessageHandler from memos.mem_scheduler.handlers.mem_read_handler import MemReadMessageHandler from memos.mem_scheduler.handlers.mem_reorganize_handler import MemReorganizeMessageHandler diff --git a/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py index 5dd5e95d3..98125c13b 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py +++ b/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py @@ -1,7 +1,8 @@ from __future__ import annotations import time -from collections.abc import Callable + +from typing import TYPE_CHECKING from memos.log import get_logger from memos.mem_scheduler.schemas.general_schemas import ( @@ -15,6 +16,9 @@ logger = get_logger(__name__) +if TYPE_CHECKING: + from collections.abc import Callable + class EnhancementPipeline: def __init__(self, process_llm, config, build_prompt: Callable[..., str]): @@ -114,6 +118,7 @@ def _process_enhancement_batch( ) ) elif FINE_STRATEGY == FineStrategy.REWRITE: + def _parse_index_and_text(s: str) -> tuple[int | None, str]: import re @@ -153,8 +158,9 @@ def _parse_index_and_text(s: str) -> tuple[int | None, str]: ) return enhanced_memories, True raise ValueError( - "Fail to run memory enhancement; retry %s/%s; processed_text_memories: %s" - % (attempt, max(1, retries) + 1, processed_text_memories) + "Fail to run memory enhancement; retry " + f"{attempt}/{max(1, retries) + 1}; " + f"processed_text_memories: {processed_text_memories}" ) except Exception as e: attempt += 1 @@ -237,9 +243,10 @@ def enhance_memories_with_query( all_success = True failed_batches = 0 - from memos.context.context import ContextThreadPoolExecutor from concurrent.futures import as_completed + from memos.context.context import ContextThreadPoolExecutor + with ContextThreadPoolExecutor(max_workers=len(batches)) as executor: future_map = { executor.submit( @@ -249,7 +256,7 @@ def enhance_memories_with_query( } enhanced_memories = [] for fut in as_completed(future_map): - bi, s, e = future_map[fut] + _bi, _s, _e = future_map[fut] batch_memories, ok = fut.result() enhanced_memories.extend(batch_memories) diff --git a/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py index 8bebe2456..315f821a9 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py +++ b/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py @@ -1,7 +1,12 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from memos.mem_scheduler.memory_manage_modules.memory_filter import MemoryFilter -from memos.memories.textual.tree import TextualMemoryItem + + +if TYPE_CHECKING: + from memos.memories.textual.tree import TextualMemoryItem class FilterPipeline: diff --git a/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py index 21dabedd9..0e347df6a 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py +++ b/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from memos.log import get_logger from memos.mem_scheduler.utils.filter_utils import ( filter_too_short_memories, @@ -7,7 +9,10 @@ transform_name_to_key, ) from memos.mem_scheduler.utils.misc_utils import extract_json_obj -from memos.memories.textual.item import TextualMemoryItem + + +if TYPE_CHECKING: + from memos.memories.textual.item import TextualMemoryItem logger = get_logger(__name__) diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py index 41e268cef..3e849f470 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py +++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py @@ -1,12 +1,17 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.memory_manage_modules.enhancement_pipeline import EnhancementPipeline from memos.mem_scheduler.memory_manage_modules.filter_pipeline import FilterPipeline from memos.mem_scheduler.memory_manage_modules.rerank_pipeline import RerankPipeline from memos.mem_scheduler.memory_manage_modules.search_pipeline import SearchPipeline -from memos.memories.textual.item import TextualMemoryItem + + +if TYPE_CHECKING: + from memos.memories.textual.item import TextualMemoryItem logger = get_logger(__name__) diff --git a/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py index 65496b478..a346622c5 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py +++ b/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py @@ -89,6 +89,6 @@ def search( else: raise NotImplementedError(str(type(text_mem_base))) except Exception as e: - logger.error("Fail to search. The exeption is %s.", e, exc_info=True) + logger.error("Fail to search. The exception is %s.", e, exc_info=True) results = [] return results diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index 0e7390ed5..e535d6f73 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -142,7 +142,7 @@ def mix_search_memories( ] # Get mem_cube for fast search - search_ctx = build_search_context(search_req=search_req, user_context=user_context) + search_ctx = build_search_context(search_req=search_req) search_priority = search_ctx.search_priority search_filter = search_ctx.search_filter diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index d570dccdd..dc7d86752 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -718,10 +718,10 @@ def _claim_pending_messages( justid=False, ) if len(claimed_result) == 2: - next_id, claimed = claimed_result - deleted_ids = [] + _next_id, claimed = claimed_result + _deleted_ids = [] elif len(claimed_result) == 3: - next_id, claimed, deleted_ids = claimed_result + _next_id, claimed, _deleted_ids = claimed_result else: raise ValueError( f"Unexpected xautoclaim response length: {len(claimed_result)}" @@ -745,10 +745,10 @@ def _claim_pending_messages( justid=False, ) if len(claimed_result) == 2: - next_id, claimed = claimed_result - deleted_ids = [] + _next_id, claimed = claimed_result + _deleted_ids = [] elif len(claimed_result) == 3: - next_id, claimed, deleted_ids = claimed_result + _next_id, claimed, _deleted_ids = claimed_result else: raise ValueError( f"Unexpected xautoclaim response length: {len(claimed_result)}" diff --git a/src/memos/search/__init__.py b/src/memos/search/__init__.py index 1fa4e6819..71388c62b 100644 --- a/src/memos/search/__init__.py +++ b/src/memos/search/__init__.py @@ -1,3 +1,4 @@ from .search_service import SearchContext, build_search_context, search_text_memories + __all__ = ["SearchContext", "build_search_context", "search_text_memories"] diff --git a/src/memos/search/search_service.py b/src/memos/search/search_service.py index 9f6280355..79c9a43e5 100644 --- a/src/memos/search/search_service.py +++ b/src/memos/search/search_service.py @@ -1,10 +1,12 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any -from memos.api.product_models import APISearchRequest -from memos.types import SearchMode, UserContext + +if TYPE_CHECKING: + from memos.api.product_models import APISearchRequest + from memos.types import SearchMode, UserContext @dataclass(frozen=True) @@ -18,7 +20,6 @@ class SearchContext: def build_search_context( search_req: APISearchRequest, - user_context: UserContext, ) -> SearchContext: target_session_id = search_req.session_id or "default_session" search_priority = {"session_id": search_req.session_id} if search_req.session_id else None @@ -44,7 +45,7 @@ def search_text_memories( """ Shared text-memory search logic for API and scheduler paths. """ - ctx = build_search_context(search_req=search_req, user_context=user_context) + ctx = build_search_context(search_req=search_req) return text_mem.search( query=search_req.query, user_name=user_context.mem_cube_id, From fdb7374ed46a06d0c208f21ad2bbe67369fb432d Mon Sep 17 00:00:00 2001 From: fancy Date: Wed, 4 Feb 2026 16:03:38 +0800 Subject: [PATCH 03/14] fix(redis): serialize schedule messages for streams - json-encode list/dict fields for Redis XADD - decode chat_history safely when reading from streams --- .../mem_scheduler/schemas/message_schemas.py | 56 +++++++++++++++---- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index c7f270f19..f08c8a067 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -1,3 +1,5 @@ +import json + from datetime import datetime from typing import Any from uuid import uuid4 @@ -79,7 +81,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): def to_dict(self) -> dict: """Convert model to dictionary suitable for Redis Stream""" - return { + raw = { "item_id": self.item_id, "user_id": self.user_id, "cube_id": self.mem_cube_id, @@ -92,21 +94,53 @@ def to_dict(self) -> dict: "task_id": self.task_id if self.task_id is not None else "", "chat_history": self.chat_history if self.chat_history is not None else [], } + return {key: self._serialize_redis_value(value) for key, value in raw.items()} + + @staticmethod + def _serialize_redis_value(value: Any) -> Any: + if value is None: + return "" + if isinstance(value, (list, dict)): + return json.dumps(value, ensure_ascii=False) + return value @classmethod def from_dict(cls, data: dict) -> "ScheduleMessageItem": """Create model from Redis Stream dictionary""" + def _decode(val: Any) -> Any: + if isinstance(val, (bytes, bytearray)): + return val.decode("utf-8") + return val + + raw_chat_history = _decode(data.get("chat_history")) + if isinstance(raw_chat_history, str): + if raw_chat_history: + try: + chat_history = json.loads(raw_chat_history) + except Exception: + chat_history = None + else: + chat_history = None + else: + chat_history = raw_chat_history + + raw_timestamp = _decode(data.get("timestamp")) + timestamp = ( + datetime.fromisoformat(raw_timestamp) + if raw_timestamp + else get_utc_now() + ) return cls( - item_id=data.get("item_id", str(uuid4())), - user_id=data["user_id"], - mem_cube_id=data["cube_id"], - trace_id=data.get("trace_id", generate_trace_id()), - label=data["label"], - content=data["content"], - timestamp=datetime.fromisoformat(data["timestamp"]), - user_name=data.get("user_name"), - task_id=data.get("task_id"), - chat_history=data.get("chat_history"), + item_id=_decode(data.get("item_id", str(uuid4()))), + user_id=_decode(data["user_id"]), + mem_cube_id=_decode(data["cube_id"]), + trace_id=_decode(data.get("trace_id", generate_trace_id())), + label=_decode(data["label"]), + content=_decode(data["content"]), + timestamp=timestamp, + user_name=_decode(data.get("user_name")), + task_id=_decode(data.get("task_id")), + chat_history=chat_history, ) From 6e3333be367ab935b72bf61be31af639c69782ed Mon Sep 17 00:00:00 2001 From: fancy Date: Wed, 4 Feb 2026 16:04:16 +0800 Subject: [PATCH 04/14] fix(examples): use scheduler handlers - avoid private _memory_update_consumer call - delegate mem update handler to built-in handler --- examples/mem_scheduler/memos_w_scheduler.py | 27 ++++++------------- .../mem_scheduler/try_schedule_modules.py | 9 ++++--- 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py index b7250a677..f3c1c7f87 100644 --- a/examples/mem_scheduler/memos_w_scheduler.py +++ b/examples/mem_scheduler/memos_w_scheduler.py @@ -84,7 +84,9 @@ def init_task(): return conversations, questions -working_memories = [] +default_mem_update_handler = mem_scheduler.handlers.get(MEM_UPDATE_TASK_LABEL) +if default_mem_update_handler is None: + logger.warning("Default MEM_UPDATE handler not found; custom handler will be a no-op.") # Define custom query handler function @@ -100,24 +102,11 @@ def custom_query_handler(messages: list[ScheduleMessageItem]): # Define custom memory update handler function def custom_mem_update_handler(messages: list[ScheduleMessageItem]): - global working_memories - search_args = {} - top_k = 2 - for msg in messages: - # Search for memories relevant to the current content in text memory (return top_k=2) - results = mem_scheduler.retriever.search( - query=msg.content, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=mem_scheduler.current_mem_cube, - top_k=top_k, - method=mem_scheduler.search_method, - search_args=search_args, - ) - working_memories.extend(results) - working_memories = working_memories[-5:] - for mem in results: - print(f"\n[scheduler] Retrieved memory: {mem.memory}") + if default_mem_update_handler is None: + logger.error("Default MEM_UPDATE handler missing; cannot process messages.") + return + # Delegate to the built-in handler to keep behavior aligned with scheduler refactor. + default_mem_update_handler(messages) async def run_with_scheduler(): diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py index d942aad4e..51d87b1aa 100644 --- a/examples/mem_scheduler/try_schedule_modules.py +++ b/examples/mem_scheduler/try_schedule_modules.py @@ -167,7 +167,8 @@ def add_msgs( label=MEM_UPDATE_TASK_LABEL, content=query, ) - # Run one session turn manually to get search candidates - mem_scheduler._memory_update_consumer( - messages=[message], - ) + # Run one session turn manually via registered handler (public surface) + handler = mem_scheduler.handlers.get(MEM_UPDATE_TASK_LABEL) + if handler is None: + raise RuntimeError("MEM_UPDATE handler not registered on mem_scheduler.") + handler([message]) From ed039231724008f9db7c19c0318fda5285169cb4 Mon Sep 17 00:00:00 2001 From: fancy Date: Wed, 4 Feb 2026 16:18:30 +0800 Subject: [PATCH 05/14] style(ruff): fix isinstance union syntax - apply X | Y form to satisfy UP038 --- src/memos/mem_scheduler/schemas/message_schemas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index f08c8a067..979061383 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -100,7 +100,7 @@ def to_dict(self) -> dict: def _serialize_redis_value(value: Any) -> Any: if value is None: return "" - if isinstance(value, (list, dict)): + if isinstance(value, list | dict): return json.dumps(value, ensure_ascii=False) return value @@ -108,7 +108,7 @@ def _serialize_redis_value(value: Any) -> Any: def from_dict(cls, data: dict) -> "ScheduleMessageItem": """Create model from Redis Stream dictionary""" def _decode(val: Any) -> Any: - if isinstance(val, (bytes, bytearray)): + if isinstance(val, bytes | bytearray): return val.decode("utf-8") return val From 5c356ddf3314cf0adc22d2d8bbbb18489a0a40d0 Mon Sep 17 00:00:00 2001 From: fancy Date: Wed, 4 Feb 2026 16:32:30 +0800 Subject: [PATCH 06/14] style(ruff): format message schemas - align datetime line breaks with ruff format --- src/memos/mem_scheduler/schemas/message_schemas.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 979061383..869fcc91f 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -107,6 +107,7 @@ def _serialize_redis_value(value: Any) -> Any: @classmethod def from_dict(cls, data: dict) -> "ScheduleMessageItem": """Create model from Redis Stream dictionary""" + def _decode(val: Any) -> Any: if isinstance(val, bytes | bytearray): return val.decode("utf-8") @@ -125,11 +126,7 @@ def _decode(val: Any) -> Any: chat_history = raw_chat_history raw_timestamp = _decode(data.get("timestamp")) - timestamp = ( - datetime.fromisoformat(raw_timestamp) - if raw_timestamp - else get_utc_now() - ) + timestamp = datetime.fromisoformat(raw_timestamp) if raw_timestamp else get_utc_now() return cls( item_id=_decode(data.get("item_id", str(uuid4()))), user_id=_decode(data["user_id"]), From f8dcb6ae8c4db1a6c3481b4a019fd16355039431 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 4 Feb 2026 18:47:55 +0800 Subject: [PATCH 07/14] refactor: integrate architectural improvements from refactor-scheduler-stage2 This commit merges key modularization benefits and bug fixes from the refactor-scheduler-stage2 branch into fancy-scheduler: Modularize activation memory logic into ActivationMemoryManager Introduce SchedulerSearchService for unified memory search coordination Extract filtering and reranking logic into MemoryPostProcessor Maintain intentional search scope of LongTermMemory and UserMemory in SearchPipeline and SchedulerSearchService Update BaseScheduler to initialize and manage lifecycle of new modules Refactor BaseSchedulerMemoryMixin to delegate tasks to specialized managers --- .../mem_scheduler/base_mixins/memory_ops.py | 153 +-------- src/memos/mem_scheduler/base_scheduler.py | 69 ++++ .../activation_memory_manager.py | 186 +++++++++++ .../memory_manage_modules/post_processor.py | 307 ++++++++++++++++++ .../memory_manage_modules/search_service.py | 297 +++++++++++++++++ 5 files changed, 874 insertions(+), 138 deletions(-) create mode 100644 src/memos/mem_scheduler/memory_manage_modules/activation_memory_manager.py create mode 100644 src/memos/mem_scheduler/memory_manage_modules/post_processor.py create mode 100644 src/memos/mem_scheduler/memory_manage_modules/search_service.py diff --git a/src/memos/mem_scheduler/base_mixins/memory_ops.py b/src/memos/mem_scheduler/base_mixins/memory_ops.py index 87f284898..35e095422 100644 --- a/src/memos/mem_scheduler/base_mixins/memory_ops.py +++ b/src/memos/mem_scheduler/base_mixins/memory_ops.py @@ -1,17 +1,12 @@ from __future__ import annotations -from datetime import datetime from typing import TYPE_CHECKING from memos.log import get_logger from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem -from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.filter_utils import transform_name_to_key -from memos.memories.activation.kv import KVCacheMemory -from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory from memos.memories.textual.naive import NaiveTextMemory from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory -from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE if TYPE_CHECKING: @@ -201,71 +196,16 @@ def update_activation_memory( mem_cube_id: MemCubeID | str, mem_cube, ) -> None: - if len(new_memories) == 0: - logger.error("update_activation_memory: new_memory is empty.") - return - if isinstance(new_memories[0], TextualMemoryItem): - new_text_memories = [mem.memory for mem in new_memories] - elif isinstance(new_memories[0], str): - new_text_memories = new_memories - else: - logger.error("Not Implemented.") - return - - try: - if isinstance(mem_cube.act_mem, VLLMKVCacheMemory): - act_mem: VLLMKVCacheMemory = mem_cube.act_mem - elif isinstance(mem_cube.act_mem, KVCacheMemory): - act_mem = mem_cube.act_mem - else: - logger.error("Not Implemented.") - return - - new_text_memory = MEMORY_ASSEMBLY_TEMPLATE.format( - memory_text="".join( - [ - f"{i + 1}. {sentence.strip()}\n" - for i, sentence in enumerate(new_text_memories) - if sentence.strip() - ] - ) - ) - - original_cache_items: list[VLLMKVCacheItem] = act_mem.get_all() - original_text_memories = [] - if len(original_cache_items) > 0: - pre_cache_item: VLLMKVCacheItem = original_cache_items[-1] - original_text_memories = pre_cache_item.records.text_memories - original_composed_text_memory = pre_cache_item.records.composed_text_memory - if original_composed_text_memory == new_text_memory: - logger.warning( - "Skipping memory update - new composition matches existing cache: %s", - new_text_memory[:50] + "..." - if len(new_text_memory) > 50 - else new_text_memory, - ) - return - act_mem.delete_all() - - cache_item = act_mem.extract(new_text_memory) - cache_item.records.text_memories = new_text_memories - cache_item.records.timestamp = get_utc_now() - - act_mem.add([cache_item]) - act_mem.dump(self.act_mem_dump_path) - - self.log_activation_memory_update( - original_text_memories=original_text_memories, - new_text_memories=new_text_memories, + if hasattr(self, "activation_memory_manager") and self.activation_memory_manager: + self.activation_memory_manager.update_activation_memory( + new_memories=new_memories, label=label, user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, - log_func_callback=self._submit_web_logs, ) - - except Exception as e: - logger.error("MOS-based activation memory update failed: %s", e, exc_info=True) + else: + logger.warning("Activation memory manager not initialized") def update_activation_memory_periodically( self, @@ -275,76 +215,13 @@ def update_activation_memory_periodically( mem_cube_id: MemCubeID | str, mem_cube, ): - try: - if ( - self.monitor.last_activation_mem_update_time == datetime.min - or self.monitor.timed_trigger( - last_time=self.monitor.last_activation_mem_update_time, - interval_seconds=interval_seconds, - ) - ): - logger.info( - "Updating activation memory for user %s and mem_cube %s", - user_id, - mem_cube_id, - ) - - if ( - user_id not in self.monitor.working_memory_monitors - or mem_cube_id not in self.monitor.working_memory_monitors[user_id] - or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories) - == 0 - ): - logger.warning( - "No memories found in working_memory_monitors, activation memory update is skipped" - ) - return - - self.monitor.update_activation_memory_monitors( - user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube - ) - - activation_db_manager = self.monitor.activation_memory_monitors[user_id][ - mem_cube_id - ] - activation_db_manager.sync_with_orm() - new_activation_memories = [ - m.memory_text for m in activation_db_manager.obj.memories - ] - - logger.info( - "Collected %s new memory entries for processing", - len(new_activation_memories), - ) - for i, memory in enumerate(new_activation_memories[:5], 1): - logger.info( - "Part of New Activation Memories | %s/%s: %s", - i, - len(new_activation_memories), - memory[:20], - ) - - self.update_activation_memory( - new_memories=new_activation_memories, - label=label, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - - self.monitor.last_activation_mem_update_time = get_utc_now() - - logger.debug( - "Activation memory update completed at %s", - self.monitor.last_activation_mem_update_time, - ) - - else: - logger.info( - "Skipping update - %s second interval not yet reached. Last update time is %s and now is %s", - interval_seconds, - self.monitor.last_activation_mem_update_time, - get_utc_now(), - ) - except Exception as e: - logger.error("Error in update_activation_memory_periodically: %s", e, exc_info=True) + if hasattr(self, "activation_memory_manager") and self.activation_memory_manager: + self.activation_memory_manager.update_activation_memory_periodically( + interval_seconds=interval_seconds, + label=label, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + else: + logger.warning("Activation memory manager not initialized") diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 2cb104343..2e408b222 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -16,7 +16,12 @@ from memos.mem_scheduler.general_modules.init_components_for_scheduler import init_components from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule +from memos.mem_scheduler.memory_manage_modules.activation_memory_manager import ( + ActivationMemoryManager, +) +from memos.mem_scheduler.memory_manage_modules.post_processor import MemoryPostProcessor from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever +from memos.mem_scheduler.memory_manage_modules.search_service import SchedulerSearchService from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor from memos.mem_scheduler.monitors.task_schedule_monitor import TaskScheduleMonitor @@ -51,6 +56,7 @@ from memos.mem_cube.base import BaseMemCube from memos.mem_feedback.simple_feedback import SimpleMemFeedback from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem + from memos.memories.textual.item import TextualMemoryItem from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.http_bge import HTTPBGEReranker @@ -118,6 +124,9 @@ def __init__(self, config: BaseSchedulerConfig): self.orchestrator = SchedulerOrchestrator() self.searcher: Searcher | None = None + self.search_service: SchedulerSearchService | None = None + self.post_processor: MemoryPostProcessor | None = None + self.activation_memory_manager: ActivationMemoryManager | None = None self.retriever: SchedulerRetriever | None = None self.db_engine: Engine | None = None self.monitor: SchedulerGeneralMonitor | None = None @@ -187,6 +196,9 @@ def init_mem_cube( self.searcher = searcher self.feedback_server = feedback_server + # Initialize search service with the searcher + self.search_service = SchedulerSearchService(searcher=self.searcher) + def initialize_modules( self, chat_llm: BaseLLM, @@ -217,6 +229,21 @@ def initialize_modules( self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config) self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config) + # Initialize search service (will be updated with searcher when mem_cube is initialized) + self.search_service = SchedulerSearchService(searcher=self.searcher) + + # Initialize post-processor for memory enhancement and filtering + self.post_processor = MemoryPostProcessor( + process_llm=self.process_llm, config=self.config + ) + + self.activation_memory_manager = ActivationMemoryManager( + act_mem_dump_path=self.act_mem_dump_path, + monitor=self.monitor, + log_func_callback=self._submit_web_logs, + log_activation_memory_update_func=self.log_activation_memory_update, + ) + if mem_reader: self.mem_reader = mem_reader @@ -366,3 +393,45 @@ def mem_cubes(self, value: dict[str, BaseMemCube]) -> None: ) # Methods moved to mixins in mem_scheduler.base_mixins. + + def update_activation_memory( + self, + new_memories: list[str | TextualMemoryItem], + label: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube: BaseMemCube, + ) -> None: + """ + Update activation memory by extracting KVCacheItems from new_memory (list of str), + add them to a KVCacheMemory instance, and dump to disk. + """ + if self.activation_memory_manager: + self.activation_memory_manager.update_activation_memory( + new_memories=new_memories, + label=label, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + else: + logger.warning("Activation memory manager not initialized") + + def update_activation_memory_periodically( + self, + interval_seconds: int, + label: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube: BaseMemCube, + ): + if self.activation_memory_manager: + self.activation_memory_manager.update_activation_memory_periodically( + interval_seconds=interval_seconds, + label=label, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + else: + logger.warning("Activation memory manager not initialized") diff --git a/src/memos/mem_scheduler/memory_manage_modules/activation_memory_manager.py b/src/memos/mem_scheduler/memory_manage_modules/activation_memory_manager.py new file mode 100644 index 000000000..589d0e421 --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/activation_memory_manager.py @@ -0,0 +1,186 @@ +from collections.abc import Callable +from datetime import datetime + +from memos.log import get_logger +from memos.mem_cube.general import GeneralMemCube +from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.memories.activation.kv import KVCacheMemory +from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory +from memos.memories.textual.tree import TextualMemoryItem +from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE +from memos.types.general_types import MemCubeID, UserID + + +logger = get_logger(__name__) + + +class ActivationMemoryManager: + def __init__( + self, + act_mem_dump_path: str, + monitor: SchedulerGeneralMonitor, + log_func_callback: Callable, + log_activation_memory_update_func: Callable, + ): + self.act_mem_dump_path = act_mem_dump_path + self.monitor = monitor + self.log_func_callback = log_func_callback + self.log_activation_memory_update_func = log_activation_memory_update_func + + def update_activation_memory( + self, + new_memories: list[str | TextualMemoryItem], + label: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube: GeneralMemCube, + ) -> None: + """ + Update activation memory by extracting KVCacheItems from new_memory (list of str), + add them to a KVCacheMemory instance, and dump to disk. + """ + if len(new_memories) == 0: + logger.error("update_activation_memory: new_memory is empty.") + return + if isinstance(new_memories[0], TextualMemoryItem): + new_text_memories = [mem.memory for mem in new_memories] + elif isinstance(new_memories[0], str): + new_text_memories = new_memories + else: + logger.error("Not Implemented.") + return + + try: + if isinstance(mem_cube.act_mem, VLLMKVCacheMemory): + act_mem: VLLMKVCacheMemory = mem_cube.act_mem + elif isinstance(mem_cube.act_mem, KVCacheMemory): + act_mem: KVCacheMemory = mem_cube.act_mem + else: + logger.error("Not Implemented.") + return + + new_text_memory = MEMORY_ASSEMBLY_TEMPLATE.format( + memory_text="".join( + [ + f"{i + 1}. {sentence.strip()}\n" + for i, sentence in enumerate(new_text_memories) + if sentence.strip() # Skip empty strings + ] + ) + ) + + # huggingface or vllm kv cache + original_cache_items: list[VLLMKVCacheItem] = act_mem.get_all() + original_text_memories = [] + if len(original_cache_items) > 0: + pre_cache_item: VLLMKVCacheItem = original_cache_items[-1] + original_text_memories = pre_cache_item.records.text_memories + original_composed_text_memory = pre_cache_item.records.composed_text_memory + if original_composed_text_memory == new_text_memory: + logger.warning( + "Skipping memory update - new composition matches existing cache: %s", + new_text_memory[:50] + "..." + if len(new_text_memory) > 50 + else new_text_memory, + ) + return + act_mem.delete_all() + + cache_item = act_mem.extract(new_text_memory) + cache_item.records.text_memories = new_text_memories + cache_item.records.timestamp = get_utc_now() + + act_mem.add([cache_item]) + act_mem.dump(self.act_mem_dump_path) + + self.log_activation_memory_update_func( + original_text_memories=original_text_memories, + new_text_memories=new_text_memories, + label=label, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + log_func_callback=self.log_func_callback, + ) + + except Exception as e: + logger.error(f"MOS-based activation memory update failed: {e}", exc_info=True) + # Re-raise the exception if it's critical for the operation + # For now, we'll continue execution but this should be reviewed + + def update_activation_memory_periodically( + self, + interval_seconds: int, + label: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube: GeneralMemCube, + ): + try: + if ( + self.monitor.last_activation_mem_update_time == datetime.min + or self.monitor.timed_trigger( + last_time=self.monitor.last_activation_mem_update_time, + interval_seconds=interval_seconds, + ) + ): + logger.info( + f"Updating activation memory for user {user_id} and mem_cube {mem_cube_id}" + ) + + if ( + user_id not in self.monitor.working_memory_monitors + or mem_cube_id not in self.monitor.working_memory_monitors[user_id] + or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories) + == 0 + ): + logger.warning( + "No memories found in working_memory_monitors, activation memory update is skipped" + ) + return + + self.monitor.update_activation_memory_monitors( + user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube + ) + + # Sync with database to get latest activation memories + activation_db_manager = self.monitor.activation_memory_monitors[user_id][ + mem_cube_id + ] + activation_db_manager.sync_with_orm() + new_activation_memories = [ + m.memory_text for m in activation_db_manager.obj.memories + ] + + logger.info( + f"Collected {len(new_activation_memories)} new memory entries for processing" + ) + # Print the content of each new activation memory + for i, memory in enumerate(new_activation_memories[:5], 1): + logger.info( + f"Part of New Activation Memorires | {i}/{len(new_activation_memories)}: {memory[:20]}" + ) + + self.update_activation_memory( + new_memories=new_activation_memories, + label=label, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + + self.monitor.last_activation_mem_update_time = get_utc_now() + + logger.debug( + f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}" + ) + + else: + logger.info( + f"Skipping update - {interval_seconds} second interval not yet reached. " + f"Last update time is {self.monitor.last_activation_mem_update_time} and now is " + f"{get_utc_now()}" + ) + except Exception as e: + logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True) diff --git a/src/memos/mem_scheduler/memory_manage_modules/post_processor.py b/src/memos/mem_scheduler/memory_manage_modules/post_processor.py new file mode 100644 index 000000000..2e1821e1e --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/post_processor.py @@ -0,0 +1,307 @@ +""" +Memory Post-Processor - Handles post-retrieval memory filtering and reranking. + +This module provides post-processing operations for retrieved memories, +including filtering and reranking operations specific to the scheduler's needs. + +Note: Memory enhancement operations (enhance_memories_with_query, recall_for_missing_memories) +have been moved to AdvancedSearcher for better architectural separation. +""" + +from memos.configs.mem_scheduler import BaseSchedulerConfig +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule +from memos.mem_scheduler.schemas.general_schemas import ( + DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, + DEFAULT_SCHEDULER_RETRIEVER_RETRIES, +) +from memos.mem_scheduler.utils.filter_utils import ( + filter_too_short_memories, + filter_vector_based_similar_memories, + transform_name_to_key, +) +from memos.memories.textual.item import TextualMemoryItem +from memos.utils import extract_json_obj + +from .memory_filter import MemoryFilter + + +logger = get_logger(__name__) + + +class MemoryPostProcessor(BaseSchedulerModule): + """ + Post-processor for retrieved memories. + + This class handles scheduler-specific post-retrieval operations: + - Memory filtering: Remove unrelated or redundant memories + - Memory reranking: Reorder memories by relevance + - Memory evaluation: Assess memory's ability to answer queries + + Design principles: + - Single Responsibility: Only handles filtering/reranking, not enhancement or retrieval + - Composable: Can be used independently or chained together + - Testable: Each operation can be tested in isolation + + Note: Memory enhancement operations have been moved to AdvancedSearcher. + + Usage: + processor = MemoryPostProcessor(process_llm=llm, config=config) + + # Filter out unrelated memories + filtered, _ = processor.filter_unrelated_memories( + query_history=["What is Python?"], + memories=raw_memories + ) + + # Rerank memories by relevance + reranked, _ = processor.process_and_rerank_memories( + queries=["What is Python?"], + original_memory=filtered, + new_memory=[], + top_k=10 + ) + """ + + def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): + """ + Initialize the post-processor. + + Args: + process_llm: LLM instance for enhancement and filtering operations + config: Scheduler configuration containing batch sizes and retry settings + """ + super().__init__() + + # Core dependencies + self.process_llm = process_llm + self.config = config + self.memory_filter = MemoryFilter(process_llm=process_llm, config=config) + + # Configuration + self.filter_similarity_threshold = 0.75 + self.filter_min_length_threshold = 6 + + # NOTE: Config keys still use "scheduler_retriever_*" prefix for backward compatibility + # TODO: Consider renaming to "post_processor_*" in future config refactor + self.batch_size: int | None = getattr( + config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE + ) + self.retries: int = getattr( + config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES + ) + + def evaluate_memory_answer_ability( + self, query: str, memory_texts: list[str], top_k: int | None = None + ) -> bool: + """ + Evaluate whether the given memories can answer the query. + + This method uses LLM to assess if the provided memories contain + sufficient information to answer the given query. + + Args: + query: The query to be answered + memory_texts: List of memory text strings + top_k: Optional limit on number of memories to consider + + Returns: + Boolean indicating whether memories can answer the query + """ + limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts + + # Build prompt using the template + prompt = self.build_prompt( + template_name="memory_answer_ability_evaluation", + query=query, + memory_list="\n".join([f"- {memory}" for memory in limited_memories]) + if limited_memories + else "No memories available", + ) + + # Use the process LLM to generate response + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + + try: + result = extract_json_obj(response) + + # Validate response structure + if "result" in result: + logger.info( + f"[Answerability] result={result['result']}; " + f"reason={result.get('reason', 'n/a')}; " + f"evaluated={len(limited_memories)}" + ) + return result["result"] + else: + logger.warning(f"[Answerability] invalid LLM JSON structure; payload={result}") + return False + + except Exception as e: + logger.error(f"[Answerability] parse failed; err={e}; raw={str(response)[:200]}...") + return False + + def rerank_memories( + self, queries: list[str], original_memories: list[str], top_k: int + ) -> tuple[list[str], bool]: + """ + Rerank memories based on relevance to given queries using LLM. + + Args: + queries: List of query strings to determine relevance + original_memories: List of memory strings to be reranked + top_k: Number of top memories to return after reranking + + Returns: + Tuple of (reranked_memories, success_flag) + - reranked_memories: List of reranked memory strings (length <= top_k) + - success_flag: True if reranking succeeded + + Note: + If LLM reranking fails, falls back to original order (truncated to top_k) + """ + logger.info(f"Starting memory reranking for {len(original_memories)} memories") + + # Build LLM prompt for memory reranking + prompt = self.build_prompt( + "memory_reranking", + queries=[f"[0] {queries[0]}"], + current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)], + ) + logger.debug(f"Generated reranking prompt: {prompt[:200]}...") + + # Get LLM response + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + logger.debug(f"Received LLM response: {response[:200]}...") + + try: + # Parse JSON response + response = extract_json_obj(response) + new_order = response["new_order"][:top_k] + text_memories_with_new_order = [original_memories[idx] for idx in new_order] + logger.info( + f"Successfully reranked memories. Returning top {len(text_memories_with_new_order)} items; " + f"Ranking reasoning: {response['reasoning']}" + ) + success_flag = True + except Exception as e: + logger.error( + f"Failed to rerank memories with LLM. Exception: {e}. Raw response: {response} ", + exc_info=True, + ) + text_memories_with_new_order = original_memories[:top_k] + success_flag = False + + return text_memories_with_new_order, success_flag + + def process_and_rerank_memories( + self, + queries: list[str], + original_memory: list[TextualMemoryItem], + new_memory: list[TextualMemoryItem], + top_k: int = 10, + ) -> tuple[list[TextualMemoryItem], bool]: + """ + Process and rerank memory items by combining, filtering, and reranking. + + This is a higher-level method that combines multiple post-processing steps: + 1. Merge original and new memories + 2. Apply similarity filtering + 3. Apply length filtering + 4. Remove duplicates + 5. Rerank by relevance + + Args: + queries: List of query strings to rerank memories against + original_memory: List of original TextualMemoryItem objects + new_memory: List of new TextualMemoryItem objects to merge + top_k: Maximum number of memories to return after reranking + + Returns: + Tuple of (reranked_memories, success_flag) + - reranked_memories: List of reranked TextualMemoryItem objects + - success_flag: True if reranking succeeded + """ + # Combine original and new memories + combined_memory = original_memory + new_memory + + # Create mapping from normalized text to memory objects + memory_map = { + transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory + } + + # Extract text representations + combined_text_memory = [m.memory for m in combined_memory] + + # Apply similarity filter + filtered_combined_text_memory = filter_vector_based_similar_memories( + text_memories=combined_text_memory, + similarity_threshold=self.filter_similarity_threshold, + ) + + # Apply length filter + filtered_combined_text_memory = filter_too_short_memories( + text_memories=filtered_combined_text_memory, + min_length_threshold=self.filter_min_length_threshold, + ) + + # Remove duplicates (preserving order) + unique_memory = list(dict.fromkeys(filtered_combined_text_memory)) + + # Rerank memories + text_memories_with_new_order, success_flag = self.rerank_memories( + queries=queries, + original_memories=unique_memory, + top_k=top_k, + ) + + # Map reranked texts back to memory objects + memories_with_new_order = [] + for text in text_memories_with_new_order: + normalized_text = transform_name_to_key(name=text) + if normalized_text in memory_map: + memories_with_new_order.append(memory_map[normalized_text]) + else: + logger.warning( + f"Memory text not found in memory map. text: {text};\n" + f"Keys of memory_map: {memory_map.keys()}" + ) + + return memories_with_new_order, success_flag + + def filter_unrelated_memories( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> tuple[list[TextualMemoryItem], bool]: + """ + Filter out memories unrelated to the query history. + + Delegates to MemoryFilter for the actual filtering logic. + """ + return self.memory_filter.filter_unrelated_memories(query_history, memories) + + def filter_redundant_memories( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> tuple[list[TextualMemoryItem], bool]: + """ + Filter out redundant memories from the list. + + Delegates to MemoryFilter for the actual filtering logic. + """ + return self.memory_filter.filter_redundant_memories(query_history, memories) + + def filter_unrelated_and_redundant_memories( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> tuple[list[TextualMemoryItem], bool]: + """ + Filter out both unrelated and redundant memories using LLM analysis. + + Delegates to MemoryFilter for the actual filtering logic. + """ + return self.memory_filter.filter_unrelated_and_redundant_memories(query_history, memories) diff --git a/src/memos/mem_scheduler/memory_manage_modules/search_service.py b/src/memos/mem_scheduler/memory_manage_modules/search_service.py new file mode 100644 index 000000000..3f00372bd --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/search_service.py @@ -0,0 +1,297 @@ +""" +Scheduler Search Service - Unified search interface for the scheduler. + +This module provides a clean abstraction over the Searcher class, +adapting it for scheduler-specific use cases while maintaining compatibility. +""" + +from memos.log import get_logger +from memos.mem_cube.general import GeneralMemCube +from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree import TreeTextMemory +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.types.general_types import SearchMode + + +logger = get_logger(__name__) + + +class SchedulerSearchService: + """ + Unified search service for the scheduler. + + This service provides a clean interface for memory search operations, + delegating to the Searcher class while handling scheduler-specific + parameter adaptations. + + Design principles: + - Single Responsibility: Only handles search coordination + - Dependency Injection: Searcher is injected, not created + - Fail-safe: Falls back to direct text_mem.search() if Searcher unavailable + + Usage: + service = SchedulerSearchService(searcher=searcher) + results = service.search( + query="user query", + user_id="user_123", + mem_cube=mem_cube, + top_k=10 + ) + """ + + def __init__(self, searcher: Searcher | None = None): + """ + Initialize the search service. + + Args: + searcher: Optional Searcher instance. If None, will fall back to + direct mem_cube.text_mem.search() calls. + """ + self.searcher = searcher + + def search( + self, + query: str, + user_id: str, + mem_cube: GeneralMemCube, + top_k: int, + mode: SearchMode = SearchMode.FAST, + search_filter: dict | None = None, + search_priority: dict | None = None, + session_id: str = "default_session", + internet_search: bool = False, + chat_history: list | None = None, + plugin: bool = False, + search_tool_memory: bool = False, + tool_mem_top_k: int = 6, + playground_search_goal_parser: bool = False, + mem_cube_id: str | None = None, + ) -> list[TextualMemoryItem]: + """ + Search for memories across both LongTermMemory and UserMemory. + + This method provides a unified interface for memory search, automatically + handling the search across different memory types and merging results. + + Args: + query: The search query string + user_id: User identifier + mem_cube: Memory cube instance containing text memory + top_k: Number of top results to return per memory type + mode: Search mode (FAST or FINE) + search_filter: Optional metadata filters for search results + search_priority: Optional metadata priority for search results + session_id: Session identifier for session-scoped search + internet_search: Whether to enable internet search + chat_history: Chat history for context + plugin: Whether this is a plugin-initiated search + search_tool_memory: Whether to search tool memory + tool_mem_top_k: Top-k for tool memory search + playground_search_goal_parser: Whether to use playground goal parser + mem_cube_id: Memory cube identifier (defaults to user_id if not provided) + + Returns: + List of TextualMemoryItem objects sorted by relevance + + Raises: + Exception: Propagates exceptions from underlying search implementations + """ + mem_cube_id = mem_cube_id or user_id + user_name = mem_cube_id + text_mem_base = mem_cube.text_mem + + # Build info dict for tracking + info = { + "user_id": user_id, + "session_id": session_id, + "chat_history": chat_history, + } + + try: + if self.searcher: + # Use injected Searcher (preferred path) + results = self._search_with_searcher( + query=query, + user_name=user_name, + top_k=top_k, + mode=mode, + search_filter=search_filter, + search_priority=search_priority, + info=info, + internet_search=internet_search, + plugin=plugin, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, + playground_search_goal_parser=playground_search_goal_parser, + ) + logger.info( + f"[SchedulerSearchService] Searched via Searcher: " + f"query='{query}' results={len(results)}" + ) + else: + # Fallback: Direct text_mem.search() call + results = self._search_with_text_mem( + text_mem_base=text_mem_base, + query=query, + user_name=user_name, + top_k=top_k, + mode=mode, + search_filter=search_filter, + search_priority=search_priority, + info=info, + internet_search=internet_search, + plugin=plugin, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, + playground_search_goal_parser=playground_search_goal_parser, + ) + logger.info( + f"[SchedulerSearchService] Searched via text_mem (fallback): " + f"query='{query}' results={len(results)}" + ) + + return results + + except Exception as e: + logger.error( + f"[SchedulerSearchService] Search failed for query='{query}': {e}", + exc_info=True, + ) + return [] + + def _search_with_searcher( + self, + query: str, + user_name: str, + top_k: int, + mode: SearchMode, + search_filter: dict | None, + search_priority: dict | None, + info: dict, + internet_search: bool, + plugin: bool, + search_tool_memory: bool, + tool_mem_top_k: int, + playground_search_goal_parser: bool, + ) -> list[TextualMemoryItem]: + """ + Search using the injected Searcher instance. + + IMPORTANT: This method searches "All" memory types in a single call to avoid + the bug where calling search() twice (for LongTermMemory and UserMemory separately) + would return 2*top_k results due to Searcher.search() applying deduplication and + top_k limiting on each call. + + This ensures the final result is properly deduplicated and limited to top_k items. + """ + # Preserve original internet search setting + original_manual_close = getattr(self.searcher, "manual_close_internet", None) + + try: + # Configure internet search + if original_manual_close is not None: + self.searcher.manual_close_internet = not internet_search + + # Search LongTermMemory + results_long_term = self.searcher.search( + query=query, + user_name=user_name, + top_k=top_k, + mode=mode, + memory_type="LongTermMemory", + search_filter=search_filter, + search_priority=search_priority, + info=info, + plugin=plugin, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, + playground_search_goal_parser=playground_search_goal_parser, + ) + + # Search UserMemory + results_user = self.searcher.search( + query=query, + user_name=user_name, + top_k=top_k, + mode=mode, + memory_type="UserMemory", + search_filter=search_filter, + search_priority=search_priority, + info=info, + plugin=plugin, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, + playground_search_goal_parser=playground_search_goal_parser, + ) + + return results_long_term + results_user + + finally: + # Restore original setting + if original_manual_close is not None: + self.searcher.manual_close_internet = original_manual_close + + def _search_with_text_mem( + self, + text_mem_base: TreeTextMemory, + query: str, + user_name: str, + top_k: int, + mode: SearchMode, + search_filter: dict | None, + search_priority: dict | None, + info: dict, + internet_search: bool, + plugin: bool, + search_tool_memory: bool, + tool_mem_top_k: int, + playground_search_goal_parser: bool, + ) -> list[TextualMemoryItem]: + """ + Fallback: Search using direct text_mem.search() calls. + + This is used when no Searcher instance is available, providing + backward compatibility with the original implementation. + + NOTE: TreeTextMemory.search() with memory_type="All" will internally + search both LongTermMemory and UserMemory and properly merge results. + """ + assert isinstance(text_mem_base, TreeTextMemory), ( + f"Fallback search requires TreeTextMemory, got {type(text_mem_base)}" + ) + + # Search LongTermMemory + results_long_term = text_mem_base.search( + query=query, + user_name=user_name, + top_k=top_k, + mode=mode, + manual_close_internet=not internet_search, + memory_type="LongTermMemory", + search_filter=search_filter, + search_priority=search_priority, + info=info, + plugin=plugin, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, + playground_search_goal_parser=playground_search_goal_parser, + ) + + # Search UserMemory + results_user = text_mem_base.search( + query=query, + user_name=user_name, + top_k=top_k, + mode=mode, + manual_close_internet=not internet_search, + memory_type="UserMemory", + search_filter=search_filter, + search_priority=search_priority, + info=info, + plugin=plugin, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, + playground_search_goal_parser=playground_search_goal_parser, + ) + + return results_long_term + results_user From e5cfafb31f6f15c2ae9a08de907cf4a0e2956987 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 5 Feb 2026 14:24:14 +0800 Subject: [PATCH 08/14] refactor: extract common logic to BaseSchedulerHandler and fix import paths --- src/memos/mem_scheduler/general_scheduler.py | 2 +- .../mem_scheduler/handlers/answer_handler.py | 56 ---- src/memos/mem_scheduler/handlers/base.py | 12 - .../handlers/feedback_handler.py | 186 ------------- .../handlers/mem_reorganize_handler.py | 254 ------------------ .../handlers/pref_add_handler.py | 90 ------- .../mem_scheduler/handlers/query_handler.py | 70 ----- src/memos/mem_scheduler/handlers/registry.py | 52 ---- .../task_schedule_modules/base_handler.py | 68 +++++ .../handlers/__init__.py | 0 .../handlers/add_handler.py | 82 +++--- .../handlers/answer_handler.py | 48 ++++ .../handlers/context.py | 0 .../handlers/feedback_handler.py | 196 ++++++++++++++ .../handlers/mem_read_handler.py | 152 ++++++----- .../handlers/mem_reorganize_handler.py | 252 +++++++++++++++++ .../handlers/memory_update_handler.py | 46 ++-- .../handlers/pref_add_handler.py | 92 +++++++ .../handlers/query_handler.py | 64 +++++ .../handlers/registry.py | 53 ++++ 20 files changed, 909 insertions(+), 866 deletions(-) delete mode 100644 src/memos/mem_scheduler/handlers/answer_handler.py delete mode 100644 src/memos/mem_scheduler/handlers/base.py delete mode 100644 src/memos/mem_scheduler/handlers/feedback_handler.py delete mode 100644 src/memos/mem_scheduler/handlers/mem_reorganize_handler.py delete mode 100644 src/memos/mem_scheduler/handlers/pref_add_handler.py delete mode 100644 src/memos/mem_scheduler/handlers/query_handler.py delete mode 100644 src/memos/mem_scheduler/handlers/registry.py create mode 100644 src/memos/mem_scheduler/task_schedule_modules/base_handler.py rename src/memos/mem_scheduler/{ => task_schedule_modules}/handlers/__init__.py (100%) rename src/memos/mem_scheduler/{ => task_schedule_modules}/handlers/add_handler.py (80%) create mode 100644 src/memos/mem_scheduler/task_schedule_modules/handlers/answer_handler.py rename src/memos/mem_scheduler/{ => task_schedule_modules}/handlers/context.py (100%) create mode 100644 src/memos/mem_scheduler/task_schedule_modules/handlers/feedback_handler.py rename src/memos/mem_scheduler/{ => task_schedule_modules}/handlers/mem_read_handler.py (78%) create mode 100644 src/memos/mem_scheduler/task_schedule_modules/handlers/mem_reorganize_handler.py rename src/memos/mem_scheduler/{ => task_schedule_modules}/handlers/memory_update_handler.py (86%) create mode 100644 src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py create mode 100644 src/memos/mem_scheduler/task_schedule_modules/handlers/query_handler.py create mode 100644 src/memos/mem_scheduler/task_schedule_modules/handlers/registry.py diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 1fc3317d8..9860e3984 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.mem_scheduler.base_scheduler import BaseScheduler -from memos.mem_scheduler.handlers import ( +from memos.mem_scheduler.task_schedule_modules.handlers import ( SchedulerHandlerContext, SchedulerHandlerRegistry, SchedulerHandlerServices, diff --git a/src/memos/mem_scheduler/handlers/answer_handler.py b/src/memos/mem_scheduler/handlers/answer_handler.py deleted file mode 100644 index 9ec4086a4..000000000 --- a/src/memos/mem_scheduler/handlers/answer_handler.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from memos.log import get_logger -from memos.mem_scheduler.handlers.base import BaseSchedulerHandler -from memos.mem_scheduler.schemas.task_schemas import ( - ANSWER_TASK_LABEL, - NOT_APPLICABLE_TYPE, - USER_INPUT_TYPE, -) -from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube - - -logger = get_logger(__name__) - -if TYPE_CHECKING: - from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem - - -class AnswerMessageHandler(BaseSchedulerHandler): - def handle(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {ANSWER_TASK_LABEL} handler.") - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - - self.ctx.services.validate_messages(messages=messages, label=ANSWER_TASK_LABEL) - - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - try: - for msg in batch: - event = self.ctx.services.create_event_log( - label="addMessage", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=NOT_APPLICABLE_TYPE, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=self.ctx.get_mem_cube(), - memcube_log_content=[ - { - "content": f"[Assistant] {msg.content}", - "ref_id": msg.item_id, - "role": "assistant", - } - ], - metadata=[], - memory_len=1, - memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), - ) - event.task_id = msg.task_id - self.ctx.services.submit_web_logs([event]) - except Exception: - logger.exception("Failed to record addMessage log for answer") diff --git a/src/memos/mem_scheduler/handlers/base.py b/src/memos/mem_scheduler/handlers/base.py deleted file mode 100644 index e04add7d7..000000000 --- a/src/memos/mem_scheduler/handlers/base.py +++ /dev/null @@ -1,12 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - - -if TYPE_CHECKING: - from memos.mem_scheduler.handlers.context import SchedulerHandlerContext - - -class BaseSchedulerHandler: - def __init__(self, ctx: SchedulerHandlerContext) -> None: - self.ctx = ctx diff --git a/src/memos/mem_scheduler/handlers/feedback_handler.py b/src/memos/mem_scheduler/handlers/feedback_handler.py deleted file mode 100644 index cf52470dd..000000000 --- a/src/memos/mem_scheduler/handlers/feedback_handler.py +++ /dev/null @@ -1,186 +0,0 @@ -from __future__ import annotations - -import json - -from typing import TYPE_CHECKING - -from memos.log import get_logger -from memos.mem_scheduler.handlers.base import BaseSchedulerHandler -from memos.mem_scheduler.schemas.task_schemas import LONG_TERM_MEMORY_TYPE, USER_INPUT_TYPE -from memos.mem_scheduler.utils.misc_utils import is_cloud_env - - -logger = get_logger(__name__) - -if TYPE_CHECKING: - from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem - - -class FeedbackMessageHandler(BaseSchedulerHandler): - def handle(self, messages: list[ScheduleMessageItem]) -> None: - try: - if not messages: - return - message = messages[0] - mem_cube = self.ctx.get_mem_cube() - - user_id = message.user_id - mem_cube_id = message.mem_cube_id - content = message.content - - try: - feedback_data = json.loads(content) if isinstance(content, str) else content - if not isinstance(feedback_data, dict): - logger.error( - "Failed to decode feedback_data or it is not a dict: %s", feedback_data - ) - return - except json.JSONDecodeError: - logger.error( - "Invalid JSON content for feedback message: %s", content, exc_info=True - ) - return - - task_id = feedback_data.get("task_id") or message.task_id - feedback_result = self.ctx.get_feedback_server().process_feedback( - user_id=user_id, - user_name=mem_cube_id, - session_id=feedback_data.get("session_id"), - chat_history=feedback_data.get("history", []), - retrieved_memory_ids=feedback_data.get("retrieved_memory_ids", []), - feedback_content=feedback_data.get("feedback_content"), - feedback_time=feedback_data.get("feedback_time"), - task_id=task_id, - info=feedback_data.get("info", None), - ) - - logger.info( - "Successfully processed feedback for user_id=%s, mem_cube_id=%s", - user_id, - mem_cube_id, - ) - - cloud_env = is_cloud_env() - if cloud_env: - record = feedback_result.get("record") if isinstance(feedback_result, dict) else {} - add_records = record.get("add") if isinstance(record, dict) else [] - update_records = record.get("update") if isinstance(record, dict) else [] - - def _extract_fields(mem_item): - mem_id = ( - getattr(mem_item, "id", None) - if not isinstance(mem_item, dict) - else mem_item.get("id") - ) - mem_memory = ( - getattr(mem_item, "memory", None) - if not isinstance(mem_item, dict) - else mem_item.get("memory") or mem_item.get("text") - ) - if mem_memory is None and isinstance(mem_item, dict): - mem_memory = mem_item.get("text") - original_content = ( - getattr(mem_item, "origin_memory", None) - if not isinstance(mem_item, dict) - else mem_item.get("origin_memory") - or mem_item.get("old_memory") - or mem_item.get("original_content") - ) - source_doc_id = None - if isinstance(mem_item, dict): - source_doc_id = mem_item.get("source_doc_id", None) - - return mem_id, mem_memory, original_content, source_doc_id - - kb_log_content: list[dict] = [] - - for mem_item in add_records or []: - mem_id, mem_memory, _, source_doc_id = _extract_fields(mem_item) - if mem_id and mem_memory: - kb_log_content.append( - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": "Feedback", - "operation": "ADD", - "memory_id": mem_id, - "content": mem_memory, - "original_content": None, - "source_doc_id": source_doc_id, - } - ) - else: - logger.warning( - "Skipping malformed feedback add item. user_id=%s mem_cube_id=%s task_id=%s item=%s", - user_id, - mem_cube_id, - task_id, - mem_item, - stack_info=True, - ) - - for mem_item in update_records or []: - mem_id, mem_memory, original_content, source_doc_id = _extract_fields(mem_item) - if mem_id and mem_memory: - kb_log_content.append( - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": "Feedback", - "operation": "UPDATE", - "memory_id": mem_id, - "content": mem_memory, - "original_content": original_content, - "source_doc_id": source_doc_id, - } - ) - else: - logger.warning( - "Skipping malformed feedback update item. user_id=%s mem_cube_id=%s task_id=%s item=%s", - user_id, - mem_cube_id, - task_id, - mem_item, - stack_info=True, - ) - - logger.info("[Feedback Scheduler] kb_log_content: %s", kb_log_content) - if kb_log_content: - logger.info( - "[DIAGNOSTIC] feedback_handler: Creating knowledgeBaseUpdate event for feedback. user_id=%s mem_cube_id=%s task_id=%s items=%s", - user_id, - mem_cube_id, - task_id, - len(kb_log_content), - ) - event = self.ctx.services.create_event_log( - label="knowledgeBaseUpdate", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - memcube_log_content=kb_log_content, - metadata=None, - memory_len=len(kb_log_content), - memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), - ) - event.log_content = ( - f"Knowledge Base Memory Update: {len(kb_log_content)} changes." - ) - event.task_id = task_id - self.ctx.services.submit_web_logs([event]) - else: - logger.warning( - "No valid feedback content generated for web log. user_id=%s mem_cube_id=%s task_id=%s", - user_id, - mem_cube_id, - task_id, - stack_info=True, - ) - else: - logger.info( - "Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)", - cloud_env, - ) - - except Exception as e: - logger.error("Error processing feedbackMemory message: %s", e, exc_info=True) diff --git a/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py b/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py deleted file mode 100644 index d437ebbd6..000000000 --- a/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py +++ /dev/null @@ -1,254 +0,0 @@ -from __future__ import annotations - -import concurrent.futures -import contextlib -import json -import traceback - -from typing import TYPE_CHECKING - -from memos.context.context import ContextThreadPoolExecutor -from memos.log import get_logger -from memos.mem_scheduler.handlers.base import BaseSchedulerHandler -from memos.mem_scheduler.schemas.task_schemas import LONG_TERM_MEMORY_TYPE, MEM_ORGANIZE_TASK_LABEL -from memos.mem_scheduler.utils.filter_utils import transform_name_to_key -from memos.memories.textual.tree import TreeTextMemory - - -logger = get_logger(__name__) - -if TYPE_CHECKING: - from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem - from memos.memories.textual.item import TextualMemoryItem - - -class MemReorganizeMessageHandler(BaseSchedulerHandler): - def handle(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {MEM_ORGANIZE_TASK_LABEL} handler.") - - def process_message(message: ScheduleMessageItem): - try: - user_id = message.user_id - mem_cube_id = message.mem_cube_id - mem_cube = self.ctx.get_mem_cube() - if mem_cube is None: - logger.warning( - "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing", - user_id, - mem_cube_id, - ) - return - content = message.content - user_name = message.user_name - - mem_ids = json.loads(content) if isinstance(content, str) else content - if not mem_ids: - return - - logger.info( - "Processing mem_reorganize for user_id=%s, mem_cube_id=%s, mem_ids=%s", - user_id, - mem_cube_id, - mem_ids, - ) - - text_mem = mem_cube.text_mem - if not isinstance(text_mem, TreeTextMemory): - logger.error("Expected TreeTextMemory but got %s", type(text_mem).__name__) - return - - self._process_memories_with_reorganize( - mem_ids=mem_ids, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - text_mem=text_mem, - user_name=user_name, - ) - - with contextlib.suppress(Exception): - mem_items: list[TextualMemoryItem] = [] - for mid in mem_ids: - with contextlib.suppress(Exception): - mem_items.append(text_mem.get(mid, user_name=user_name)) - if len(mem_items) > 1: - keys: list[str] = [] - memcube_content: list[dict] = [] - meta: list[dict] = [] - merged_target_ids: set[str] = set() - with contextlib.suppress(Exception): - if hasattr(text_mem, "graph_store"): - for mid in mem_ids: - edges = text_mem.graph_store.get_edges( - mid, type="MERGED_TO", direction="OUT" - ) - for edge in edges: - target = ( - edge.get("to") or edge.get("dst") or edge.get("target") - ) - if target: - merged_target_ids.add(target) - for item in mem_items: - key = getattr( - getattr(item, "metadata", {}), "key", None - ) or transform_name_to_key(getattr(item, "memory", "")) - keys.append(key) - memcube_content.append( - {"content": key or "(no key)", "ref_id": item.id, "type": "merged"} - ) - meta.append( - { - "ref_id": item.id, - "id": item.id, - "key": key, - "memory": item.memory, - "memory_type": item.metadata.memory_type, - "status": item.metadata.status, - "confidence": item.metadata.confidence, - "tags": item.metadata.tags, - "updated_at": getattr(item.metadata, "updated_at", None) - or getattr(item.metadata, "update_at", None), - } - ) - combined_key = keys[0] if keys else "" - post_ref_id = None - post_meta = { - "ref_id": None, - "id": None, - "key": None, - "memory": None, - "memory_type": None, - "status": None, - "confidence": None, - "tags": None, - "updated_at": None, - } - if merged_target_ids: - post_ref_id = next(iter(merged_target_ids)) - with contextlib.suppress(Exception): - merged_item = text_mem.get(post_ref_id, user_name=user_name) - combined_key = ( - getattr(getattr(merged_item, "metadata", {}), "key", None) - or combined_key - ) - post_meta = { - "ref_id": post_ref_id, - "id": post_ref_id, - "key": getattr( - getattr(merged_item, "metadata", {}), "key", None - ), - "memory": getattr(merged_item, "memory", None), - "memory_type": getattr( - getattr(merged_item, "metadata", {}), "memory_type", None - ), - "status": getattr( - getattr(merged_item, "metadata", {}), "status", None - ), - "confidence": getattr( - getattr(merged_item, "metadata", {}), "confidence", None - ), - "tags": getattr( - getattr(merged_item, "metadata", {}), "tags", None - ), - "updated_at": getattr( - getattr(merged_item, "metadata", {}), "updated_at", None - ) - or getattr( - getattr(merged_item, "metadata", {}), "update_at", None - ), - } - if not post_ref_id: - import hashlib - - post_ref_id = ( - "merge-" - + hashlib.md5("".join(sorted(mem_ids)).encode()).hexdigest() - ) - post_meta["ref_id"] = post_ref_id - post_meta["id"] = post_ref_id - if not post_meta.get("key"): - post_meta["key"] = combined_key - if not keys: - keys = [item.id for item in mem_items] - memcube_content.append( - { - "content": combined_key if combined_key else "(no key)", - "ref_id": post_ref_id, - "type": "postMerge", - } - ) - meta.append(post_meta) - event = self.ctx.services.create_event_log( - label="mergeMemory", - from_memory_type=LONG_TERM_MEMORY_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - memcube_log_content=memcube_content, - metadata=meta, - memory_len=len(keys), - memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), - ) - self.ctx.services.submit_web_logs([event]) - - logger.info( - "Successfully processed mem_reorganize for user_id=%s, mem_cube_id=%s", - user_id, - mem_cube_id, - ) - - except Exception as e: - logger.error("Error processing mem_reorganize message: %s", e, exc_info=True) - - with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: - futures = [executor.submit(process_message, msg) for msg in messages] - for future in concurrent.futures.as_completed(futures): - try: - future.result() - except Exception as e: - logger.error("Thread task failed: %s", e, exc_info=True) - - def _process_memories_with_reorganize( - self, - mem_ids: list[str], - user_id: str, - mem_cube_id: str, - mem_cube, - text_mem: TreeTextMemory, - user_name: str, - ) -> None: - try: - mem_reader = self.ctx.get_mem_reader() - if mem_reader is None: - logger.warning( - "mem_reader not available in scheduler, skipping enhanced processing" - ) - return - - memory_items = [] - for mem_id in mem_ids: - try: - memory_item = text_mem.get(mem_id, user_name=user_name) - memory_items.append(memory_item) - except Exception as e: - logger.warning( - "Failed to get memory %s: %s|%s", mem_id, e, traceback.format_exc() - ) - continue - - if not memory_items: - logger.warning("No valid memory items found for processing") - return - - logger.info("Processing %s memories with mem_reader", len(memory_items)) - text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) - logger.info("Remove and Refresh Memories") - logger.debug("Finished add %s memory: %s", user_id, mem_ids) - - except Exception: - logger.error( - "Error in _process_memories_with_reorganize: %s", - traceback.format_exc(), - exc_info=True, - ) diff --git a/src/memos/mem_scheduler/handlers/pref_add_handler.py b/src/memos/mem_scheduler/handlers/pref_add_handler.py deleted file mode 100644 index 4d17b0847..000000000 --- a/src/memos/mem_scheduler/handlers/pref_add_handler.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import annotations - -import concurrent.futures -import json - -from typing import TYPE_CHECKING - -from memos.context.context import ContextThreadPoolExecutor -from memos.log import get_logger -from memos.mem_scheduler.handlers.base import BaseSchedulerHandler -from memos.mem_scheduler.schemas.task_schemas import PREF_ADD_TASK_LABEL -from memos.memories.textual.preference import PreferenceTextMemory - - -logger = get_logger(__name__) - -if TYPE_CHECKING: - from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem - - -class PrefAddMessageHandler(BaseSchedulerHandler): - def handle(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {PREF_ADD_TASK_LABEL} handler.") - - def process_message(message: ScheduleMessageItem): - try: - mem_cube = self.ctx.get_mem_cube() - if mem_cube is None: - logger.warning( - "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing", - message.user_id, - message.mem_cube_id, - ) - return - - user_id = message.user_id - session_id = message.session_id - mem_cube_id = message.mem_cube_id - content = message.content - messages_list = json.loads(content) - info = message.info or {} - - logger.info( - "Processing pref_add for user_id=%s, mem_cube_id=%s", user_id, mem_cube_id - ) - - pref_mem = mem_cube.pref_mem - if pref_mem is None: - logger.warning( - "Preference memory not initialized for mem_cube_id=%s, skipping pref_add processing", - mem_cube_id, - ) - return - if not isinstance(pref_mem, PreferenceTextMemory): - logger.error( - "Expected PreferenceTextMemory but got %s for mem_cube_id=%s", - type(pref_mem).__name__, - mem_cube_id, - ) - return - - pref_memories = pref_mem.get_memory( - messages_list, - type="chat", - info={ - **info, - "user_id": user_id, - "session_id": session_id, - "mem_cube_id": mem_cube_id, - }, - ) - pref_ids = pref_mem.add(pref_memories) - - logger.info( - "Successfully processed and add preferences for user_id=%s, mem_cube_id=%s, pref_ids=%s", - user_id, - mem_cube_id, - pref_ids, - ) - - except Exception as e: - logger.error("Error processing pref_add message: %s", e, exc_info=True) - - with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: - futures = [executor.submit(process_message, msg) for msg in messages] - for future in concurrent.futures.as_completed(futures): - try: - future.result() - except Exception as e: - logger.error("Thread task failed: %s", e, exc_info=True) diff --git a/src/memos/mem_scheduler/handlers/query_handler.py b/src/memos/mem_scheduler/handlers/query_handler.py deleted file mode 100644 index 4d3a09368..000000000 --- a/src/memos/mem_scheduler/handlers/query_handler.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -from memos.log import get_logger -from memos.mem_scheduler.handlers.base import BaseSchedulerHandler -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.schemas.task_schemas import ( - MEM_UPDATE_TASK_LABEL, - NOT_APPLICABLE_TYPE, - QUERY_TASK_LABEL, - USER_INPUT_TYPE, -) -from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube - - -logger = get_logger(__name__) - - -class QueryMessageHandler(BaseSchedulerHandler): - def handle(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {QUERY_TASK_LABEL} handler.") - - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - - self.ctx.services.validate_messages(messages=messages, label=QUERY_TASK_LABEL) - - mem_update_messages: list[ScheduleMessageItem] = [] - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - - for msg in batch: - try: - event = self.ctx.services.create_event_log( - label="addMessage", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=NOT_APPLICABLE_TYPE, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=self.ctx.get_mem_cube(), - memcube_log_content=[ - { - "content": f"[User] {msg.content}", - "ref_id": msg.item_id, - "role": "user", - } - ], - metadata=[], - memory_len=1, - memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), - ) - event.task_id = msg.task_id - self.ctx.services.submit_web_logs([event]) - except Exception: - logger.exception("Failed to record addMessage log for query") - - update_msg = ScheduleMessageItem( - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - label=MEM_UPDATE_TASK_LABEL, - content=msg.content, - session_id=msg.session_id, - user_name=msg.user_name, - info=msg.info, - task_id=msg.task_id, - ) - mem_update_messages.append(update_msg) - - self.ctx.services.submit_messages(messages=mem_update_messages) diff --git a/src/memos/mem_scheduler/handlers/registry.py b/src/memos/mem_scheduler/handlers/registry.py deleted file mode 100644 index 2a62aa57f..000000000 --- a/src/memos/mem_scheduler/handlers/registry.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - - -if TYPE_CHECKING: - from collections.abc import Callable - - from memos.mem_scheduler.handlers.context import SchedulerHandlerContext - -from memos.mem_scheduler.handlers.add_handler import AddMessageHandler -from memos.mem_scheduler.handlers.answer_handler import AnswerMessageHandler -from memos.mem_scheduler.handlers.feedback_handler import FeedbackMessageHandler -from memos.mem_scheduler.handlers.mem_read_handler import MemReadMessageHandler -from memos.mem_scheduler.handlers.mem_reorganize_handler import MemReorganizeMessageHandler -from memos.mem_scheduler.handlers.memory_update_handler import MemoryUpdateHandler -from memos.mem_scheduler.handlers.pref_add_handler import PrefAddMessageHandler -from memos.mem_scheduler.handlers.query_handler import QueryMessageHandler -from memos.mem_scheduler.schemas.task_schemas import ( - ADD_TASK_LABEL, - ANSWER_TASK_LABEL, - MEM_FEEDBACK_TASK_LABEL, - MEM_ORGANIZE_TASK_LABEL, - MEM_READ_TASK_LABEL, - MEM_UPDATE_TASK_LABEL, - PREF_ADD_TASK_LABEL, - QUERY_TASK_LABEL, -) - - -class SchedulerHandlerRegistry: - def __init__(self, ctx: SchedulerHandlerContext) -> None: - self.query = QueryMessageHandler(ctx) - self.answer = AnswerMessageHandler(ctx) - self.add = AddMessageHandler(ctx) - self.memory_update = MemoryUpdateHandler(ctx) - self.mem_feedback = FeedbackMessageHandler(ctx) - self.mem_read = MemReadMessageHandler(ctx) - self.mem_reorganize = MemReorganizeMessageHandler(ctx) - self.pref_add = PrefAddMessageHandler(ctx) - - def build_dispatch_map(self) -> dict[str, Callable]: - return { - QUERY_TASK_LABEL: self.query.handle, - ANSWER_TASK_LABEL: self.answer.handle, - MEM_UPDATE_TASK_LABEL: self.memory_update.handle, - ADD_TASK_LABEL: self.add.handle, - MEM_READ_TASK_LABEL: self.mem_read.handle, - MEM_ORGANIZE_TASK_LABEL: self.mem_reorganize.handle, - PREF_ADD_TASK_LABEL: self.pref_add.handle, - MEM_FEEDBACK_TASK_LABEL: self.mem_feedback.handle, - } diff --git a/src/memos/mem_scheduler/task_schedule_modules/base_handler.py b/src/memos/mem_scheduler/task_schedule_modules/base_handler.py new file mode 100644 index 000000000..5e40588e0 --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/base_handler.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING + +from memos.log import get_logger +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube + + +if TYPE_CHECKING: + from collections.abc import Callable + + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + from memos.mem_scheduler.task_schedule_modules.handlers.context import SchedulerHandlerContext + + +logger = get_logger(__name__) + + +class BaseSchedulerHandler: + def __init__(self, scheduler_context: SchedulerHandlerContext) -> None: + self.scheduler_context = scheduler_context + + @property + @abstractmethod + def expected_task_label(self) -> str: + """The expected task label for this handler.""" + ... + + def validate_and_log_messages(self, messages: list[ScheduleMessageItem], label: str) -> None: + logger.info(f"Messages {messages} assigned to {label} handler.") + self.scheduler_context.services.validate_messages(messages=messages, label=label) + + def handle_exception(self, e: Exception, message: str = "Error processing messages") -> None: + logger.error(f"{message}: {e}", exc_info=True) + + def process_grouped_messages( + self, + messages: list[ScheduleMessageItem], + message_handler: Callable[[str, str, list[ScheduleMessageItem]], None], + ) -> None: + grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) + for user_id, user_batches in grouped_messages.items(): + for mem_cube_id, batch in user_batches.items(): + if not batch: + continue + try: + message_handler(user_id, mem_cube_id, batch) + except Exception as e: + self.handle_exception( + e, f"Error processing batch for user {user_id}, mem_cube {mem_cube_id}" + ) + + @abstractmethod + def batch_handler( + self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem] + ) -> None: ... + + def __call__(self, messages: list[ScheduleMessageItem]) -> None: + """ + Process the messages. + """ + self.validate_and_log_messages(messages=messages, label=self.expected_task_label) + + self.process_grouped_messages( + messages=messages, + message_handler=self.batch_handler, + ) diff --git a/src/memos/mem_scheduler/handlers/__init__.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/__init__.py similarity index 100% rename from src/memos/mem_scheduler/handlers/__init__.py rename to src/memos/mem_scheduler/task_schedule_modules/handlers/__init__.py diff --git a/src/memos/mem_scheduler/handlers/add_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py similarity index 80% rename from src/memos/mem_scheduler/handlers/add_handler.py rename to src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py index 5d1a8d3e0..63718fd92 100644 --- a/src/memos/mem_scheduler/handlers/add_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py @@ -5,14 +5,14 @@ from typing import TYPE_CHECKING from memos.log import get_logger -from memos.mem_scheduler.handlers.base import BaseSchedulerHandler from memos.mem_scheduler.schemas.task_schemas import ( ADD_TASK_LABEL, LONG_TERM_MEMORY_TYPE, USER_INPUT_TYPE, ) +from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler from memos.mem_scheduler.utils.filter_utils import transform_name_to_key -from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube, is_cloud_env +from memos.mem_scheduler.utils.misc_utils import is_cloud_env if TYPE_CHECKING: @@ -24,40 +24,30 @@ class AddMessageHandler(BaseSchedulerHandler): - def handle(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {ADD_TASK_LABEL} handler.") - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) + @property + def expected_task_label(self) -> str: + return ADD_TASK_LABEL - self.ctx.services.validate_messages(messages=messages, label=ADD_TASK_LABEL) - try: - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - - for msg in batch: - prepared_add_items, prepared_update_items_with_original = ( - self.log_add_messages(msg=msg) - ) - logger.info( - "prepared_add_items: %s;\n prepared_update_items_with_original: %s", - prepared_add_items, - prepared_update_items_with_original, - ) - cloud_env = is_cloud_env() - - if cloud_env: - self.send_add_log_messages_to_cloud_env( - msg, prepared_add_items, prepared_update_items_with_original - ) - else: - self.send_add_log_messages_to_local_env( - msg, prepared_add_items, prepared_update_items_with_original - ) + def batch_handler( + self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem] + ) -> None: + for msg in batch: + prepared_add_items, prepared_update_items_with_original = self.log_add_messages(msg=msg) + logger.info( + "prepared_add_items: %s;\n prepared_update_items_with_original: %s", + prepared_add_items, + prepared_update_items_with_original, + ) + cloud_env = is_cloud_env() - except Exception as e: - logger.error(f"Error: {e}", exc_info=True) + if cloud_env: + self.send_add_log_messages_to_cloud_env( + msg, prepared_add_items, prepared_update_items_with_original + ) + else: + self.send_add_log_messages_to_local_env( + msg, prepared_add_items, prepared_update_items_with_original + ) def log_add_messages(self, msg: ScheduleMessageItem): try: @@ -70,7 +60,7 @@ def log_add_messages(self, msg: ScheduleMessageItem): prepared_update_items_with_original = [] missing_ids: list[str] = [] - mem_cube = self.ctx.get_mem_cube() + mem_cube = self.scheduler_context.get_mem_cube() for memory_id in userinput_memory_ids: try: @@ -207,38 +197,38 @@ def send_add_log_messages_to_local_env( events = [] if add_content_legacy: - event = self.ctx.services.create_event_log( + event = self.scheduler_context.services.create_event_log( label="addMemory", from_memory_type=USER_INPUT_TYPE, to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, - mem_cube=self.ctx.get_mem_cube(), + mem_cube=self.scheduler_context.get_mem_cube(), memcube_log_content=add_content_legacy, metadata=add_meta_legacy, memory_len=len(add_content_legacy), - memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + memcube_name=self.scheduler_context.services.map_memcube_name(msg.mem_cube_id), ) event.task_id = msg.task_id events.append(event) if update_content_legacy: - event = self.ctx.services.create_event_log( + event = self.scheduler_context.services.create_event_log( label="updateMemory", from_memory_type=LONG_TERM_MEMORY_TYPE, to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, - mem_cube=self.ctx.get_mem_cube(), + mem_cube=self.scheduler_context.get_mem_cube(), memcube_log_content=update_content_legacy, metadata=update_meta_legacy, memory_len=len(update_content_legacy), - memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + memcube_name=self.scheduler_context.services.map_memcube_name(msg.mem_cube_id), ) event.task_id = msg.task_id events.append(event) logger.info("send_add_log_messages_to_local_env: %s", len(events)) if events: - self.ctx.services.submit_web_logs( + self.scheduler_context.services.submit_web_logs( events, additional_log_info="send_add_log_messages_to_cloud_env" ) @@ -292,18 +282,18 @@ def send_add_log_messages_to_cloud_env( msg.task_id, json.dumps(kb_log_content, indent=2), ) - event = self.ctx.services.create_event_log( + event = self.scheduler_context.services.create_event_log( label="knowledgeBaseUpdate", from_memory_type=USER_INPUT_TYPE, to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, - mem_cube=self.ctx.get_mem_cube(), + mem_cube=self.scheduler_context.get_mem_cube(), memcube_log_content=kb_log_content, metadata=None, memory_len=len(kb_log_content), - memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + memcube_name=self.scheduler_context.services.map_memcube_name(msg.mem_cube_id), ) event.log_content = f"Knowledge Base Memory Update: {len(kb_log_content)} changes." event.task_id = msg.task_id - self.ctx.services.submit_web_logs([event]) + self.scheduler_context.services.submit_web_logs([event]) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/answer_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/answer_handler.py new file mode 100644 index 000000000..8ca56f859 --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/answer_handler.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from memos.log import get_logger +from memos.mem_scheduler.schemas.task_schemas import ( + ANSWER_TASK_LABEL, + NOT_APPLICABLE_TYPE, + USER_INPUT_TYPE, +) +from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +class AnswerMessageHandler(BaseSchedulerHandler): + @property + def expected_task_label(self) -> str: + return ANSWER_TASK_LABEL + + def batch_handler( + self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem] + ) -> None: + for msg in batch: + event = self.scheduler_context.services.create_event_log( + label="addMessage", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=NOT_APPLICABLE_TYPE, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=self.scheduler_context.get_mem_cube(), + memcube_log_content=[ + { + "content": f"[Assistant] {msg.content}", + "ref_id": msg.item_id, + "role": "assistant", + } + ], + metadata=[], + memory_len=1, + memcube_name=self.scheduler_context.services.map_memcube_name(msg.mem_cube_id), + ) + event.task_id = msg.task_id + self.scheduler_context.services.submit_web_logs([event]) diff --git a/src/memos/mem_scheduler/handlers/context.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/context.py similarity index 100% rename from src/memos/mem_scheduler/handlers/context.py rename to src/memos/mem_scheduler/task_schedule_modules/handlers/context.py diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/feedback_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/feedback_handler.py new file mode 100644 index 000000000..173d37b50 --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/feedback_handler.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import json + +from typing import TYPE_CHECKING + +from memos.log import get_logger +from memos.mem_scheduler.schemas.task_schemas import ( + LONG_TERM_MEMORY_TYPE, + MEM_FEEDBACK_TASK_LABEL, + USER_INPUT_TYPE, +) +from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler +from memos.mem_scheduler.utils.misc_utils import is_cloud_env + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +class FeedbackMessageHandler(BaseSchedulerHandler): + @property + def expected_task_label(self) -> str: + return MEM_FEEDBACK_TASK_LABEL + + def batch_handler( + self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem] + ) -> None: + for message in batch: + try: + self.process_single_feedback(message) + except Exception as e: + logger.error( + "Error processing feedbackMemory message: %s", + e, + exc_info=True, + ) + + def process_single_feedback(self, message: ScheduleMessageItem) -> None: + mem_cube = self.scheduler_context.get_mem_cube() + + user_id = message.user_id + mem_cube_id = message.mem_cube_id + content = message.content + + try: + feedback_data = json.loads(content) if isinstance(content, str) else content + if not isinstance(feedback_data, dict): + logger.error( + "Failed to decode feedback_data or it is not a dict: %s", feedback_data + ) + return + except json.JSONDecodeError: + logger.error("Invalid JSON content for feedback message: %s", content, exc_info=True) + return + + task_id = feedback_data.get("task_id") or message.task_id + feedback_result = self.scheduler_context.get_feedback_server().process_feedback( + user_id=user_id, + user_name=mem_cube_id, + session_id=feedback_data.get("session_id"), + chat_history=feedback_data.get("history", []), + retrieved_memory_ids=feedback_data.get("retrieved_memory_ids", []), + feedback_content=feedback_data.get("feedback_content"), + feedback_time=feedback_data.get("feedback_time"), + task_id=task_id, + info=feedback_data.get("info", None), + ) + + logger.info( + "Successfully processed feedback for user_id=%s, mem_cube_id=%s", + user_id, + mem_cube_id, + ) + + cloud_env = is_cloud_env() + if cloud_env: + record = feedback_result.get("record") if isinstance(feedback_result, dict) else {} + add_records = record.get("add") if isinstance(record, dict) else [] + update_records = record.get("update") if isinstance(record, dict) else [] + + def _extract_fields(mem_item): + mem_id = ( + getattr(mem_item, "id", None) + if not isinstance(mem_item, dict) + else mem_item.get("id") + ) + mem_memory = ( + getattr(mem_item, "memory", None) + if not isinstance(mem_item, dict) + else mem_item.get("memory") or mem_item.get("text") + ) + if mem_memory is None and isinstance(mem_item, dict): + mem_memory = mem_item.get("text") + original_content = ( + getattr(mem_item, "origin_memory", None) + if not isinstance(mem_item, dict) + else mem_item.get("origin_memory") + or mem_item.get("old_memory") + or mem_item.get("original_content") + ) + source_doc_id = None + if isinstance(mem_item, dict): + source_doc_id = mem_item.get("source_doc_id", None) + + return mem_id, mem_memory, original_content, source_doc_id + + kb_log_content: list[dict] = [] + + for mem_item in add_records or []: + mem_id, mem_memory, _, source_doc_id = _extract_fields(mem_item) + if mem_id and mem_memory: + kb_log_content.append( + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": "Feedback", + "operation": "ADD", + "memory_id": mem_id, + "content": mem_memory, + "original_content": None, + "source_doc_id": source_doc_id, + } + ) + else: + logger.warning( + "Skipping malformed feedback add item. user_id=%s mem_cube_id=%s task_id=%s item=%s", + user_id, + mem_cube_id, + task_id, + mem_item, + stack_info=True, + ) + + for mem_item in update_records or []: + mem_id, mem_memory, original_content, source_doc_id = _extract_fields(mem_item) + if mem_id and mem_memory: + kb_log_content.append( + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": "Feedback", + "operation": "UPDATE", + "memory_id": mem_id, + "content": mem_memory, + "original_content": original_content, + "source_doc_id": source_doc_id, + } + ) + else: + logger.warning( + "Skipping malformed feedback update item. user_id=%s mem_cube_id=%s task_id=%s item=%s", + user_id, + mem_cube_id, + task_id, + mem_item, + stack_info=True, + ) + + logger.info("[Feedback Scheduler] kb_log_content: %s", kb_log_content) + if kb_log_content: + logger.info( + "[DIAGNOSTIC] feedback_handler: Creating knowledgeBaseUpdate event for feedback. user_id=%s mem_cube_id=%s task_id=%s items=%s", + user_id, + mem_cube_id, + task_id, + len(kb_log_content), + ) + event = self.scheduler_context.services.create_event_log( + label="knowledgeBaseUpdate", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + memcube_log_content=kb_log_content, + metadata=None, + memory_len=len(kb_log_content), + memcube_name=self.scheduler_context.services.map_memcube_name(mem_cube_id), + ) + event.log_content = f"Knowledge Base Memory Update: {len(kb_log_content)} changes." + event.task_id = task_id + self.scheduler_context.services.submit_web_logs([event]) + else: + logger.warning( + "No valid feedback content generated for web log. user_id=%s mem_cube_id=%s task_id=%s", + user_id, + mem_cube_id, + task_id, + stack_info=True, + ) + else: + logger.info( + "Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)", + cloud_env, + ) diff --git a/src/memos/mem_scheduler/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py similarity index 78% rename from src/memos/mem_scheduler/handlers/mem_read_handler.py rename to src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index 76789f113..3bbed09e3 100644 --- a/src/memos/mem_scheduler/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -9,12 +9,12 @@ from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger -from memos.mem_scheduler.handlers.base import BaseSchedulerHandler from memos.mem_scheduler.schemas.task_schemas import ( LONG_TERM_MEMORY_TYPE, MEM_READ_TASK_LABEL, USER_INPUT_TYPE, ) +from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler from memos.mem_scheduler.utils.filter_utils import transform_name_to_key from memos.mem_scheduler.utils.misc_utils import is_cloud_env from memos.memories.textual.tree import TreeTextMemory @@ -27,76 +27,80 @@ class MemReadMessageHandler(BaseSchedulerHandler): - def handle(self, messages: list[ScheduleMessageItem]) -> None: + @property + def expected_task_label(self) -> str: + return MEM_READ_TASK_LABEL + + def batch_handler( + self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem] + ) -> None: logger.info( - "[DIAGNOSTIC] mem_read_handler called. Received messages: %s", - [msg.model_dump_json(indent=2) for msg in messages], + "[DIAGNOSTIC] mem_read_handler batch_handler called. Batch size: %s", len(batch) ) - logger.info(f"Messages {messages} assigned to {MEM_READ_TASK_LABEL} handler.") - - def process_message(message: ScheduleMessageItem): - try: - user_id = message.user_id - mem_cube_id = message.mem_cube_id - mem_cube = self.ctx.get_mem_cube() - if mem_cube is None: - logger.error( - "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing", - user_id, - mem_cube_id, - stack_info=True, - ) - return - content = message.content - user_name = message.user_name - info = message.info or {} - chat_history = message.chat_history - - mem_ids = json.loads(content) if isinstance(content, str) else content - if not mem_ids: - return + with ContextThreadPoolExecutor(max_workers=min(8, len(batch))) as executor: + futures = [executor.submit(self.process_message, msg) for msg in batch] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error("Thread task failed: %s", e, stack_info=True) - logger.info( - "Processing mem_read for user_id=%s, mem_cube_id=%s, mem_ids=%s", + def process_message(self, message: ScheduleMessageItem): + try: + user_id = message.user_id + mem_cube_id = message.mem_cube_id + mem_cube = self.scheduler_context.get_mem_cube() + if mem_cube is None: + logger.error( + "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing", user_id, mem_cube_id, - mem_ids, + stack_info=True, ) + return - text_mem = mem_cube.text_mem - if not isinstance(text_mem, TreeTextMemory): - logger.error("Expected TreeTextMemory but got %s", type(text_mem).__name__) - return + content = message.content + user_name = message.user_name + info = message.info or {} + chat_history = message.chat_history - self._process_memories_with_reader( - mem_ids=mem_ids, - user_id=user_id, - mem_cube_id=mem_cube_id, - text_mem=text_mem, - user_name=user_name, - custom_tags=info.get("custom_tags", None), - task_id=message.task_id, - info=info, - chat_history=chat_history, - ) + mem_ids = json.loads(content) if isinstance(content, str) else content + if not mem_ids: + return - logger.info( - "Successfully processed mem_read for user_id=%s, mem_cube_id=%s", - user_id, - mem_cube_id, - ) + logger.info( + "Processing mem_read for user_id=%s, mem_cube_id=%s, mem_ids=%s", + user_id, + mem_cube_id, + mem_ids, + ) - except Exception as e: - logger.error("Error processing mem_read message: %s", e, stack_info=True) + text_mem = mem_cube.text_mem + if not isinstance(text_mem, TreeTextMemory): + logger.error("Expected TreeTextMemory but got %s", type(text_mem).__name__) + return - with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: - futures = [executor.submit(process_message, msg) for msg in messages] - for future in concurrent.futures.as_completed(futures): - try: - future.result() - except Exception as e: - logger.error("Thread task failed: %s", e, stack_info=True) + self._process_memories_with_reader( + mem_ids=mem_ids, + user_id=user_id, + mem_cube_id=mem_cube_id, + text_mem=text_mem, + user_name=user_name, + custom_tags=info.get("custom_tags", None), + task_id=message.task_id, + info=info, + chat_history=chat_history, + ) + + logger.info( + "Successfully processed mem_read for user_id=%s, mem_cube_id=%s", + user_id, + mem_cube_id, + ) + + except Exception as e: + logger.error("Error processing mem_read message: %s", e, stack_info=True) def _process_memories_with_reader( self, @@ -119,7 +123,7 @@ def _process_memories_with_reader( ) kb_log_content: list[dict] = [] try: - mem_reader = self.ctx.get_mem_reader() + mem_reader = self.scheduler_context.get_mem_reader() if mem_reader is None: logger.warning( "mem_reader not available in scheduler, skipping enhanced processing" @@ -244,23 +248,25 @@ def _process_memories_with_reader( task_id, json.dumps(kb_log_content, indent=2), ) - event = self.ctx.services.create_event_log( + event = self.scheduler_context.services.create_event_log( label="knowledgeBaseUpdate", from_memory_type=USER_INPUT_TYPE, to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=self.ctx.get_mem_cube(), + mem_cube=self.scheduler_context.get_mem_cube(), memcube_log_content=kb_log_content, metadata=None, memory_len=len(kb_log_content), - memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + memcube_name=self.scheduler_context.services.map_memcube_name( + mem_cube_id + ), ) event.log_content = ( f"Knowledge Base Memory Update: {len(kb_log_content)} changes." ) event.task_id = task_id - self.ctx.services.submit_web_logs([event]) + self.scheduler_context.services.submit_web_logs([event]) else: add_content_legacy: list[dict] = [] add_meta_legacy: list[dict] = [] @@ -288,20 +294,22 @@ def _process_memories_with_reader( } ) if add_content_legacy: - event = self.ctx.services.create_event_log( + event = self.scheduler_context.services.create_event_log( label="addMemory", from_memory_type=USER_INPUT_TYPE, to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=self.ctx.get_mem_cube(), + mem_cube=self.scheduler_context.get_mem_cube(), memcube_log_content=add_content_legacy, metadata=add_meta_legacy, memory_len=len(add_content_legacy), - memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + memcube_name=self.scheduler_context.services.map_memcube_name( + mem_cube_id + ), ) event.task_id = task_id - self.ctx.services.submit_web_logs([event]) + self.scheduler_context.services.submit_web_logs([event]) else: logger.info("No enhanced memories generated by mem_reader") else: @@ -351,19 +359,19 @@ def _process_memories_with_reader( } for mem_id in mem_ids ] - event = self.ctx.services.create_event_log( + event = self.scheduler_context.services.create_event_log( label="knowledgeBaseUpdate", from_memory_type=USER_INPUT_TYPE, to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=self.ctx.get_mem_cube(), + mem_cube=self.scheduler_context.get_mem_cube(), memcube_log_content=kb_log_content, metadata=None, memory_len=len(kb_log_content), - memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + memcube_name=self.scheduler_context.services.map_memcube_name(mem_cube_id), ) event.log_content = f"Knowledge Base Memory Update failed: {exc!s}" event.task_id = task_id event.status = "failed" - self.ctx.services.submit_web_logs([event]) + self.scheduler_context.services.submit_web_logs([event]) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_reorganize_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_reorganize_handler.py new file mode 100644 index 000000000..1d1dbe566 --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_reorganize_handler.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +import concurrent.futures +import contextlib +import json +import traceback + +from typing import TYPE_CHECKING + +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_scheduler.schemas.task_schemas import ( + LONG_TERM_MEMORY_TYPE, + MEM_ORGANIZE_TASK_LABEL, +) +from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler +from memos.mem_scheduler.utils.filter_utils import transform_name_to_key +from memos.memories.textual.tree import TreeTextMemory + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + from memos.memories.textual.item import TextualMemoryItem + + +class MemReorganizeMessageHandler(BaseSchedulerHandler): + @property + def expected_task_label(self) -> str: + return MEM_ORGANIZE_TASK_LABEL + + def batch_handler( + self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem] + ) -> None: + with ContextThreadPoolExecutor(max_workers=min(8, len(batch))) as executor: + futures = [executor.submit(self.process_message, msg) for msg in batch] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error("Thread task failed: %s", e, exc_info=True) + + def process_message(self, message: ScheduleMessageItem): + try: + user_id = message.user_id + mem_cube_id = message.mem_cube_id + mem_cube = self.scheduler_context.get_mem_cube() + if mem_cube is None: + logger.warning( + "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing", + user_id, + mem_cube_id, + ) + return + content = message.content + user_name = message.user_name + + mem_ids = json.loads(content) if isinstance(content, str) else content + if not mem_ids: + return + + logger.info( + "Processing mem_reorganize for user_id=%s, mem_cube_id=%s, mem_ids=%s", + user_id, + mem_cube_id, + mem_ids, + ) + + text_mem = mem_cube.text_mem + if not isinstance(text_mem, TreeTextMemory): + logger.error("Expected TreeTextMemory but got %s", type(text_mem).__name__) + return + + self._process_memories_with_reorganize( + mem_ids=mem_ids, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + text_mem=text_mem, + user_name=user_name, + ) + + with contextlib.suppress(Exception): + mem_items: list[TextualMemoryItem] = [] + for mid in mem_ids: + with contextlib.suppress(Exception): + mem_items.append(text_mem.get(mid, user_name=user_name)) + if len(mem_items) > 1: + keys: list[str] = [] + memcube_content: list[dict] = [] + meta: list[dict] = [] + merged_target_ids: set[str] = set() + with contextlib.suppress(Exception): + if hasattr(text_mem, "graph_store"): + for mid in mem_ids: + edges = text_mem.graph_store.get_edges( + mid, type="MERGED_TO", direction="OUT" + ) + for edge in edges: + target = edge.get("to") or edge.get("dst") or edge.get("target") + if target: + merged_target_ids.add(target) + for item in mem_items: + key = getattr( + getattr(item, "metadata", {}), "key", None + ) or transform_name_to_key(getattr(item, "memory", "")) + keys.append(key) + memcube_content.append( + {"content": key or "(no key)", "ref_id": item.id, "type": "merged"} + ) + meta.append( + { + "ref_id": item.id, + "id": item.id, + "key": key, + "memory": item.memory, + "memory_type": item.metadata.memory_type, + "status": item.metadata.status, + "confidence": item.metadata.confidence, + "tags": item.metadata.tags, + "updated_at": getattr(item.metadata, "updated_at", None) + or getattr(item.metadata, "update_at", None), + } + ) + combined_key = keys[0] if keys else "" + post_ref_id = None + post_meta = { + "ref_id": None, + "id": None, + "key": None, + "memory": None, + "memory_type": None, + "status": None, + "confidence": None, + "tags": None, + "updated_at": None, + } + if merged_target_ids: + post_ref_id = next(iter(merged_target_ids)) + with contextlib.suppress(Exception): + merged_item = text_mem.get(post_ref_id, user_name=user_name) + combined_key = ( + getattr(getattr(merged_item, "metadata", {}), "key", None) + or combined_key + ) + post_meta = { + "ref_id": post_ref_id, + "id": post_ref_id, + "key": getattr(getattr(merged_item, "metadata", {}), "key", None), + "memory": getattr(merged_item, "memory", None), + "memory_type": getattr( + getattr(merged_item, "metadata", {}), "memory_type", None + ), + "status": getattr( + getattr(merged_item, "metadata", {}), "status", None + ), + "confidence": getattr( + getattr(merged_item, "metadata", {}), "confidence", None + ), + "tags": getattr(getattr(merged_item, "metadata", {}), "tags", None), + "updated_at": getattr( + getattr(merged_item, "metadata", {}), "updated_at", None + ) + or getattr(getattr(merged_item, "metadata", {}), "update_at", None), + } + if not post_ref_id: + import hashlib + + post_ref_id = ( + "merge-" + hashlib.md5("".join(sorted(mem_ids)).encode()).hexdigest() + ) + post_meta["ref_id"] = post_ref_id + post_meta["id"] = post_ref_id + if not post_meta.get("key"): + post_meta["key"] = combined_key + if not keys: + keys = [item.id for item in mem_items] + memcube_content.append( + { + "content": combined_key if combined_key else "(no key)", + "ref_id": post_ref_id, + "type": "postMerge", + } + ) + meta.append(post_meta) + event = self.scheduler_context.services.create_event_log( + label="mergeMemory", + from_memory_type=LONG_TERM_MEMORY_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + memcube_log_content=memcube_content, + metadata=meta, + memory_len=len(keys), + memcube_name=self.scheduler_context.services.map_memcube_name(mem_cube_id), + ) + self.scheduler_context.services.submit_web_logs([event]) + + logger.info( + "Successfully processed mem_reorganize for user_id=%s, mem_cube_id=%s", + user_id, + mem_cube_id, + ) + + except Exception as e: + logger.error("Error processing mem_reorganize message: %s", e, exc_info=True) + + def _process_memories_with_reorganize( + self, + mem_ids: list[str], + user_id: str, + mem_cube_id: str, + mem_cube, + text_mem: TreeTextMemory, + user_name: str, + ) -> None: + try: + mem_reader = self.scheduler_context.get_mem_reader() + if mem_reader is None: + logger.warning( + "mem_reader not available in scheduler, skipping enhanced processing" + ) + return + + memory_items = [] + for mem_id in mem_ids: + try: + memory_item = text_mem.get(mem_id, user_name=user_name) + memory_items.append(memory_item) + except Exception as e: + logger.warning( + "Failed to get memory %s: %s|%s", mem_id, e, traceback.format_exc() + ) + continue + + if not memory_items: + logger.warning("No valid memory items found for processing") + return + + logger.info("Processing %s memories with mem_reader", len(memory_items)) + text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) + logger.info("Remove and Refresh Memories") + logger.debug("Finished add %s memory: %s", user_id, mem_ids) + + except Exception: + logger.error( + "Error in _process_memories_with_reorganize: %s", + traceback.format_exc(), + exc_info=True, + ) diff --git a/src/memos/mem_scheduler/handlers/memory_update_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/memory_update_handler.py similarity index 86% rename from src/memos/mem_scheduler/handlers/memory_update_handler.py rename to src/memos/mem_scheduler/task_schedule_modules/handlers/memory_update_handler.py index 0d3d1719e..a8968e878 100644 --- a/src/memos/mem_scheduler/handlers/memory_update_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/memory_update_handler.py @@ -3,15 +3,14 @@ from typing import TYPE_CHECKING from memos.log import get_logger -from memos.mem_scheduler.handlers.base import BaseSchedulerHandler from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem from memos.mem_scheduler.schemas.task_schemas import ( DEFAULT_MAX_QUERY_KEY_WORDS, MEM_UPDATE_TASK_LABEL, QUERY_TASK_LABEL, ) +from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler from memos.mem_scheduler.utils.filter_utils import is_all_chinese, is_all_english -from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.memories.textual.naive import NaiveTextMemory from memos.memories.textual.tree import TreeTextMemory @@ -25,21 +24,14 @@ class MemoryUpdateHandler(BaseSchedulerHandler): - def handle(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {MEM_UPDATE_TASK_LABEL} handler.") + @property + def expected_task_label(self) -> str: + return MEM_UPDATE_TASK_LABEL - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - - self.ctx.services.validate_messages(messages=messages, label=MEM_UPDATE_TASK_LABEL) - - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - self.long_memory_update_process( - user_id=user_id, mem_cube_id=mem_cube_id, messages=batch - ) + def batch_handler( + self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem] + ) -> None: + self.long_memory_update_process(user_id=user_id, mem_cube_id=mem_cube_id, messages=batch) def long_memory_update_process( self, @@ -47,10 +39,10 @@ def long_memory_update_process( mem_cube_id: str, messages: list[ScheduleMessageItem], ) -> None: - mem_cube = self.ctx.get_mem_cube() - monitor = self.ctx.get_monitor() + mem_cube = self.scheduler_context.get_mem_cube() + monitor = self.scheduler_context.get_monitor() - query_key_words_limit = self.ctx.get_query_key_words_limit() + query_key_words_limit = self.scheduler_context.get_query_key_words_limit() for msg in messages: monitor.register_query_monitor_if_not_exists(user_id=user_id, mem_cube_id=mem_cube_id) @@ -111,7 +103,7 @@ def long_memory_update_process( user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, - top_k=self.ctx.get_top_k(), + top_k=self.scheduler_context.get_top_k(), ) logger.info( "[long_memory_update_process] Processed %s queries %s and retrieved %s new candidate memories for user_id=%s: " @@ -122,7 +114,7 @@ def long_memory_update_process( user_id, ) - new_order_working_memory = self.ctx.services.replace_working_memory( + new_order_working_memory = self.scheduler_context.services.replace_working_memory( user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, @@ -158,11 +150,11 @@ def long_memory_update_process( logger.debug( "Activation memory update %s (interval: %ss)", - "enabled" if self.ctx.get_enable_activation_memory() else "disabled", + "enabled" if self.scheduler_context.get_enable_activation_memory() else "disabled", monitor.act_mem_update_interval, ) - if self.ctx.get_enable_activation_memory(): - self.ctx.services.update_activation_memory_periodically( + if self.scheduler_context.get_enable_activation_memory(): + self.scheduler_context.services.update_activation_memory_periodically( interval_seconds=monitor.act_mem_update_interval, label=QUERY_TASK_LABEL, user_id=user_id, @@ -207,7 +199,7 @@ def process_session_turn( ) text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] - monitor = self.ctx.get_monitor() + monitor = self.scheduler_context.get_monitor() intent_result = monitor.detect_intent( q_list=queries, text_working_memory=text_working_memory ) @@ -247,8 +239,8 @@ def process_session_turn( num_evidence = len(missing_evidences) k_per_evidence = max(1, top_k // max(1, num_evidence)) new_candidates: list[TextualMemoryItem] = [] - retriever = self.ctx.get_retriever() - search_method = self.ctx.get_search_method() + retriever = self.scheduler_context.get_retriever() + search_method = self.scheduler_context.get_search_method() for item in missing_evidences: logger.info( diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py new file mode 100644 index 000000000..1d03e0476 --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import concurrent.futures +import json + +from typing import TYPE_CHECKING + +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_scheduler.schemas.task_schemas import PREF_ADD_TASK_LABEL +from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler +from memos.memories.textual.preference import PreferenceTextMemory + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +class PrefAddMessageHandler(BaseSchedulerHandler): + @property + def expected_task_label(self) -> str: + return PREF_ADD_TASK_LABEL + + def batch_handler( + self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem] + ) -> None: + with ContextThreadPoolExecutor(max_workers=min(8, len(batch))) as executor: + futures = [executor.submit(self.process_message, msg) for msg in batch] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error("Thread task failed: %s", e, exc_info=True) + + def process_message(self, message: ScheduleMessageItem): + try: + mem_cube = self.scheduler_context.get_mem_cube() + if mem_cube is None: + logger.warning( + "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing", + message.user_id, + message.mem_cube_id, + ) + return + + user_id = message.user_id + session_id = message.session_id + mem_cube_id = message.mem_cube_id + content = message.content + messages_list = json.loads(content) + info = message.info or {} + + logger.info("Processing pref_add for user_id=%s, mem_cube_id=%s", user_id, mem_cube_id) + + pref_mem = mem_cube.pref_mem + if pref_mem is None: + logger.warning( + "Preference memory not initialized for mem_cube_id=%s, skipping pref_add processing", + mem_cube_id, + ) + return + if not isinstance(pref_mem, PreferenceTextMemory): + logger.error( + "Expected PreferenceTextMemory but got %s for mem_cube_id=%s", + type(pref_mem).__name__, + mem_cube_id, + ) + return + + pref_memories = pref_mem.get_memory( + messages_list, + type="chat", + info={ + **info, + "user_id": user_id, + "session_id": session_id, + "mem_cube_id": mem_cube_id, + }, + ) + pref_ids = pref_mem.add(pref_memories) + + logger.info( + "Successfully processed and add preferences for user_id=%s, mem_cube_id=%s, pref_ids=%s", + user_id, + mem_cube_id, + pref_ids, + ) + + except Exception as e: + logger.error("Error processing pref_add message: %s", e, exc_info=True) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/query_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/query_handler.py new file mode 100644 index 000000000..29f0a88fd --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/query_handler.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from memos.log import get_logger +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + MEM_UPDATE_TASK_LABEL, + NOT_APPLICABLE_TYPE, + QUERY_TASK_LABEL, + USER_INPUT_TYPE, +) +from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler + + +logger = get_logger(__name__) + + +class QueryMessageHandler(BaseSchedulerHandler): + @property + def expected_task_label(self) -> str: + return QUERY_TASK_LABEL + + def batch_handler( + self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem] + ) -> None: + mem_update_messages: list[ScheduleMessageItem] = [] + for msg in batch: + try: + event = self.scheduler_context.services.create_event_log( + label="addMessage", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=NOT_APPLICABLE_TYPE, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=self.scheduler_context.get_mem_cube(), + memcube_log_content=[ + { + "content": f"[User] {msg.content}", + "ref_id": msg.item_id, + "role": "user", + } + ], + metadata=[], + memory_len=1, + memcube_name=self.scheduler_context.services.map_memcube_name(msg.mem_cube_id), + ) + event.task_id = msg.task_id + self.scheduler_context.services.submit_web_logs([event]) + except Exception: + logger.exception("Failed to record addMessage log for query") + + update_msg = ScheduleMessageItem( + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + label=MEM_UPDATE_TASK_LABEL, + content=msg.content, + session_id=msg.session_id, + user_name=msg.user_name, + info=msg.info, + task_id=msg.task_id, + ) + mem_update_messages.append(update_msg) + + if mem_update_messages: + self.scheduler_context.services.submit_messages(messages=mem_update_messages) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/registry.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/registry.py new file mode 100644 index 000000000..98b3bbb53 --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/registry.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + + from .context import SchedulerHandlerContext + +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + ANSWER_TASK_LABEL, + MEM_FEEDBACK_TASK_LABEL, + MEM_ORGANIZE_TASK_LABEL, + MEM_READ_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, + PREF_ADD_TASK_LABEL, + QUERY_TASK_LABEL, +) + +from .add_handler import AddMessageHandler +from .answer_handler import AnswerMessageHandler +from .feedback_handler import FeedbackMessageHandler +from .mem_read_handler import MemReadMessageHandler +from .mem_reorganize_handler import MemReorganizeMessageHandler +from .memory_update_handler import MemoryUpdateHandler +from .pref_add_handler import PrefAddMessageHandler +from .query_handler import QueryMessageHandler + + +class SchedulerHandlerRegistry: + def __init__(self, ctx: SchedulerHandlerContext) -> None: + self.query = QueryMessageHandler(ctx) + self.answer = AnswerMessageHandler(ctx) + self.add = AddMessageHandler(ctx) + self.memory_update = MemoryUpdateHandler(ctx) + self.mem_feedback = FeedbackMessageHandler(ctx) + self.mem_read = MemReadMessageHandler(ctx) + self.mem_reorganize = MemReorganizeMessageHandler(ctx) + self.pref_add = PrefAddMessageHandler(ctx) + + def build_dispatch_map(self) -> dict[str, Callable]: + return { + QUERY_TASK_LABEL: self.query, + ANSWER_TASK_LABEL: self.answer, + MEM_UPDATE_TASK_LABEL: self.memory_update, + ADD_TASK_LABEL: self.add, + MEM_READ_TASK_LABEL: self.mem_read, + MEM_ORGANIZE_TASK_LABEL: self.mem_reorganize, + PREF_ADD_TASK_LABEL: self.pref_add, + MEM_FEEDBACK_TASK_LABEL: self.mem_feedback, + } From 92d27254e8bc03672ba89ac5c584bc5fa3c8b6a4 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 5 Feb 2026 14:42:02 +0800 Subject: [PATCH 09/14] refactor: rename ctx to scheduler_context in mem_scheduler Rename abbreviation 'ctx' to 'scheduler_context' in GeneralScheduler and SchedulerHandlerRegistry to improve code readability and clarity. --- src/memos/mem_scheduler/general_scheduler.py | 4 ++-- .../task_schedule_modules/handlers/registry.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 9860e3984..6d6f38d95 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -31,7 +31,7 @@ def __init__(self, config: GeneralSchedulerConfig): transform_working_memories_to_monitors=self.transform_working_memories_to_monitors, log_working_memory_replacement=self.log_working_memory_replacement, ) - ctx = SchedulerHandlerContext( + scheduler_context = SchedulerHandlerContext( get_mem_cube=lambda: self.mem_cube, get_monitor=lambda: self.monitor, get_retriever=lambda: self.retriever, @@ -44,5 +44,5 @@ def __init__(self, config: GeneralSchedulerConfig): services=services, ) - self._handler_registry = SchedulerHandlerRegistry(ctx) + self._handler_registry = SchedulerHandlerRegistry(scheduler_context) self.register_handlers(self._handler_registry.build_dispatch_map()) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/registry.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/registry.py index 98b3bbb53..8b12b44ba 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/registry.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/registry.py @@ -30,15 +30,15 @@ class SchedulerHandlerRegistry: - def __init__(self, ctx: SchedulerHandlerContext) -> None: - self.query = QueryMessageHandler(ctx) - self.answer = AnswerMessageHandler(ctx) - self.add = AddMessageHandler(ctx) - self.memory_update = MemoryUpdateHandler(ctx) - self.mem_feedback = FeedbackMessageHandler(ctx) - self.mem_read = MemReadMessageHandler(ctx) - self.mem_reorganize = MemReorganizeMessageHandler(ctx) - self.pref_add = PrefAddMessageHandler(ctx) + def __init__(self, scheduler_context: SchedulerHandlerContext) -> None: + self.query = QueryMessageHandler(scheduler_context) + self.answer = AnswerMessageHandler(scheduler_context) + self.add = AddMessageHandler(scheduler_context) + self.memory_update = MemoryUpdateHandler(scheduler_context) + self.mem_feedback = FeedbackMessageHandler(scheduler_context) + self.mem_read = MemReadMessageHandler(scheduler_context) + self.mem_reorganize = MemReorganizeMessageHandler(scheduler_context) + self.pref_add = PrefAddMessageHandler(scheduler_context) def build_dispatch_map(self) -> dict[str, Callable]: return { From 48701b7bac9ca5f591a316d9b76ac73071271ba0 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 5 Feb 2026 15:31:56 +0800 Subject: [PATCH 10/14] refactor: sync orchestrator config and fix scheduler imports - Sync orchestrator config removal when unregistering handlers in dispatcher - Fix missing TaskPriorityLevel import in dispatcher - Fix register_handlers signature in BaseSchedulerQueueMixin - Fix handler registry imports and initialization map - Fix relative imports in handlers package --- .../mem_scheduler/base_mixins/queue_ops.py | 9 +++- .../{handlers => }/context.py | 0 .../task_schedule_modules/dispatcher.py | 52 ++++++++++++++++--- .../handlers/__init__.py | 7 ++- .../task_schedule_modules/orchestrator.py | 40 +++++++++++--- .../{handlers => }/registry.py | 28 +++++----- 6 files changed, 105 insertions(+), 31 deletions(-) rename src/memos/mem_scheduler/task_schedule_modules/{handlers => }/context.py (100%) rename src/memos/mem_scheduler/task_schedule_modules/{handlers => }/registry.py (60%) diff --git a/src/memos/mem_scheduler/base_mixins/queue_ops.py b/src/memos/mem_scheduler/base_mixins/queue_ops.py index e5709ff36..590189c24 100644 --- a/src/memos/mem_scheduler/base_mixins/queue_ops.py +++ b/src/memos/mem_scheduler/base_mixins/queue_ops.py @@ -336,7 +336,14 @@ def handlers(self) -> dict[str, Callable]: return self.dispatcher.handlers def register_handlers( - self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]] + self, + handlers: dict[ + str, + Callable[[list[ScheduleMessageItem]], None] + | tuple[ + Callable[[list[ScheduleMessageItem]], None], TaskPriorityLevel | None, int | None + ], + ], ) -> None: if not self.dispatcher: logger.warning("Dispatcher is not initialized, cannot register handlers") diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/context.py b/src/memos/mem_scheduler/task_schedule_modules/context.py similarity index 100% rename from src/memos/mem_scheduler/task_schedule_modules/handlers/context.py rename to src/memos/mem_scheduler/task_schedule_modules/context.py diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index 2099da5a1..74ab15209 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -20,7 +20,7 @@ DEFAULT_STOP_WAIT, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem, ScheduleMessageItem -from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem +from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem, TaskPriorityLevel from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue @@ -428,34 +428,70 @@ def get_running_task_count(self) -> int: with self._task_lock: return len(self._running_tasks) - def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]): + def register_handler( + self, + label: str, + handler: Callable[[list[ScheduleMessageItem]], None], + priority: TaskPriorityLevel | None = None, + min_idle_ms: int | None = None, + ): """ Register a handler function for a specific message label. Args: label: Message label to handle handler: Callable that processes messages of this label + priority: Optional priority level for the task + min_idle_ms: Optional minimum idle time for task claiming """ self.handlers[label] = handler + if self.orchestrator: + self.orchestrator.set_task_config( + task_label=label, priority=priority, min_idle_ms=min_idle_ms + ) def register_handlers( - self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]] + self, + handlers: dict[ + str, + Callable[[list[ScheduleMessageItem]], None] + | tuple[ + Callable[[list[ScheduleMessageItem]], None], TaskPriorityLevel | None, int | None + ], + ], ) -> None: """ Bulk register multiple handlers from a dictionary. Args: - handlers: Dictionary mapping labels to handler functions - Format: {label: handler_callable} + handlers: Dictionary where key is label and value is either: + - handler_callable + - tuple(handler_callable, priority, min_idle_ms) """ - for label, handler in handlers.items(): + for label, value in handlers.items(): if not isinstance(label, str): logger.error(f"Invalid label type: {type(label)}. Expected str.") continue + + if isinstance(value, tuple): + if len(value) != 3: + logger.error( + f"Invalid handler tuple for label '{label}'. Expected (handler, priority, min_idle_ms)." + ) + continue + handler, priority, min_idle_ms = value + else: + handler = value + priority = None + min_idle_ms = None + if not callable(handler): logger.error(f"Handler for label '{label}' is not callable.") continue - self.register_handler(label=label, handler=handler) + + self.register_handler( + label=label, handler=handler, priority=priority, min_idle_ms=min_idle_ms + ) logger.info(f"Registered {len(handlers)} handlers in bulk") def unregister_handler(self, label: str) -> bool: @@ -470,6 +506,8 @@ def unregister_handler(self, label: str) -> bool: """ if label in self.handlers: del self.handlers[label] + if self.orchestrator: + self.orchestrator.remove_task_config(label) logger.info(f"Unregistered handler for label: {label}") return True else: diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/__init__.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/__init__.py index 75c56791a..e5700e641 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/__init__.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/__init__.py @@ -1,5 +1,8 @@ -from .context import SchedulerHandlerContext, SchedulerHandlerServices -from .registry import SchedulerHandlerRegistry +from memos.mem_scheduler.task_schedule_modules.context import ( + SchedulerHandlerContext, + SchedulerHandlerServices, +) +from memos.mem_scheduler.task_schedule_modules.registry import SchedulerHandlerRegistry __all__ = [ diff --git a/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py b/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py index cb5a49421..af46b3dcd 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py +++ b/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py @@ -17,11 +17,8 @@ from memos.log import get_logger from memos.mem_scheduler.schemas.task_schemas import ( - ADD_TASK_LABEL, - ANSWER_TASK_LABEL, DEFAULT_PENDING_CLAIM_MIN_IDLE_MS, PREF_ADD_TASK_LABEL, - QUERY_TASK_LABEL, TaskPriorityLevel, ) from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule @@ -38,11 +35,7 @@ def __init__(self): """ # Cache of fetched messages grouped by (user_id, mem_cube_id, task_label) self._cache = None - self.tasks_priorities = { - ADD_TASK_LABEL: TaskPriorityLevel.LEVEL_1, - QUERY_TASK_LABEL: TaskPriorityLevel.LEVEL_1, - ANSWER_TASK_LABEL: TaskPriorityLevel.LEVEL_1, - } + self.tasks_priorities = {} # Per-task minimum idle time (ms) before claiming pending messages # Default fallback handled in `get_task_idle_min`. @@ -54,6 +47,37 @@ def __init__(self): def get_stream_priorities(self) -> None | dict: return None + def set_task_config( + self, + task_label: str, + priority: TaskPriorityLevel | None = None, + min_idle_ms: int | None = None, + ): + """ + Dynamically register or update task configuration. + + Args: + task_label: The label of the task. + priority: The priority level of the task. + min_idle_ms: The minimum idle time (ms) for claiming pending messages. + """ + if priority is not None: + self.tasks_priorities[task_label] = priority + if min_idle_ms is not None: + self.tasks_min_idle_ms[task_label] = min_idle_ms + + def remove_task_config(self, task_label: str): + """ + Remove task configuration for a specific label. + + Args: + task_label: The label of the task to remove configuration for. + """ + if task_label in self.tasks_priorities: + del self.tasks_priorities[task_label] + if task_label in self.tasks_min_idle_ms: + del self.tasks_min_idle_ms[task_label] + def get_task_priority(self, task_label: str): return self.tasks_priorities.get(task_label, TaskPriorityLevel.LEVEL_3) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/registry.py b/src/memos/mem_scheduler/task_schedule_modules/registry.py similarity index 60% rename from src/memos/mem_scheduler/task_schedule_modules/handlers/registry.py rename to src/memos/mem_scheduler/task_schedule_modules/registry.py index 8b12b44ba..962c8b954 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/registry.py +++ b/src/memos/mem_scheduler/task_schedule_modules/registry.py @@ -17,16 +17,17 @@ MEM_UPDATE_TASK_LABEL, PREF_ADD_TASK_LABEL, QUERY_TASK_LABEL, + TaskPriorityLevel, ) -from .add_handler import AddMessageHandler -from .answer_handler import AnswerMessageHandler -from .feedback_handler import FeedbackMessageHandler -from .mem_read_handler import MemReadMessageHandler -from .mem_reorganize_handler import MemReorganizeMessageHandler -from .memory_update_handler import MemoryUpdateHandler -from .pref_add_handler import PrefAddMessageHandler -from .query_handler import QueryMessageHandler +from .handlers.add_handler import AddMessageHandler +from .handlers.answer_handler import AnswerMessageHandler +from .handlers.feedback_handler import FeedbackMessageHandler +from .handlers.mem_read_handler import MemReadMessageHandler +from .handlers.mem_reorganize_handler import MemReorganizeMessageHandler +from .handlers.memory_update_handler import MemoryUpdateHandler +from .handlers.pref_add_handler import PrefAddMessageHandler +from .handlers.query_handler import QueryMessageHandler class SchedulerHandlerRegistry: @@ -41,13 +42,14 @@ def __init__(self, scheduler_context: SchedulerHandlerContext) -> None: self.pref_add = PrefAddMessageHandler(scheduler_context) def build_dispatch_map(self) -> dict[str, Callable]: - return { - QUERY_TASK_LABEL: self.query, - ANSWER_TASK_LABEL: self.answer, + predefined_handlers = { + QUERY_TASK_LABEL: (self.query, TaskPriorityLevel.LEVEL_1, None), + ANSWER_TASK_LABEL: (self.answer, TaskPriorityLevel.LEVEL_1, None), MEM_UPDATE_TASK_LABEL: self.memory_update, - ADD_TASK_LABEL: self.add, + ADD_TASK_LABEL: (self.add, TaskPriorityLevel.LEVEL_1, None), MEM_READ_TASK_LABEL: self.mem_read, MEM_ORGANIZE_TASK_LABEL: self.mem_reorganize, - PREF_ADD_TASK_LABEL: self.pref_add, + PREF_ADD_TASK_LABEL: (self.pref_add, None, 600_000), MEM_FEEDBACK_TASK_LABEL: self.mem_feedback, } + return predefined_handlers From 9031d0f19552a7bb6bf63945e324d1f3c2880f10 Mon Sep 17 00:00:00 2001 From: "glin1993@outlook.com" <> Date: Thu, 5 Feb 2026 18:11:50 +0800 Subject: [PATCH 11/14] fix: resolve PR #1 review issues (P0 imports, P3 types/init) --- src/memos/mem_scheduler/base_scheduler.py | 3 +-- .../mem_scheduler/memory_manage_modules/post_processor.py | 2 +- src/memos/mem_scheduler/task_schedule_modules/base_handler.py | 2 +- src/memos/mem_scheduler/task_schedule_modules/registry.py | 2 +- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 2e408b222..f733615e0 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -229,8 +229,7 @@ def initialize_modules( self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config) self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config) - # Initialize search service (will be updated with searcher when mem_cube is initialized) - self.search_service = SchedulerSearchService(searcher=self.searcher) + # Initialize post-processor for memory enhancement and filtering self.post_processor = MemoryPostProcessor( diff --git a/src/memos/mem_scheduler/memory_manage_modules/post_processor.py b/src/memos/mem_scheduler/memory_manage_modules/post_processor.py index 2e1821e1e..569f07667 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/post_processor.py +++ b/src/memos/mem_scheduler/memory_manage_modules/post_processor.py @@ -22,7 +22,7 @@ transform_name_to_key, ) from memos.memories.textual.item import TextualMemoryItem -from memos.utils import extract_json_obj +from memos.mem_scheduler.utils.misc_utils import extract_json_obj from .memory_filter import MemoryFilter diff --git a/src/memos/mem_scheduler/task_schedule_modules/base_handler.py b/src/memos/mem_scheduler/task_schedule_modules/base_handler.py index 5e40588e0..603b038e1 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/base_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/base_handler.py @@ -11,7 +11,7 @@ from collections.abc import Callable from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem - from memos.mem_scheduler.task_schedule_modules.handlers.context import SchedulerHandlerContext + from memos.mem_scheduler.task_schedule_modules.context import SchedulerHandlerContext logger = get_logger(__name__) diff --git a/src/memos/mem_scheduler/task_schedule_modules/registry.py b/src/memos/mem_scheduler/task_schedule_modules/registry.py index 962c8b954..f47be933e 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/registry.py +++ b/src/memos/mem_scheduler/task_schedule_modules/registry.py @@ -41,7 +41,7 @@ def __init__(self, scheduler_context: SchedulerHandlerContext) -> None: self.mem_reorganize = MemReorganizeMessageHandler(scheduler_context) self.pref_add = PrefAddMessageHandler(scheduler_context) - def build_dispatch_map(self) -> dict[str, Callable]: + def build_dispatch_map(self) -> dict[str, Callable | tuple]: predefined_handlers = { QUERY_TASK_LABEL: (self.query, TaskPriorityLevel.LEVEL_1, None), ANSWER_TASK_LABEL: (self.answer, TaskPriorityLevel.LEVEL_1, None), From d28c7982804e53e73ee1d74718cbd1c27af0bcec Mon Sep 17 00:00:00 2001 From: fancy Date: Thu, 5 Feb 2026 19:24:28 +0800 Subject: [PATCH 12/14] chore(ruff): fix unused unpacked vars - prefix unused unpacked vars with underscore - apply ruff format changes --- examples/mem_agent/deepsearch_example.py | 2 +- src/memos/mem_feedback/feedback.py | 2 +- src/memos/mem_scheduler/base_scheduler.py | 2 -- .../mem_scheduler/memory_manage_modules/post_processor.py | 2 +- .../memories/textual/tree_text_memory/retrieve/searcher.py | 2 +- src/memos/memos_tools/dinding_report_bot.py | 2 +- src/memos/types/openai_chat_completion_types/__init__.py | 2 +- .../chat_completion_assistant_message_param.py | 2 +- .../chat_completion_system_message_param.py | 2 +- .../chat_completion_tool_message_param.py | 2 +- .../chat_completion_user_message_param.py | 2 +- tests/mem_reader/test_coarse_memory_type.py | 4 ++-- 12 files changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index 6a9405456..6dbe202c2 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -179,7 +179,7 @@ def factory_initialization() -> tuple[DeepSearchMemAgent, dict[str, Any]]: def main(): - agent_factory, components_factory = factory_initialization() + agent_factory, _components_factory = factory_initialization() results = agent_factory.run( "Caroline met up with friends, family, and mentors in early July 2023.", user_id="locomo_exp_user_0_speaker_b_ct-1118", diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index e38318a64..b793c49a2 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -1165,7 +1165,7 @@ def process_feedback( info, **kwargs, ) - done, pending = concurrent.futures.wait([answer_future, core_future], timeout=30) + _done, pending = concurrent.futures.wait([answer_future, core_future], timeout=30) for fut in pending: fut.cancel() try: diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index f733615e0..7c26336ed 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -229,8 +229,6 @@ def initialize_modules( self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config) self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config) - - # Initialize post-processor for memory enhancement and filtering self.post_processor = MemoryPostProcessor( process_llm=self.process_llm, config=self.config diff --git a/src/memos/mem_scheduler/memory_manage_modules/post_processor.py b/src/memos/mem_scheduler/memory_manage_modules/post_processor.py index 569f07667..28dc22925 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/post_processor.py +++ b/src/memos/mem_scheduler/memory_manage_modules/post_processor.py @@ -21,8 +21,8 @@ filter_vector_based_similar_memories, transform_name_to_key, ) -from memos.memories.textual.item import TextualMemoryItem from memos.mem_scheduler.utils.misc_utils import extract_json_obj +from memos.memories.textual.item import TextualMemoryItem from .memory_filter import MemoryFilter diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index dcd4e1fba..39aa4e9ac 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -88,7 +88,7 @@ def retrieve( logger.info( f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" ) - parsed_goal, query_embedding, context, query = self._parse_task( + parsed_goal, query_embedding, _context, query = self._parse_task( query, info, mode, diff --git a/src/memos/memos_tools/dinding_report_bot.py b/src/memos/memos_tools/dinding_report_bot.py index d8b762855..5bbd1f4cd 100644 --- a/src/memos/memos_tools/dinding_report_bot.py +++ b/src/memos/memos_tools/dinding_report_bot.py @@ -146,7 +146,7 @@ def _text_wh(draw: ImageDraw.ImageDraw, text: str, font: ImageFont.ImageFont): # center alignment title_w, title_h = _text_wh(draw, title, font_title) - sub_w, sub_h = _text_wh(draw, subtitle, font_sub) + sub_w, _sub_h = _text_wh(draw, subtitle, font_sub) title_x = (w - title_w) // 2 title_y = h // 2 - title_h diff --git a/src/memos/types/openai_chat_completion_types/__init__.py b/src/memos/types/openai_chat_completion_types/__init__.py index 4a08a9f24..025e75360 100644 --- a/src/memos/types/openai_chat_completion_types/__init__.py +++ b/src/memos/types/openai_chat_completion_types/__init__.py @@ -1,4 +1,4 @@ -# ruff: noqa: F403, F401 +# ruff: noqa: F403 from .chat_completion_assistant_message_param import * from .chat_completion_content_part_image_param import * diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py index 3c5638788..f28796c2d 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py @@ -1,4 +1,4 @@ -# ruff: noqa: TC001, TC003 +# ruff: noqa: TC001 from __future__ import annotations diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py index ea2101229..13a9a89af 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py @@ -1,4 +1,4 @@ -# ruff: noqa: TC001, TC003 +# ruff: noqa: TC001 from __future__ import annotations diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py index 99c845d11..f76f2b862 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py @@ -1,4 +1,4 @@ -# ruff: noqa: TC001, TC003 +# ruff: noqa: TC001 from __future__ import annotations diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py index 8c004f340..b5bee9842 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py @@ -1,4 +1,4 @@ -# ruff: noqa: TC001, TC003 +# ruff: noqa: TC001 from __future__ import annotations diff --git a/tests/mem_reader/test_coarse_memory_type.py b/tests/mem_reader/test_coarse_memory_type.py index bd90d6a69..1cbb6b2eb 100644 --- a/tests/mem_reader/test_coarse_memory_type.py +++ b/tests/mem_reader/test_coarse_memory_type.py @@ -64,7 +64,7 @@ def test_chat_passthrough(): def test_doc_local_file(): - local_path, content = create_temp_file("test local file content") + local_path, _content = create_temp_file("test local file content") result = coerce_scene_data([local_path], "doc") filename = os.path.basename(local_path) @@ -108,7 +108,7 @@ def test_doc_plain_text(): def test_doc_mixed(): - local_path, content = create_temp_file("local file content") + local_path, _content = create_temp_file("local file content") url = "https://example.com/x.pdf" plain = "hello world" From f97eeae3dd97233071a9478308a189eec4a3a7d5 Mon Sep 17 00:00:00 2001 From: "glin1993@outlook.com" <> Date: Fri, 6 Feb 2026 14:46:28 +0800 Subject: [PATCH 13/14] perf(search): include embeddings for mmr --- src/memos/mem_scheduler/optimized_scheduler.py | 1 + src/memos/multi_mem_cube/single_cube.py | 1 + src/memos/search/search_service.py | 2 ++ 3 files changed, 4 insertions(+) diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index e535d6f73..d6b566dfe 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -111,6 +111,7 @@ def search_memories( search_req=search_req, user_context=user_context, mode=mode, + include_embedding=(search_req.dedup == "mmr"), ) def mix_search_memories( diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index d75f8576e..a547fd296 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -463,6 +463,7 @@ def _fast_search( search_req=search_req, user_context=user_context, mode=SearchMode.FAST, + include_embedding=(search_req.dedup == "mmr"), ) formatted_memories = [ diff --git a/src/memos/search/search_service.py b/src/memos/search/search_service.py index 79c9a43e5..6d57e3605 100644 --- a/src/memos/search/search_service.py +++ b/src/memos/search/search_service.py @@ -41,6 +41,7 @@ def search_text_memories( search_req: APISearchRequest, user_context: UserContext, mode: SearchMode, + include_embedding: bool | None = None, ) -> list[Any]: """ Shared text-memory search logic for API and scheduler paths. @@ -62,4 +63,5 @@ def search_text_memories( include_skill_memory=search_req.include_skill_memory, skill_mem_top_k=search_req.skill_mem_top_k, dedup=search_req.dedup, + include_embedding=include_embedding, ) From 428a01c380e56c1f386c50c909112f297892e9da Mon Sep 17 00:00:00 2001 From: chentang Date: Fri, 6 Feb 2026 15:12:32 +0800 Subject: [PATCH 14/14] fix: Pass user_context in mem_read and pref_add handlers - Update MemReadMessageHandler to extract user_context from message and pass it to _process_memories_with_reader and transfer_mem. - Update PrefAddMessageHandler to extract user_context from message and pass it to pref_mem.add. - This ensures user context information is available during memory reading and preference adding operations. --- .../task_schedule_modules/handlers/mem_read_handler.py | 5 +++++ .../task_schedule_modules/handlers/pref_add_handler.py | 2 ++ 2 files changed, 7 insertions(+) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index 3bbed09e3..6bbbd4335 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + from memos.types.general_types import UserContext class MemReadMessageHandler(BaseSchedulerHandler): @@ -64,6 +65,7 @@ def process_message(self, message: ScheduleMessageItem): user_name = message.user_name info = message.info or {} chat_history = message.chat_history + user_context = message.user_context mem_ids = json.loads(content) if isinstance(content, str) else content if not mem_ids: @@ -91,6 +93,7 @@ def process_message(self, message: ScheduleMessageItem): task_id=message.task_id, info=info, chat_history=chat_history, + user_context=user_context, ) logger.info( @@ -113,6 +116,7 @@ def _process_memories_with_reader( task_id: str | None = None, info: dict | None = None, chat_history: list | None = None, + user_context: UserContext | None = None, ) -> None: logger.info( "[DIAGNOSTIC] mem_read_handler._process_memories_with_reader called. mem_ids: %s, user_id: %s, mem_cube_id: %s, task_id: %s", @@ -165,6 +169,7 @@ def _process_memories_with_reader( custom_tags=custom_tags, user_name=user_name, chat_history=chat_history, + user_context=user_context, ) except Exception as e: logger.warning("%s: Fail to transfer mem: %s", e, memory_items) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py index 1d03e0476..b7dd2fa4c 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py @@ -50,6 +50,7 @@ def process_message(self, message: ScheduleMessageItem): mem_cube_id = message.mem_cube_id content = message.content messages_list = json.loads(content) + user_context = message.user_context info = message.info or {} logger.info("Processing pref_add for user_id=%s, mem_cube_id=%s", user_id, mem_cube_id) @@ -78,6 +79,7 @@ def process_message(self, message: ScheduleMessageItem): "session_id": session_id, "mem_cube_id": mem_cube_id, }, + user_context=user_context, ) pref_ids = pref_mem.add(pref_memories)