diff --git a/src/model_api/models/__init__.py b/src/model_api/models/__init__.py index 2df1af31..414754b2 100644 --- a/src/model_api/models/__init__.py +++ b/src/model_api/models/__init__.py @@ -34,7 +34,7 @@ get_contours, ) from .visual_prompting import Prompt, SAMLearnableVisualPrompter, SAMVisualPrompter -from .yolo import YOLO, YOLOF, YOLOX, YoloV3ONNX, YoloV4, YOLOv5, YOLOv8 +from .yolo import YOLO, YOLO11, YOLOF, YOLOX, YoloV3ONNX, YoloV4, YOLOv5, YOLOv8 classification_models = [ "resnet-18-pytorch", @@ -92,6 +92,7 @@ "TopDownKeypointDetectionPipeline", "VisualPromptingResult", "YOLO", + "YOLO11", "YOLOF", "YOLOv3ONNX", "YOLOv4", diff --git a/src/model_api/models/yolo.py b/src/model_api/models/yolo.py index b232cb1b..f42b24f9 100644 --- a/src/model_api/models/yolo.py +++ b/src/model_api/models/yolo.py @@ -746,7 +746,7 @@ def __init__(self, inference_adapter, configuration, preload=False): out_shape = output.shape if len(out_shape) != 3: self.raise_error("the output must be of rank 3") - if self.labels and len(self.labels) + 4 != out_shape[1]: + if self.params.labels and len(self.params.labels) + 4 != out_shape[1]: self.raise_error("number of labels must be smaller than out_shape[1] by 4") @classmethod @@ -799,7 +799,7 @@ def postprocess(self, outputs, meta) -> DetectionResult: ) keep_top_k = 30000 iou_threshold = self.params.iou_threshold - if self.agnostic_nms: # type: ignore[attr-defined] + if self.params.agnostic_nms: boxes = boxes[ nms( boxes[:, 2], @@ -807,12 +807,12 @@ def postprocess(self, outputs, meta) -> DetectionResult: boxes[:, 4], boxes[:, 5], boxes[:, 1], - iou_threshold, # type: ignore[attr-defined] + iou_threshold, keep_top_k=keep_top_k, ) ] else: - boxes, _ = multiclass_nms(boxes, iou_threshold, keep_top_k) # type: ignore[attr-defined] + boxes, _ = multiclass_nms(boxes, iou_threshold, keep_top_k) inputImgWidth = meta["original_shape"][1] inputImgHeight = meta["original_shape"][0] resize_meta = ResizeMetadata.compute( @@ -853,3 +853,9 @@ class YOLOv8(YOLOv5): """YOLOv5 and YOLOv8 are identical in terms of inference""" __model__ = "YOLOv8" + + +class YOLO11(YOLOv5): + """YOLO11 uses the same inference approach as YOLOv5 and YOLOv8""" + + __model__ = "YOLO11"