diff --git a/tracklab/engine/offline.py b/tracklab/engine/offline.py index 6214adf2..77325763 100644 --- a/tracklab/engine/offline.py +++ b/tracklab/engine/offline.py @@ -1,6 +1,7 @@ import logging from tracklab.engine import TrackingEngine +from tracklab.engine.engine import merge_dataframes from tracklab.utils.cv2 import cv2_load_image log = logging.getLogger(__name__) @@ -19,7 +20,11 @@ def video_loop(self, tracker_state, video, video_id): model_names = self.module_names for model_name in model_names: if self.models[model_name].level == "video": - detections = self.models[model_name].process(detections, image_pred) + batch_detections = self.models[model_name].process(detections, image_pred) + if isinstance(batch_detections, tuple): + batch_detections, batch_metadata = batch_detections + image_pred = merge_dataframes(image_pred, batch_metadata) + detections = merge_dataframes(detections, batch_detections) continue self.datapipes[model_name].update(image_filepaths, image_pred, detections) self.callback(