Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,6 @@ def status_tracker(self) -> TaskStatusTracker | None:
try:
self._status_tracker = TaskStatusTracker(self.redis)
# Propagate to submodules when created lazily
if self.dispatcher:
self.dispatcher.status_tracker = self._status_tracker
if self.memos_message_queue:
self.memos_message_queue.set_status_tracker(self._status_tracker)
except Exception as e:
Expand Down
127 changes: 109 additions & 18 deletions src/memos/mem_scheduler/task_schedule_modules/local_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,18 @@
the local memos_message_queue functionality in BaseScheduler.
"""

from typing import TYPE_CHECKING


if TYPE_CHECKING:
from collections.abc import Callable

from memos.log import get_logger
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.task_schemas import DEFAULT_STREAM_KEY_PREFIX
from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule


Expand All @@ -16,26 +25,38 @@
class SchedulerLocalQueue(RedisSchedulerModule):
def __init__(
self,
maxsize: int,
maxsize: int = 0,
stream_key_prefix: str = DEFAULT_STREAM_KEY_PREFIX,
orchestrator: SchedulerOrchestrator | None = None,
status_tracker: TaskStatusTracker | None = None,
):
"""
Initialize the SchedulerLocalQueue with a maximum queue size limit.
Arguments match SchedulerRedisQueue for compatibility.

Args:
maxsize (int): Maximum number of messages allowed
in each individual queue.
If exceeded, subsequent puts will block
or raise an exception based on `block` parameter.
maxsize (int): Maximum number of messages allowed in each individual queue.
stream_key_prefix (str): Prefix for stream keys (simulated).
orchestrator: SchedulerOrchestrator instance (ignored).
status_tracker: TaskStatusTracker instance (ignored).
"""
super().__init__()

self.stream_key_prefix = "local_queue"
self.stream_key_prefix = stream_key_prefix or "local_queue"

self.max_internal_message_queue_size = maxsize

# Dictionary to hold per-stream queues: key = stream_key, value = Queue[ScheduleMessageItem]
self.queue_streams: dict[str, Queue[ScheduleMessageItem]] = {}

self.orchestrator = orchestrator
self.status_tracker = status_tracker

self._is_listening = False
self._message_handler: Callable[[ScheduleMessageItem], None] | None = None

logger.info(
f"SchedulerLocalQueue initialized with max_internal_message_queue_size={maxsize}"
f"SchedulerLocalQueue initialized with max_internal_message_queue_size={self.max_internal_message_queue_size}"
)

def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str:
Expand All @@ -62,7 +83,7 @@ def put(
Exception: Any underlying error during queue.put() operation.
"""
stream_key = self.get_stream_key(
user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.task_label
user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.label
)

message.stream_key = stream_key
Expand All @@ -86,7 +107,7 @@ def get(
stream_key: str,
block: bool = True,
timeout: float | None = None,
batch_size: int | None = None,
batch_size: int | None = 1,
) -> list[ScheduleMessageItem]:
if batch_size is not None and batch_size <= 0:
logger.warning(
Expand All @@ -99,47 +120,85 @@ def get(
logger.error(f"Stream {stream_key} does not exist when trying to get messages.")
return []

# Ensure we always request a batch so we get a list back
effective_batch_size = batch_size if batch_size is not None else 1

# Note: Assumes custom Queue implementation supports batch_size parameter
res = self.queue_streams[stream_key].get(
block=block, timeout=timeout, batch_size=batch_size
block=block, timeout=timeout, batch_size=effective_batch_size
)
logger.debug(
f"Retrieved {len(res)} messages from queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}"
)
return res

def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]:
def get_nowait(self, stream_key: str, batch_size: int | None = 1) -> list[ScheduleMessageItem]:
"""
Non-blocking version of get(). Equivalent to get(block=False, batch_size=batch_size).
Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size).

Returns immediately with available messages or an empty list if queue is empty.

Args:
stream_key (str): The stream/queue identifier.
batch_size (int | None): Number of messages to retrieve in a batch.
If None, retrieves one message.

Returns:
List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty.
"""
logger.debug(f"get_nowait() called with batch_size: {batch_size}")
return self.get(block=False, batch_size=batch_size)
logger.debug(f"get_nowait() called for {stream_key} with batch_size: {batch_size}")
return self.get(stream_key=stream_key, block=False, batch_size=batch_size)

def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]:
"""
Get messages from all streams in round-robin or sequential fashion.
Equivalent to SchedulerRedisQueue.get_messages.
"""
messages = []
# Snapshot keys to avoid runtime modification issues
stream_keys = list(self.queue_streams.keys())

