Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 217 additions & 64 deletions src/memos/mem_reader/multi_modal_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader
from memos.mem_reader.utils import parse_json_result
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
from memos.templates.mem_reader_prompts import MEMORY_MERGE_PROMPT_EN, MEMORY_MERGE_PROMPT_ZH
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
from memos.types import MessagesType
from memos.utils import timed
Expand Down Expand Up @@ -316,7 +317,6 @@ def _get_llm_response(
custom_tags: list[str] | None = None,
sources: list | None = None,
prompt_type: str = "chat",
related_memories: str | None = None,
) -> dict:
"""
Override parent method to improve language detection by using actual text content
Expand All @@ -327,7 +327,6 @@ def _get_llm_response(
custom_tags: Optional custom tags
sources: Optional list of SourceMessage objects to extract text content from
prompt_type: Type of prompt to use ("chat" or "doc")
related_memories: related_memories in the graph

Returns:
LLM response dictionary
Expand Down Expand Up @@ -362,10 +361,7 @@ def _get_llm_response(
else:
template = PROMPT_DICT["chat"][lang]
examples = PROMPT_DICT["chat"][f"{lang}_example"]
related_memories_str = related_memories if related_memories is not None else ""
prompt = template.replace("${conversation}", mem_str).replace(
"${reference}", related_memories_str
)
prompt = template.replace("${conversation}", mem_str)

custom_tags_prompt = (
PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags))
Expand Down Expand Up @@ -398,6 +394,7 @@ def _get_llm_response(
],
"summary": mem_str,
}
logger.info(f"[MultiModalFine] Task {messages}, Result {response_json}")
return response_json

def _determine_prompt_type(self, sources: list) -> str:
Expand All @@ -418,6 +415,182 @@ def _determine_prompt_type(self, sources: list) -> str:

return prompt_type

def _get_maybe_merged_memory(
self,
extracted_memory_dict: dict,
mem_text: str,
sources: list,
**kwargs,
) -> dict:
"""
Check if extracted memory should be merged with similar existing memories.
If merge is needed, return merged memory dict with merged_from field.
Otherwise, return original memory dict.

Args:
extracted_memory_dict: The extracted memory dict from LLM response
mem_text: The memory text content
sources: Source messages for language detection
**kwargs: Additional parameters (merge_similarity_threshold, etc.)

Returns:
Memory dict (possibly merged) with merged_from field if merged
"""
# If no graph_db or user_name, return original
if not self.graph_db or "user_name" not in kwargs:
return extracted_memory_dict
user_name = kwargs.get("user_name")

# Detect language
lang = "en"
if sources:
for source in sources:
if hasattr(source, "lang") and source.lang:
lang = source.lang
break
elif isinstance(source, dict) and source.get("lang"):
lang = source.get("lang")
break
if lang is None:
lang = detect_lang(mem_text)

# Search for similar memories
merge_threshold = kwargs.get("merge_similarity_threshold", 0.3)

try:
search_results = self.graph_db.search_by_embedding(
vector=self.embedder.embed(mem_text)[0],
top_k=20,
status="activated",
threshold=merge_threshold,
user_name=user_name,
filter={
"or": [
{"memory_type": "LongTermMemory"},
{"memory_type": "UserMemory"},
{"memory_type": "WorkingMemory"},
]
},
)

if not search_results:
# No similar memories found, return original
return extracted_memory_dict

# Get full memory details
similar_memory_ids = [r["id"] for r in search_results if r.get("id")]
similar_memories_list = [
self.graph_db.get_node(mem_id, include_embedding=False)
for mem_id in similar_memory_ids
]

# Filter out None and mode:fast memories
filtered_similar = []
for mem in similar_memories_list:
if not mem:
continue
mem_metadata = mem.get("metadata", {})
tags = mem_metadata.get("tags", [])
if isinstance(tags, list) and "mode:fast" in tags:
continue
filtered_similar.append(
{
"id": mem.get("id"),
"memory": mem.get("memory", ""),
}
)
logger.info(
f"Valid similar memories for {mem_text} is "
f"{len(filtered_similar)}: {filtered_similar}"
)

if not filtered_similar:
# No valid similar memories, return original
return extracted_memory_dict

# Create a temporary TextualMemoryItem for merge check
temp_memory_item = TextualMemoryItem(
memory=mem_text,
metadata=TreeNodeTextualMemoryMetadata(
user_id="",
session_id="",
memory_type=extracted_memory_dict.get("memory_type", "LongTermMemory"),
status="activated",
tags=extracted_memory_dict.get("tags", []),
key=extracted_memory_dict.get("key", ""),
),
)

# Try to merge with LLM
merge_result = self._merge_memories_with_llm(
temp_memory_item, filtered_similar, lang=lang
)

if merge_result:
# Return merged memory dict
merged_dict = extracted_memory_dict.copy()
merged_dict["value"] = merge_result.get("value", mem_text)
merged_dict["merged_from"] = merge_result.get("merged_from", [])
logger.info(
f"[MultiModalFine] Merged memory with {len(merged_dict['merged_from'])} existing memories"
)
return merged_dict
else:
# No merge needed, return original
return extracted_memory_dict

except Exception as e:
logger.error(f"[MultiModalFine] Error in get_maybe_merged_memory: {e}")
# On error, return original
return extracted_memory_dict

def _merge_memories_with_llm(
self,
new_memory: TextualMemoryItem,
similar_memories: list[dict],
lang: str = "en",
) -> dict | None:
"""
Use LLM to merge new memory with similar existing memories.

