Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/mem_agent/deepsearch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def factory_initialization() -> tuple[DeepSearchMemAgent, dict[str, Any]]:


def main():
agent_factory, components_factory = factory_initialization()
agent_factory, _components_factory = factory_initialization()
results = agent_factory.run(
"Caroline met up with friends, family, and mentors in early July 2023.",
user_id="locomo_exp_user_0_speaker_b_ct-1118",
Expand Down
27 changes: 8 additions & 19 deletions examples/mem_scheduler/memos_w_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def init_task():
return conversations, questions


working_memories = []
default_mem_update_handler = mem_scheduler.handlers.get(MEM_UPDATE_TASK_LABEL)
if default_mem_update_handler is None:
logger.warning("Default MEM_UPDATE handler not found; custom handler will be a no-op.")


# Define custom query handler function
Expand All @@ -100,24 +102,11 @@ def custom_query_handler(messages: list[ScheduleMessageItem]):

# Define custom memory update handler function
def custom_mem_update_handler(messages: list[ScheduleMessageItem]):
global working_memories
search_args = {}
top_k = 2
for msg in messages:
# Search for memories relevant to the current content in text memory (return top_k=2)
results = mem_scheduler.retriever.search(
query=msg.content,
user_id=msg.user_id,
mem_cube_id=msg.mem_cube_id,
mem_cube=mem_scheduler.current_mem_cube,
top_k=top_k,
method=mem_scheduler.search_method,
search_args=search_args,
)
working_memories.extend(results)
working_memories = working_memories[-5:]
for mem in results:
print(f"\n[scheduler] Retrieved memory: {mem.memory}")
if default_mem_update_handler is None:
logger.error("Default MEM_UPDATE handler missing; cannot process messages.")
return
# Delegate to the built-in handler to keep behavior aligned with scheduler refactor.
default_mem_update_handler(messages)


async def run_with_scheduler():
Expand Down
9 changes: 5 additions & 4 deletions examples/mem_scheduler/try_schedule_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def add_msgs(
label=MEM_UPDATE_TASK_LABEL,
content=query,
)
# Run one session turn manually to get search candidates
mem_scheduler._memory_update_consumer(
messages=[message],
)
# Run one session turn manually via registered handler (public surface)
handler = mem_scheduler.handlers.get(MEM_UPDATE_TASK_LABEL)
if handler is None:
raise RuntimeError("MEM_UPDATE handler not registered on mem_scheduler.")
handler([message])
2 changes: 1 addition & 1 deletion src/memos/mem_feedback/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ def process_feedback(
info,
**kwargs,
)
done, pending = concurrent.futures.wait([answer_future, core_future], timeout=30)
_done, pending = concurrent.futures.wait([answer_future, core_future], timeout=30)
for fut in pending:
fut.cancel()
try:
Expand Down
10 changes: 10 additions & 0 deletions src/memos/mem_scheduler/base_mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .memory_ops import BaseSchedulerMemoryMixin
from .queue_ops import BaseSchedulerQueueMixin
from .web_log_ops import BaseSchedulerWebLogMixin


__all__ = [
"BaseSchedulerMemoryMixin",
"BaseSchedulerQueueMixin",
"BaseSchedulerWebLogMixin",
]
227 changes: 227 additions & 0 deletions src/memos/mem_scheduler/base_mixins/memory_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from memos.log import get_logger
from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
from memos.mem_scheduler.utils.filter_utils import transform_name_to_key
from memos.memories.textual.naive import NaiveTextMemory
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory


if TYPE_CHECKING:
from memos.types.general_types import MemCubeID, UserID


logger = get_logger(__name__)


class BaseSchedulerMemoryMixin:
def transform_working_memories_to_monitors(
self, query_keywords, memories: list[TextualMemoryItem]
) -> list[MemoryMonitorItem]:
result = []
mem_length = len(memories)
for idx, mem in enumerate(memories):
text_mem = mem.memory
mem_key = transform_name_to_key(name=text_mem)

keywords_score = 0
if query_keywords and text_mem:
for keyword, count in query_keywords.items():
keyword_count = text_mem.count(keyword)
if keyword_count > 0:
keywords_score += keyword_count * count
logger.debug(
"Matched keyword '%s' %s times, added %s to keywords_score",
keyword,
keyword_count,
keywords_score,
)

sorting_score = mem_length - idx

mem_monitor = MemoryMonitorItem(
memory_text=text_mem,
tree_memory_item=mem,
tree_memory_item_mapping_key=mem_key,
sorting_score=sorting_score,
keywords_score=keywords_score,
recording_count=1,
)
result.append(mem_monitor)

logger.info("Transformed %s memories to monitors", len(result))
return result

def replace_working_memory(
self,
user_id: UserID | str,
mem_cube_id: MemCubeID | str,
mem_cube,
original_memory: list[TextualMemoryItem],
new_memory: list[TextualMemoryItem],
) -> None | list[TextualMemoryItem]:
text_mem_base = mem_cube.text_mem
if isinstance(text_mem_base, TreeTextMemory):
query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id]
query_db_manager.sync_with_orm()

