Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/memos/mem_reader/multi_modal_struct.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import concurrent.futures
import json
import re
import time
import traceback

from typing import Any
Expand Down Expand Up @@ -788,6 +789,7 @@ def _process_multi_modal_data(
better understanding via calling llm
**kwargs: Additional parameters (mode, etc.)
"""
init_time = time.time()
# Pop custom_tags from info (same as simple_struct.py)
# must pop here, avoid add to info, only used in sync fine mode
custom_tags = info.pop("custom_tags", None) if isinstance(info, dict) else None
Expand All @@ -798,14 +800,21 @@ def _process_multi_modal_data(
# Parse each message in the list
all_memory_items = []
for msg in scene_data_info:
items = self.multi_modal_parser.parse(msg, info, mode="fast", **kwargs)
items = self.multi_modal_parser.parse(
msg, info, mode="fast", is_need_emb=False, **kwargs
)
all_memory_items.extend(items)
else:
# Parse as single message
all_memory_items = self.multi_modal_parser.parse(
scene_data_info, info, mode="fast", **kwargs
scene_data_info, info, mode="fast", is_need_emb=False, **kwargs
)

print(f"time for multi_modal_parser.parse: {time.time() - init_time}")
init_time = time.time()

fast_memory_items = self._concat_multi_modal_memories(all_memory_items)
print(f"time for _concat_multi_modal_memories: {time.time() - init_time}")
if mode == "fast":
return fast_memory_items
else:
Expand Down
3 changes: 2 additions & 1 deletion src/memos/mem_reader/read_multi_modal/user_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def parse_fast(
info: dict[str, Any],
**kwargs,
) -> list[TextualMemoryItem]:
is_need_emb = kwargs.get("is_need_emb", True)
if not isinstance(message, dict):
logger.warning(f"[UserParser] Expected dict, got {type(message)}")
return []
Expand Down Expand Up @@ -192,7 +193,7 @@ def parse_fast(
status="activated",
tags=["mode:fast"],
key=_derive_key(line),
embedding=self.embedder.embed([line])[0],
embedding=self.embedder.embed([line])[0] if is_need_emb else None,
usage=[],
sources=sources,
background="",
Expand Down
10 changes: 10 additions & 0 deletions src/memos/multi_mem_cube/single_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
import time
import traceback

from dataclasses import dataclass
Expand Down Expand Up @@ -772,6 +773,8 @@ def _process_text_mem(
Returns:
List of formatted memory responses
"""
init_time = time.time()
print(f"init time: {init_time}")
target_session_id = add_req.session_id or "default_session"

# Decide extraction mode:
Expand Down Expand Up @@ -806,6 +809,12 @@ def _process_text_mem(
)
flattened_local = [mm for m in memories_local for mm in m]

print(
f"Time for get_memory: {time.time() - init_time}, "
f"total memories: {len(flattened_local)}"
)
init_time = time.time()

# Explicitly set source_doc_id to metadata if present in info
source_doc_id = (add_req.info or {}).get("source_doc_id")
if source_doc_id:
Expand All @@ -819,6 +828,7 @@ def _process_text_mem(
flattened_local,
user_name=user_context.mem_cube_id,
)
print(f"time for add: {time.time() - init_time}")
self.logger.info(
f"Added {len(mem_ids_local)} memories for user {add_req.user_id} "
f"in session {add_req.session_id}: {mem_ids_local}"
Expand Down