Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tracklab/configs/engine/video.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ num_workers: ${num_cores}
callbacks:
progress:
_target_: tracklab.callbacks.Progressbar
use_rich: ${use_rich}
vis: ${visualization}
53 changes: 38 additions & 15 deletions tracklab/engine/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def track_dataset(self):
self.callback("on_dataset_track_start")
self.callback(
"on_video_loop_start",
video_metadata=pd.Series(name=self.video_filename),
video_metadata=pd.Series({"name": self.video_filename}),
video_idx=0,
index=0,
)
detections = self.video_loop()
self.callback(
"on_video_loop_end",
video_metadata=pd.Series(name=self.video_filename),
video_metadata=pd.Series({"name": self.video_filename}),
video_idx=0,
detections=detections,
)
Expand All @@ -81,6 +81,11 @@ def video_loop(self):
# print('in offline.py, model_names: ', model_names)
frame_idx = -1
detections = pd.DataFrame()

# Initialize module callbacks at the start
for model_name in model_names:
self.callback("on_module_start", task=model_name, dataloader=[])

while video_cap.isOpened():
frame_idx += 1
ret, frame = video_cap.read()
Expand All @@ -89,10 +94,13 @@ def video_loop(self):
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if not ret:
break
metadata = pd.Series({"id": frame_idx, "frame": frame_idx,
base_metadata = pd.Series({"id": frame_idx, "frame": frame_idx,
"video_id": video_filename}, name=frame_idx)
self.callback("on_image_loop_start",
image_metadata=metadata, image_idx=frame_idx, index=frame_idx)
image_metadata=base_metadata, image_idx=frame_idx, index=frame_idx)

image_metadata = pd.DataFrame([base_metadata])

for model_name in model_names:
model = self.models[model_name]
if len(detections) > 0:
Expand All @@ -102,49 +110,64 @@ def video_loop(self):
if model.level == "video":
raise "Video-level not supported for online video tracking"
elif model.level == "image":
batch = model.preprocess(image=image, detections=dets, metadata=metadata)
batch = model.preprocess(image=image, detections=dets, metadata=image_metadata.iloc[0])
batch = type(model).collate_fn([(frame_idx, batch)])
detections = self.default_step(batch, model_name, detections, metadata)
detections, image_metadata = self.default_step(batch, model_name, detections, image_metadata)
elif model.level == "detection":
for idx, detection in dets.iterrows():
batch = model.preprocess(image=image, detection=detection, metadata=metadata)
batch = model.preprocess(image=image, detection=detection, metadata=image_metadata.iloc[0])
batch = type(model).collate_fn([(detection.name, batch)])
detections = self.default_step(batch, model_name, detections, metadata)
detections, image_metadata = self.default_step(batch, model_name, detections, image_metadata)

self.callback("on_image_loop_end",
image_metadata=metadata, image=image,
image_metadata=image_metadata.iloc[0], image=image,
image_idx=frame_idx, detections=detections)

# Finalize module callbacks at the end
for model_name in model_names:
self.callback("on_module_end", task=model_name, detections=detections)

return detections

def default_step(self, batch: Any, task: str, detections: pd.DataFrame, metadata, **kwargs):
def default_step(self, batch: Any, task: str, detections: pd.DataFrame, image_pred: pd.DataFrame, **kwargs):
model = self.models[task]
self.callback(f"on_module_step_start", task=task, batch=batch)
idxs, batch = batch
idxs = idxs.cpu() if isinstance(idxs, torch.Tensor) else idxs
if model.level == "image":
log.info(f"step : {idxs}")
batch_metadatas = pd.DataFrame([metadata])
log.info(f"step : {idxs} --- task : {task}")
batch_metadatas = image_pred.loc[list(idxs)] # self.img_metadatas.loc[idxs]
if len(detections) > 0:
batch_input_detections = detections.loc[
np.isin(detections.image_id, batch_metadatas.index)
]
else:
batch_input_detections = detections

batch_detections = self.models[task].process(
batch,
batch_input_detections,
batch_metadatas)
else:
batch_detections = detections.loc[idxs]
if not image_pred.empty:
batch_metadatas = image_pred.loc[np.isin(image_pred.index, batch_detections.image_id)]
else:
batch_metadatas = image_pred
batch_detections = self.models[task].process(
batch=batch,
detections=batch_detections,
metadatas=None,
metadatas=batch_metadatas,
**kwargs,
)

if isinstance(batch_detections, tuple):
batch_detections, batch_metadatas = batch_detections
image_pred = merge_dataframes(image_pred, batch_metadatas)

detections = merge_dataframes(detections, batch_detections)

self.callback(
f"on_module_step_end", task=task, batch=batch, detections=detections
)
return detections

return detections, image_pred
119 changes: 94 additions & 25 deletions tracklab/visualization/visualization_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import cv2
import pandas as pd
import numpy as np
import platform

from tracklab.callbacks import Progressbar, Callback
from tracklab.visualization import Visualizer
Expand All @@ -23,6 +25,7 @@ class VisualizationEngine(Callback):
`draw_detection`.
save_images: whether to save the visualization as image files (.jpeg)
save_videos: whether to save the visualization as video files (.mp4)
show_online: whether to show online tracking in realtime (only work if the pipeline doesn't involve VideoLevelModule)
process_n_videos: number of videos to visualize. Will visualize the first N videos.
process_n_frames_by_video: number of frames per video to visualize. Will visualize
frames every N/n frames (not first n frames)
Expand All @@ -32,6 +35,7 @@ def __init__(self,
visualizers: Dict[str, Visualizer],
save_images: bool = False,
save_videos: bool = False,
show_online: bool = False,
video_fps: int = 25,
process_n_videos: Optional[int] = None,
process_n_frames_by_video: Optional[int] = None,
Expand All @@ -41,13 +45,17 @@ def __init__(self,
self.save_dir = Path("visualization")
self.save_images = save_images
self.save_videos = save_videos
self.show_online = show_online
self.video_fps = video_fps
self.max_videos = process_n_videos
self.max_frames = process_n_frames_by_video
self.windows = []
for visualizer in visualizers.values():
visualizer.post_init(**kwargs)

def on_dataset_track_end(self, engine: "TrackingEngine"):
if self.show_online:
cv2.destroyAllWindows()
if self.save_videos or self.save_images:
log.info(f"Visualization output at : {self.save_dir.absolute()}")

Expand All @@ -58,31 +66,92 @@ def on_video_loop_end(self, engine, video_metadata, video_idx, detections,
self.visualize(engine.tracker_state, video_idx, detections, image_pred, progress)
progress.on_module_end(None, "vis", None)

"""
#TODO implement the online visualization
previous code:
if self.cfg.show_online:
tracker_state = engine.tracker_state
if tracker_state.detections_gt is not None:
ground_truths = tracker_state.detections_gt[
tracker_state.detections_gt.image_id == image_metadata.name
]
else:
ground_truths = None
if len(detections) == 0:
image = image
else:
detections = detections[detections.image_id == image_metadata.name]
image = self.draw_frame(image_metadata,
detections, ground_truths, "inf", image=image)
if platform.system() == "Linux" and self.video_name not in self.windows:
self.windows.append(self.video_name)
cv2.namedWindow(str(self.video_name),
cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(str(self.video_name), image.shape[1], image.shape[0])
cv2.imshow(str(self.video_name), image)
cv2.waitKey(1)
"""
def on_image_loop_end(self, engine, image_metadata, image, image_idx, detections):
"""
Handle real-time display during online video tracking.
"""
if not self.show_online:
return

try:
# Filter detections for current frame
frame_detections = (
detections[detections.image_id == image_metadata.name]
if len(detections) > 0
else pd.DataFrame()
)

# Get ground truth (usually None for online tracking)
ground_truths = pd.DataFrame()

# Create dummy image metadata for compatibility
image_pred = pd.Series(
{
"lines": getattr(image_metadata, "lines", {}),
"keypoints": getattr(image_metadata, "keypoints", {}),
"file_path": f"frame_{image_idx:06d}.jpg", # Dummy path
},
name=image_metadata.name,
)

image_gt = pd.Series(
{
"frame": image_idx,
"nframes": -1, # Unknown total frames in online mode
},
name=image_metadata.name,
)

# Draw frame with all visualizers
display_image = self.draw_online_frame(
image_metadata,
image,
frame_detections,
ground_truths,
image_pred,
image_gt,
nframes=-1,
)

# Display the image
video_name = str(engine.video_filename)
if platform.system() == "Linux" and video_name not in self.windows:
self.windows.append(video_name)
cv2.namedWindow(video_name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
cv2.resizeWindow(
video_name, display_image.shape[1], display_image.shape[0]
)

# Convert RGB to BGR for OpenCV display
cv2.imshow(video_name, display_image)
cv2.waitKey(1) # Non-blocking wait

except Exception as e:
log.warning(f"Error in online visualization: {e}")

def draw_online_frame(
self,
image_metadata,
image,
detections_pred,
detections_gt,
image_pred,
image_gt,
nframes,
):
"""Draw frame using all configured visualizers."""
# Create a copy of the image to avoid modifying the original
image = np.copy(image)

for visualizer in self.visualizers.values():
try:
visualizer.draw_frame(
image, detections_pred, detections_gt, image_pred, image_gt
)
except Exception as e:
log.warning(f"Visualizer {type(visualizer).__name__} raised error: {e}")

return final_patch(image)

def visualize(self, tracker_state: TrackerState, video_id, detections, image_preds, progress=None):
image_metadatas = tracker_state.image_metadatas[tracker_state.image_metadatas.video_id == video_id]
Expand Down