Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/model_api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
get_contours,
)
from .visual_prompting import Prompt, SAMLearnableVisualPrompter, SAMVisualPrompter
from .yolo import YOLO, YOLOF, YOLOX, YoloV3ONNX, YoloV4, YOLOv5, YOLOv8
from .yolo import YOLO, YOLO11, YOLOF, YOLOX, YoloV3ONNX, YoloV4, YOLOv5, YOLOv8

classification_models = [
"resnet-18-pytorch",
Expand Down Expand Up @@ -92,6 +92,7 @@
"TopDownKeypointDetectionPipeline",
"VisualPromptingResult",
"YOLO",
"YOLO11",
"YOLOF",
"YOLOv3ONNX",
"YOLOv4",
Expand Down
14 changes: 10 additions & 4 deletions src/model_api/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def __init__(self, inference_adapter, configuration, preload=False):
out_shape = output.shape
if len(out_shape) != 3:
self.raise_error("the output must be of rank 3")
if self.labels and len(self.labels) + 4 != out_shape[1]:
if self.params.labels and len(self.params.labels) + 4 != out_shape[1]:
self.raise_error("number of labels must be smaller than out_shape[1] by 4")

@classmethod
Expand Down Expand Up @@ -799,20 +799,20 @@ def postprocess(self, outputs, meta) -> DetectionResult:
)
keep_top_k = 30000
iou_threshold = self.params.iou_threshold
if self.agnostic_nms: # type: ignore[attr-defined]
if self.params.agnostic_nms:
boxes = boxes[
nms(
boxes[:, 2],
boxes[:, 3],
boxes[:, 4],
boxes[:, 5],
boxes[:, 1],
iou_threshold, # type: ignore[attr-defined]
iou_threshold,
keep_top_k=keep_top_k,
)
]
else:
boxes, _ = multiclass_nms(boxes, iou_threshold, keep_top_k) # type: ignore[attr-defined]
boxes, _ = multiclass_nms(boxes, iou_threshold, keep_top_k)
inputImgWidth = meta["original_shape"][1]
inputImgHeight = meta["original_shape"][0]
resize_meta = ResizeMetadata.compute(
Expand Down Expand Up @@ -853,3 +853,9 @@ class YOLOv8(YOLOv5):
"""YOLOv5 and YOLOv8 are identical in terms of inference"""

__model__ = "YOLOv8"


class YOLO11(YOLOv5):
"""YOLO11 uses the same inference approach as YOLOv5 and YOLOv8"""

__model__ = "YOLO11"
Loading