diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index 94988295b..ca87d95d2 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -7,9 +7,13 @@ from typing import Any +from memos.log import get_logger from memos.templates.instruction_completion import instruct_completion +logger = get_logger(__name__) + + def to_iter(running: Any) -> list[Any]: """ Normalize running tasks to a list of task objects. @@ -29,7 +33,9 @@ def to_iter(running: Any) -> list[Any]: return list(running) if running else [] -def format_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]: +def format_memory_item( + memory_data: Any, include_embedding: bool = False, save_sources: bool = True +) -> dict[str, Any]: """ Format a single memory item for API response. @@ -49,7 +55,8 @@ def format_memory_item(memory_data: Any, include_embedding: bool = False) -> dic memory["ref_id"] = ref_id if not include_embedding: memory["metadata"]["embedding"] = [] - memory["metadata"]["sources"] = [] + if not save_sources: + memory["metadata"]["sources"] = [] memory["metadata"]["usage"] = [] memory["metadata"]["ref_id"] = ref_id memory["metadata"]["id"] = memory_id @@ -125,3 +132,96 @@ def post_process_textual_mem( } ) return memories_result + + +def separate_knowledge_and_conversation_mem(memories: list[dict[str, Any]]): + """ + Separate knowledge and conversation memories from retrieval results. + """ + knowledge_mem = [] + conversation_mem = [] + for item in memories: + sources = item["metadata"]["sources"] + if ( + len(sources) > 0 + and "type" in sources[0] + and sources[0]["type"] == "file" + and "content" in sources[0] + and sources[0]["content"] != "" + ): # TODO change to memory_type + knowledge_mem.append(item) + else: + conversation_mem.append(item) + + logger.info( + f"Retrieval results number of knowledge_mem: {len(knowledge_mem)}, conversation_mem: {len(conversation_mem)}" + ) + return knowledge_mem, conversation_mem + + +def rerank_knowledge_mem( + reranker: Any, + query: str, + text_mem: list[dict[str, Any]], + top_k: int, + file_mem_proportion: float = 0.5, +) -> list[dict[str, Any]]: + """ + Rerank knowledge memories and keep conversation memories. + """ + memid2cubeid = {} + memories_list = [] + for memory_group in text_mem: + cube_id = memory_group["cube_id"] + memories = memory_group["memories"] + memories_list.extend(memories) + for memory in memories: + memid2cubeid[memory["id"]] = cube_id + + knowledge_mem, conversation_mem = separate_knowledge_and_conversation_mem(memories_list) + knowledge_mem_top_k = max(int(top_k * file_mem_proportion), int(top_k - len(conversation_mem))) + reranked_knowledge_mem = reranker.rerank(query, knowledge_mem, top_k=len(knowledge_mem)) + reranked_knowledge_mem = [item[0] for item in reranked_knowledge_mem] + + # TODO revoke sources replace memory value + for item in reranked_knowledge_mem: + item["memory"] = item["metadata"]["sources"][0]["content"] + item["metadata"]["sources"] = [] + + for item in conversation_mem: + item["metadata"]["sources"] = [] + + # deduplicate: remove items with duplicate memory content + original_count = len(reranked_knowledge_mem) + seen_memories = set[Any]() + deduplicated_knowledge_mem = [] + for item in reranked_knowledge_mem: + memory_content = item.get("memory", "") + if memory_content and memory_content not in seen_memories: + seen_memories.add(memory_content) + deduplicated_knowledge_mem.append(item) + deduplicated_count = len(deduplicated_knowledge_mem) + logger.info( + f"After filtering duplicate knowledge base text from sources, count changed from {original_count} to {deduplicated_count}" + ) + + reranked_knowledge_mem = deduplicated_knowledge_mem[:knowledge_mem_top_k] + conversation_mem_top_k = top_k - len(reranked_knowledge_mem) + cubeid2memories = {} + text_mem_res = [] + + for memory in reranked_knowledge_mem + conversation_mem[:conversation_mem_top_k]: + cube_id = memid2cubeid[memory["id"]] + if cube_id not in cubeid2memories: + cubeid2memories[cube_id] = [] + cubeid2memories[cube_id].append(memory) + + for cube_id, memories in cubeid2memories.items(): + text_mem_res.append( + { + "cube_id": cube_id, + "memories": memories, + } + ) + + return text_mem_res diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index ef829d757..14bb8eec5 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -204,7 +204,7 @@ def handle_get_memories( preferences, total_pref = naive_mem_cube.pref_mem.get_memory_by_filter( filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size ) - format_preferences = [format_memory_item(item) for item in preferences] + format_preferences = [format_memory_item(item, save_sources=False) for item in preferences] return GetMemoryResponse( message="Memories retrieved successfully", diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 3774410dc..32a970b22 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -5,9 +5,12 @@ using dependency injection for better modularity and testability. """ +import time + from typing import Any from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.handlers.formatters_handler import rerank_knowledge_mem from memos.api.product_models import APISearchRequest, SearchResponse from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( @@ -69,6 +72,18 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse # Restore original top_k for downstream logic or response metadata search_req.top_k = original_top_k + start_time = time.time() + text_mem = results["text_mem"] + results["text_mem"] = rerank_knowledge_mem( + self.reranker, + query=search_req.query, + text_mem=text_mem, + top_k=original_top_k, + file_mem_proportion=0.5, + ) + rerank_time = time.time() - start_time + self.logger.info(f"[Knowledge_replace_memory_time] Rerank time: {rerank_time} seconds") + self.logger.info( f"[SearchHandler] Final search results: count={len(results)} results={results}" ) diff --git a/src/memos/reranker/concat.py b/src/memos/reranker/concat.py index 502af18b6..b39496a1c 100644 --- a/src/memos/reranker/concat.py +++ b/src/memos/reranker/concat.py @@ -83,10 +83,18 @@ def concat_original_source( merge_field = ["sources"] if rerank_source is None else rerank_source.split(",") documents = [] for item in graph_results: - memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m + m = item.get("memory") if isinstance(item, dict) else getattr(item, "memory", None) + + memory = _TAG1.sub("", m) if isinstance(m, str) else m + sources = [] for field in merge_field: - source = getattr(item.metadata, field, None) + if isinstance(item, dict): + metadata = item.get("metadata", {}) + source = metadata.get(field) if isinstance(metadata, dict) else None + else: + source = getattr(item.metadata, field, None) if hasattr(item, "metadata") else None + if source is None: continue sources.append((memory, source)) diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index 4e9054f1e..32034cf6d 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -129,7 +129,7 @@ def __init__( def rerank( self, query: str, - graph_results: list[TextualMemoryItem], + graph_results: list[TextualMemoryItem] | list[dict[str, Any]], top_k: int, search_priority: dict | None = None, **kwargs, @@ -164,11 +164,15 @@ def rerank( if self.rerank_source: documents = concat_original_source(graph_results, self.rerank_source) else: - documents = [ - (_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m) - for item in graph_results - ] - documents = [d for d in documents if isinstance(d, str) and d] + documents = [] + filtered_graph_results = [] + for item in graph_results: + m = item.get("memory") if isinstance(item, dict) else getattr(item, "memory", None) + + if isinstance(m, str) and m: + documents.append(_TAG1.sub("", m)) + filtered_graph_results.append(item) + graph_results = filtered_graph_results logger.info(f"[HTTPBGERerankerSample] query: {query} , documents: {documents[:5]}...")