query_history = query_db_manager.obj.get_queries_with_timesort()

original_count = len(original_memory)
filtered_original_memory = []
for origin_mem in original_memory:
if "mode:fast" not in origin_mem.metadata.tags:
filtered_original_memory.append(origin_mem)
else:
logger.debug(
"Filtered out memory - ID: %s, Tags: %s",
getattr(origin_mem, "id", "unknown"),
origin_mem.metadata.tags,
)
filtered_count = original_count - len(filtered_original_memory)
remaining_count = len(filtered_original_memory)

logger.info(
"Filtering complete. Removed %s memories with tag 'mode:fast'. Remaining memories: %s",
filtered_count,
remaining_count,
)
original_memory = filtered_original_memory

memories_with_new_order, rerank_success_flag = (
self.retriever.process_and_rerank_memories(
queries=query_history,
original_memory=original_memory,
new_memory=new_memory,
top_k=self.top_k,
)
)

logger.info("Filtering memories based on query history: %s queries", len(query_history))
filtered_memories, filter_success_flag = self.retriever.filter_unrelated_memories(
query_history=query_history,
memories=memories_with_new_order,
)

if filter_success_flag:
logger.info(
"Memory filtering completed successfully. Filtered from %s to %s memories",
len(memories_with_new_order),
len(filtered_memories),
)
memories_with_new_order = filtered_memories
else:
logger.warning(
"Memory filtering failed - keeping all memories as fallback. Original count: %s",
len(memories_with_new_order),
)

query_keywords = query_db_manager.obj.get_keywords_collections()
logger.info(
"Processing %s memories with %s query keywords",
len(memories_with_new_order),
len(query_keywords),
)
new_working_memory_monitors = self.transform_working_memories_to_monitors(
query_keywords=query_keywords,
memories=memories_with_new_order,
)

if not rerank_success_flag:
for one in new_working_memory_monitors:
one.sorting_score = 0

logger.info("update %s working_memory_monitors", len(new_working_memory_monitors))
self.monitor.update_working_memory_monitors(
new_working_memory_monitors=new_working_memory_monitors,
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
)

mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][
mem_cube_id
].obj.get_sorted_mem_monitors(reverse=True)
new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors]

text_mem_base.replace_working_memory(memories=new_working_memories)

logger.info(
"The working memory has been replaced with %s new memories.",
len(memories_with_new_order),
)
self.log_working_memory_replacement(
original_memory=original_memory,
new_memory=new_working_memories,
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
log_func_callback=self._submit_web_logs,
)
elif isinstance(text_mem_base, NaiveTextMemory):
logger.info(
"NaiveTextMemory: Updating working memory monitors with %s candidates.",
len(new_memory),
)

query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id]
query_db_manager.sync_with_orm()
query_keywords = query_db_manager.obj.get_keywords_collections()

new_working_memory_monitors = self.transform_working_memories_to_monitors(
query_keywords=query_keywords,
memories=new_memory,
)

self.monitor.update_working_memory_monitors(
new_working_memory_monitors=new_working_memory_monitors,
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
)
memories_with_new_order = new_memory
else:
logger.error("memory_base is not supported")
memories_with_new_order = new_memory

return memories_with_new_order

def update_activation_memory(
self,
new_memories: list[str | TextualMemoryItem],
label: str,
user_id: UserID | str,
mem_cube_id: MemCubeID | str,
mem_cube,
) -> None:
if hasattr(self, "activation_memory_manager") and self.activation_memory_manager:
self.activation_memory_manager.update_activation_memory(
new_memories=new_memories,
label=label,
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
)
else:
logger.warning("Activation memory manager not initialized")

def update_activation_memory_periodically(
self,
interval_seconds: int,
label: str,
user_id: UserID | str,
mem_cube_id: MemCubeID | str,
mem_cube,
):
if hasattr(self, "activation_memory_manager") and self.activation_memory_manager:
self.activation_memory_manager.update_activation_memory_periodically(
interval_seconds=interval_seconds,
label=label,
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
)
else:
logger.warning("Activation memory manager not initialized")
Loading