# Simple strategy: try to get up to batch_size messages across all streams
# We can just iterate and collect.

# Calculate how many to get per stream to be fair?
# Or just greedy? Redis implementation uses a complex logic.
# For local, let's keep it simple: just iterate and take what's available (non-blocking)

for stream_key in stream_keys:
if len(messages) >= batch_size:
break

needed = batch_size - len(messages)
# Use get_nowait to avoid blocking
fetched = self.get_nowait(stream_key=stream_key, batch_size=needed)
messages.extend(fetched)

return messages

def qsize(self) -> dict:
"""
Return the current size of all internal queues as a dictionary.

Each key is the stream name, and each value is the number of messages in that queue.
Also includes 'total_size'.

Returns:
Dict[str, int]: Mapping from stream name to current queue size.
"""
sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()}
total_size = sum(sizes.values())
sizes["total_size"] = total_size
logger.debug(f"Current queue sizes: {sizes}")
return sizes

def clear(self) -> None:
for queue in self.queue_streams.values():
queue.clear()
def clear(self, stream_key: str | None = None) -> None:
if stream_key:
if stream_key in self.queue_streams:
self.queue_streams[stream_key].clear()
else:
for queue in self.queue_streams.values():
queue.clear()

@property
def unfinished_tasks(self) -> int:
Expand All @@ -151,6 +210,38 @@ def unfinished_tasks(self) -> int:
Returns:
int: Sum of all message counts in all internal queues.
"""
total = sum(self.qsize().values())
# qsize() now includes "total_size", so we need to be careful not to double count if we use qsize() values
# But qsize() implementation above sums values from queue_streams, then adds total_size.
# So sum(self.queue_streams.values().qsize()) is safer.
total = sum(queue.qsize() for queue in self.queue_streams.values())
logger.debug(f"Total unfinished tasks across all queues: {total}")
return total

def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]:
"""
Return list of active stream keys.
"""
prefix = stream_key_prefix or self.stream_key_prefix
return [k for k in self.queue_streams if k.startswith(prefix)]

def size(self) -> int:
"""
Total size of all queues.
"""
return sum(q.qsize() for q in self.queue_streams.values())

def empty(self) -> bool:
"""
Check if all queues are empty.
"""
return self.size() == 0

def full(self) -> bool:
"""
Check if any queue is full (approximate).
"""
if self.max_internal_message_queue_size <= 0:
return False
return any(
q.qsize() >= self.max_internal_message_queue_size for q in self.queue_streams.values()
)
95 changes: 59 additions & 36 deletions src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
the local memos_message_queue functionality in BaseScheduler.
"""

import contextlib
import os
import re
import threading
Expand Down Expand Up @@ -201,6 +200,20 @@ def _refresh_stream_keys(
recent_seconds=DEFAULT_STREAM_RECENT_ACTIVE_SECONDS,
now_sec=now_sec,
)

# Ensure consumer groups for newly discovered active streams
with self._stream_keys_lock:
# Identify keys we haven't seen yet
new_streams = [k for k in active_stream_keys if k not in self.seen_streams]

# Create groups outside the lock to avoid blocking
for key in new_streams:
self._ensure_consumer_group(key)

if new_streams:
with self._stream_keys_lock:
self.seen_streams.update(new_streams)

