diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index f968ea7b9..7af3afe74 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -308,6 +308,7 @@ def init_server() -> dict[str, Any]: mem_reader=mem_reader, searcher=searcher, reranker=feedback_reranker, + pref_mem=pref_mem, ) # Initialize Scheduler diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 0b3fc3846..fad15a7cd 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -2,6 +2,7 @@ import difflib import json import re +import uuid from datetime import datetime from typing import TYPE_CHECKING, Any @@ -33,6 +34,7 @@ if TYPE_CHECKING: + from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.templates.mem_feedback_prompts import ( FEEDBACK_ANSWER_PROMPT, @@ -90,6 +92,7 @@ def __init__(self, config: MemFeedbackConfig): self.stopword_manager = StopwordManager self.searcher: Searcher = None self.reranker = None + self.pref_mem: SimplePreferenceTextMemory = None self.DB_IDX_READY = False @require_python_package( @@ -115,7 +118,7 @@ def _retry_db_operation(self, operation): return operation() except Exception as e: logger.error( - f"[1223 Feedback Core: _retry_db_operation] DB operation failed: {e}", exc_info=True + f"[1224 Feedback Core: _retry_db_operation] DB operation failed: {e}", exc_info=True ) raise @@ -129,7 +132,7 @@ def _batch_embed(self, texts: list[str], embed_bs: int = 5): results.extend(self._embed_once(batch)) except Exception as e: logger.error( - f"[1223 Feedback Core: process_feedback_core] Embedding batch failed, Cover with all zeros: {len(batch)} entries: {e}" + f"[1224 Feedback Core: process_feedback_core] Embedding batch failed, Cover with all zeros: {len(batch)} entries: {e}" ) results.extend([[0.0] * dim for _ in range(len(batch))]) return results @@ -145,7 +148,7 @@ def _pure_add(self, user_name: str, feedback_content: str, feedback_time: str, i lambda: self.memory_manager.add(to_add_memories, user_name=user_name, use_batch=False) ) logger.info( - f"[1223 Feedback Core: _pure_add] Pure added {len(added_ids)} memories for user {user_name}." + f"[1224 Feedback Core: _pure_add] Pure added {len(added_ids)} memories for user {user_name}." ) return { "record": { @@ -182,7 +185,7 @@ def _keyword_replace_judgement(self, feedback_content: str) -> dict | None: return judge_res else: logger.warning( - "[1223 Feedback Core: _feedback_judgement] feedback judgement failed, return []" + "[1224 Feedback Core: _feedback_judgement] feedback judgement failed, return []" ) return {} @@ -207,7 +210,7 @@ def _feedback_judgement( return judge_res else: logger.warning( - "[1223 Feedback Core: _feedback_judgement] feedback judgement failed, return []" + "[1224 Feedback Core: _feedback_judgement] feedback judgement failed, return []" ) return [] @@ -271,6 +274,14 @@ def _single_update_operation( """ Individual update operations """ + if "preference" in old_memory_item.metadata.__dict__: + logger.info( + f"[1224 Feedback Core: _single_update_operation] pref_memory: {old_memory_item.id}" + ) + return self._single_update_pref( + old_memory_item, new_memory_item, user_id, user_name, operation + ) + memory_type = old_memory_item.metadata.memory_type source_doc_id = ( old_memory_item.metadata.file_ids[0] @@ -281,6 +292,7 @@ def _single_update_operation( ) if operation and "text" in operation and operation["text"]: new_memory_item.memory = operation["text"] + new_memory_item.metadata.embedding = self._batch_embed([operation["text"]])[0] if memory_type == "WorkingMemory": fields = { @@ -317,6 +329,68 @@ def _single_update_operation( "origin_memory": old_memory_item.memory, } + def _single_update_pref( + self, + old_memory_item: TextualMemoryItem, + new_memory_item: TextualMemoryItem, + user_id: str, + user_name: str, + operation: dict, + ): + """update preference memory""" + + feedback_context = new_memory_item.memory + if operation and "text" in operation and operation["text"]: + new_memory_item.memory = operation["text"] + new_memory_item.metadata.embedding = self._batch_embed([operation["text"]])[0] + + to_add_memory = old_memory_item.model_copy(deep=True) + to_add_memory.metadata.key = new_memory_item.metadata.key + to_add_memory.metadata.tags = new_memory_item.metadata.tags + to_add_memory.memory = new_memory_item.memory + to_add_memory.metadata.preference = new_memory_item.memory + to_add_memory.metadata.embedding = new_memory_item.metadata.embedding + + to_add_memory.metadata.user_id = new_memory_item.metadata.user_id + to_add_memory.metadata.original_text = old_memory_item.memory + to_add_memory.metadata.covered_history = old_memory_item.id + + to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = ( + datetime.now().isoformat() + ) + to_add_memory.metadata.context_summary = ( + old_memory_item.metadata.context_summary + " \n" + feedback_context + ) + + # add new memory + to_add_memory.id = str(uuid.uuid4()) + added_ids = self._retry_db_operation(lambda: self.pref_mem.add([to_add_memory])) + # delete + deleted_id = old_memory_item.id + collection_name = old_memory_item.metadata.preference_type + self._retry_db_operation( + lambda: self.pref_mem.delete_with_collection_name(collection_name, [deleted_id]) + ) + # add archived + old_memory_item.metadata.status = "archived" + old_memory_item.metadata.original_text = "archived" + old_memory_item.metadata.embedding = [0.0] * 1024 + + archived_ids = self._retry_db_operation(lambda: self.pref_mem.add([old_memory_item])) + + logger.info( + f"[Memory Feedback UPDATE Pref] New Add:{added_ids!s} | Set archived:{archived_ids!s}" + ) + + return { + "id": to_add_memory.id, + "text": new_memory_item.memory, + "source_doc_id": "", + "archived_id": old_memory_item.id, + "origin_memory": old_memory_item.memory, + "type": "preference", + } + def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> set[str]: """Delete working memory bindings""" bindings_to_delete = extract_working_binding_ids(mem_items) @@ -334,11 +408,11 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> self.graph_store.delete_node(mid, user_name=user_name) logger.info( - f"[1223 Feedback Core:_del_working_binding] Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" + f"[1224 Feedback Core:_del_working_binding] Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" ) except Exception as e: logger.warning( - f"[1223 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" + f"[1224 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" ) def semantics_feedback( @@ -355,13 +429,12 @@ def semantics_feedback( lang = detect_lang("".join(memory_item.memory)) template = FEEDBACK_PROMPT_DICT["compare"][lang] if current_memories == []: - # retrieve feedback - feedback_retrieved = self._retrieve(memory_item.memory, info=info, user_name=user_name) - - # retrieve question + # retrieve last_user_index = max(i for i, d in enumerate(chat_history_list) if d["role"] == "user") last_qa = " ".join([item["content"] for item in chat_history_list[last_user_index:]]) supplementary_retrieved = self._retrieve(last_qa, info=info, user_name=user_name) + feedback_retrieved = self._retrieve(memory_item.memory, info=info, user_name=user_name) + ids = [] for item in feedback_retrieved + supplementary_retrieved: if item.id not in ids: @@ -385,9 +458,14 @@ def semantics_feedback( with ContextThreadPoolExecutor(max_workers=10) as executor: future_to_chunk_idx = {} for chunk in memory_chunks: - current_memories_str = "\n".join( - [f"{item.id}: {item.memory}" for item in chunk] - ) + chunk_list = [] + for item in chunk: + if "preference" in item.metadata.__dict__: + chunk_list.append(f"{item.id}: {item.metadata.preference}") + else: + chunk_list.append(f"{item.id}: {item.memory}") + current_memories_str = "\n".join(chunk_list) + prompt = template.format( now_time=now_time, current_memories=current_memories_str, @@ -408,7 +486,7 @@ def semantics_feedback( all_operations.extend(chunk_operations["operations"]) except Exception as e: logger.error( - f"[1223 Feedback Core: semantics_feedback] Operation failed: {e}" + f"[1224 Feedback Core: semantics_feedback] Operation failed: {e}" ) standard_operations = self.standard_operations(all_operations, current_memories) @@ -458,7 +536,7 @@ def semantics_feedback( update_results.append(result) except Exception as e: logger.error( - f"[1223 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", + f"[1224 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", exc_info=True, ) if update_results: @@ -486,7 +564,7 @@ def _feedback_memory( ] if filterd_ids: logger.warning( - f"[1223 Feedback Core: _feedback_memory] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." + f"[1224 Feedback Core: _feedback_memory] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." ) current_memories = [ @@ -518,7 +596,7 @@ def _feedback_memory( results[i] = node except Exception as e: logger.error( - f"[1223 Feedback Core: _feedback_memory] Error processing memory index {i}: {e}", + f"[1224 Feedback Core: _feedback_memory] Error processing memory index {i}: {e}", exc_info=True, ) mem_res = [r for r in results if r] @@ -542,13 +620,18 @@ def _info_comparison(self, memory: TextualMemoryItem, _info: dict, include_keys: record.append(info_v == mem_v) return all(record) - def _retrieve(self, query: str, info=None, top_k=100, user_name=None): + def _retrieve(self, query: str, info=None, top_k=20, user_name=None): """Retrieve memory items""" retrieved_mems = self.searcher.search( query, info=info, user_name=user_name, top_k=top_k, full_recall=True ) retrieved_mems = [item[0] for item in retrieved_mems if float(item[1]) > 0.01] - return retrieved_mems + + pref_info = {} + if "user_id" in info: + pref_info = {"user_id": info["user_id"]} + retrieved_prefs = self.pref_mem.search(query, top_k, pref_info) + return retrieved_mems + retrieved_prefs def _vec_query(self, new_memories_embedding: list[float], user_name=None): """Vector retrieval query""" @@ -577,7 +660,7 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None): if not retrieved_ids: logger.info( - f"[1223 Feedback Core: _vec_query] No similar memories found for embedding query for user {user_name}." + f"[1224 Feedback Core: _vec_query] No similar memories found for embedding query for user {user_name}." ) filterd_ids = [ @@ -585,7 +668,7 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None): ] if filterd_ids: logger.warning( - f"[1223 Feedback Core: _vec_query] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." + f"[1224 Feedback Core: _vec_query] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." ) return [ TextualMemoryItem(**item) @@ -639,9 +722,9 @@ def filter_fault_update(self, operations: list[dict]): ): all_judge.extend(judge_res["operations_judgement"]) except Exception as e: - logger.error(f"[1223 Feedback Core: filter_fault_update] Judgement failed: {e}") + logger.error(f"[1224 Feedback Core: filter_fault_update] Judgement failed: {e}") - logger.info(f"[1223 Feedback Core: filter_fault_update] LLM judgement: {all_judge}") + logger.info(f"[1224 Feedback Core: filter_fault_update] LLM judgement: {all_judge}") id2op = {item["id"]: item for item in updated_operations} valid_updates = [] for judge in all_judge: @@ -652,7 +735,7 @@ def filter_fault_update(self, operations: list[dict]): valid_updates.append(valid_update) logger.info( - f"[1223 Feedback Core: filter_fault_update] {len(updated_operations)} -> {len(valid_updates)}" + f"[1224 Feedback Core: filter_fault_update] {len(updated_operations)} -> {len(valid_updates)}" ) return valid_updates + [item for item in operations if item["operation"] != "UPDATE"] @@ -680,11 +763,11 @@ def correct_item(data): and "text" in data and "old_memory" in data and data["operation"].lower() == "update" - ) + ), "Invalid operation item" if not should_keep_update(data["text"], data["old_memory"]): logger.warning( - f"[1223 Feedback Core: semantics_feedback] Due to the excessive proportion of changes, skip update: {data}" + f"[1224 Feedback Core: correct_item] Due to the excessive proportion of changes, skip update: {data}" ) return None @@ -704,14 +787,14 @@ def correct_item(data): return data except Exception: logger.error( - f"[1223 Feedback Core: standard_operations] Error processing operation item: {data}", + f"[1224 Feedback Core: standard_operations] Error processing operation item: {data}", exc_info=True, ) return None dehallu_res = [correct_item(item) for item in operations] dehalluded_operations = [item for item in dehallu_res if item] - logger.info(f"[1223 Feedback Core: dehalluded_operations] {dehalluded_operations}") + logger.info(f"[1224 Feedback Core: dehalluded_operations] {dehalluded_operations}") # c add objects add_texts = [] @@ -725,7 +808,7 @@ def correct_item(data): elif item["operation"].lower() == "update": llm_operations.append(item) logger.info( - f"[1223 Feedback Core: deduplicate add] {len(dehalluded_operations)} -> {len(llm_operations)} memories" + f"[1224 Feedback Core: deduplicate add] {len(dehalluded_operations)} -> {len(llm_operations)} memories" ) # Update takes precedence over add @@ -739,7 +822,7 @@ def correct_item(data): ] if filtered_items: logger.info( - f"[1223 Feedback Core: semantics_feedback] Due to have update objects, skip add: {filtered_items}" + f"[1224 Feedback Core: semantics_feedback] Due to have update objects, skip add: {filtered_items}" ) return update_items else: @@ -787,7 +870,7 @@ def _doc_filter(self, doc_scope: str, memories: list[TextualMemoryItem]): memid for inscope_file in inscope_docs for memid in filename2_memid[inscope_file] ] logger.info( - f"[1223 Feedback Core: process_keyword_replace] These docs are in scope : {inscope_docs}, relared memids: {inscope_ids}" + f"[1224 Feedback Core: process_keyword_replace] These docs are in scope : {inscope_docs}, relared memids: {inscope_ids}" ) filter_memories = [mem for mem in memories if mem.id in inscope_ids] return filter_memories @@ -841,7 +924,7 @@ def process_keyword_replace( retrieved_memories = self._doc_filter(doc_scope, retrieved_memories) logger.info( - f"[1223 Feedback Core: process_keyword_replace] Keywords recalled memory for user {user_name}: {len(retrieved_ids)} memories | After filtering: {len(retrieved_memories)} memories." + f"[1224 Feedback Core: process_keyword_replace] Keywords recalled memory for user {user_name}: {len(retrieved_ids)} memories | After filtering: {len(retrieved_memories)} memories." ) if not retrieved_memories: @@ -926,7 +1009,7 @@ def check_validity(item): info.update({"user_id": user_id, "user_name": user_name, "session_id": session_id}) logger.info( - f"[1223 Feedback Core: process_feedback_core] Starting memory feedback process for user {user_name}" + f"[1224 Feedback Core: process_feedback_core] Starting memory feedback process for user {user_name}" ) # feedback keywords update kwp_judge = self._keyword_replace_judgement(feedback_content) @@ -959,7 +1042,7 @@ def check_validity(item): if not valid_feedback: logger.warning( - f"[1223 Feedback Core: process_feedback_core] No valid judgements for user {user_name}: {raw_judge}." + f"[1224 Feedback Core: process_feedback_core] No valid judgements for user {user_name}: {raw_judge}." ) return {"record": {"add": [], "update": []}} @@ -1007,13 +1090,13 @@ def check_validity(item): add_memories = mem_record["record"]["add"] update_memories = mem_record["record"]["update"] logger.info( - f"[1223 Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback | add {len(add_memories)} memories | update {len(update_memories)} memories for user {user_name}." + f"[1224 Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback | add {len(add_memories)} memories | update {len(update_memories)} memories for user {user_name}." ) return mem_record except Exception as e: logger.error( - f"[1223 Feedback Core: process_feedback_core] Error for user {user_name}: {e}" + f"[1224 Feedback Core: process_feedback_core] Error for user {user_name}: {e}" ) return {"record": {"add": [], "update": []}} diff --git a/src/memos/mem_feedback/simple_feedback.py b/src/memos/mem_feedback/simple_feedback.py index 429c2ea20..e32f939c7 100644 --- a/src/memos/mem_feedback/simple_feedback.py +++ b/src/memos/mem_feedback/simple_feedback.py @@ -4,6 +4,7 @@ from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.mem_feedback.feedback import MemFeedback from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -23,6 +24,7 @@ def __init__( mem_reader: SimpleStructMemReader, searcher: Searcher, reranker: BaseReranker, + pref_mem: SimplePreferenceTextMemory, ): self.llm = llm self.embedder = embedder @@ -31,5 +33,6 @@ def __init__( self.mem_reader = mem_reader self.searcher = searcher self.stopword_manager = StopwordManager + self.pref_mem = pref_mem self.reranker = reranker self.DB_IDX_READY = False diff --git a/src/memos/mem_feedback/utils.py b/src/memos/mem_feedback/utils.py index c32c12328..8cb7f97a3 100644 --- a/src/memos/mem_feedback/utils.py +++ b/src/memos/mem_feedback/utils.py @@ -48,8 +48,11 @@ def calculate_similarity(text1: str, text2: str) -> float: similarity = calculate_similarity(old_text, new_text) change_ratio = 1 - similarity + if change_ratio == float(0): + return False + if old_len < 200: - return change_ratio < 0.5 + return change_ratio < 0.7 else: return change_ratio < 0.2 diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index ba7b558fd..8fd60153d 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -418,6 +418,7 @@ def init_components() -> dict[str, Any]: mem_reader=mem_reader, searcher=searcher, reranker=feedback_reranker, + pref_mem=pref_mem, ) # Return all components as a dictionary for easy access and extension return {"naive_mem_cube": naive_mem_cube, "feedback_server": feedback_server} diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index e1bc0e72b..9e521158d 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -87,6 +87,9 @@ def search( Returns: list[TextualMemoryItem]: List of matching memories. """ + if not isinstance(search_filter, dict): + search_filter = {} + search_filter.update({"status": "activated"}) logger.info(f"search_filter for preference memory: {search_filter}") return self.retriever.retrieve(query, top_k, info, search_filter) diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py index 1f02132bb..ee37d638c 100644 --- a/src/memos/memories/textual/simple_preference.py +++ b/src/memos/memories/textual/simple_preference.py @@ -61,6 +61,9 @@ def search( Returns: list[TextualMemoryItem]: List of matching memories. """ + if not isinstance(search_filter, dict): + search_filter = {} + search_filter.update({"status": "activated"}) return self.retriever.retrieve(query, top_k, info, search_filter) def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: