From 7a640603e9c6e51e93ba4bbe4e7a9f3b722dba12 Mon Sep 17 00:00:00 2001 From: Anh Date: Tue, 3 Jun 2025 01:00:04 +1200 Subject: [PATCH 1/2] Support Online Tracking --- tracklab/engine/video.py | 53 +++++--- .../visualization/visualization_engine.py | 119 ++++++++++++++---- 2 files changed, 132 insertions(+), 40 deletions(-) diff --git a/tracklab/engine/video.py b/tracklab/engine/video.py index dc8bdad7..9ec80870 100644 --- a/tracklab/engine/video.py +++ b/tracklab/engine/video.py @@ -51,14 +51,14 @@ def track_dataset(self): self.callback("on_dataset_track_start") self.callback( "on_video_loop_start", - video_metadata=pd.Series(name=self.video_filename), + video_metadata=pd.Series({"name": self.video_filename}), video_idx=0, index=0, ) detections = self.video_loop() self.callback( "on_video_loop_end", - video_metadata=pd.Series(name=self.video_filename), + video_metadata=pd.Series({"name": self.video_filename}), video_idx=0, detections=detections, ) @@ -81,6 +81,11 @@ def video_loop(self): # print('in offline.py, model_names: ', model_names) frame_idx = -1 detections = pd.DataFrame() + + # Initialize module callbacks at the start + for model_name in model_names: + self.callback("on_module_start", task=model_name, dataloader=[]) + while video_cap.isOpened(): frame_idx += 1 ret, frame = video_cap.read() @@ -89,10 +94,13 @@ def video_loop(self): image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if not ret: break - metadata = pd.Series({"id": frame_idx, "frame": frame_idx, + base_metadata = pd.Series({"id": frame_idx, "frame": frame_idx, "video_id": video_filename}, name=frame_idx) self.callback("on_image_loop_start", - image_metadata=metadata, image_idx=frame_idx, index=frame_idx) + image_metadata=base_metadata, image_idx=frame_idx, index=frame_idx) + + image_metadata = pd.DataFrame([base_metadata]) + for model_name in model_names: model = self.models[model_name] if len(detections) > 0: @@ -102,49 +110,64 @@ def video_loop(self): if model.level == "video": raise "Video-level not supported for online video tracking" elif model.level == "image": - batch = model.preprocess(image=image, detections=dets, metadata=metadata) + batch = model.preprocess(image=image, detections=dets, metadata=image_metadata.iloc[0]) batch = type(model).collate_fn([(frame_idx, batch)]) - detections = self.default_step(batch, model_name, detections, metadata) + detections, image_metadata = self.default_step(batch, model_name, detections, image_metadata) elif model.level == "detection": for idx, detection in dets.iterrows(): - batch = model.preprocess(image=image, detection=detection, metadata=metadata) + batch = model.preprocess(image=image, detection=detection, metadata=image_metadata.iloc[0]) batch = type(model).collate_fn([(detection.name, batch)]) - detections = self.default_step(batch, model_name, detections, metadata) + detections, image_metadata = self.default_step(batch, model_name, detections, image_metadata) + self.callback("on_image_loop_end", - image_metadata=metadata, image=image, + image_metadata=image_metadata.iloc[0], image=image, image_idx=frame_idx, detections=detections) + # Finalize module callbacks at the end + for model_name in model_names: + self.callback("on_module_end", task=model_name, detections=detections) + return detections - def default_step(self, batch: Any, task: str, detections: pd.DataFrame, metadata, **kwargs): + def default_step(self, batch: Any, task: str, detections: pd.DataFrame, image_pred: pd.DataFrame, **kwargs): model = self.models[task] self.callback(f"on_module_step_start", task=task, batch=batch) idxs, batch = batch idxs = idxs.cpu() if isinstance(idxs, torch.Tensor) else idxs if model.level == "image": - log.info(f"step : {idxs}") - batch_metadatas = pd.DataFrame([metadata]) + log.info(f"step : {idxs} --- task : {task}") + batch_metadatas = image_pred.loc[list(idxs)] # self.img_metadatas.loc[idxs] if len(detections) > 0: batch_input_detections = detections.loc[ np.isin(detections.image_id, batch_metadatas.index) ] else: batch_input_detections = detections + batch_detections = self.models[task].process( batch, batch_input_detections, batch_metadatas) else: batch_detections = detections.loc[idxs] + if not image_pred.empty: + batch_metadatas = image_pred.loc[np.isin(image_pred.index, batch_detections.image_id)] + else: + batch_metadatas = image_pred batch_detections = self.models[task].process( batch=batch, detections=batch_detections, - metadatas=None, + metadatas=batch_metadatas, **kwargs, ) + + if isinstance(batch_detections, tuple): + batch_detections, batch_metadatas = batch_detections + image_pred = merge_dataframes(image_pred, batch_metadatas) + detections = merge_dataframes(detections, batch_detections) + self.callback( f"on_module_step_end", task=task, batch=batch, detections=detections ) - return detections - + return detections, image_pred diff --git a/tracklab/visualization/visualization_engine.py b/tracklab/visualization/visualization_engine.py index 0b96a5b4..66efd4e2 100644 --- a/tracklab/visualization/visualization_engine.py +++ b/tracklab/visualization/visualization_engine.py @@ -6,6 +6,8 @@ import cv2 import pandas as pd +import numpy as np +import platform from tracklab.callbacks import Progressbar, Callback from tracklab.visualization import Visualizer @@ -23,6 +25,7 @@ class VisualizationEngine(Callback): `draw_detection`. save_images: whether to save the visualization as image files (.jpeg) save_videos: whether to save the visualization as video files (.mp4) + show_online: whether to show online tracking in realtime (only work if the pipeline doesn't involve VideoLevelModule) process_n_videos: number of videos to visualize. Will visualize the first N videos. process_n_frames_by_video: number of frames per video to visualize. Will visualize frames every N/n frames (not first n frames) @@ -32,6 +35,7 @@ def __init__(self, visualizers: Dict[str, Visualizer], save_images: bool = False, save_videos: bool = False, + show_online: bool = False, video_fps: int = 25, process_n_videos: Optional[int] = None, process_n_frames_by_video: Optional[int] = None, @@ -41,13 +45,17 @@ def __init__(self, self.save_dir = Path("visualization") self.save_images = save_images self.save_videos = save_videos + self.show_online = show_online self.video_fps = video_fps self.max_videos = process_n_videos self.max_frames = process_n_frames_by_video + self.windows = [] for visualizer in visualizers.values(): visualizer.post_init(**kwargs) def on_dataset_track_end(self, engine: "TrackingEngine"): + if self.show_online: + cv2.destroyAllWindows() if self.save_videos or self.save_images: log.info(f"Visualization output at : {self.save_dir.absolute()}") @@ -58,31 +66,92 @@ def on_video_loop_end(self, engine, video_metadata, video_idx, detections, self.visualize(engine.tracker_state, video_idx, detections, image_pred, progress) progress.on_module_end(None, "vis", None) - """ - #TODO implement the online visualization - previous code: - if self.cfg.show_online: - tracker_state = engine.tracker_state - if tracker_state.detections_gt is not None: - ground_truths = tracker_state.detections_gt[ - tracker_state.detections_gt.image_id == image_metadata.name - ] - else: - ground_truths = None - if len(detections) == 0: - image = image - else: - detections = detections[detections.image_id == image_metadata.name] - image = self.draw_frame(image_metadata, - detections, ground_truths, "inf", image=image) - if platform.system() == "Linux" and self.video_name not in self.windows: - self.windows.append(self.video_name) - cv2.namedWindow(str(self.video_name), - cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) - cv2.resizeWindow(str(self.video_name), image.shape[1], image.shape[0]) - cv2.imshow(str(self.video_name), image) - cv2.waitKey(1) - """ + def on_image_loop_end(self, engine, image_metadata, image, image_idx, detections): + """ + Handle real-time display during online video tracking. + """ + if not self.show_online: + return + + try: + # Filter detections for current frame + frame_detections = ( + detections[detections.image_id == image_metadata.name] + if len(detections) > 0 + else pd.DataFrame() + ) + + # Get ground truth (usually None for online tracking) + ground_truths = pd.DataFrame() + + # Create dummy image metadata for compatibility + image_pred = pd.Series( + { + "lines": getattr(image_metadata, "lines", {}), + "keypoints": getattr(image_metadata, "keypoints", {}), + "file_path": f"frame_{image_idx:06d}.jpg", # Dummy path + }, + name=image_metadata.name, + ) + + image_gt = pd.Series( + { + "frame": image_idx, + "nframes": -1, # Unknown total frames in online mode + }, + name=image_metadata.name, + ) + + # Draw frame with all visualizers + display_image = self.draw_online_frame( + image_metadata, + image, + frame_detections, + ground_truths, + image_pred, + image_gt, + nframes=-1, + ) + + # Display the image + video_name = str(engine.video_filename) + if platform.system() == "Linux" and video_name not in self.windows: + self.windows.append(video_name) + cv2.namedWindow(video_name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) + cv2.resizeWindow( + video_name, display_image.shape[1], display_image.shape[0] + ) + + # Convert RGB to BGR for OpenCV display + cv2.imshow(video_name, display_image) + cv2.waitKey(1) # Non-blocking wait + + except Exception as e: + log.warning(f"Error in online visualization: {e}") + + def draw_online_frame( + self, + image_metadata, + image, + detections_pred, + detections_gt, + image_pred, + image_gt, + nframes, + ): + """Draw frame using all configured visualizers.""" + # Create a copy of the image to avoid modifying the original + image = np.copy(image) + + for visualizer in self.visualizers.values(): + try: + visualizer.draw_frame( + image, detections_pred, detections_gt, image_pred, image_gt + ) + except Exception as e: + log.warning(f"Visualizer {type(visualizer).__name__} raised error: {e}") + + return final_patch(image) def visualize(self, tracker_state: TrackerState, video_id, detections, image_preds, progress=None): image_metadatas = tracker_state.image_metadatas[tracker_state.image_metadatas.video_id == video_id] From 0c59d39177c71fc8f59b43059b13167d71b1c1b9 Mon Sep 17 00:00:00 2001 From: Anh Date: Thu, 10 Jul 2025 01:03:34 +1200 Subject: [PATCH 2/2] respsect `use_rich` parameter for video engine --- tracklab/configs/engine/video.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/tracklab/configs/engine/video.yaml b/tracklab/configs/engine/video.yaml index b5affaef..1f8f0596 100644 --- a/tracklab/configs/engine/video.yaml +++ b/tracklab/configs/engine/video.yaml @@ -6,4 +6,5 @@ num_workers: ${num_cores} callbacks: progress: _target_: tracklab.callbacks.Progressbar + use_rich: ${use_rich} vis: ${visualization}