deleted_count = self._delete_streams(keys_to_delete)
self._update_stream_cache_with_log(
stream_key_prefix=stream_key_prefix,
Expand Down Expand Up @@ -560,10 +573,7 @@ def _read_new_messages_batch(
return {}

# Pre-ensure consumer groups to avoid NOGROUP during batch reads
for stream_key in stream_keys:
with contextlib.suppress(Exception):
self._ensure_consumer_group(stream_key=stream_key)

# (Optimization: rely on put() and _refresh_stream_keys() to ensure groups)
pipe = self._redis_conn.pipeline(transaction=False)
for stream_key in stream_keys:
pipe.xreadgroup(
Expand Down Expand Up @@ -676,13 +686,6 @@ def _batch_claim_pending_messages(
Returns:
A list of (stream_key, claimed_entries) pairs for all successful claims.
"""
if not self._redis_conn or not claims_spec:
return []

# Ensure consumer groups exist to avoid NOGROUP errors during batch claim
for stream_key, _need_count, _label in claims_spec:
with contextlib.suppress(Exception):
self._ensure_consumer_group(stream_key=stream_key)

pipe = self._redis_conn.pipeline(transaction=False)
for stream_key, need_count, label in claims_spec:
Expand All @@ -696,26 +699,42 @@ def _batch_claim_pending_messages(
justid=False,
)

results = []
try:
results = pipe.execute()
except Exception:
# Fallback: attempt sequential xautoclaim for robustness
for stream_key, need_count, label in claims_spec:
try:
self._ensure_consumer_group(stream_key=stream_key)
res = self._redis_conn.xautoclaim(
name=stream_key,
groupname=self.consumer_group,
consumername=self.consumer_name,
min_idle_time=self.orchestrator.get_task_idle_min(task_label=label),
start_id="0-0",
count=need_count,
justid=False,
)
results.append(res)
except Exception:
continue
# Execute with raise_on_error=False so we get exceptions in the results list
# instead of aborting the whole batch.
results = pipe.execute(raise_on_error=False)
except Exception as e:
logger.error(f"Pipeline execution critical failure: {e}")
results = [e] * len(claims_spec)

# Handle individual failures (e.g. NOGROUP) by retrying just that stream
final_results = []
for i, res in enumerate(results):
if isinstance(res, Exception):
err_msg = str(res).lower()
if "nogroup" in err_msg or "no such key" in err_msg:
stream_key, need_count, label = claims_spec[i]
try:
self._ensure_consumer_group(stream_key=stream_key)
retry_res = self._redis_conn.xautoclaim(
name=stream_key,
groupname=self.consumer_group,
consumername=self.consumer_name,
min_idle_time=self.orchestrator.get_task_idle_min(task_label=label),
start_id="0-0",
count=need_count,
justid=False,
)
final_results.append(retry_res)
except Exception as retry_err:
logger.warning(f"Retry xautoclaim failed for {stream_key}: {retry_err}")
final_results.append(None)
else:
final_results.append(None)
else:
final_results.append(res)

results = final_results

claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = []
for (stream_key, _need_count, _label), claimed_result in zip(
Expand Down Expand Up @@ -1159,17 +1178,23 @@ def _delete_streams(self, keys_to_delete: list[str]) -> int:
del_pipe.delete(key)
del_pipe.execute()
deleted_count = len(keys_to_delete)
# Clean up empty-tracking state for deleted keys
# Clean up empty-tracking state and seen_streams for deleted keys
with self._empty_stream_seen_lock:
for key in keys_to_delete:
self._empty_stream_seen_times.pop(key, None)

with self._stream_keys_lock:
for key in keys_to_delete:
self.seen_streams.discard(key)
except Exception:
for key in keys_to_delete:
try:
self._redis_conn.delete(key)
deleted_count += 1
with self._empty_stream_seen_lock:
self._empty_stream_seen_times.pop(key, None)
with self._stream_keys_lock:
self.seen_streams.discard(key)
except Exception:
pass
return deleted_count
Expand All @@ -1190,8 +1215,6 @@ def _update_stream_cache_with_log(
self._stream_keys_last_refresh = time.time()
cache_count = len(self._stream_keys_cache)
logger.info(
f"[REDIS_QUEUE] Stream keys refresh: prefix='{stream_key_prefix}', "
f"total={len(candidate_keys)}, active={len(active_stream_keys)}, cached={cache_count}, "
f"active_threshold_sec={int(active_threshold_sec)}, deleted={deleted_count}, "
f"inactive_threshold_sec={int(DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS)}"
f"Refreshed stream keys cache: {cache_count} active keys, "
f"{deleted_count} deleted, {len(candidate_keys)} candidates examined."
)
Loading