Args:
new_memory: The newly extracted memory item
similar_memories: List of similar memories from graph_db (with id and memory fields)
lang: Language code ("en" or "zh")

Returns:
Merged memory dict with merged_from field, or None if no merge needed
"""
if not similar_memories:
return None

# Build merge prompt using template
similar_memories_text = "\n".join(
[f"[{mem['id']}]: {mem['memory']}" for mem in similar_memories]
)

merge_prompt_template = MEMORY_MERGE_PROMPT_ZH if lang == "zh" else MEMORY_MERGE_PROMPT_EN
merge_prompt = merge_prompt_template.format(
new_memory=new_memory.memory,
similar_memories=similar_memories_text,
)

try:
response_text = self.llm.generate([{"role": "user", "content": merge_prompt}])
merge_result = parse_json_result(response_text)

if merge_result.get("should_merge", False):
return {
"value": merge_result.get("value", new_memory.memory),
"merged_from": merge_result.get(
"merged_from", [mem["id"] for mem in similar_memories]
),
}
except Exception as e:
logger.error(f"[MultiModalFine] Error in merge LLM call: {e}")

return None

def _process_string_fine(
self,
fast_memory_items: list[TextualMemoryItem],
Expand Down Expand Up @@ -460,90 +633,69 @@ def _process_one_item(fast_item: TextualMemoryItem) -> list[TextualMemoryItem]:
# Determine prompt type based on sources
prompt_type = self._determine_prompt_type(sources)

# recall related memories
related_memories = None
memory_ids = []
if self.graph_db:
if "user_name" in kwargs:
memory_ids = self.graph_db.search_by_embedding(
vector=self.embedder.embed(mem_str)[0],
top_k=20,
status="activated",
user_name=kwargs.get("user_name"),
filter={
"or": [
{"memory_type": "LongTermMemory"},
{"memory_type": "UserMemory"},
{"memory_type": "WorkingMemory"},
]
},
)
memory_ids = set({r["id"] for r in memory_ids if r.get("id")})
related_memories_list = self.graph_db.get_nodes(
list(memory_ids),
include_embedding=False,
user_name=kwargs.get("user_name"),
)
related_memories = "\n".join(
["{}: {}".format(mem["id"], mem["memory"]) for mem in related_memories_list]
)
else:
logger.warning("user_name is null when graph_db exists")

# ========== Stage 1: Normal extraction (without reference) ==========
try:
resp = self._get_llm_response(
mem_str, custom_tags, sources, prompt_type, related_memories
)
resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type)
except Exception as e:
logger.error(f"[MultiModalFine] Error calling LLM: {e}")
return fine_items

if resp.get("memory list", []):
for m in resp.get("memory list", []):
try:
# Check and merge with similar memories if needed
m_maybe_merged = self._get_maybe_merged_memory(
extracted_memory_dict=m,
mem_text=m.get("value", ""),
sources=sources,
**kwargs,
)
# Normalize memory_type (same as simple_struct)
memory_type = (
m.get("memory_type", "LongTermMemory")
m_maybe_merged.get("memory_type", "LongTermMemory")
.replace("长期记忆", "LongTermMemory")
.replace("用户记忆", "UserMemory")
)
if "merged_from" in m:
for merged_id in m["merged_from"]:
if merged_id not in memory_ids:
logger.warning("merged id not valid!!!!!")
info_per_item["merged_from"] = m["merged_from"]
# Create fine mode memory item (same as simple_struct)
node = self._make_memory_item(
value=m.get("value", ""),
value=m_maybe_merged.get("value", ""),
info=info_per_item,
memory_type=memory_type,
tags=m.get("tags", []),
key=m.get("key", ""),
tags=m_maybe_merged.get("tags", []),
key=m_maybe_merged.get("key", ""),
sources=sources, # Preserve sources from fast item
background=resp.get("summary", ""),
**extra_kwargs,
)
# Add merged_from to info if present
if "merged_from" in m_maybe_merged:
node.metadata.info = node.metadata.info or {}
node.metadata.info["merged_from"] = m_maybe_merged["merged_from"]
fine_items.append(node)
except Exception as e:
logger.error(f"[MultiModalFine] parse error: {e}")
elif resp.get("value") and resp.get("key"):
try:
if "merged_from" in resp:
for merged_id in resp["merged_from"]:
if merged_id not in memory_ids:
logger.warning("merged id not valid!!!!!")
info_per_item["merged_from"] = resp["merged_from"]
# Create fine mode memory item (same as simple_struct)
# Check and merge with similar memories if needed
resp_maybe_merged = self._get_maybe_merged_memory(
extracted_memory_dict=resp,
mem_text=resp.get("value", "").strip(),
sources=sources,
**kwargs,
)
node = self._make_memory_item(
value=resp.get("value", "").strip(),
value=resp_maybe_merged.get("value", "").strip(),
info=info_per_item,
memory_type="LongTermMemory",
tags=resp.get("tags", []),
key=resp.get("key", None),
tags=resp_maybe_merged.get("tags", []),
key=resp_maybe_merged.get("key", None),
sources=sources, # Preserve sources from fast item
background=resp.get("summary", ""),
**extra_kwargs,
)
# Add merged_from to info if present
if "merged_from" in resp_maybe_merged:
node.metadata.info = node.metadata.info or {}
node.metadata.info["merged_from"] = resp_maybe_merged["merged_from"]
fine_items.append(node)
except Exception as e:
logger.error(f"[MultiModalFine] parse error: {e}")
Expand Down Expand Up @@ -694,9 +846,7 @@ def _process_multi_modal_data(

@timed
def _process_transfer_multi_modal_data(
self,
raw_node: TextualMemoryItem,
custom_tags: list[str] | None = None,
self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None, **kwargs
) -> list[TextualMemoryItem]:
"""
Process transfer for multimodal data.
Expand All @@ -720,9 +870,11 @@ def _process_transfer_multi_modal_data(
# Part A: call llm in parallel using thread pool
with ContextThreadPoolExecutor(max_workers=2) as executor:
future_string = executor.submit(
self._process_string_fine, [raw_node], info, custom_tags
self._process_string_fine, [raw_node], info, custom_tags, **kwargs
)
future_tool = executor.submit(
self._process_tool_trajectory_fine, [raw_node], info, **kwargs
)
future_tool = executor.submit(self._process_tool_trajectory_fine, [raw_node], info)

# Collect results
fine_memory_items_string_parser = future_string.result()
Expand Down Expand Up @@ -789,6 +941,7 @@ def fine_transfer_simple_mem(
input_memories: list[TextualMemoryItem],
type: str,
custom_tags: list[str] | None = None,
**kwargs,
) -> list[list[TextualMemoryItem]]:
if not input_memories:
return []
Expand All @@ -799,7 +952,7 @@ def fine_transfer_simple_mem(
with ContextThreadPoolExecutor() as executor:
futures = [
executor.submit(
self._process_transfer_multi_modal_data, scene_data_info, custom_tags
self._process_transfer_multi_modal_data, scene_data_info, custom_tags, **kwargs
)
for scene_data_info in input_memories
]
Expand Down
Loading