diff --git a/.gitignore b/.gitignore index ac31eb41a..97af509ea 100644 --- a/.gitignore +++ b/.gitignore @@ -216,3 +216,5 @@ cython_debug/ outputs evaluation/data/temporal_locomo +test_add_pipeline.py +test_file_pipeline.py diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 13dd92189..ba527d602 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -43,6 +43,7 @@ ) from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer @@ -190,6 +191,7 @@ def init_server() -> dict[str, Any]: ) embedder = EmbedderFactory.from_config(embedder_config) nli_client = NLIClient(base_url=nli_client_config["base_url"]) + memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) reranker = RerankerFactory.from_config(reranker_config) @@ -393,4 +395,5 @@ def init_server() -> dict[str, Any]: "redis_client": redis_client, "deepsearch_agent": deepsearch_agent, "nli_client": nli_client, + "memory_history_manager": memory_history_manager, } diff --git a/src/memos/chunkers/base.py b/src/memos/chunkers/base.py index c2a783baa..0c781faf9 100644 --- a/src/memos/chunkers/base.py +++ b/src/memos/chunkers/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from memos.configs.chunker import BaseChunkerConfig +import re class Chunk: @@ -22,3 +23,42 @@ def __init__(self, config: BaseChunkerConfig): @abstractmethod def chunk(self, text: str) -> list[Chunk]: """Chunk the given text into smaller chunks.""" + + def protect_urls(self, text: str) -> tuple[str, dict[str, str]]: + """ + Protect URLs in text from being split during chunking. + + Args: + text: Text to process + + Returns: + tuple: (Text with URLs replaced by placeholders, URL mapping dictionary) + """ + url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+' + url_map = {} + + def replace_url(match): + url = match.group(0) + placeholder = f"__URL_{len(url_map)}__" + url_map[placeholder] = url + return placeholder + + protected_text = re.sub(url_pattern, replace_url, text) + return protected_text, url_map + + def restore_urls(self, text: str, url_map: dict[str, str]) -> str: + """ + Restore protected URLs in text back to their original form. + + Args: + text: Text with URL placeholders + url_map: URL mapping dictionary from protect_urls + + Returns: + str: Text with URLs restored + """ + restored_text = text + for placeholder, url in url_map.items(): + restored_text = restored_text.replace(placeholder, url) + + return restored_text \ No newline at end of file diff --git a/src/memos/chunkers/charactertext_chunker.py b/src/memos/chunkers/charactertext_chunker.py index 15c0958ba..25739d96f 100644 --- a/src/memos/chunkers/charactertext_chunker.py +++ b/src/memos/chunkers/charactertext_chunker.py @@ -36,6 +36,8 @@ def __init__( def chunk(self, text: str, **kwargs) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - chunks = self.chunker.split_text(text) + protected_text, url_map = self.protect_urls(text) + chunks = self.chunker.split_text(protected_text) + chunks = [self.restore_urls(chunk, url_map) for chunk in chunks] logger.debug(f"Generated {len(chunks)} chunks from input text") return chunks diff --git a/src/memos/chunkers/markdown_chunker.py b/src/memos/chunkers/markdown_chunker.py index b7771ac35..8474c4328 100644 --- a/src/memos/chunkers/markdown_chunker.py +++ b/src/memos/chunkers/markdown_chunker.py @@ -2,6 +2,8 @@ from memos.dependency import require_python_package from memos.log import get_logger +import re + from .base import BaseChunker, Chunk @@ -22,6 +24,7 @@ def __init__( chunk_size: int = 1000, chunk_overlap: int = 200, recursive: bool = False, + auto_fix_headers: bool = True, ): from langchain_text_splitters import ( MarkdownHeaderTextSplitter, @@ -29,6 +32,7 @@ def __init__( ) self.config = config + self.auto_fix_headers = auto_fix_headers self.chunker = MarkdownHeaderTextSplitter( headers_to_split_on=config.headers_to_split_on if config @@ -46,17 +50,105 @@ def __init__( def chunk(self, text: str, **kwargs) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - md_header_splits = self.chunker.split_text(text) + # Protect URLs first + protected_text, url_map = self.protect_urls(text) + # Auto-detect and fix malformed header hierarchy if enabled + if self.auto_fix_headers and self._detect_malformed_headers(protected_text): + logger.info("detected malformed header hierarchy, attempting to fix...") + protected_text = self._fix_header_hierarchy(protected_text) + logger.info("Header hierarchy fix completed") + + md_header_splits = self.chunker.split_text(protected_text) chunks = [] if self.chunker_recursive: md_header_splits = self.chunker_recursive.split_documents(md_header_splits) for doc in md_header_splits: try: chunk = " ".join(list(doc.metadata.values())) + "\n" + doc.page_content + chunk = self.restore_urls(chunk, url_map) chunks.append(chunk) except Exception as e: logger.warning(f"warning chunking document: {e}") - chunks.append(doc.page_content) + restored_chunk = self.restore_urls(doc.page_content, url_map) + chunks.append(restored_chunk) logger.info(f"Generated chunks: {chunks[:5]}") logger.debug(f"Generated {len(chunks)} chunks from input text") return chunks + + def _detect_malformed_headers(self, text: str) -> bool: + """Detect if markdown has improper header hierarchy usage.""" + # Extract all valid markdown header lines + header_levels = [] + pattern = re.compile(r'^#{1,6}\s+.+') + for line in text.split('\n'): + stripped_line = line.strip() + if pattern.match(stripped_line): + hash_match = re.match(r'^(#+)', stripped_line) + if hash_match: + level = len(hash_match.group(1)) + header_levels.append(level) + + total_headers = len(header_levels) + if total_headers == 0: + logger.debug("No valid headers detected, skipping check") + return False + + # Calculate level-1 header ratio + level1_count = sum(1 for level in header_levels if level == 1) + + # Determine if malformed: >90% are level-1 when total > 5 + # OR all headers are level-1 when total ≤ 5 + if total_headers > 5: + level1_ratio = level1_count / total_headers + if level1_ratio > 0.9: + logger.warning( + f"Detected header hierarchy issue: {level1_count}/{total_headers} " + f"({level1_ratio:.1%}) of headers are level 1" + ) + return True + elif total_headers <= 5 and level1_count == total_headers: + logger.warning( + f"Detected header hierarchy issue: all {total_headers} headers are level 1" + ) + return True + return False + + def _fix_header_hierarchy(self, text: str) -> str: + """ + Fix markdown header hierarchy by adjusting levels. + + Strategy: + 1. Keep the first header unchanged as level-1 parent + 2. Increment all subsequent headers by 1 level (max level 6) + """ + header_pattern = re.compile(r'^(#{1,6})\s+(.+)$') + lines = text.split('\n') + fixed_lines = [] + first_valid_header = False + + for line in lines: + stripped_line = line.strip() + # Match valid header lines (invalid # lines kept as-is) + header_match = header_pattern.match(stripped_line) + if header_match: + current_hashes, title_content = header_match.groups() + current_level = len(current_hashes) + + if not first_valid_header: + # First valid header: keep original level unchanged + fixed_line = f"{current_hashes} {title_content}" + first_valid_header = True + logger.debug(f"Keep first header at level {current_level}: {title_content[:50]}...") + else: + # Subsequent headers: increment by 1, cap at level 6 + new_level = min(current_level + 1, 6) + new_hashes = '#' * new_level + fixed_line = f"{new_hashes} {title_content}" + logger.debug(f"Adjust header level: {current_level} -> {new_level}: {title_content[:50]}...") + fixed_lines.append(fixed_line) + else: + fixed_lines.append(line) + + # Join with newlines to preserve original formatting + fixed_text = '\n'.join(fixed_lines) + return fixed_text diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index f39dfb8e2..e695d0d9a 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -43,11 +43,13 @@ def __init__(self, config: SentenceChunkerConfig): def chunk(self, text: str) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - chonkie_chunks = self.chunker.chunk(text) + protected_text, url_map = self.protect_urls(text) + chonkie_chunks = self.chunker.chunk(protected_text) chunks = [] for c in chonkie_chunks: chunk = Chunk(text=c.text, token_count=c.token_count, sentences=c.sentences) + chunk = self.restore_urls(chunk.text, url_map) chunks.append(chunk) logger.debug(f"Generated {len(chunks)} chunks from input text") diff --git a/src/memos/chunkers/simple_chunker.py b/src/memos/chunkers/simple_chunker.py index cc0dc40d0..e66bb6bc7 100644 --- a/src/memos/chunkers/simple_chunker.py +++ b/src/memos/chunkers/simple_chunker.py @@ -20,12 +20,15 @@ def _simple_split_text(self, text: str, chunk_size: int, chunk_overlap: int) -> Returns: List of text chunks """ - if not text or len(text) <= chunk_size: - return [text] if text.strip() else [] + protected_text, url_map = self.protect_urls(text) + + if not protected_text or len(protected_text) <= chunk_size: + chunks = [protected_text] if protected_text.strip() else [] + return [self.restore_urls(chunk, url_map) for chunk in chunks] chunks = [] start = 0 - text_len = len(text) + text_len = len(protected_text) while start < text_len: # Calculate end position @@ -35,16 +38,16 @@ def _simple_split_text(self, text: str, chunk_size: int, chunk_overlap: int) -> if end < text_len: # Try to break at newline, sentence end, or space for separator in ["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " "]: - last_sep = text.rfind(separator, start, end) + last_sep = protected_text.rfind(separator, start, end) if last_sep != -1: end = last_sep + len(separator) break - chunk = text[start:end].strip() + chunk = protected_text[start:end].strip() if chunk: chunks.append(chunk) # Move start position with overlap start = max(start + 1, end - chunk_overlap) - return chunks + return [self.restore_urls(chunk, url_map) for chunk in chunks] \ No newline at end of file diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index fbc704d0b..462f64c2a 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -103,14 +103,25 @@ def _handle_url(self, url_str: str, filename: str) -> tuple[str, str | None, boo return response.text, None, True file_ext = os.path.splitext(filename)[1].lower() - if file_ext in [".md", ".markdown", ".txt"]: + if file_ext in [".md", ".markdown", ".txt"] or self._is_oss_md(url_str): return response.text, None, True with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_ext) as temp_file: temp_file.write(response.content) return "", temp_file.name, False except Exception as e: logger.error(f"[FileContentParser] URL processing error: {e}") - return f"[File URL download failed: {url_str}]", None + return f"[File URL download failed: {url_str}]", None, False + + def _is_oss_md(self, url: str) -> bool: + """Check if URL is an OSS markdown file based on pattern.""" + loose_pattern = re.compile(r"^https?://[^/]*\.aliyuncs\.com/.*/([^/?#]+)") + match = loose_pattern.search(url) + if not match: + return False + + file_name = match.group(1) + lower_name = file_name.lower() + return lower_name.endswith((".md", ".markdown", ".txt")) def _is_base64(self, data: str) -> bool: """Quick heuristic to check base64-like string.""" @@ -133,7 +144,12 @@ def _handle_local(self, data: str) -> str: return "" def _process_single_image( - self, image_url: str, original_ref: str, info: dict[str, Any], **kwargs + self, + image_url: str, + original_ref: str, + info: dict[str, Any], + header_context: list[str] | None = None, + **kwargs, ) -> tuple[str, str]: """ Process a single image and return (original_ref, replacement_text). @@ -142,6 +158,7 @@ def _process_single_image( image_url: URL of the image to process original_ref: Original markdown image reference to replace info: Dictionary containing user_id and session_id + header_context: Optional list of header titles providing context for the image **kwargs: Additional parameters for ImageParser Returns: @@ -167,20 +184,31 @@ def _process_single_image( if hasattr(item, "memory") and item.memory: extracted_texts.append(str(item.memory)) + # Prepare header context string if available + header_context_str = "" + if header_context: + # Join headers with " > " to show hierarchy + header_hierarchy = " > ".join(header_context) + header_context_str = f"[Section: {header_hierarchy}]\n\n" + if extracted_texts: # Combine all extracted texts extracted_content = "\n".join(extracted_texts) + # build final replacement text + replacement_text = ( + f"{header_context_str}[Image Content from {image_url}]:\n{extracted_content}\n" + ) # Replace image with extracted content return ( original_ref, - f"\n[Image Content from {image_url}]:\n{extracted_content}\n", + replacement_text, ) else: # If no content extracted, keep original with a note logger.warning(f"[FileContentParser] No content extracted from image: {image_url}") return ( original_ref, - f"\n[Image: {image_url} - No content extracted]\n", + f"{header_context_str}[Image: {image_url} - No content extracted]\n", ) except Exception as e: @@ -188,7 +216,9 @@ def _process_single_image( # On error, keep original image reference return (original_ref, original_ref) - def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) -> str: + def _extract_and_process_images( + self, text: str, info: dict[str, Any], headers: dict[int, dict] | None = None, **kwargs + ) -> str: """ Extract all images from markdown text and process them using ImageParser in parallel. Replaces image references with extracted text content. @@ -196,6 +226,7 @@ def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) Args: text: Markdown text containing image references info: Dictionary containing user_id and session_id + headers: Optional dictionary mapping line numbers to header info **kwargs: Additional parameters for ImageParser Returns: @@ -219,7 +250,13 @@ def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) for match in image_matches: image_url = match.group(2) original_ref = match.group(0) - tasks.append((image_url, original_ref)) + image_position = match.start() + + header_context = None + if headers: + header_context = self._get_header_context(text, image_position, headers) + + tasks.append((image_url, original_ref, header_context)) # Process images in parallel replacements = {} @@ -228,9 +265,14 @@ def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) with ContextThreadPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit( - self._process_single_image, image_url, original_ref, info, **kwargs + self._process_single_image, + image_url, + original_ref, + info, + header_context, + **kwargs, ): (image_url, original_ref) - for image_url, original_ref in tasks + for image_url, original_ref, header_context in tasks } # Collect results with progress tracking @@ -648,9 +690,18 @@ def parse_fine( ) if not parsed_text: return [] + + # Extract markdown headers if applicable + headers = {} + if is_markdown: + headers = self._extract_markdown_headers(parsed_text) + logger.info(f"[FileContentParser] Extracted {len(headers)} headers from markdown") + # Extract and process images from parsed_text if is_markdown and parsed_text and self.image_parser: - parsed_text = self._extract_and_process_images(parsed_text, info, **kwargs) + parsed_text = self._extract_and_process_images( + parsed_text, info, headers=headers if headers else None, **kwargs + ) # Extract info fields if not info: @@ -824,3 +875,94 @@ def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem: chunk_idx=None, ) ] + + def _extract_markdown_headers(self, text: str) -> dict[int, dict]: + """ + Extract markdown headers and their positions. + + Args: + text: Markdown text to parse + """ + if not text: + return {} + + headers = {} + # Pattern to match markdown headers: # Title, ## Title, etc. + header_pattern = r"^(#{1,6})\s+(.+)$" + + lines = text.split("\n") + char_position = 0 + + for line_num, line in enumerate(lines): + # Match header pattern (must be at start of line) + match = re.match(header_pattern, line.strip()) + if match: + level = len(match.group(1)) # Number of # symbols (1-6) + title = match.group(2).strip() # Extract title text + + # Store header info with its position + headers[line_num] = {"level": level, "title": title, "position": char_position} + + logger.debug(f"[FileContentParser] Found H{level} at line {line_num}: {title}") + + # Update character position for next line (+1 for newline character) + char_position += len(line) + 1 + + logger.info(f"[FileContentParser] Extracted {len(headers)} headers from markdown") + return headers + + def _get_header_context( + self, text: str, image_position: int, headers: dict[int, dict] + ) -> list[str]: + """ + Get all header levels above an image position in hierarchical order. + + Finds the image's line number, then identifies all preceding headers + and constructs the hierarchical path to the image location. + + Args: + text: Full markdown text + image_position: Character position of the image in text + headers: Dict of headers from _extract_markdown_headers + """ + if not headers: + return [] + + # Find the line number corresponding to the image position + lines = text.split("\n") + char_count = 0 + image_line = 0 + + for i, line in enumerate(lines): + if char_count >= image_position: + image_line = i + break + char_count += len(line) + 1 # +1 for newline + + # Filter headers that appear before the image + preceding_headers = { + line_num: info for line_num, info in headers.items() if line_num < image_line + } + + if not preceding_headers: + return [] + + # Build hierarchical header stack + header_stack = [] + + for line_num in sorted(preceding_headers.keys()): + header = preceding_headers[line_num] + level = header["level"] + title = header["title"] + + # Pop headers of same or lower level + while header_stack and header_stack[-1]["level"] >= level: + removed = header_stack.pop() + logger.debug(f"[FileContentParser] Popped H{removed['level']}: {removed['title']}") + + # Push current header onto stack + header_stack.append({"level": level, "title": title}) + + # Return titles in order + result = [h["title"] for h in header_stack] + return result diff --git a/src/memos/mem_reader/read_multi_modal/utils.py b/src/memos/mem_reader/read_multi_modal/utils.py index a6d910e54..6e25fca12 100644 --- a/src/memos/mem_reader/read_multi_modal/utils.py +++ b/src/memos/mem_reader/read_multi_modal/utils.py @@ -346,7 +346,8 @@ def detect_lang(text): r"\b(user|assistant|query|answer)\s*:", "", cleaned_text, flags=re.IGNORECASE ) cleaned_text = re.sub(r"\[[\d\-:\s]+\]", "", cleaned_text) - + # remove URLs to prevent the dilution of Chinese characters + cleaned_text = re.sub(r'https?://[^\s<>"{}|\\^`\[\]]+',"", cleaned_text) # extract chinese characters chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" chinese_chars = re.findall(chinese_pattern, cleaned_text) diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 46770758d..63476c7cc 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -45,6 +45,43 @@ class SourceMessage(BaseModel): model_config = ConfigDict(extra="allow") +class ArchivedTextualMemory(BaseModel): + """ + This is a light-weighted class for storing archived versions of memories. + + When an existing memory item needs to be updated due to conflict/duplicate with new memory contents, + its previous contents will be preserved, in 2 places: + 1. ArchivedTextualMemory, which only contains minimal information, like memory content and create time, + stored in the 'history' field of the original node. + 2. A new memory node, storing full original information including sources and embedding, + and referenced by 'archived_memory_id'. + """ + + version: int = Field( + default=1, + description="The version of the archived memory content. Will be compared to the version of the active memory item(in Metadata)", + ) + is_fast: bool = Field( + default=False, + description="Whether this archived memory was created in fast mode, thus raw.", + ) + memory: str | None = Field( + default_factory=lambda: "", description="The content of the archived version of the memory." + ) + update_type: Literal["conflict", "duplicate", "extract", "unrelated"] = Field( + default="unrelated", + description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`).", + ) + archived_memory_id: str | None = Field( + default=None, + description="Link to a memory node with status='archived', storing full original information, including sources and embedding.", + ) + created_at: str | None = Field( + default_factory=lambda: datetime.now().isoformat(), + description="The time the memory was created.", + ) + + class TextualMemoryMetadata(BaseModel): """Metadata for a memory item. @@ -60,9 +97,29 @@ class TextualMemoryMetadata(BaseModel): default=None, description="The ID of the session during which the memory was created. Useful for tracking context in conversations.", ) - status: Literal["activated", "archived", "deleted"] | None = Field( + status: Literal["activated", "resolving", "archived", "deleted"] | None = Field( default="activated", - description="The status of the memory, e.g., 'activated', 'archived', 'deleted'.", + description="The status of the memory, e.g., 'activated', 'resolving'(updating with conflicting/duplicating new memories), 'archived', 'deleted'.", + ) + is_fast: bool | None = Field( + default=None, + description="Whether or not the memory was created in fast mode, carrying raw memory contents that haven't been edited by llms yet.", + ) + evolve_to: list[str] | None = Field( + default_factory=list, + description="Only valid if a node was once a (raw)fast node. Recording which new memory nodes it 'evolves' to after llm extraction.", + ) + version: int | None = Field( + default=None, + description="The version of the memory. Will be incremented when the memory is updated.", + ) + history: list[ArchivedTextualMemory] | None = Field( + default_factory=list, + description="Storing the archived versions of the memory. Only preserving core information of each version.", + ) + working_binding: str | None = Field( + default=None, + description="The working memory id binding of the (fast) memory.", ) type: str | None = Field(default=None) key: str | None = Field(default=None, description="Memory key or title.") diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py new file mode 100644 index 000000000..1afdc9281 --- /dev/null +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -0,0 +1,166 @@ +import logging + +from typing import Literal + +from memos.context.context import ContextThreadPoolExecutor +from memos.extras.nli_model.client import NLIClient +from memos.extras.nli_model.types import NLIResult +from memos.graph_dbs.base import BaseGraphDB +from memos.memories.textual.item import ArchivedTextualMemory, TextualMemoryItem + + +logger = logging.getLogger(__name__) + +CONFLICT_MEMORY_TITLE = "[possibly conflicting memories]" +DUPLICATE_MEMORY_TITLE = "[possibly duplicate memories]" + + +def _append_related_content( + new_item: TextualMemoryItem, duplicates: list[str], conflicts: list[str] +) -> None: + """ + Append duplicate and conflict memory contents to the new item's memory text, + truncated to avoid excessive length. + """ + max_per_item_len = 200 + max_section_len = 1000 + + def _format_section(title: str, items: list[str]) -> str: + if not items: + return "" + + section_content = "" + for mem in items: + # Truncate individual item + snippet = mem[:max_per_item_len] + "..." if len(mem) > max_per_item_len else mem + # Check total section length + if len(section_content) + len(snippet) + 5 > max_section_len: + section_content += "\n- ... (more items truncated)" + break + section_content += f"\n- {snippet}" + + return f"\n\n{title}:{section_content}" + + append_text = "" + append_text += _format_section(CONFLICT_MEMORY_TITLE, conflicts) + append_text += _format_section(DUPLICATE_MEMORY_TITLE, duplicates) + + if append_text: + new_item.memory += append_text + + +def _detach_related_content(new_item: TextualMemoryItem) -> None: + """ + Detach duplicate and conflict memory contents from the new item's memory text. + """ + markers = [f"\n\n{CONFLICT_MEMORY_TITLE}:", f"\n\n{DUPLICATE_MEMORY_TITLE}:"] + + cut_index = -1 + for marker in markers: + idx = new_item.memory.find(marker) + if idx != -1 and (cut_index == -1 or idx < cut_index): + cut_index = idx + + if cut_index != -1: + new_item.memory = new_item.memory[:cut_index] + + return + + +class MemoryHistoryManager: + def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: + """ + Initialize the MemoryHistoryManager. + + Args: + nli_client: NLIClient for conflict/duplicate detection. + graph_db: GraphDB instance for marking operations during history management. + """ + self.nli_client = nli_client + self.graph_db = graph_db + + def resolve_history_via_nli( + self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] + ) -> list[TextualMemoryItem]: + """ + Detect relationships (Duplicate/Conflict) between the new item and related items using NLI, + and attach them as history to the new fast item. + + Args: + new_item: The new memory item being added. + related_items: Existing memory items that might be related. + + Returns: + List of duplicate or conflicting memory items judged by the NLI service. + """ + if not related_items: + return [] + + # 1. Call NLI + nli_results = self.nli_client.compare_one_to_many( + new_item.memory, [r.memory for r in related_items] + ) + + # 2. Process results and attach to history + duplicate_memories = [] + conflict_memories = [] + + for r_item, nli_res in zip(related_items, nli_results, strict=False): + if nli_res == NLIResult.DUPLICATE: + update_type = "duplicate" + duplicate_memories.append(r_item.memory) + elif nli_res == NLIResult.CONTRADICTION: + update_type = "conflict" + conflict_memories.append(r_item.memory) + else: + update_type = "unrelated" + + # Safely get created_at, fallback to updated_at + created_at = getattr(r_item.metadata, "created_at", None) or r_item.metadata.updated_at + + archived = ArchivedTextualMemory( + version=r_item.metadata.version or 1, + is_fast=r_item.metadata.is_fast or False, + memory=r_item.memory, + update_type=update_type, + archived_memory_id=r_item.id, + created_at=created_at, + ) + new_item.metadata.history.append(archived) + logger.info( + f"[MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" + ) + + # 3. Concat duplicate/conflict memories to new_item.memory + # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss. + _append_related_content(new_item, duplicate_memories, conflict_memories) + + return duplicate_memories + conflict_memories + + def mark_memory_status( + self, + memory_items: list[TextualMemoryItem], + status: Literal["activated", "resolving", "archived", "deleted"], + ) -> None: + """ + Support status marking operations during history management. Common usages are: + 1. Mark conflict/duplicate old memories' status as "resolving", + to make them invisible to /search api, but still visible for PreUpdateRetriever. + 2. Mark resolved memories' status as "activated", to restore their visibility. + """ + # Execute the actual marking operation - in db. + with ContextThreadPoolExecutor() as executor: + futures = [] + for mem in memory_items: + futures.append( + executor.submit( + self.graph_db.update_node, + id=mem.id, + fields={"status": status}, + ) + ) + + # Wait for all tasks to complete and raise any exceptions + for future in futures: + future.result() + return diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py new file mode 100644 index 000000000..46cf3a1f6 --- /dev/null +++ b/tests/memories/textual/test_history_manager.py @@ -0,0 +1,137 @@ +import uuid + +from unittest.mock import MagicMock + +import pytest + +from memos.extras.nli_model.client import NLIClient +from memos.extras.nli_model.types import NLIResult +from memos.graph_dbs.base import BaseGraphDB +from memos.memories.textual.item import ( + TextualMemoryItem, + TextualMemoryMetadata, +) +from memos.memories.textual.tree_text_memory.organize.history_manager import ( + MemoryHistoryManager, + _append_related_content, + _detach_related_content, +) + + +@pytest.fixture +def mock_nli_client(): + client = MagicMock(spec=NLIClient) + return client + + +@pytest.fixture +def mock_graph_db(): + return MagicMock(spec=BaseGraphDB) + + +@pytest.fixture +def history_manager(mock_nli_client, mock_graph_db): + return MemoryHistoryManager(nli_client=mock_nli_client, graph_db=mock_graph_db) + + +def test_detach_related_content(): + original_memory = "This is the original memory content." + item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + + duplicates = ["Duplicate 1", "Duplicate 2"] + conflicts = ["Conflict 1", "Conflict 2"] + + # 1. Append content + _append_related_content(item, duplicates, conflicts) + + # Verify content was appended + assert item.memory != original_memory + assert "[possibly conflicting memories]" in item.memory + assert "[possibly duplicate memories]" in item.memory + assert "Duplicate 1" in item.memory + assert "Conflict 1" in item.memory + + # 2. Detach content + _detach_related_content(item) + + # 3. Verify content is restored + assert item.memory == original_memory + + +def test_detach_only_conflicts(): + original_memory = "Original memory." + item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + + duplicates = [] + conflicts = ["Conflict A"] + + _append_related_content(item, duplicates, conflicts) + assert "Conflict A" in item.memory + assert "Duplicate" not in item.memory + + _detach_related_content(item) + assert item.memory == original_memory + + +def test_detach_only_duplicates(): + original_memory = "Original memory." + item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + + duplicates = ["Duplicate A"] + conflicts = [] + + _append_related_content(item, duplicates, conflicts) + assert "Duplicate A" in item.memory + assert "Conflict" not in item.memory + + _detach_related_content(item) + assert item.memory == original_memory + + +def test_truncation(history_manager, mock_nli_client): + # Setup + new_item = TextualMemoryItem(memory="Test") + long_memory = "A" * 300 + related_item = TextualMemoryItem(memory=long_memory) + + mock_nli_client.compare_one_to_many.return_value = [NLIResult.DUPLICATE] + + # Action + history_manager.resolve_history_via_nli(new_item, [related_item]) + + # Assert + assert "possibly duplicate memories" in new_item.memory + assert "..." in new_item.memory # Should be truncated + assert len(new_item.memory) < 1000 # Ensure reasonable length + + +def test_empty_related_items(history_manager, mock_nli_client): + new_item = TextualMemoryItem(memory="Test") + history_manager.resolve_history_via_nli(new_item, []) + + mock_nli_client.compare_one_to_many.assert_not_called() + assert new_item.metadata.history is None or len(new_item.metadata.history) == 0 + + +def test_mark_memory_status(history_manager, mock_graph_db): + # Setup + id1 = uuid.uuid4().hex + id2 = uuid.uuid4().hex + id3 = uuid.uuid4().hex + items = [ + TextualMemoryItem(memory="M1", id=id1), + TextualMemoryItem(memory="M2", id=id2), + TextualMemoryItem(memory="M3", id=id3), + ] + status = "resolving" + + # Action + history_manager.mark_memory_status(items, status) + + # Assert + assert mock_graph_db.update_node.call_count == 3 + + # Verify we called it correctly + mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}) + mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}) + mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status})