diff --git a/fastsam/predict.py b/fastsam/predict.py index cc15128..727a57d 100644 --- a/fastsam/predict.py +++ b/fastsam/predict.py @@ -31,6 +31,7 @@ def postprocess(self, preds, img, orig_imgs): full_box = full_box.view(1, -1) critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) if critical_iou_index.numel() != 0: + critical_iou_index = critical_iou_index[:1] full_box[0][4] = p[0][critical_iou_index][:,4] full_box[0][6:] = p[0][critical_iou_index][:,6:] p[0][critical_iou_index] = full_box diff --git a/fastsam/utils.py b/fastsam/utils.py index 33d37cd..7b550e5 100644 --- a/fastsam/utils.py +++ b/fastsam/utils.py @@ -44,7 +44,7 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals box1: (4, ) boxes: (n, 4) Returns: - high_iou_indices: Indices of boxes with IoU > thres + high_iou_indices: Indices of boxes with IoU > thres sorted in descending order ''' boxes = adjust_bboxes_to_image_border(boxes, image_shape) # obtain coordinates for intersections @@ -72,8 +72,9 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals # get indices of boxes with IoU > thres high_iou_indices = torch.nonzero(iou > iou_thres).flatten() + sorted_high_iou_indices = high_iou_indices[torch.argsort(iou[high_iou_indices], descending=True)] - return high_iou_indices + return sorted_high_iou_indices def image_to_np_ndarray(image):