diff --git a/src/memos/api/config.py b/src/memos/api/config.py index adbd04e3c..70d9366e3 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -538,6 +538,10 @@ def get_internet_config() -> dict[str, Any]: "chunker": { "backend": "sentence", "config": { + "save_rawfile": os.getenv( + "MEM_READER_SAVE_RAWFILENODE", "true" + ).lower() + == "true", "tokenizer_or_token_counter": "gpt2", "chunk_size": 512, "chunk_overlap": 128, @@ -804,6 +808,8 @@ def get_product_default_config() -> dict[str, Any]: "chunker": { "backend": "sentence", "config": { + "save_rawfile": os.getenv("MEM_READER_SAVE_RAWFILENODE", "true").lower() + == "true", "tokenizer_or_token_counter": "gpt2", "chunk_size": 512, "chunk_overlap": 128, @@ -924,6 +930,8 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene "chunker": { "backend": "sentence", "config": { + "save_rawfile": os.getenv("MEM_READER_SAVE_RAWFILENODE", "true").lower() + == "true", "tokenizer_or_token_counter": "gpt2", "chunk_size": 512, "chunk_overlap": 128, diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index cecc42c6c..06c4fd223 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -113,7 +113,7 @@ def post_process_textual_mem( mem for mem in text_formatted_mem if mem["metadata"]["memory_type"] - in ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"] + in ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory", "RawFileMemory"] ] tool_mem = [ mem @@ -157,12 +157,13 @@ def separate_knowledge_and_conversation_mem(memories: list[dict[str, Any]]): for item in memories: sources = item.get("metadata", {}).get("sources", []) if ( - len(sources) > 0 + item["metadata"]["memory_type"] != "RawFileMemory" + and len(sources) > 0 and "type" in sources[0] and sources[0]["type"] == "file" and "content" in sources[0] and sources[0]["content"] != "" - ): # TODO change to memory_type + ): knowledge_mem.append(item) else: conversation_mem.append(item) @@ -203,8 +204,7 @@ def rerank_knowledge_mem( key=lambda item: item.get("metadata", {}).get("relativity", 0.0), reverse=True, ) - - # TODO revoke sources replace memory value + # replace memory value with source.content for LongTermMemory, WorkingMemory or UserMemory for item in reranked_knowledge_mem: item["memory"] = item["metadata"]["sources"][0]["content"] item["metadata"]["sources"] = [] diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index e0f483e0a..feaf55680 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -422,8 +422,8 @@ class APISearchRequest(BaseRequest): ) # Internal field for search memory type search_memory_type: str = Field( - "All", - description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory, SkillMemory", + "AllSummaryMemory", + description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory, RawFileMemory, AllSummaryMemory, SkillMemory", ) # ==== Context ==== @@ -461,6 +461,13 @@ class APISearchRequest(BaseRequest): description="Source of the search query [plugin will router diff search]", ) + neighbor_discovery: bool = Field( + False, + description="Whether to enable neighbor discovery. " + "If enabled, the system will automatically recall neighbor chunks " + "relevant to the query. Default: False.", + ) + @model_validator(mode="after") def _convert_deprecated_fields(self) -> "APISearchRequest": """ diff --git a/src/memos/configs/chunker.py b/src/memos/configs/chunker.py index c2af012f0..f9a738415 100644 --- a/src/memos/configs/chunker.py +++ b/src/memos/configs/chunker.py @@ -14,6 +14,7 @@ class BaseChunkerConfig(BaseConfig): chunk_size: int = Field(default=512, description="Maximum tokens per chunk") chunk_overlap: int = Field(default=128, description="Overlap between chunks") min_sentences_per_chunk: int = Field(default=1, description="Minimum sentences in each chunk") + save_rawfile: bool = Field(default=True, description="Whether to save rawfile") # TODO class SentenceChunkerConfig(BaseChunkerConfig): diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index cf67a9bc3..6c6d1821f 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -235,20 +235,16 @@ def _single_add_operation( to_add_memory.metadata.tags = new_memory_item.metadata.tags to_add_memory.memory = new_memory_item.memory to_add_memory.metadata.embedding = new_memory_item.metadata.embedding - to_add_memory.metadata.user_id = new_memory_item.metadata.user_id - to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = ( - datetime.now().isoformat() - ) - to_add_memory.metadata.background = new_memory_item.metadata.background else: to_add_memory = new_memory_item.model_copy(deep=True) - to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = ( - datetime.now().isoformat() - ) - to_add_memory.metadata.background = new_memory_item.metadata.background - to_add_memory.id = "" + to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = ( + datetime.now().isoformat() + ) + to_add_memory.metadata.background = new_memory_item.metadata.background + to_add_memory.metadata.sources = [] + added_ids = self._retry_db_operation( lambda: self.memory_manager.add([to_add_memory], user_name=user_name, use_batch=False) ) @@ -626,10 +622,39 @@ def _info_comparison(self, memory: TextualMemoryItem, _info: dict, include_keys: def _retrieve(self, query: str, info=None, top_k=20, user_name=None): """Retrieve memory items""" - retrieved_mems = self.searcher.search( - query, info=info, user_name=user_name, top_k=top_k, full_recall=True + + def check_has_edges(mem_item: TextualMemoryItem) -> tuple[TextualMemoryItem, bool]: + """Check if a memory item has edges.""" + edges = self.searcher.graph_store.get_edges(mem_item.id, user_name=user_name) + return (mem_item, len(edges) == 0) + + text_mems = self.searcher.search( + query, + info=info, + memory_type="AllSummaryMemory", + user_name=user_name, + top_k=top_k, + full_recall=True, ) - retrieved_mems = [item[0] for item in retrieved_mems if float(item[1]) > 0.01] + text_mems = [item[0] for item in text_mems if float(item[1]) > 0.01] + + # Memory with edges is not modified by feedback + retrieved_mems = [] + with ContextThreadPoolExecutor(max_workers=10) as executor: + futures = {executor.submit(check_has_edges, item): item for item in text_mems} + for future in concurrent.futures.as_completed(futures): + try: + mem_item, has_no_edges = future.result() + if has_no_edges: + retrieved_mems.append(mem_item) + except Exception as e: + logger.error(f"[0107 Feedback Core: _retrieve] Error checking edges: {e}") + + if len(retrieved_mems) < len(text_mems): + logger.info( + f"[0107 Feedback Core: _retrieve] {len(text_mems) - len(retrieved_mems)} " + f"text memories are not modified by feedback due to edges." + ) if self.pref_feedback: pref_info = {} diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 1a312868a..8b0968ca1 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -288,6 +288,7 @@ def _build_window_from_items( # Collect all memory texts and sources memory_texts = [] all_sources = [] + seen_content = set() # Track seen source content to avoid duplicates roles = set() aggregated_file_ids: list[str] = [] @@ -301,8 +302,18 @@ def _build_window_from_items( item_sources = [item_sources] for source in item_sources: - # Add source to all_sources - all_sources.append(source) + # Get content from source for deduplication + source_content = None + if isinstance(source, dict): + source_content = source.get("content", "") + else: + source_content = getattr(source, "content", "") or "" + + # Only add if content is different (empty content is considered unique) + content_key = source_content if source_content else None + if content_key and content_key not in seen_content: + seen_content.add(content_key) + all_sources.append(source) # Extract role from source if hasattr(source, "role") and source.role: @@ -464,7 +475,10 @@ def _determine_prompt_type(self, sources: list) -> str: source_role = source.get("role") if source_role in {"user", "assistant", "system", "tool"}: prompt_type = "chat" - + if hasattr(source, "type"): + source_type = source.type + if source_type == "file": + prompt_type = "doc" return prompt_type def _get_maybe_merged_memory( @@ -641,11 +655,14 @@ def _process_string_fine( ) -> list[TextualMemoryItem]: """ Process fast mode memory items through LLM to generate fine mode memories. + Where fast_memory_items are raw chunk memory items, not the final memory items. """ if not fast_memory_items: return [] - def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]: + def _process_one_item( + fast_item: TextualMemoryItem, chunk_idx: int, total_chunks: int + ) -> list[TextualMemoryItem]: """Process a single fast memory item and return a list of fine items.""" fine_items: list[TextualMemoryItem] = [] @@ -749,12 +766,40 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]: except Exception as e: logger.error(f"[MultiModalFine] parse error: {e}") + # save rawfile node + if self.save_rawfile and prompt_type == "doc" and len(fine_items) > 0: + rawfile_chunk = mem_str + file_info = fine_items[0].metadata.sources[0].file_info + source = self.multi_modal_parser.file_content_parser.create_source( + message={"file": file_info}, + info=info_per_item, + chunk_index=chunk_idx, + chunk_total=total_chunks, + chunk_content="", + ) + rawfile_node = self._make_memory_item( + value=rawfile_chunk, + info=info_per_item, + memory_type="RawFileMemory", + tags=[ + "mode:fine", + "multimodal:file", + f"chunk:{chunk_idx + 1}/{total_chunks}", + ], + sources=[source], + ) + rawfile_node.metadata.summary_ids = [mem_node.id for mem_node in fine_items] + fine_items.append(rawfile_node) return fine_items fine_memory_items: list[TextualMemoryItem] = [] + total_chunks_len = len(fast_memory_items) with ContextThreadPoolExecutor(max_workers=30) as executor: - futures = [executor.submit(_process_one_item, item) for item in fast_memory_items] + futures = [ + executor.submit(_process_one_item, item, idx, total_chunks_len) + for idx, item in enumerate[TextualMemoryItem](fast_memory_items) + ] for future in concurrent.futures.as_completed(futures): try: @@ -764,6 +809,63 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]: except Exception as e: logger.error(f"[MultiModalFine] worker error: {e}") + # related preceding and following rawfilememories + fine_memory_items = self._relate_preceding_following_rawfile_memories(fine_memory_items) + return fine_memory_items + + def _relate_preceding_following_rawfile_memories( + self, fine_memory_items: list[TextualMemoryItem] + ) -> list[TextualMemoryItem]: + """ + Relate RawFileMemory items to each other by setting preceding_id and following_id. + """ + # Filter RawFileMemory items and track their original positions + rawfile_items_with_pos = [] + for idx, item in enumerate[TextualMemoryItem](fine_memory_items): + if ( + hasattr(item.metadata, "memory_type") + and item.metadata.memory_type == "RawFileMemory" + ): + rawfile_items_with_pos.append((idx, item)) + + if len(rawfile_items_with_pos) <= 1: + return fine_memory_items + + def get_chunk_idx(item_with_pos) -> int: + """Extract chunk_idx from item's source metadata.""" + _, item = item_with_pos + if item.metadata.sources and len(item.metadata.sources) > 0: + source = item.metadata.sources[0] + # Handle both SourceMessage object and dict + if isinstance(source, dict): + file_info = source.get("file_info") + if file_info and isinstance(file_info, dict): + chunk_idx = file_info.get("chunk_index") + if chunk_idx is not None: + return chunk_idx + else: + # SourceMessage object + file_info = getattr(source, "file_info", None) + if file_info and isinstance(file_info, dict): + chunk_idx = file_info.get("chunk_index") + if chunk_idx is not None: + return chunk_idx + return float("inf") + + # Sort items by chunk_index + sorted_rawfile_items_with_pos = sorted(rawfile_items_with_pos, key=get_chunk_idx) + + # Relate adjacent items + for i in range(len(sorted_rawfile_items_with_pos) - 1): + _, current_item = sorted_rawfile_items_with_pos[i] + _, next_item = sorted_rawfile_items_with_pos[i + 1] + current_item.metadata.following_id = next_item.id + next_item.metadata.preceding_id = current_item.id + + # Replace sorted items back to original positions in fine_memory_items + for orig_idx, item in sorted_rawfile_items_with_pos: + fine_memory_items[orig_idx] = item + return fine_memory_items def _get_llm_tool_trajectory_response(self, mem_str: str) -> dict: diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 0f3f3ef01..2b49d63ba 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -416,6 +416,7 @@ def parse_fast( # Extract file parameters (all are optional) file_data = file_info.get("file_data", "") file_id = file_info.get("file_id", "") + filename = file_info.get("filename", "") file_url_flag = False # Build content string based on available information content_parts = [] @@ -436,12 +437,25 @@ def parse_fast( # Check if it looks like a URL elif file_data.startswith(("http://", "https://", "file://")): file_url_flag = True + content_parts.append(f"[File URL: {file_data}]") else: # TODO: split into multiple memory items content_parts.append(file_data) else: content_parts.append(f"[File Data: {type(file_data).__name__}]") + # Priority 2: If file_id is provided, reference it + if file_id: + content_parts.append(f"[File ID: {file_id}]") + + # Priority 3: If filename is provided, include it + if filename: + content_parts.append(f"[Filename: {filename}]") + + # If no content can be extracted, create a placeholder + if not content_parts: + content_parts.append("[File: unknown]") + # Combine content parts content = " ".join(content_parts) @@ -793,10 +807,36 @@ def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem: logger.warning(f"[FileContentParser] Fallback to raw for chunk {chunk_idx}") return _make_fallback(chunk_idx, chunk_text) + def _relate_chunks(items: list[TextualMemoryItem]) -> None: + """ + Relate chunks to each other. + """ + if len(items) <= 1: + return [] + + def get_chunk_idx(item: TextualMemoryItem) -> int: + """Extract chunk_idx from item's source metadata.""" + if item.metadata.sources and len(item.metadata.sources) > 0: + source = item.metadata.sources[0] + if source.file_info and isinstance(source.file_info, dict): + chunk_idx = source.file_info.get("chunk_index") + if chunk_idx is not None: + return chunk_idx + return float("inf") + + sorted_items = sorted(items, key=get_chunk_idx) + + # Relate adjacent items + for i in range(len(sorted_items) - 1): + sorted_items[i].metadata.following_id = sorted_items[i + 1].id + sorted_items[i + 1].metadata.preceding_id = sorted_items[i].id + return sorted_items + # Process chunks concurrently with progress bar memory_items = [] chunk_map = dict(valid_chunks) total_chunks = len(valid_chunks) + fallback_count = 0 logger.info(f"[FileContentParser] Processing {total_chunks} chunks with LLM...") @@ -814,20 +854,53 @@ def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem: chunk_idx = futures[future] try: node = future.result() - if node: - memory_items.append(node) + memory_items.append(node) + + # Check if this node is a fallback by checking tags + is_fallback = any(tag.startswith("fallback:") for tag in node.metadata.tags) + if is_fallback: + fallback_count += 1 + + # save raw file + node_id = node.id + if node.memory != node.metadata.sources[0].content: + chunk_node = _make_memory_item( + value=node.metadata.sources[0].content, + mem_type="RawFileMemory", + tags=[ + "mode:fine", + "multimodal:file", + f"chunk:{chunk_idx + 1}/{total_chunks}", + ], + chunk_idx=chunk_idx, + chunk_content="", + ) + chunk_node.metadata.summary_ids = [node_id] + memory_items.append(chunk_node) + except Exception as e: tqdm.write(f"[ERROR] Chunk {chunk_idx} failed: {e}") logger.error(f"[FileContentParser] Future failed for chunk {chunk_idx}: {e}") # Create fallback for failed future if chunk_idx in chunk_map: + fallback_count += 1 memory_items.append( _make_fallback(chunk_idx, chunk_map[chunk_idx], "error") ) + fallback_percentage = (fallback_count / total_chunks * 100) if total_chunks > 0 else 0.0 logger.info( - f"[FileContentParser] Completed processing {len(memory_items)}/{total_chunks} chunks" + f"[FileContentParser] Completed processing {len(memory_items)}/{total_chunks} chunks, " + f"fallback count: {fallback_count}/{total_chunks} ({fallback_percentage:.1f}%)" ) + rawfile_items = [ + memory for memory in memory_items if memory.metadata.memory_type == "RawFileMemory" + ] + mem_items = [ + memory for memory in memory_items if memory.metadata.memory_type != "RawFileMemory" + ] + related_rawfile_items = _relate_chunks(rawfile_items) + memory_items = mem_items + related_rawfile_items return memory_items or [ _make_memory_item( diff --git a/src/memos/mem_reader/read_multi_modal/user_parser.py b/src/memos/mem_reader/read_multi_modal/user_parser.py index abfebc5db..2e5ea6eae 100644 --- a/src/memos/mem_reader/read_multi_modal/user_parser.py +++ b/src/memos/mem_reader/read_multi_modal/user_parser.py @@ -68,6 +68,10 @@ def create_source( part_type = part.get("type", "") if part_type == "text": text_contents.append(part.get("text", "")) + if part_type == "file": + file_info = part.get("file", {}) + file_data = file_info.get("file_data", "") + text_contents.append(file_data) # Detect overall language from all text content overall_lang = "en" diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 2c4fee853..ceaf28bfa 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -176,6 +176,7 @@ def __init__(self, config: SimpleStructMemReaderConfig): self.llm = LLMFactory.from_config(config.llm) self.embedder = EmbedderFactory.from_config(config.embedder) self.chunker = ChunkerFactory.from_config(config.chunker) + self.save_rawfile = self.chunker.config.save_rawfile self.memory_max_length = 8000 # Use token-based windowing; default to ~5000 tokens if not configured self.chat_window_max_tokens = getattr(self.config, "chat_window_max_tokens", 1024) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index 6bbbd4335..82857c459 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -183,15 +183,41 @@ def _process_memories_with_reader( logger.info("mem_reader processed %s enhanced memories", len(flattened_memories)) if flattened_memories: - enhanced_mem_ids = text_mem.add(flattened_memories, user_name=user_name) + mem_group = [ + memory + for memory in flattened_memories + if memory.metadata.memory_type != "RawFileMemory" + ] + enhanced_mem_ids = text_mem.add(mem_group, user_name=user_name) logger.info( "Added %s enhanced memories: %s", len(enhanced_mem_ids), enhanced_mem_ids, ) + # add raw file nodes and edges + if mem_reader.save_rawfile: + raw_file_mem_group = [ + memory + for memory in flattened_memories + if memory.metadata.memory_type == "RawFileMemory" + ] + text_mem.add_rawfile_nodes_n_edges( + raw_file_mem_group, + enhanced_mem_ids, + user_id=user_id, + user_name=user_name, + ) + logger.info("Added %s Rawfile memories.", len(raw_file_mem_group)) + + # Mark merged_from memories as archived when provided in memory metadata + summary_memories = [ + memory + for memory in flattened_memories + if memory.metadata.memory_type != "RawFileMemory" + ] if mem_reader.graph_db: - for memory in flattened_memories: + for memory in summary_memories: merged_from = (memory.metadata.info or {}).get("merged_from") if merged_from: old_ids = ( @@ -216,7 +242,7 @@ def _process_memories_with_reader( ) else: has_merged_from = any( - (m.metadata.info or {}).get("merged_from") for m in flattened_memories + (m.metadata.info or {}).get("merged_from") for m in summary_memories ) if has_merged_from: logger.warning( diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 63476c7cc..7e40f1d50 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -169,6 +169,7 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata): "OuterMemory", "ToolSchemaMemory", "ToolTrajectoryMemory", + "RawFileMemory", "SkillMemory", ] = Field(default="WorkingMemory", description="Memory lifecycle type.") sources: list[SourceMessage] | None = Field( diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 90326a044..5faf8aa09 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -1,7 +1,9 @@ +import concurrent.futures import json import os import shutil import tempfile +import time from datetime import datetime from pathlib import Path @@ -9,6 +11,7 @@ from memos.configs.memory import TreeTextMemoryConfig from memos.configs.reranker import RerankerConfigFactory +from memos.context.context import ContextThreadPoolExecutor from memos.dependency import require_python_package from memos.embedders.factory import EmbedderFactory, OllamaEmbedder from memos.graph_dbs.factory import GraphStoreFactory, Neo4jGraphDB @@ -503,3 +506,100 @@ def _cleanup_old_backups(root_dir: Path, keep_last_n: int) -> None: logger.info(f"Deleted old backup directory: {old_dir}") except Exception as e: logger.warning(f"Failed to delete backup {old_dir}: {e}") + + def add_rawfile_nodes_n_edges( + self, + raw_file_mem_group: list[TextualMemoryItem], + mem_ids: list[str], + user_id: str | None = None, + user_name: str | None = None, + ) -> None: + """ + Add raw file nodes and edges to the graph. Edges are between raw file ids and mem_ids. + Args: + raw_file_mem_group: List of raw file memory items. + mem_ids: List of memory IDs. + user_name: cube id. + """ + rawfile_ids_local: list[str] = self.add( + raw_file_mem_group, + user_name=user_name, + ) + + from_ids = [] + to_ids = [] + types = [] + + for raw_file_mem in raw_file_mem_group: + # Add SUMMARY edge: memory -> raw file; raw file -> memory + if hasattr(raw_file_mem.metadata, "summary_ids") and raw_file_mem.metadata.summary_ids: + summary_ids = raw_file_mem.metadata.summary_ids + for summary_id in summary_ids: + if summary_id in mem_ids: + from_ids.append(summary_id) + to_ids.append(raw_file_mem.id) + types.append("MATERIAL") + + from_ids.append(raw_file_mem.id) + to_ids.append(summary_id) + types.append("SUMMARY") + + # Add FOLLOWING edge: current chunk -> next chunk + if ( + hasattr(raw_file_mem.metadata, "following_id") + and raw_file_mem.metadata.following_id + ): + following_id = raw_file_mem.metadata.following_id + if following_id in rawfile_ids_local: + from_ids.append(raw_file_mem.id) + to_ids.append(following_id) + types.append("FOLLOWING") + + # Add PRECEDING edge: previous chunk -> current chunk + if ( + hasattr(raw_file_mem.metadata, "preceding_id") + and raw_file_mem.metadata.preceding_id + ): + preceding_id = raw_file_mem.metadata.preceding_id + if preceding_id in rawfile_ids_local: + from_ids.append(raw_file_mem.id) + to_ids.append(preceding_id) + types.append("PRECEDING") + + start_time = time.time() + self.add_graph_edges( + from_ids, + to_ids, + types, + user_name=user_name, + ) + end_time = time.time() + logger.info(f"[RawFile] Added {len(rawfile_ids_local)} chunks for user {user_id}") + logger.info( + f"[RawFile] Time taken to add edges: {end_time - start_time} seconds for {len(from_ids)} edges" + ) + + def add_graph_edges( + self, from_ids: list[str], to_ids: list[str], types: list[str], user_name: str | None = None + ) -> None: + """ + Add edges to the graph. + Args: + from_ids: List of source node IDs. + to_ids: List of target node IDs. + types: List of edge types. + user_name: Optional user name. + """ + with ContextThreadPoolExecutor(max_workers=20) as executor: + futures = { + executor.submit( + self.graph_store.add_edge, from_id, to_id, edge_type, user_name=user_name + ) + for from_id, to_id, edge_type in zip(from_ids, to_ids, types, strict=False) + } + + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.exception("Add edge error: ", exc_info=e) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 5e9c74f61..cbc349d67 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -68,12 +68,14 @@ def __init__( self.current_memory_size = { "WorkingMemory": 0, "LongTermMemory": 0, + "RawFileMemory": 0, "UserMemory": 0, } if not memory_size: self.memory_size = { "WorkingMemory": 20, "LongTermMemory": 1500, + "RawFileMemory": 1500, "UserMemory": 480, } logger.info(f"MemorySize is {self.memory_size}") @@ -157,7 +159,7 @@ def _add_memories_batch( graph_node_ids: list[str] = [] for memory in memories: - working_id = str(uuid.uuid4()) + working_id = memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4()) if memory.metadata.memory_type in ( "WorkingMemory", @@ -181,11 +183,12 @@ def _add_memories_batch( "UserMemory", "ToolSchemaMemory", "ToolTrajectoryMemory", + "RawFileMemory", "SkillMemory", ): - if not memory.id: - logger.error("Memory ID is not set, generating a new one") - graph_node_id = memory.id or str(uuid.uuid4()) + graph_node_id = ( + memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4()) + ) metadata_dict = memory.metadata.model_dump(exclude_none=True) metadata_dict["updated_at"] = datetime.now().isoformat() @@ -315,7 +318,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non ids: list[str] = [] futures = [] - working_id = str(uuid.uuid4()) + working_id = memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4()) with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: if memory.metadata.memory_type in ( @@ -334,6 +337,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non "UserMemory", "ToolSchemaMemory", "ToolTrajectoryMemory", + "RawFileMemory", "SkillMemory", ): f_graph = ex.submit( @@ -386,9 +390,7 @@ def _add_to_graph_memory( """ Generalized method to add memory to a graph-based memory type (e.g., LongTermMemory, UserMemory). """ - if not memory.id: - logger.error("Memory ID is not set, generating a new one") - node_id = memory.id or str(uuid.uuid4()) + node_id = memory.id if hasattr(memory, "id") else str(uuid.uuid4()) # Step 2: Add new node to graph metadata_dict = memory.metadata.model_dump(exclude_none=True) tags = metadata_dict.get("tags") or [] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 255394317..e5e96dd58 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -67,6 +67,7 @@ def retrieve( "UserMemory", "ToolSchemaMemory", "ToolTrajectoryMemory", + "RawFileMemory", "SkillMemory", ]: raise ValueError(f"Unsupported memory scope: {memory_scope}") @@ -396,6 +397,7 @@ def search_path_b(): for r in all_hits: rid = r.get("id") if rid: + rid = str(rid).strip("\"'") score = r.get("score", 0.0) if rid not in id_to_score or score > id_to_score[rid]: id_to_score[rid] = score @@ -414,11 +416,20 @@ def search_path_b(): ) # Restore score-based order and inject scores into metadata - id_to_node = {n.get("id"): n for n in node_dicts} + id_to_node = {} + for n in node_dicts: + node_id = n.get("id") + if node_id: + # Ensure ID is a string and strip any surrounding quotes + node_id = str(node_id).strip("\"'") + id_to_node[node_id] = n + ordered_nodes = [] for rid in sorted_ids: - if rid in id_to_node: - node = id_to_node[rid] + # Ensure rid is normalized for matching + rid_normalized = str(rid).strip("\"'") + if rid_normalized in id_to_node: + node = id_to_node[rid_normalized] # Inject similarity score as relativity if "metadata" not in node: node["metadata"] = {} @@ -512,6 +523,8 @@ def _fulltext_recall( for r in all_hits: rid = r.get("id") if rid: + # Ensure ID is a string and strip any surrounding quotes + rid = str(rid).strip("\"'") score = r.get("score", 0.0) if rid not in id_to_score or score > id_to_score[rid]: id_to_score[rid] = score @@ -530,11 +543,20 @@ def _fulltext_recall( ) # Restore score-based order and inject scores into metadata - id_to_node = {n.get("id"): n for n in node_dicts} + id_to_node = {} + for n in node_dicts: + node_id = n.get("id") + if node_id: + # Ensure ID is a string and strip any surrounding quotes + node_id = str(node_id).strip("\"'") + id_to_node[node_id] = n + ordered_nodes = [] for rid in sorted_ids: - if rid in id_to_node: - node = id_to_node[rid] + # Ensure rid is normalized for matching + rid_normalized = str(rid).strip("\"'") + if rid_normalized in id_to_node: + node = id_to_node[rid_normalized] # Inject similarity score as relativity if "metadata" not in node: node["metadata"] = {} diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 39aa4e9ac..356402c90 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -1,5 +1,7 @@ import traceback +from concurrent.futures import as_completed + from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import OllamaEmbedder from memos.graph_dbs.factory import Neo4jGraphDB @@ -483,8 +485,8 @@ def _retrieve_from_long_term_and_user( else: cot_embeddings = query_embedding - with ContextThreadPoolExecutor(max_workers=2) as executor: - if memory_type in ["All", "LongTermMemory"]: + with ContextThreadPoolExecutor(max_workers=3) as executor: + if memory_type in ["All", "AllSummaryMemory", "LongTermMemory"]: tasks.append( executor.submit( self.graph_retriever.retrieve, @@ -500,7 +502,7 @@ def _retrieve_from_long_term_and_user( use_fast_graph=self.use_fast_graph, ) ) - if memory_type in ["All", "UserMemory"]: + if memory_type in ["All", "AllSummaryMemory", "UserMemory"]: tasks.append( executor.submit( self.graph_retriever.retrieve, @@ -516,10 +518,28 @@ def _retrieve_from_long_term_and_user( use_fast_graph=self.use_fast_graph, ) ) + if memory_type in ["All", "RawFileMemory"]: + tasks.append( + executor.submit( + self.graph_retriever.retrieve, + query=query, + parsed_goal=parsed_goal, + query_embedding=cot_embeddings, + top_k=top_k * 2, + memory_scope="RawFileMemory", + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, + id_filter=id_filter, + use_fast_graph=self.use_fast_graph, + ) + ) # Collect results from all tasks for task in tasks: results.extend(task.result()) + results = self._deduplicate_rawfile_results(results, user_name=user_name) + results = self._filter_intermediate_content(results) return self.reranker.rerank( query=query, @@ -872,7 +892,7 @@ def _sort_and_trim( (item, score) for item, score in results if item.metadata.memory_type - in ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"] + in ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory", "RawFileMemory"] ] sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k] @@ -891,6 +911,66 @@ def _sort_and_trim( ) return final_items + @timed + def _deduplicate_rawfile_results(self, results, user_name: str | None = None): + """ + Deduplicate rawfile related memories by edge + """ + if not results: + return results + + summary_ids_to_remove = set() + rawfile_items = [item for item in results if item.metadata.memory_type == "RawFileMemory"] + if not rawfile_items: + return results + + with ContextThreadPoolExecutor(max_workers=min(len(rawfile_items), 10)) as executor: + futures = [ + executor.submit( + self.graph_store.get_edges, + rawfile_item.id, + type="SUMMARY", + direction="OUTGOING", + user_name=user_name, + ) + for rawfile_item in rawfile_items + ] + for future in as_completed(futures): + try: + edges = future.result() + for edge in edges: + summary_target_id = edge.get("to") + if summary_target_id: + summary_ids_to_remove.add(summary_target_id) + logger.debug( + f"[DEDUP] Marking summary node {summary_target_id} for removal (pointed by RawFileMemory)" + ) + except Exception as e: + logger.warning(f"[DEDUP] Failed to get summary target ids: {e}") + + filtered_results = [] + for item in results: + if item.id in summary_ids_to_remove: + logger.debug( + f"[DEDUP] Removing summary node {item.id} because it is pointed by RawFileMemory" + ) + continue + filtered_results.append(item) + + return filtered_results + + def _filter_intermediate_content(self, results): + """Filter intermediate content""" + filtered_results = [] + for item in results: + if ( + "File URL:" not in item.memory + and "File ID:" not in item.memory + and "Filename:" not in item.memory + ): + filtered_results.append(item) + return filtered_results + @timed def _update_usage_history(self, items, info, user_name: str | None = None): """Update usage history in graph DB diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index a547fd296..6da55ce02 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -24,6 +24,7 @@ MEM_READ_TASK_LABEL, PREF_ADD_TASK_LABEL, ) +from memos.memories.textual.item import TextualMemoryItem from memos.multi_mem_cube.views import MemCubeView from memos.search import search_text_memories from memos.templates.mem_reader_prompts import PROMPT_MAPPING @@ -45,7 +46,6 @@ from memos.mem_cube.navie import NaiveMemCube from memos.mem_reader.simple_struct import SimpleStructMemReader from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler - from memos.memories.textual.item import TextualMemoryItem @dataclass @@ -269,11 +269,12 @@ def _deep_search( search_filter=search_filter, info=info, ) - formatted_memories = [ - format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr")) - for data in enhanced_memories - ] - return formatted_memories + return self._postformat_memories( + enhanced_memories, + user_context.mem_cube_id, + include_embedding=search_req.dedup == "sim", + neighbor_discovery=search_req.neighbor_discovery, + ) def _agentic_search( self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int @@ -281,11 +282,12 @@ def _agentic_search( deepsearch_results = self.deepsearch_agent.run( search_req.query, user_id=user_context.mem_cube_id ) - formatted_memories = [ - format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr")) - for data in deepsearch_results - ] - return formatted_memories + return self._postformat_memories( + deepsearch_results, + user_context.mem_cube_id, + include_embedding=search_req.dedup == "sim", + neighbor_discovery=search_req.neighbor_discovery, + ) def _fine_search( self, @@ -326,6 +328,7 @@ def _fine_search( user_name=user_context.mem_cube_id, top_k=search_req.top_k, mode=SearchMode.FINE, + memory_type=search_req.search_memory_type, manual_close_internet=not search_req.internet_search, moscube=search_req.moscube, search_filter=search_filter, @@ -365,7 +368,7 @@ def _fine_search( user_name=user_context.mem_cube_id, top_k=retrieval_size, mode=SearchMode.FAST, - memory_type="All", + memory_type=search_req.search_memory_type, search_priority=search_priority, search_filter=search_filter, info=info, @@ -393,10 +396,12 @@ def _dedup_by_content(memories: list) -> list: deduped_memories = ( enhanced_memories if search_req.dedup == "no" else _dedup_by_content(enhanced_memories) ) - formatted_memories = [ - format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr")) - for data in deduped_memories - ] + formatted_memories = self._postformat_memories( + deduped_memories, + user_context.mem_cube_id, + include_embedding=search_req.dedup == "sim", + neighbor_discovery=search_req.neighbor_discovery, + ) logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}") @@ -438,7 +443,7 @@ def _search_pref( }, search_filter=search_req.filter, ) - return [format_memory_item(data) for data in results] + return self._postformat_memories(results, user_context.mem_cube_id) except Exception as e: self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) return [] @@ -466,12 +471,65 @@ def _fast_search( include_embedding=(search_req.dedup == "mmr"), ) - formatted_memories = [ - format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr")) - for data in search_results - ] + return self._postformat_memories( + search_results, + user_context.mem_cube_id, + include_embedding=search_req.dedup == "sim", + neighbor_discovery=search_req.neighbor_discovery, + ) - return formatted_memories + def _postformat_memories( + self, + search_results: list, + user_name: str, + include_embedding: bool = False, + neighbor_discovery: bool = False, + ) -> list: + """ + Postprocess search results. + """ + + def extract_edge_info(edges_info: list[dict], neighbor_relativity: float): + edge_mems = [] + for edge in edges_info: + chunk_target_id = edge.get("to") + edge_type = edge.get("type") + item_neighbor = self.searcher.graph_store.get_node(chunk_target_id) + if item_neighbor: + item_neighbor_mem = TextualMemoryItem(**item_neighbor) + item_neighbor_mem.metadata.relativity = neighbor_relativity + edge_mems.append(item_neighbor_mem) + item_neighbor_id = item_neighbor.get("id", "None") + self.logger.info( + f"Add neighbor chunk: {item_neighbor_id}, edge_type: {edge_type} for {item.id}" + ) + return edge_mems + + final_items = [] + if neighbor_discovery: + for item in search_results: + if item.metadata.memory_type == "RawFileMemory": + neighbor_relativity = item.metadata.relativity * 0.8 + preceding_info = self.searcher.graph_store.get_edges( + item.id, type="PRECEDING", direction="OUTGOING", user_name=user_name + ) + final_items.extend(extract_edge_info(preceding_info, neighbor_relativity)) + + final_items.append(item) + + following_info = self.searcher.graph_store.get_edges( + item.id, type="FOLLOWING", direction="OUTGOING", user_name=user_name + ) + final_items.extend(extract_edge_info(following_info, neighbor_relativity)) + + else: + final_items.append(item) + else: + final_items = search_results + + return [ + format_memory_item(data, include_embedding=include_embedding) for data in final_items + ] def _mix_search( self, @@ -812,16 +870,34 @@ def _process_text_mem( self.logger.info(f"Memory extraction completed for user {add_req.user_id}") # Add memories to text_mem + mem_group = [ + memory for memory in flattened_local if memory.metadata.memory_type != "RawFileMemory" + ] mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( - flattened_local, + mem_group, user_name=user_context.mem_cube_id, ) + self.logger.info( f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " f"in session {add_req.session_id}: {mem_ids_local}" ) - # Schedule async/sync tasks + # Add raw file nodes and edges + if self.mem_reader.save_rawfile and extract_mode == "fine": + raw_file_mem_group = [ + memory + for memory in flattened_local + if memory.metadata.memory_type == "RawFileMemory" + ] + self.naive_mem_cube.text_mem.add_rawfile_nodes_n_edges( + raw_file_mem_group, + mem_ids_local, + user_id=add_req.user_id, + user_name=user_context.mem_cube_id, + ) + + # Schedule async/sync tasks: async process raw chunk memory | sync only send messages self._schedule_memory_tasks( add_req=add_req, user_context=user_context, diff --git a/tests/memories/textual/test_tree_searcher.py b/tests/memories/textual/test_tree_searcher.py index 2a5536cf8..3d1469d00 100644 --- a/tests/memories/textual/test_tree_searcher.py +++ b/tests/memories/textual/test_tree_searcher.py @@ -48,11 +48,23 @@ def test_searcher_fast_path(mock_searcher): mock_searcher.embedder.embed.return_value = [[0.1] * 5, [0.2] * 5] # working path mock - mock_searcher.graph_retriever.retrieve.side_effect = [ - [make_item("wm1", 0.9)[0]], # working memory - [make_item("lt1", 0.8)[0]], # long-term - [make_item("um1", 0.7)[0]], # user - ] + # For "All", _retrieve_from_working_memory calls once (WorkingMemory), + # and _retrieve_from_long_term_and_user calls 3 times (LongTermMemory, UserMemory, RawFileMemory) + # Use a function to handle concurrent calls with different memory_scope + def retrieve_side_effect(*args, **kwargs): + memory_scope = kwargs.get("memory_scope", "") + if memory_scope == "WorkingMemory": + return [make_item("wm1", 0.9)[0]] + elif memory_scope == "LongTermMemory": + return [make_item("lt1", 0.8)[0]] + elif memory_scope == "UserMemory": + return [make_item("um1", 0.7)[0]] + elif memory_scope == "RawFileMemory": + return [make_item("rm1", 0.6)[0]] + else: + return [] + + mock_searcher.graph_retriever.retrieve.side_effect = retrieve_side_effect mock_searcher.reranker.rerank.return_value = [ make_item("wm1", 0.9), make_item("lt1", 0.8),