From c3472c41ec0c6f91bd65f5af04c9ef53d4c7503d Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 25 Dec 2025 15:45:16 +0800 Subject: [PATCH 1/2] fix: fix bugs when running local queue for memos --- .../task_schedule_modules/dispatcher.py | 2 - .../task_schedule_modules/local_queue.py | 139 +++++++++++++++--- .../task_schedule_modules/redis_queue.py | 95 +++++++----- .../task_schedule_modules/task_queue.py | 23 +-- .../mem_scheduler/utils/status_tracker.py | 20 ++- 5 files changed, 200 insertions(+), 79 deletions(-) diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index 35df3db64..cdd491183 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -129,8 +129,6 @@ def status_tracker(self) -> TaskStatusTracker | None: try: self._status_tracker = TaskStatusTracker(self.redis) # Propagate to submodules when created lazily - if self.dispatcher: - self.dispatcher.status_tracker = self._status_tracker if self.memos_message_queue: self.memos_message_queue.set_status_tracker(self._status_tracker) except Exception as e: diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index 69cfc0af9..72dccfd98 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -4,9 +4,18 @@ the local memos_message_queue functionality in BaseScheduler. """ +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + from memos.log import get_logger from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import DEFAULT_STREAM_KEY_PREFIX +from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator +from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule @@ -16,26 +25,38 @@ class SchedulerLocalQueue(RedisSchedulerModule): def __init__( self, - maxsize: int, + maxsize: int = 0, + stream_key_prefix: str = DEFAULT_STREAM_KEY_PREFIX, + orchestrator: SchedulerOrchestrator | None = None, + status_tracker: TaskStatusTracker | None = None, ): """ Initialize the SchedulerLocalQueue with a maximum queue size limit. + Arguments match SchedulerRedisQueue for compatibility. Args: - maxsize (int): Maximum number of messages allowed - in each individual queue. - If exceeded, subsequent puts will block - or raise an exception based on `block` parameter. + maxsize (int): Maximum number of messages allowed in each individual queue. + stream_key_prefix (str): Prefix for stream keys (simulated). + orchestrator: SchedulerOrchestrator instance (ignored). + status_tracker: TaskStatusTracker instance (ignored). """ super().__init__() - self.stream_key_prefix = "local_queue" + self.stream_key_prefix = stream_key_prefix or "local_queue" self.max_internal_message_queue_size = maxsize + # Dictionary to hold per-stream queues: key = stream_key, value = Queue[ScheduleMessageItem] self.queue_streams: dict[str, Queue[ScheduleMessageItem]] = {} + + self.orchestrator = orchestrator + self.status_tracker = status_tracker + + self._is_listening = False + self._message_handler: Callable[[ScheduleMessageItem], None] | None = None + logger.info( - f"SchedulerLocalQueue initialized with max_internal_message_queue_size={maxsize}" + f"SchedulerLocalQueue initialized with max_internal_message_queue_size={self.max_internal_message_queue_size}" ) def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str: @@ -62,7 +83,7 @@ def put( Exception: Any underlying error during queue.put() operation. """ stream_key = self.get_stream_key( - user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.task_label + user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.label ) message.stream_key = stream_key @@ -86,7 +107,7 @@ def get( stream_key: str, block: bool = True, timeout: float | None = None, - batch_size: int | None = None, + batch_size: int | None = 1, ) -> list[ScheduleMessageItem]: if batch_size is not None and batch_size <= 0: logger.warning( @@ -99,47 +120,85 @@ def get( logger.error(f"Stream {stream_key} does not exist when trying to get messages.") return [] + # Ensure we always request a batch so we get a list back + effective_batch_size = batch_size if batch_size is not None else 1 + # Note: Assumes custom Queue implementation supports batch_size parameter res = self.queue_streams[stream_key].get( - block=block, timeout=timeout, batch_size=batch_size + block=block, timeout=timeout, batch_size=effective_batch_size ) logger.debug( f"Retrieved {len(res)} messages from queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}" ) return res - def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]: + def get_nowait(self, stream_key: str, batch_size: int | None = 1) -> list[ScheduleMessageItem]: """ - Non-blocking version of get(). Equivalent to get(block=False, batch_size=batch_size). + Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size). Returns immediately with available messages or an empty list if queue is empty. Args: + stream_key (str): The stream/queue identifier. batch_size (int | None): Number of messages to retrieve in a batch. If None, retrieves one message. Returns: List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty. """ - logger.debug(f"get_nowait() called with batch_size: {batch_size}") - return self.get(block=False, batch_size=batch_size) + logger.debug(f"get_nowait() called for {stream_key} with batch_size: {batch_size}") + return self.get(stream_key=stream_key, block=False, batch_size=batch_size) + + def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: + """ + Get messages from all streams in round-robin or sequential fashion. + Equivalent to SchedulerRedisQueue.get_messages. + """ + messages = [] + # Snapshot keys to avoid runtime modification issues + stream_keys = list(self.queue_streams.keys()) + + # Simple strategy: try to get up to batch_size messages across all streams + # We can just iterate and collect. + + # Calculate how many to get per stream to be fair? + # Or just greedy? Redis implementation uses a complex logic. + # For local, let's keep it simple: just iterate and take what's available (non-blocking) + + for stream_key in stream_keys: + if len(messages) >= batch_size: + break + + needed = batch_size - len(messages) + # Use get_nowait to avoid blocking + fetched = self.get_nowait(stream_key=stream_key, batch_size=needed) + messages.extend(fetched) + + return messages def qsize(self) -> dict: """ Return the current size of all internal queues as a dictionary. Each key is the stream name, and each value is the number of messages in that queue. + Also includes 'total_size'. Returns: Dict[str, int]: Mapping from stream name to current queue size. """ sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()} + total_size = sum(sizes.values()) + sizes["total_size"] = total_size logger.debug(f"Current queue sizes: {sizes}") return sizes - def clear(self) -> None: - for queue in self.queue_streams.values(): - queue.clear() + def clear(self, stream_key: str | None = None) -> None: + if stream_key: + if stream_key in self.queue_streams: + self.queue_streams[stream_key].clear() + else: + for queue in self.queue_streams.values(): + queue.clear() @property def unfinished_tasks(self) -> int: @@ -151,6 +210,50 @@ def unfinished_tasks(self) -> int: Returns: int: Sum of all message counts in all internal queues. """ - total = sum(self.qsize().values()) + # qsize() now includes "total_size", so we need to be careful not to double count if we use qsize() values + # But qsize() implementation above sums values from queue_streams, then adds total_size. + # So sum(self.queue_streams.values().qsize()) is safer. + total = sum(queue.qsize() for queue in self.queue_streams.values()) logger.debug(f"Total unfinished tasks across all queues: {total}") return total + + def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: + """ + Return list of active stream keys. + """ + prefix = stream_key_prefix or self.stream_key_prefix + return [k for k in self.queue_streams if k.startswith(prefix)] + + def size(self) -> int: + """ + Total size of all queues. + """ + return sum(q.qsize() for q in self.queue_streams.values()) + + def empty(self) -> bool: + """ + Check if all queues are empty. + """ + return self.size() == 0 + + def full(self) -> bool: + """ + Check if any queue is full (approximate). + """ + if self.max_internal_message_queue_size <= 0: + return False + return any( + q.qsize() >= self.max_internal_message_queue_size for q in self.queue_streams.values() + ) + + def ack_message( + self, + user_id: str, + mem_cube_id: str, + task_label: str, + redis_message_id, + message: ScheduleMessageItem | None, + ) -> None: + """ + Acknowledge a message (no-op for local queue as messages are popped immediately). + """ diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 1c57f18f0..c27eb5f50 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -5,7 +5,6 @@ the local memos_message_queue functionality in BaseScheduler. """ -import contextlib import os import re import threading @@ -201,6 +200,20 @@ def _refresh_stream_keys( recent_seconds=DEFAULT_STREAM_RECENT_ACTIVE_SECONDS, now_sec=now_sec, ) + + # Ensure consumer groups for newly discovered active streams + with self._stream_keys_lock: + # Identify keys we haven't seen yet + new_streams = [k for k in active_stream_keys if k not in self.seen_streams] + + # Create groups outside the lock to avoid blocking + for key in new_streams: + self._ensure_consumer_group(key) + + if new_streams: + with self._stream_keys_lock: + self.seen_streams.update(new_streams) + deleted_count = self._delete_streams(keys_to_delete) self._update_stream_cache_with_log( stream_key_prefix=stream_key_prefix, @@ -560,10 +573,7 @@ def _read_new_messages_batch( return {} # Pre-ensure consumer groups to avoid NOGROUP during batch reads - for stream_key in stream_keys: - with contextlib.suppress(Exception): - self._ensure_consumer_group(stream_key=stream_key) - + # (Optimization: rely on put() and _refresh_stream_keys() to ensure groups) pipe = self._redis_conn.pipeline(transaction=False) for stream_key in stream_keys: pipe.xreadgroup( @@ -676,13 +686,6 @@ def _batch_claim_pending_messages( Returns: A list of (stream_key, claimed_entries) pairs for all successful claims. """ - if not self._redis_conn or not claims_spec: - return [] - - # Ensure consumer groups exist to avoid NOGROUP errors during batch claim - for stream_key, _need_count, _label in claims_spec: - with contextlib.suppress(Exception): - self._ensure_consumer_group(stream_key=stream_key) pipe = self._redis_conn.pipeline(transaction=False) for stream_key, need_count, label in claims_spec: @@ -696,26 +699,42 @@ def _batch_claim_pending_messages( justid=False, ) - results = [] try: - results = pipe.execute() - except Exception: - # Fallback: attempt sequential xautoclaim for robustness - for stream_key, need_count, label in claims_spec: - try: - self._ensure_consumer_group(stream_key=stream_key) - res = self._redis_conn.xautoclaim( - name=stream_key, - groupname=self.consumer_group, - consumername=self.consumer_name, - min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), - start_id="0-0", - count=need_count, - justid=False, - ) - results.append(res) - except Exception: - continue + # Execute with raise_on_error=False so we get exceptions in the results list + # instead of aborting the whole batch. + results = pipe.execute(raise_on_error=False) + except Exception as e: + logger.error(f"Pipeline execution critical failure: {e}") + results = [e] * len(claims_spec) + + # Handle individual failures (e.g. NOGROUP) by retrying just that stream + final_results = [] + for i, res in enumerate(results): + if isinstance(res, Exception): + err_msg = str(res).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + stream_key, need_count, label = claims_spec[i] + try: + self._ensure_consumer_group(stream_key=stream_key) + retry_res = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), + start_id="0-0", + count=need_count, + justid=False, + ) + final_results.append(retry_res) + except Exception as retry_err: + logger.warning(f"Retry xautoclaim failed for {stream_key}: {retry_err}") + final_results.append(None) + else: + final_results.append(None) + else: + final_results.append(res) + + results = final_results claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = [] for (stream_key, _need_count, _label), claimed_result in zip( @@ -1159,10 +1178,14 @@ def _delete_streams(self, keys_to_delete: list[str]) -> int: del_pipe.delete(key) del_pipe.execute() deleted_count = len(keys_to_delete) - # Clean up empty-tracking state for deleted keys + # Clean up empty-tracking state and seen_streams for deleted keys with self._empty_stream_seen_lock: for key in keys_to_delete: self._empty_stream_seen_times.pop(key, None) + + with self._stream_keys_lock: + for key in keys_to_delete: + self.seen_streams.discard(key) except Exception: for key in keys_to_delete: try: @@ -1170,6 +1193,8 @@ def _delete_streams(self, keys_to_delete: list[str]) -> int: deleted_count += 1 with self._empty_stream_seen_lock: self._empty_stream_seen_times.pop(key, None) + with self._stream_keys_lock: + self.seen_streams.discard(key) except Exception: pass return deleted_count @@ -1190,8 +1215,6 @@ def _update_stream_cache_with_log( self._stream_keys_last_refresh = time.time() cache_count = len(self._stream_keys_cache) logger.info( - f"[REDIS_QUEUE] Stream keys refresh: prefix='{stream_key_prefix}', " - f"total={len(candidate_keys)}, active={len(active_stream_keys)}, cached={cache_count}, " - f"active_threshold_sec={int(active_threshold_sec)}, deleted={deleted_count}, " - f"inactive_threshold_sec={int(DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS)}" + f"Refreshed stream keys cache: {cache_count} active keys, " + f"{deleted_count} deleted, {len(candidate_keys)} candidates examined." ) diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index c20243242..b49db2b36 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -153,28 +153,7 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt ) def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: - if isinstance(self.memos_message_queue, SchedulerRedisQueue): - return self.memos_message_queue.get_messages(batch_size=batch_size) - stream_keys = self.get_stream_keys() - - if len(stream_keys) == 0: - return [] - - messages: list[ScheduleMessageItem] = [] - - for stream_key in stream_keys: - fetched = self.memos_message_queue.get( - stream_key=stream_key, - block=False, - batch_size=batch_size, - ) - - messages.extend(fetched) - if len(messages) > 0: - logger.debug( - f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}" - ) - return messages + return self.memos_message_queue.get_messages(batch_size=batch_size) def clear(self): self.memos_message_queue.clear() diff --git a/src/memos/mem_scheduler/utils/status_tracker.py b/src/memos/mem_scheduler/utils/status_tracker.py index d8c8d2cee..c42ef0d0f 100644 --- a/src/memos/mem_scheduler/utils/status_tracker.py +++ b/src/memos/mem_scheduler/utils/status_tracker.py @@ -13,10 +13,13 @@ class TaskStatusTracker: @require_python_package(import_name="redis", install_command="pip install redis") - def __init__(self, redis_client: "redis.Redis"): + def __init__(self, redis_client: "redis.Redis | None"): self.redis = redis_client def _get_key(self, user_id: str) -> str: + if not self.redis: + return + return f"memos:task_meta:{user_id}" def _get_task_items_key(self, user_id: str, task_id: str) -> str: @@ -61,6 +64,9 @@ def task_submitted( self.redis.expire(key, timedelta(days=7)) def task_started(self, task_id: str, user_id: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -77,6 +83,9 @@ def task_started(self, task_id: str, user_id: str): self.redis.expire(key, timedelta(days=7)) def task_completed(self, task_id: str, user_id: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -108,11 +117,17 @@ def task_failed(self, task_id: str, user_id: str, error_message: str): self.redis.expire(key, timedelta(days=7)) def get_task_status(self, task_id: str, user_id: str) -> dict | None: + if not self.redis: + return None + key = self._get_key(user_id) data = self.redis.hget(key, task_id) return json.loads(data) if data else None def get_all_tasks_for_user(self, user_id: str) -> dict[str, dict]: + if not self.redis: + return {} + key = self._get_key(user_id) all_tasks = self.redis.hgetall(key) return {tid: json.loads(t_data) for tid, t_data in all_tasks.items()} @@ -180,6 +195,9 @@ def get_all_tasks_global(self) -> dict[str, dict[str, dict]]: Returns: dict: {user_id: {task_id: task_data, ...}, ...} """ + if not self.redis: + return {} + all_users_tasks = {} cursor: int | str = 0 while True: From 4ef6fec2ea98bf61c90ec1993c17c851a73aec65 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 25 Dec 2025 16:30:02 +0800 Subject: [PATCH 2/2] fix: remove an unnecessary function --- .../task_schedule_modules/local_queue.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index 72dccfd98..791cedf41 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -245,15 +245,3 @@ def full(self) -> bool: return any( q.qsize() >= self.max_internal_message_queue_size for q in self.queue_streams.values() ) - - def ack_message( - self, - user_id: str, - mem_cube_id: str, - task_label: str, - redis_message_id, - message: ScheduleMessageItem | None, - ) -> None: - """ - Acknowledge a message (no-op for local queue as messages are popped immediately). - """