diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index f6a016556..ce75f6dc5 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -108,10 +108,10 @@ def _create_chunk_item(chunk): ) return split_item - # Use thread pool to parallel process chunks + # Use thread pool to parallel process chunks, but keep the original order with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: futures = [executor.submit(_create_chunk_item, chunk) for chunk in chunks] - for future in concurrent.futures.as_completed(futures): + for future in futures: split_item = future.result() if split_item is not None: split_items.append(split_item) @@ -146,26 +146,33 @@ def _concat_multi_modal_memories( parallel_chunking = True if parallel_chunking: - # parallel chunk large memory items + # parallel chunk large memory items, but keep the original order with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: - future_to_item = { - executor.submit(self._split_large_memory_item, item, max_tokens): item - for item in all_memory_items - if (item.memory or "") and self._count_tokens(item.memory) > max_tokens - } - processed_items.extend( - [ - item - for item in all_memory_items - if not ( - (item.memory or "") and self._count_tokens(item.memory) > max_tokens - ) - ] - ) - # collect split items from futures - for future in concurrent.futures.as_completed(future_to_item): - split_items = future.result() - processed_items.extend(split_items) + # Create a list to hold futures with their original index + futures = [] + for idx, item in enumerate(all_memory_items): + if (item.memory or "") and self._count_tokens(item.memory) > max_tokens: + future = executor.submit(self._split_large_memory_item, item, max_tokens) + futures.append( + (idx, future, True) + ) # True indicates this item needs splitting + else: + futures.append((idx, item, False)) # False indicates no splitting needed + + # Process results in original order + temp_results = [None] * len(all_memory_items) + for idx, future_or_item, needs_splitting in futures: + if needs_splitting: + # Wait for the future to complete and get the split items + split_items = future_or_item.result() + temp_results[idx] = split_items + else: + # No splitting needed, use the original item + temp_results[idx] = [future_or_item] + + # Flatten the results while preserving order + for items in temp_results: + processed_items.extend(items) else: # serial chunk large memory items for item in all_memory_items: @@ -831,8 +838,9 @@ def _process_multi_modal_data( if isinstance(scene_data_info, list): # Parse each message in the list all_memory_items = [] - # Use thread pool to parse each message in parallel + # Use thread pool to parse each message in parallel, but keep the original order with ContextThreadPoolExecutor(max_workers=30) as executor: + # submit tasks and keep the original order futures = [ executor.submit( self.multi_modal_parser.parse, @@ -844,7 +852,8 @@ def _process_multi_modal_data( ) for msg in scene_data_info ] - for future in concurrent.futures.as_completed(futures): + # collect results in original order + for future in futures: try: items = future.result() all_memory_items.extend(items)