diff --git a/tracklab/wrappers/bbox_detector/transformers_api.py b/tracklab/wrappers/bbox_detector/transformers_api.py index fcd79892..42fa0965 100644 --- a/tracklab/wrappers/bbox_detector/transformers_api.py +++ b/tracklab/wrappers/bbox_detector/transformers_api.py @@ -29,7 +29,10 @@ def preprocess(self, image, detections: pd.DataFrame, metadata: pd.Series) -> An @torch.no_grad() def process(self, batch: Any, detections: pd.DataFrame, metadatas: pd.DataFrame): - images = self.image_processor(batch, return_tensors="pt") + images = self.image_processor(batch, return_tensors="pt") # TODO: this should happen in the preprocess step + for key in images.keys(): + images[key] = images[key].to(self.device) + outputs = self.model(**images) results = self.image_processor.post_process_object_detection( outputs, target_sizes=[batch.shape[1:3]]*batch.shape[0], threshold=self.min_confidence @@ -42,7 +45,7 @@ def process(self, batch: Any, detections: pd.DataFrame, metadatas: pd.DataFrame) pd.Series( dict( image_id=metadatas["id"].values[i], - bbox_ltwh=ltrb_to_ltwh(box.numpy(), (batch.shape[2], batch.shape[1])), + bbox_ltwh=ltrb_to_ltwh(box.cpu().numpy(), (batch.shape[2], batch.shape[1])), bbox_conf=score.item(), video_id=metadatas["video_id"].values[i], category_id=1,