diff --git a/examples/data/config/mem_scheduler/mem_cube_config.yaml b/examples/data/config/mem_scheduler/mem_cube_config.yaml new file mode 100644 index 000000000..398d8dbb3 --- /dev/null +++ b/examples/data/config/mem_scheduler/mem_cube_config.yaml @@ -0,0 +1,21 @@ +user_id: "user_test" +cube_id: "user_test/mem_cube_naive" +text_mem: + backend: "naive_text" + config: + extractor_llm: + backend: "huggingface_singleton" + config: + model_name_or_path: "Qwen/Qwen3-0.6B" + temperature: 0.1 + max_tokens: 1024 +act_mem: + backend: "kv_cache" + config: + memory_filename: "activation_memory.pickle" + extractor_llm: + backend: "huggingface_singleton" + config: + model_name_or_path: "Qwen/Qwen3-0.6B" + temperature: 0.8 + max_tokens: 1024 diff --git a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml index bd9910300..a5e91dc4e 100644 --- a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml +++ b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml @@ -10,16 +10,12 @@ mem_reader: backend: "simple_struct" config: llm: - backend: "openai" + backend: "huggingface_singleton" config: - model_name_or_path: "gpt-4o-mini" - temperature: 0.8 - max_tokens: 4096 - top_p: 0.9 - top_k: 50 + model_name_or_path: "Qwen/Qwen3-1.7B" + temperature: 0.1 remove_think_prefix: true - api_key: "sk-xxxxxx" - api_base: "https://api.openai.com/v1" + max_tokens: 4096 embedder: backend: "ollama" config: diff --git a/examples/mem_scheduler/quick_start_examples.py b/examples/mem_scheduler/quick_start_examples.py new file mode 100644 index 000000000..c71869e76 --- /dev/null +++ b/examples/mem_scheduler/quick_start_examples.py @@ -0,0 +1,253 @@ +import json +import shutil +import sys +import uuid + +from pathlib import Path + +from transformers import DynamicCache + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.configs.memory import MemoryConfigFactory +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + ANSWER_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, + QUERY_TASK_LABEL, +) +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.mem_scheduler.utils.misc_utils import parse_yaml +from memos.memories.activation.item import KVCacheItem +from memos.memories.factory import MemoryFactory + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent +sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory + + +def get_cache_info(cache): + if not cache: + return None + + num_layers = 0 + total_size_bytes = 0 + + if hasattr(cache, "layers"): + num_layers = len(cache.layers) + for layer in cache.layers: + if hasattr(layer, "key_cache") and layer.key_cache is not None: + total_size_bytes += layer.key_cache.nelement() * layer.key_cache.element_size() + if hasattr(layer, "value_cache") and layer.value_cache is not None: + total_size_bytes += layer.value_cache.nelement() * layer.value_cache.element_size() + + if hasattr(layer, "keys") and layer.keys is not None: + total_size_bytes += layer.keys.nelement() * layer.keys.element_size() + if hasattr(layer, "values") and layer.values is not None: + total_size_bytes += layer.values.nelement() * layer.values.element_size() + + elif hasattr(cache, "key_cache") and hasattr(cache, "value_cache"): + num_layers = len(cache.key_cache) + for k, v in zip(cache.key_cache, cache.value_cache, strict=False): + if k is not None: + total_size_bytes += k.nelement() * k.element_size() + if v is not None: + total_size_bytes += v.nelement() * v.element_size() + + return { + "num_layers": num_layers, + "size_bytes": total_size_bytes, + "size_mb": f"{total_size_bytes / (1024 * 1024):.2f} MB", + } + + +def serialize_item(obj): + if isinstance(obj, list): + return [serialize_item(x) for x in obj] + + if isinstance(obj, KVCacheItem): + return { + "id": obj.id, + "metadata": obj.metadata, + "records": obj.records.model_dump() + if hasattr(obj.records, "model_dump") + else obj.records, + "memory": get_cache_info(obj.memory), + } + + if isinstance(obj, DynamicCache): + return get_cache_info(obj) + + return str(obj) + + +def kv_cache_only(): + # 为 KVCacheMemory(HuggingFace 后端)创建配置 + config = MemoryConfigFactory( + backend="kv_cache", + config={ + "extractor_llm": { + "backend": "huggingface", + "config": { + "model_name_or_path": "Qwen/Qwen3-0.6B", + "max_tokens": 32, + "add_generation_prompt": True, + "remove_think_prefix": True, + }, + }, + }, + ) + + # 实例化 KVCacheMemory + kv_mem = MemoryFactory.from_config(config) + + # 提取一个 KVCacheItem(DynamicCache) + prompt = [ + {"role": "user", "content": "What is MemOS?"}, + {"role": "assistant", "content": "MemOS is a memory operating system for LLMs."}, + ] + print("===== Extract KVCacheItem =====") + cache_item = kv_mem.extract(prompt) + print(json.dumps(serialize_item(cache_item), indent=2, default=str)) + + # 将缓存添加到内存中 + kv_mem.add([cache_item]) + print("All caches:") + print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str)) + + # 通过 ID 获取 + retrieved = kv_mem.get(cache_item.id) + print("Retrieved:") + print(json.dumps(serialize_item(retrieved), indent=2, default=str)) + + # 合并缓存 + item2 = kv_mem.extract([{"role": "user", "content": "Tell me a joke."}]) + kv_mem.add([item2]) + merged = kv_mem.get_cache([cache_item.id, item2.id]) + print("Merged cache:") + print(json.dumps(serialize_item(merged), indent=2, default=str)) + + # 删除其中一个 + kv_mem.delete([cache_item.id]) + print("After delete:") + print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str)) + + # 导出和加载缓存 + kv_mem.dump("tmp/kv_mem") + print("Dumped to tmp/kv_mem") + kv_mem.delete_all() + kv_mem.load("tmp/kv_mem") + print("Loaded caches:") + print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str)) + + +def run_scheduler_example(): + # 使用 MemScheduler 加载主 MOS 配置 + config = parse_yaml( + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" + ) + mos_config = MOSConfig(**config) + mos = MOS(mos_config) + + # 创建动态用户 ID + user_id = str(uuid.uuid4()) + mos.create_user(user_id=user_id) + + # 创建 MemCube 配置并导出 + config = GeneralMemCubeConfig.from_yaml_file( + f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" + ) + mem_cube_id = "mem_cube_5" + mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" + + # 若存在旧目录则删除 + if Path(mem_cube_name_or_path).exists(): + shutil.rmtree(mem_cube_name_or_path) + print(f"{mem_cube_name_or_path} is not empty, and has been removed.") + + # 导出新的 MemCube + mem_cube = GeneralMemCube(config) + mem_cube.dump(mem_cube_name_or_path) + + # 为该用户注册 MemCube + mos.register_mem_cube( + mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id + ) + + # Define custom scheduler handlers + def custom_query_handler(messages: list[ScheduleMessageItem]): + for msg in messages: + print(f"\n[scheduler] 用户输入了query: {msg.content}") + # Trigger mem_update manually + new_msg = msg.model_copy(update={"label": MEM_UPDATE_TASK_LABEL}) + mos.mem_scheduler.submit_messages([new_msg]) + + def custom_answer_handler(messages: list[ScheduleMessageItem]): + for msg in messages: + mem_cube = mos.mem_cubes.get(msg.mem_cube_id) + kv_mem = mem_cube.act_mem + for cache_item in kv_mem.get_all(): + print( + f"[scheduler] act memory: {get_cache_info(cache_item.memory)} ({cache_item.records})" + ) + print(f"\n[scheduler] LLM回复了answer:{msg.content}") + + def custom_mem_update_handler(messages: list[ScheduleMessageItem]): + for msg in messages: + mem_cube = mos.mem_cubes.get(msg.mem_cube_id) + kv_mem = mem_cube.act_mem + if mem_cube and mem_cube.text_mem: + results = mem_cube.text_mem.search(msg.content, top_k=3) + for mem in results: + print(f"\n[scheduler] searched memories: {mem.memory}") + + cache_item = kv_mem.extract(mem.memory) + cache_item.records.text_memories = [mem.memory] + cache_item.records.timestamp = get_utc_now() + kv_mem.add([cache_item]) + + # Register custom handlers + mos.mem_scheduler.dispatcher.register_handlers( + { + QUERY_TASK_LABEL: custom_query_handler, + ANSWER_TASK_LABEL: custom_answer_handler, + MEM_UPDATE_TASK_LABEL: custom_mem_update_handler, + } + ) + + # 添加消息 + messages = [ + {"role": "user", "content": "I like playing football."}, + {"role": "assistant", "content": "I like playing football too."}, + ] + mos.add(messages, user_id=user_id, mem_cube_id=mem_cube_id) + + # 聊天循环: 展示 TreeTextMemory 节点 + KVCache + while True: + user_input = input("👤 [You] ").strip() + print() + response = mos.chat(user_input, user_id=user_id) + retrieved_memories = mos.get_all(mem_cube_id=mem_cube_id, user_id=user_id) + + print(f"🤖 [Assistant] {response}") + + # 展示 TreeTextMemory 中的各类型节点 + text_memories = retrieved_memories["text_mem"][0]["memories"] + # Handle different memory structures (NaiveTextMemory returns list, TreeTextMemory returns dict with nodes) + if isinstance(text_memories, dict) and "nodes" in text_memories: + for node in text_memories["nodes"]: + mem_type = node["metadata"].get("memory_type", "Unknown") + print(f"[{mem_type}] {node['memory']}") + elif isinstance(text_memories, list): + for mem in text_memories: + # Naive memory items might not have memory_type metadata, or it might be different + print(f"[TextMemory] {mem.memory if hasattr(mem, 'memory') else mem}") + + +if __name__ == "__main__": + kv_cache_only() + + run_scheduler_example() diff --git a/src/memos/llms/hf.py b/src/memos/llms/hf.py index d46db7c9e..b5fc4ba13 100644 --- a/src/memos/llms/hf.py +++ b/src/memos/llms/hf.py @@ -2,13 +2,7 @@ from typing import Any from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, DynamicCache, - LogitsProcessorList, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, ) from memos.configs.llm import HFLLMConfig @@ -30,6 +24,17 @@ def __init__(self, config: HFLLMConfig): """ Initialize the HFLLM model and tokenizer, and set up logits processors for sampling. """ + import torch + + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + ) + self.config = config # Default model if not specified @@ -37,9 +42,14 @@ def __init__(self, config: HFLLMConfig): self.config.model_name_or_path = "Qwen/Qwen3-1.7B" # Initialize hf model - self.model = AutoModelForCausalLM.from_pretrained( - self.config.model_name_or_path, torch_dtype="auto", device_map="auto" - ) + if torch.backends.mps.is_available(): + self.model = AutoModelForCausalLM.from_pretrained( + self.config.model_name_or_path, torch_dtype="auto" + ).to("mps") + else: + self.model = AutoModelForCausalLM.from_pretrained( + self.config.model_name_or_path, torch_dtype="auto", device_map="auto" + ) self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_name_or_path, use_fast=True ) @@ -355,6 +365,7 @@ def build_kv_cache(self, messages) -> DynamicCache: DynamicCache: The constructed KV cache object. """ import torch + import transformers # Accept multiple input types and convert to standard chat messages if isinstance(messages, str): @@ -391,7 +402,7 @@ def build_kv_cache(self, messages) -> DynamicCache: # Convert from legacy tuple format to DynamicCache if needed if isinstance(kv, tuple): - kv = DynamicCache.from_legacy_cache(kv) + kv = transformers.DynamicCache.from_legacy_cache(kv) # Handle compatibility between old and new transformers versions # In newer versions, DynamicCache uses 'layers' attribute diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 1a88fa831..e7f01ec3e 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -311,7 +311,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = past_key_values = None if self.config.enable_activation_memory: - if self.config.chat_model.backend != "huggingface": + if self.config.chat_model.backend not in ["huggingface", "huggingface_singleton"]: logger.error( "Activation memory only used for huggingface backend. Skipping activation memory." ) @@ -498,7 +498,9 @@ def register_mem_cube( existing_cube = self.user_manager.get_cube(mem_cube_id) # check the embedder is it consistent with MOSConfig - if self.config.mem_reader.config.embedder != ( + if hasattr( + self.mem_cubes[mem_cube_id].text_mem.config, "embedder" + ) and self.config.mem_reader.config.embedder != ( cube_embedder := self.mem_cubes[mem_cube_id].text_mem.config.embedder ): logger.warning( diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py index 0114fc0da..0dc6ab209 100644 --- a/src/memos/mem_os/main.py +++ b/src/memos/mem_os/main.py @@ -310,7 +310,7 @@ def _generate_enhanced_response_with_context( # Handle activation memory if enabled (same as core method) past_key_values = None if self.config.enable_activation_memory: - if self.config.chat_model.backend != "huggingface": + if self.config.chat_model.backend not in ["huggingface", "huggingface_singleton"]: logger.error( "Activation memory only used for huggingface backend. Skipping activation memory." ) diff --git a/src/memos/mem_os/utils/format_utils.py b/src/memos/mem_os/utils/format_utils.py index 5fdb59058..f6e33bb31 100644 --- a/src/memos/mem_os/utils/format_utils.py +++ b/src/memos/mem_os/utils/format_utils.py @@ -1087,38 +1087,64 @@ def convert_activation_memory_to_serializable( serializable_items = [] for item in act_mem_items: + key_layers = 0 + val_layers = 0 + device = "unknown" + dtype = "unknown" + key_shapes = [] + value_shapes = [] + + if item.memory: + if hasattr(item.memory, "layers"): + key_layers = len(item.memory.layers) + val_layers = len(item.memory.layers) + if key_layers > 0: + l0 = item.memory.layers[0] + k0 = getattr(l0, "key_cache", getattr(l0, "keys", None)) + if k0 is not None: + device = str(k0.device) + dtype = str(k0.dtype) + + for i, layer in enumerate(item.memory.layers): + k = getattr(layer, "key_cache", getattr(layer, "keys", None)) + v = getattr(layer, "value_cache", getattr(layer, "values", None)) + if k is not None: + key_shapes.append({"layer": i, "shape": list(k.shape)}) + if v is not None: + value_shapes.append({"layer": i, "shape": list(v.shape)}) + + elif hasattr(item.memory, "key_cache"): + key_layers = len(item.memory.key_cache) + val_layers = len(item.memory.value_cache) + if key_layers > 0 and item.memory.key_cache[0] is not None: + device = str(item.memory.key_cache[0].device) + dtype = str(item.memory.key_cache[0].dtype) + + for i, key_tensor in enumerate(item.memory.key_cache): + if key_tensor is not None: + key_shapes.append({"layer": i, "shape": list(key_tensor.shape)}) + + for i, val_tensor in enumerate(item.memory.value_cache): + if val_tensor is not None: + value_shapes.append({"layer": i, "shape": list(val_tensor.shape)}) + # Extract basic information that can be serialized serializable_item = { "id": item.id, "metadata": item.metadata, "memory_info": { "type": "DynamicCache", - "key_cache_layers": len(item.memory.key_cache) if item.memory else 0, - "value_cache_layers": len(item.memory.value_cache) if item.memory else 0, - "device": str(item.memory.key_cache[0].device) - if item.memory and item.memory.key_cache - else "unknown", - "dtype": str(item.memory.key_cache[0].dtype) - if item.memory and item.memory.key_cache - else "unknown", + "key_cache_layers": key_layers, + "value_cache_layers": val_layers, + "device": device, + "dtype": dtype, }, } # Add tensor shape information if available - if item.memory and item.memory.key_cache: - key_shapes = [] - value_shapes = [] - - for i, key_tensor in enumerate(item.memory.key_cache): - if key_tensor is not None: - key_shapes.append({"layer": i, "shape": list(key_tensor.shape)}) - - if i < len(item.memory.value_cache) and item.memory.value_cache[i] is not None: - value_shapes.append( - {"layer": i, "shape": list(item.memory.value_cache[i].shape)} - ) - + if key_shapes: serializable_item["memory_info"]["key_shapes"] = key_shapes + if value_shapes: serializable_item["memory_info"]["value_shapes"] = value_shapes serializable_items.append(serializable_item) @@ -1144,7 +1170,19 @@ def convert_activation_memory_summary(act_mem_items: list[KVCacheItem]) -> dict[ total_parameters = 0 for item in act_mem_items: - if item.memory and item.memory.key_cache: + if not item.memory: + continue + + if hasattr(item.memory, "layers"): + total_layers += len(item.memory.layers) + for layer in item.memory.layers: + k = getattr(layer, "key_cache", getattr(layer, "keys", None)) + v = getattr(layer, "value_cache", getattr(layer, "values", None)) + if k is not None: + total_parameters += k.numel() + if v is not None: + total_parameters += v.numel() + elif hasattr(item.memory, "key_cache"): total_layers += len(item.memory.key_cache) # Calculate approximate parameter count diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 70472958e..61a7d2b6d 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -618,7 +618,7 @@ def _read_memory( messages=combined_messages, memory_list=original_memory_group, user_only=os.getenv("SIMPLE_STRUCT_REWRITE_USER_ONLY", "true").lower() - == "true", + == "false", ) serialized_revised_memories = json.dumps( [one.memory for one in revised_memory_list], indent=2 diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 728203f5b..3f5c90b67 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -74,6 +74,7 @@ from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule from memos.memories.activation.kv import KVCacheMemory from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory +from memos.memories.textual.naive import NaiveTextMemory from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE @@ -198,13 +199,16 @@ def init_mem_cube( logger.error("mem_cube is None, cannot initialize", stack_info=True) self.mem_cube = mem_cube self.text_mem: TreeTextMemory = self.mem_cube.text_mem - self.reranker: HTTPBGEReranker = self.text_mem.reranker + self.reranker: HTTPBGEReranker = getattr(self.text_mem, "reranker", None) if searcher is None: - self.searcher: Searcher = self.text_mem.get_searcher( - manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", - moscube=False, - process_llm=self.process_llm, - ) + if hasattr(self.text_mem, "get_searcher"): + self.searcher: Searcher = self.text_mem.get_searcher( + manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + moscube=False, + process_llm=self.process_llm, + ) + else: + self.searcher = None else: self.searcher = searcher self.feedback_server = feedback_server @@ -540,6 +544,29 @@ def replace_working_memory( mem_cube=mem_cube, log_func_callback=self._submit_web_logs, ) + elif isinstance(text_mem_base, NaiveTextMemory): + # For NaiveTextMemory, we populate the monitors with the new candidates so activation memory can pick them up + logger.info( + f"NaiveTextMemory: Updating working memory monitors with {len(new_memory)} candidates." + ) + + # Use query keywords if available, otherwise just basic monitoring + query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] + query_db_manager.sync_with_orm() + query_keywords = query_db_manager.obj.get_keywords_collections() + + new_working_memory_monitors = self.transform_working_memories_to_monitors( + query_keywords=query_keywords, + memories=new_memory, + ) + + self.monitor.update_working_memory_monitors( + new_working_memory_monitors=new_working_memory_monitors, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + memories_with_new_order = new_memory else: logger.error("memory_base is not supported") memories_with_new_order = new_memory @@ -1008,6 +1035,9 @@ def _monitor_loop(self): try: q_sizes = self.memos_message_queue.qsize() + if not isinstance(q_sizes, dict): + continue + for stream_key, queue_length in q_sizes.items(): # Skip aggregate keys like 'total_size' if stream_key == "total_size": diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py index 57d78676f..fd83ec86f 100644 --- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -55,7 +55,11 @@ def create_autofilled_log_item( "mem_cube is None — this should not happen in production!", stack_info=True ) text_mem_base: TreeTextMemory = mem_cube.text_mem - current_memory_sizes = text_mem_base.get_current_memory_size(user_name=mem_cube_id) + + current_memory_sizes = {} + if hasattr(text_mem_base, "get_current_memory_size"): + current_memory_sizes = text_mem_base.get_current_memory_size(user_name=mem_cube_id) + current_memory_sizes = { "long_term_memory_size": current_memory_sizes.get("LongTermMemory", 0), "user_memory_size": current_memory_sizes.get("UserMemory", 0), @@ -63,14 +67,32 @@ def create_autofilled_log_item( "transformed_act_memory_size": NOT_INITIALIZED, "parameter_memory_size": NOT_INITIALIZED, } + memory_capacities = { - "long_term_memory_capacity": text_mem_base.memory_manager.memory_size["LongTermMemory"], - "user_memory_capacity": text_mem_base.memory_manager.memory_size["UserMemory"], - "working_memory_capacity": text_mem_base.memory_manager.memory_size["WorkingMemory"], + "long_term_memory_capacity": 0, + "user_memory_capacity": 0, + "working_memory_capacity": 0, "transformed_act_memory_capacity": NOT_INITIALIZED, "parameter_memory_capacity": NOT_INITIALIZED, } + if hasattr(text_mem_base, "memory_manager") and hasattr( + text_mem_base.memory_manager, "memory_size" + ): + memory_capacities.update( + { + "long_term_memory_capacity": text_mem_base.memory_manager.memory_size.get( + "LongTermMemory", 0 + ), + "user_memory_capacity": text_mem_base.memory_manager.memory_size.get( + "UserMemory", 0 + ), + "working_memory_capacity": text_mem_base.memory_manager.memory_size.get( + "WorkingMemory", 0 + ), + } + ) + if hasattr(self, "monitor"): if ( user_id in self.monitor.activation_memory_monitors diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 86066f346..9b19e9ecb 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -34,6 +34,7 @@ is_cloud_env, ) from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.naive import NaiveTextMemory from memos.memories.textual.preference import PreferenceTextMemory from memos.memories.textual.tree import TreeTextMemory from memos.types import ( @@ -846,7 +847,9 @@ def _process_memories_with_reader( memory_item = text_mem.get(mem_id, user_name=user_name) memory_items.append(memory_item) except Exception as e: - logger.warning(f"Failed to get memory {mem_id}: {e}") + logger.warning( + f"[_process_memories_with_reader] Failed to get memory {mem_id}: {e}" + ) continue if not memory_items: @@ -1364,22 +1367,31 @@ def process_session_turn( text_mem_base = mem_cube.text_mem if not isinstance(text_mem_base, TreeTextMemory): - logger.error( - f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} " - f"for mem_cube_id={mem_cube_id}, user_id={user_id}. " - f"text_mem_base value: {text_mem_base}", - exc_info=True, + if isinstance(text_mem_base, NaiveTextMemory): + logger.debug( + f"NaiveTextMemory used for mem_cube_id={mem_cube_id}, processing session turn with simple search." + ) + # Treat NaiveTextMemory similar to TreeTextMemory but with simpler logic + # We will perform retrieval to get "working memory" candidates for activation memory + # But we won't have a distinct "current working memory" + cur_working_memory = [] + else: + logger.warning( + f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} " + f"for mem_cube_id={mem_cube_id}, user_id={user_id}. " + f"text_mem_base value: {text_mem_base}" + ) + return [], [] + else: + cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory( + user_name=mem_cube_id ) - return + cur_working_memory = cur_working_memory[:top_k] logger.info( f"[process_session_turn] Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}" ) - cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory( - user_name=mem_cube_id - ) - cur_working_memory = cur_working_memory[:top_k] text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] intent_result = self.monitor.detect_intent( q_list=queries, text_working_memory=text_working_memory @@ -1419,15 +1431,28 @@ def process_session_turn( ) search_args = {} - results: list[TextualMemoryItem] = self.retriever.search( - query=item, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - top_k=k_per_evidence, - method=self.search_method, - search_args=search_args, - ) + if isinstance(text_mem_base, NaiveTextMemory): + # NaiveTextMemory doesn't support complex search args usually, but let's see + # self.retriever.search calls mem_cube.text_mem.search + # NaiveTextMemory.search takes query and top_k + # SchedulerRetriever.search handles method dispatch + # For NaiveTextMemory, we might need to bypass retriever or extend it + # But let's try calling naive memory directly if retriever fails or doesn't support it + try: + results = text_mem_base.search(query=item, top_k=k_per_evidence) + except Exception as e: + logger.warning(f"NaiveTextMemory search failed: {e}") + results = [] + else: + results: list[TextualMemoryItem] = self.retriever.search( + query=item, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + top_k=k_per_evidence, + method=self.search_method, + search_args=search_args, + ) logger.info( f"[process_session_turn] Search results for missing evidence '{item}': " diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py index b097b1e2d..d75d6ee75 100644 --- a/src/memos/mem_scheduler/monitors/general_monitor.py +++ b/src/memos/mem_scheduler/monitors/general_monitor.py @@ -200,15 +200,19 @@ def update_working_memory_monitors( mem_cube_id: str, mem_cube: GeneralMemCube, ): - text_mem_base: TreeTextMemory = mem_cube.text_mem - assert isinstance(text_mem_base, TreeTextMemory) - self.working_mem_monitor_capacity = min( - DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, - ( - int(text_mem_base.memory_manager.memory_size["WorkingMemory"]) - + self.partial_retention_number - ), - ) + text_mem_base = mem_cube.text_mem + + if isinstance(text_mem_base, TreeTextMemory): + self.working_mem_monitor_capacity = min( + DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, + ( + int(text_mem_base.memory_manager.memory_size["WorkingMemory"]) + + self.partial_retention_number + ), + ) + else: + # Fallback for NaiveTextMemory and others + self.working_mem_monitor_capacity = DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT # register monitors self.register_memory_manager_if_not_exists( diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index 35df3db64..e2c1621d4 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -128,9 +128,6 @@ def status_tracker(self) -> TaskStatusTracker | None: if self._status_tracker is None: try: self._status_tracker = TaskStatusTracker(self.redis) - # Propagate to submodules when created lazily - if self.dispatcher: - self.dispatcher.status_tracker = self._status_tracker if self.memos_message_queue: self.memos_message_queue.set_status_tracker(self._status_tracker) except Exception as e: diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index eae70f8ef..791cedf41 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -4,9 +4,18 @@ the local memos_message_queue functionality in BaseScheduler. """ +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + from memos.log import get_logger from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import DEFAULT_STREAM_KEY_PREFIX +from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator +from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule @@ -16,26 +25,38 @@ class SchedulerLocalQueue(RedisSchedulerModule): def __init__( self, - maxsize: int, + maxsize: int = 0, + stream_key_prefix: str = DEFAULT_STREAM_KEY_PREFIX, + orchestrator: SchedulerOrchestrator | None = None, + status_tracker: TaskStatusTracker | None = None, ): """ Initialize the SchedulerLocalQueue with a maximum queue size limit. + Arguments match SchedulerRedisQueue for compatibility. Args: - maxsize (int): Maximum number of messages allowed - in each individual queue. - If exceeded, subsequent puts will block - or raise an exception based on `block` parameter. + maxsize (int): Maximum number of messages allowed in each individual queue. + stream_key_prefix (str): Prefix for stream keys (simulated). + orchestrator: SchedulerOrchestrator instance (ignored). + status_tracker: TaskStatusTracker instance (ignored). """ super().__init__() - self.stream_key_prefix = "local_queue" + self.stream_key_prefix = stream_key_prefix or "local_queue" self.max_internal_message_queue_size = maxsize + # Dictionary to hold per-stream queues: key = stream_key, value = Queue[ScheduleMessageItem] self.queue_streams: dict[str, Queue[ScheduleMessageItem]] = {} + + self.orchestrator = orchestrator + self.status_tracker = status_tracker + + self._is_listening = False + self._message_handler: Callable[[ScheduleMessageItem], None] | None = None + logger.info( - f"SchedulerLocalQueue initialized with max_internal_message_queue_size={maxsize}" + f"SchedulerLocalQueue initialized with max_internal_message_queue_size={self.max_internal_message_queue_size}" ) def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str: @@ -86,7 +107,7 @@ def get( stream_key: str, block: bool = True, timeout: float | None = None, - batch_size: int | None = None, + batch_size: int | None = 1, ) -> list[ScheduleMessageItem]: if batch_size is not None and batch_size <= 0: logger.warning( @@ -99,18 +120,19 @@ def get( logger.error(f"Stream {stream_key} does not exist when trying to get messages.") return [] + # Ensure we always request a batch so we get a list back + effective_batch_size = batch_size if batch_size is not None else 1 + # Note: Assumes custom Queue implementation supports batch_size parameter res = self.queue_streams[stream_key].get( - block=block, timeout=timeout, batch_size=batch_size + block=block, timeout=timeout, batch_size=effective_batch_size ) logger.debug( f"Retrieved {len(res)} messages from queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}" ) return res - def get_nowait( - self, stream_key: str, batch_size: int | None = None - ) -> list[ScheduleMessageItem]: + def get_nowait(self, stream_key: str, batch_size: int | None = 1) -> list[ScheduleMessageItem]: """ Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size). @@ -170,35 +192,13 @@ def qsize(self) -> dict: logger.debug(f"Current queue sizes: {sizes}") return sizes - def size(self) -> int: - """ - Get the current size of the queue (total message count). - Compatible with SchedulerRedisQueue. - """ - return self.unfinished_tasks - - def empty(self) -> bool: - """ - Check if the queue is empty. - Compatible with SchedulerRedisQueue. - """ - return self.size() == 0 - - def full(self) -> bool: - """ - Check if the queue is full. - Compatible with SchedulerRedisQueue. - """ - # Local queue limits are per-stream (max_internal_message_queue_size). - # It is considered full only if all streams are full. - if not self.queue_streams: - return False - - return all(queue.full() for queue in self.queue_streams.values()) - - def clear(self) -> None: - for queue in self.queue_streams.values(): - queue.clear() + def clear(self, stream_key: str | None = None) -> None: + if stream_key: + if stream_key in self.queue_streams: + self.queue_streams[stream_key].clear() + else: + for queue in self.queue_streams.values(): + queue.clear() @property def unfinished_tasks(self) -> int: @@ -216,3 +216,32 @@ def unfinished_tasks(self) -> int: total = sum(queue.qsize() for queue in self.queue_streams.values()) logger.debug(f"Total unfinished tasks across all queues: {total}") return total + + def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: + """ + Return list of active stream keys. + """ + prefix = stream_key_prefix or self.stream_key_prefix + return [k for k in self.queue_streams if k.startswith(prefix)] + + def size(self) -> int: + """ + Total size of all queues. + """ + return sum(q.qsize() for q in self.queue_streams.values()) + + def empty(self) -> bool: + """ + Check if all queues are empty. + """ + return self.size() == 0 + + def full(self) -> bool: + """ + Check if any queue is full (approximate). + """ + if self.max_internal_message_queue_size <= 0: + return False + return any( + q.qsize() >= self.max_internal_message_queue_size for q in self.queue_streams.values() + ) diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 2f4318003..941c52164 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -787,7 +787,7 @@ def qsize(self) -> dict: Total number of messages across all matching streams. """ if not self._redis_conn: - return 0 + return {} total_size = 0 try: diff --git a/src/memos/memories/activation/kv.py b/src/memos/memories/activation/kv.py index 98d611dbf..1981b958f 100644 --- a/src/memos/memories/activation/kv.py +++ b/src/memos/memories/activation/kv.py @@ -2,9 +2,7 @@ import pickle from datetime import datetime -from importlib.metadata import version -from packaging.version import Version from transformers import DynamicCache from memos.configs.memory import KVCacheMemoryConfig @@ -211,10 +209,24 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: return caches[0] merged = DynamicCache() - num_layers = len(caches[0].key_cache) - if Version(version("transformers")) >= Version("4.54.0"): - merged.append_new_layers(num_layers - 1) + # Check for new structure (layers) + if hasattr(caches[0], "layers"): + num_layers = len(caches[0].layers) + + # Ensure merged has layers attribute and populate it + if not hasattr(merged, "layers"): + merged.layers = [] + + if num_layers > 0: + # Get the class of the layer from the first cache + # We assume all caches use the same layer class + layer_cls = type(caches[0].layers[0]) + + # Populate merged.layers + while len(merged.layers) < num_layers: + merged.layers.append(layer_cls()) + for layer in range(num_layers): # gather all K and V for this layer keys = [c.layers[layer].keys for c in caches] @@ -223,7 +235,10 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: merged.layers[layer].keys = torch.cat(keys, dim=-2) merged.layers[layer].values = torch.cat(vals, dim=-2) - else: + # Check for old structure (key_cache) + elif hasattr(caches[0], "key_cache"): + num_layers = len(caches[0].key_cache) + for layer in range(num_layers): # gather all K and V for this layer keys = [c.key_cache[layer] for c in caches] @@ -232,6 +247,11 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: merged.key_cache.append(torch.cat(keys, dim=-2)) merged.value_cache.append(torch.cat(vals, dim=-2)) + else: + raise AttributeError( + "DynamicCache object has neither 'layers' nor 'key_cache' attributes" + ) + return merged diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index 40971c77e..26795a2b1 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -622,7 +622,7 @@ 专注于从图像中提取事实性、可观察的信息。除非与用户记忆明显相关,否则避免推测。""" -SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT = """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT_BACKUP = """ You are a strict, language-preserving memory validator and rewriter. Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. @@ -655,6 +655,39 @@ Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT = """ +You are a strict, language-preserving memory validator and rewriter. + +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. + +Rules: +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what is explicitly stated by the user in messages marked as [user]. Remove or flag anything not directly present in the user’s utterances—no assumptions, interpretations, predictions, generalizations, or content originating solely from [assistant]. +3. **Source Attribution Requirement**: + - Every memory must be clearly traceable to its source: + - If a fact appears **only in [assistant] messages** and **is not affirmed by [user]**, label it as “[assistant] memory”. + - If [assistant] states something and [user] explicitly contradicts or denies it, label it as “[assistant] memory, but [user] [brief quote or summary of denial]”. + - If a fact is stated by [user] —whether or not [assistant] also mentions it— it is attributed to “[user]” and may be retained without qualification. +4. **Timestamp Exception**: Memories may include timestamps (e.g., "On December 19, 2026") derived from conversation metadata. If such a date likely reflects the conversation time (even if not in the `messages` list), do NOT treat it as hallucinated—but still attribute it to “[user]” only if the user mentioned or confirmed the date. + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. +- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference from [assistant]" + - "[assistant] memory, but [user] said 'I don't have a dog'" + - "fully grounded in [user]" + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT = """ You are a strict, language-preserving memory validator and rewriter. diff --git a/tests/llms/test_hf.py b/tests/llms/test_hf.py index 595995ad1..375bf2247 100644 --- a/tests/llms/test_hf.py +++ b/tests/llms/test_hf.py @@ -11,8 +11,8 @@ from memos.llms.hf import HFLLM -@patch("memos.llms.hf.AutoModelForCausalLM", MagicMock()) -@patch("memos.llms.hf.AutoTokenizer", MagicMock()) +@patch("transformers.AutoModelForCausalLM", MagicMock()) +@patch("transformers.AutoTokenizer", MagicMock()) class TestHFLLM(unittest.TestCase): def setUp(self): self.mock_inputs = MagicMock()