diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index fad15a7cd..15d7c336a 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -5,7 +5,7 @@ import uuid from datetime import datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -18,6 +18,8 @@ from memos.log import get_logger from memos.mem_feedback.base import BaseMemFeedback from memos.mem_feedback.utils import ( + extract_bracket_content, + extract_square_brackets_content, general_split_into_chunks, make_mem_item, should_keep_update, @@ -118,7 +120,7 @@ def _retry_db_operation(self, operation): return operation() except Exception as e: logger.error( - f"[1224 Feedback Core: _retry_db_operation] DB operation failed: {e}", exc_info=True + f"[0107 Feedback Core: _retry_db_operation] DB operation failed: {e}", exc_info=True ) raise @@ -132,7 +134,7 @@ def _batch_embed(self, texts: list[str], embed_bs: int = 5): results.extend(self._embed_once(batch)) except Exception as e: logger.error( - f"[1224 Feedback Core: process_feedback_core] Embedding batch failed, Cover with all zeros: {len(batch)} entries: {e}" + f"[0107 Feedback Core: process_feedback_core] Embedding batch failed, Cover with all zeros: {len(batch)} entries: {e}" ) results.extend([[0.0] * dim for _ in range(len(batch))]) return results @@ -148,7 +150,7 @@ def _pure_add(self, user_name: str, feedback_content: str, feedback_time: str, i lambda: self.memory_manager.add(to_add_memories, user_name=user_name, use_batch=False) ) logger.info( - f"[1224 Feedback Core: _pure_add] Pure added {len(added_ids)} memories for user {user_name}." + f"[0107 Feedback Core: _pure_add] Pure added {len(added_ids)} memories for user {user_name}." ) return { "record": { @@ -180,12 +182,12 @@ def _keyword_replace_judgement(self, feedback_content: str) -> dict | None: user_feedback=feedback_content, ) - judge_res = self._get_llm_response(prompt) + judge_res = self._get_llm_response(prompt, load_type="bracket") if judge_res: return judge_res else: logger.warning( - "[1224 Feedback Core: _feedback_judgement] feedback judgement failed, return []" + "[0107 Feedback Core: _feedback_judgement] feedback judgement failed, return []" ) return {} @@ -205,12 +207,12 @@ def _feedback_judgement( feedback_time=feedback_time, ) - judge_res = self._get_llm_response(prompt) + judge_res = self._get_llm_response(prompt, load_type="square_bracket") if judge_res: return judge_res else: logger.warning( - "[1224 Feedback Core: _feedback_judgement] feedback judgement failed, return []" + "[0107 Feedback Core: _feedback_judgement] feedback judgement failed, return []" ) return [] @@ -276,7 +278,7 @@ def _single_update_operation( """ if "preference" in old_memory_item.metadata.__dict__: logger.info( - f"[1224 Feedback Core: _single_update_operation] pref_memory: {old_memory_item.id}" + f"[0107 Feedback Core: _single_update_operation] pref_memory: {old_memory_item.id}" ) return self._single_update_pref( old_memory_item, new_memory_item, user_id, user_name, operation @@ -408,11 +410,11 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> self.graph_store.delete_node(mid, user_name=user_name) logger.info( - f"[1224 Feedback Core:_del_working_binding] Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" + f"[0107 Feedback Core:_del_working_binding] Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" ) except Exception as e: logger.warning( - f"[1224 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" + f"[0107 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" ) def semantics_feedback( @@ -473,7 +475,7 @@ def semantics_feedback( chat_history=history_str, ) - future = executor.submit(self._get_llm_response, prompt) + future = executor.submit(self._get_llm_response, prompt, load_type="bracket") future_to_chunk_idx[future] = chunk for future in concurrent.futures.as_completed(future_to_chunk_idx): try: @@ -486,7 +488,7 @@ def semantics_feedback( all_operations.extend(chunk_operations["operations"]) except Exception as e: logger.error( - f"[1224 Feedback Core: semantics_feedback] Operation failed: {e}" + f"[0107 Feedback Core: semantics_feedback] Operation failed: {e}" ) standard_operations = self.standard_operations(all_operations, current_memories) @@ -536,7 +538,7 @@ def semantics_feedback( update_results.append(result) except Exception as e: logger.error( - f"[1224 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", + f"[0107 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", exc_info=True, ) if update_results: @@ -564,7 +566,7 @@ def _feedback_memory( ] if filterd_ids: logger.warning( - f"[1224 Feedback Core: _feedback_memory] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." + f"[0107 Feedback Core: _feedback_memory] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." ) current_memories = [ @@ -596,7 +598,7 @@ def _feedback_memory( results[i] = node except Exception as e: logger.error( - f"[1224 Feedback Core: _feedback_memory] Error processing memory index {i}: {e}", + f"[0107 Feedback Core: _feedback_memory] Error processing memory index {i}: {e}", exc_info=True, ) mem_res = [r for r in results if r] @@ -660,7 +662,7 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None): if not retrieved_ids: logger.info( - f"[1224 Feedback Core: _vec_query] No similar memories found for embedding query for user {user_name}." + f"[0107 Feedback Core: _vec_query] No similar memories found for embedding query for user {user_name}." ) filterd_ids = [ @@ -668,7 +670,7 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None): ] if filterd_ids: logger.warning( - f"[1224 Feedback Core: _vec_query] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." + f"[0107 Feedback Core: _vec_query] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." ) return [ TextualMemoryItem(**item) @@ -676,22 +678,41 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None): if "mode:fast" not in item["metadata"]["tags"] ] - def _get_llm_response(self, prompt: str, dsl: bool = True) -> dict: + def _get_llm_response( + self, + prompt: str, + dsl: bool = True, + load_type: Literal["bracket", "square_bracket"] | None = None, + ) -> dict: messages = [{"role": "user", "content": prompt}] + response_text = "" try: response_text = self.llm.generate(messages, temperature=0.3, timeout=60) - if dsl: + if not dsl: + return response_text + try: response_text = response_text.replace("```", "").replace("json", "") cleaned_text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", "", response_text) response_json = json.loads(cleaned_text) - else: - return response_text + return response_json + except (json.JSONDecodeError, ValueError) as e: + if load_type == "bracket": + response_json = extract_bracket_content(response_text) + return response_json + elif load_type == "square_bracket": + response_json = extract_square_brackets_content(response_text) + return response_json + else: + logger.error( + f"[Feedback Core LLM Error] Exception during chat generation: {e} | response_text: {response_text}" + ) + return None + except Exception as e: logger.error( f"[Feedback Core LLM Error] Exception during chat generation: {e} | response_text: {response_text}" ) - response_json = None - return response_json + return None def filter_fault_update(self, operations: list[dict]): """To address the randomness of large model outputs, it is necessary to conduct validity evaluation on the texts used for memory override operations.""" @@ -710,7 +731,7 @@ def filter_fault_update(self, operations: list[dict]): raw_operations_str = {"operations": chunk} prompt = template.format(raw_operations=str(raw_operations_str)) - future = executor.submit(self._get_llm_response, prompt) + future = executor.submit(self._get_llm_response, prompt, load_type="bracket") future_to_chunk_idx[future] = chunk for future in concurrent.futures.as_completed(future_to_chunk_idx): try: @@ -722,9 +743,9 @@ def filter_fault_update(self, operations: list[dict]): ): all_judge.extend(judge_res["operations_judgement"]) except Exception as e: - logger.error(f"[1224 Feedback Core: filter_fault_update] Judgement failed: {e}") + logger.error(f"[0107 Feedback Core: filter_fault_update] Judgement failed: {e}") - logger.info(f"[1224 Feedback Core: filter_fault_update] LLM judgement: {all_judge}") + logger.info(f"[0107 Feedback Core: filter_fault_update] LLM judgement: {all_judge}") id2op = {item["id"]: item for item in updated_operations} valid_updates = [] for judge in all_judge: @@ -735,7 +756,7 @@ def filter_fault_update(self, operations: list[dict]): valid_updates.append(valid_update) logger.info( - f"[1224 Feedback Core: filter_fault_update] {len(updated_operations)} -> {len(valid_updates)}" + f"[0107 Feedback Core: filter_fault_update] {len(updated_operations)} -> {len(valid_updates)}" ) return valid_updates + [item for item in operations if item["operation"] != "UPDATE"] @@ -767,7 +788,7 @@ def correct_item(data): if not should_keep_update(data["text"], data["old_memory"]): logger.warning( - f"[1224 Feedback Core: correct_item] Due to the excessive proportion of changes, skip update: {data}" + f"[0107 Feedback Core: correct_item] Due to the excessive proportion of changes, skip update: {data}" ) return None @@ -787,14 +808,14 @@ def correct_item(data): return data except Exception: logger.error( - f"[1224 Feedback Core: standard_operations] Error processing operation item: {data}", + f"[0107 Feedback Core: standard_operations] Error processing operation item: {data}", exc_info=True, ) return None dehallu_res = [correct_item(item) for item in operations] dehalluded_operations = [item for item in dehallu_res if item] - logger.info(f"[1224 Feedback Core: dehalluded_operations] {dehalluded_operations}") + logger.info(f"[0107 Feedback Core: dehalluded_operations] {dehalluded_operations}") # c add objects add_texts = [] @@ -808,7 +829,7 @@ def correct_item(data): elif item["operation"].lower() == "update": llm_operations.append(item) logger.info( - f"[1224 Feedback Core: deduplicate add] {len(dehalluded_operations)} -> {len(llm_operations)} memories" + f"[0107 Feedback Core: deduplicate add] {len(dehalluded_operations)} -> {len(llm_operations)} memories" ) # Update takes precedence over add @@ -822,7 +843,7 @@ def correct_item(data): ] if filtered_items: logger.info( - f"[1224 Feedback Core: semantics_feedback] Due to have update objects, skip add: {filtered_items}" + f"[0107 Feedback Core: semantics_feedback] Due to have update objects, skip add: {filtered_items}" ) return update_items else: @@ -870,7 +891,7 @@ def _doc_filter(self, doc_scope: str, memories: list[TextualMemoryItem]): memid for inscope_file in inscope_docs for memid in filename2_memid[inscope_file] ] logger.info( - f"[1224 Feedback Core: process_keyword_replace] These docs are in scope : {inscope_docs}, relared memids: {inscope_ids}" + f"[0107 Feedback Core: process_keyword_replace] These docs are in scope : {inscope_docs}, relared memids: {inscope_ids}" ) filter_memories = [mem for mem in memories if mem.id in inscope_ids] return filter_memories @@ -924,7 +945,7 @@ def process_keyword_replace( retrieved_memories = self._doc_filter(doc_scope, retrieved_memories) logger.info( - f"[1224 Feedback Core: process_keyword_replace] Keywords recalled memory for user {user_name}: {len(retrieved_ids)} memories | After filtering: {len(retrieved_memories)} memories." + f"[0107 Feedback Core: process_keyword_replace] Keywords recalled memory for user {user_name}: {len(retrieved_ids)} memories | After filtering: {len(retrieved_memories)} memories." ) if not retrieved_memories: @@ -1009,7 +1030,7 @@ def check_validity(item): info.update({"user_id": user_id, "user_name": user_name, "session_id": session_id}) logger.info( - f"[1224 Feedback Core: process_feedback_core] Starting memory feedback process for user {user_name}" + f"[0107 Feedback Core: process_feedback_core] Starting memory feedback process for user {user_name}" ) # feedback keywords update kwp_judge = self._keyword_replace_judgement(feedback_content) @@ -1042,7 +1063,7 @@ def check_validity(item): if not valid_feedback: logger.warning( - f"[1224 Feedback Core: process_feedback_core] No valid judgements for user {user_name}: {raw_judge}." + f"[0107 Feedback Core: process_feedback_core] No valid judgements for user {user_name}: {raw_judge}." ) return {"record": {"add": [], "update": []}} @@ -1090,13 +1111,13 @@ def check_validity(item): add_memories = mem_record["record"]["add"] update_memories = mem_record["record"]["update"] logger.info( - f"[1224 Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback | add {len(add_memories)} memories | update {len(update_memories)} memories for user {user_name}." + f"[0107 Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback | add {len(add_memories)} memories | update {len(update_memories)} memories for user {user_name}." ) return mem_record except Exception as e: logger.error( - f"[1224 Feedback Core: process_feedback_core] Error for user {user_name}: {e}" + f"[0107 Feedback Core: process_feedback_core] Error for user {user_name}: {e}" ) return {"record": {"add": [], "update": []}} diff --git a/src/memos/mem_feedback/utils.py b/src/memos/mem_feedback/utils.py index 8cb7f97a3..8e3b2f34c 100644 --- a/src/memos/mem_feedback/utils.py +++ b/src/memos/mem_feedback/utils.py @@ -1,3 +1,6 @@ +import json +import re + from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata @@ -147,3 +150,81 @@ def make_mem_item(text: str, **kwargs) -> TextualMemoryItem: info=info_, ), ) + + +def extract_bracket_content(text): + """ + Extract and parse JSON content enclosed in curly braces {} from text. + """ + # Strategy 1: Greedy match to capture the outermost complete brace pair + greedy_match = re.search(r"\{.*\}", text, re.DOTALL) + if greedy_match is None: + error_msg = f"No curly brace content found in text: {text}" + raise ValueError(error_msg) + + greedy_content = greedy_match.group(0) + + # Strategy 2: Non-greedy match to find all brace pairs, use the last one + non_greedy_matches = re.findall(r"\{.*?\}", text, re.DOTALL) + if not non_greedy_matches: + error_msg = f"No curly brace content found in text: {text}" + raise ValueError(error_msg) + + non_greedy_content = non_greedy_matches[-1] + + for content in [greedy_content, non_greedy_content]: + try: + parsed_data = json.loads(content) + return parsed_data + except json.JSONDecodeError: + continue + + for content in [greedy_content, non_greedy_content]: + try: + fixed_content = content.replace("{{", "{").replace("}}", "}") + parsed_data = json.loads(fixed_content) + return parsed_data + except json.JSONDecodeError: + continue + + error_msg = f"Failed to parse JSON content from curly braces. Text preview: {text}" + raise ValueError(error_msg) + + +def extract_square_brackets_content(text): + """ + Extract and parse JSON content enclosed in square brackets [] from text. + """ + # Strategy 1: Greedy match to capture the outermost complete bracket pair + greedy_match = re.search(r"\[.*\]", text, re.DOTALL) + if greedy_match is None: + error_msg = f"No square bracket content found in text: {text}" + raise ValueError(error_msg) + + greedy_content = greedy_match.group(0) + + # Strategy 2: Non-greedy match to find all bracket pairs, use the last one + non_greedy_matches = re.findall(r"\[.*?\]", text, re.DOTALL) + if not non_greedy_matches: + error_msg = f"No square bracket content found in text: {text}" + raise ValueError(error_msg) + + non_greedy_content = non_greedy_matches[-1] + + for content in [greedy_content, non_greedy_content]: + try: + parsed_data = json.loads(content) + return parsed_data + except json.JSONDecodeError: + continue + + for content in [greedy_content, non_greedy_content]: + try: + fixed_content = content.replace("{{", "{").replace("}}", "}") + parsed_data = json.loads(fixed_content) + return parsed_data + except json.JSONDecodeError: + continue + + error_msg = f"Failed to parse JSON content from square brackets. Text preview: {text}" + raise ValueError(error_msg)