Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 32 additions & 23 deletions src/memos/mem_reader/multi_modal_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def _create_chunk_item(chunk):
)
return split_item

# Use thread pool to parallel process chunks
# Use thread pool to parallel process chunks, but keep the original order
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(_create_chunk_item, chunk) for chunk in chunks]
for future in concurrent.futures.as_completed(futures):
for future in futures:
split_item = future.result()
if split_item is not None:
split_items.append(split_item)
Expand Down Expand Up @@ -146,26 +146,33 @@ def _concat_multi_modal_memories(
parallel_chunking = True

if parallel_chunking:
# parallel chunk large memory items
# parallel chunk large memory items, but keep the original order
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
future_to_item = {
executor.submit(self._split_large_memory_item, item, max_tokens): item
for item in all_memory_items
if (item.memory or "") and self._count_tokens(item.memory) > max_tokens
}
processed_items.extend(
[
item
for item in all_memory_items
if not (
(item.memory or "") and self._count_tokens(item.memory) > max_tokens
)
]
)
# collect split items from futures
for future in concurrent.futures.as_completed(future_to_item):
split_items = future.result()
processed_items.extend(split_items)
# Create a list to hold futures with their original index
futures = []
for idx, item in enumerate(all_memory_items):
if (item.memory or "") and self._count_tokens(item.memory) > max_tokens:
future = executor.submit(self._split_large_memory_item, item, max_tokens)
futures.append(
(idx, future, True)
) # True indicates this item needs splitting
else:
futures.append((idx, item, False)) # False indicates no splitting needed

# Process results in original order
temp_results = [None] * len(all_memory_items)
for idx, future_or_item, needs_splitting in futures:
if needs_splitting:
# Wait for the future to complete and get the split items
split_items = future_or_item.result()
temp_results[idx] = split_items
else:
# No splitting needed, use the original item
temp_results[idx] = [future_or_item]

# Flatten the results while preserving order
for items in temp_results:
processed_items.extend(items)
else:
# serial chunk large memory items
for item in all_memory_items:
Expand Down Expand Up @@ -831,8 +838,9 @@ def _process_multi_modal_data(
if isinstance(scene_data_info, list):
# Parse each message in the list
all_memory_items = []
# Use thread pool to parse each message in parallel
# Use thread pool to parse each message in parallel, but keep the original order
with ContextThreadPoolExecutor(max_workers=30) as executor:
# submit tasks and keep the original order
futures = [
executor.submit(
self.multi_modal_parser.parse,
Expand All @@ -844,7 +852,8 @@ def _process_multi_modal_data(
)
for msg in scene_data_info
]
for future in concurrent.futures.as_completed(futures):
# collect results in original order
for future in futures:
try:
items = future.result()
all_memory_items.extend(items)
Expand Down