From d940e20703ca545d3f09de4d4508df5684140ffd Mon Sep 17 00:00:00 2001 From: Lukas Hennies Date: Tue, 11 Feb 2025 11:59:32 +0100 Subject: [PATCH 1/7] fix: add psutil --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b40e8ff..8c04939 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ pandas>=1.1.4 seaborn>=0.11.0 gradio==3.35.2 - +psutil>=6.0.0 # Ultralytics----------------------------------- # ultralytics == 8.0.120 From 50d75c4752fa9e66afa6130ad239c874dee9657d Mon Sep 17 00:00:00 2001 From: Lukas Hennies Date: Tue, 11 Feb 2025 12:17:37 +0100 Subject: [PATCH 2/7] fix: add model components to allowed globals for load --- ultralytics/nn/tasks.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 3c2ba06..bf47de7 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -7,11 +7,18 @@ import torch import torch.nn as nn +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn import Sequential, Conv2d, MaxPool2d, Upsample, ConvTranspose2d +from ultralytics.nn.modules.block import C2f, DFL, Bottleneck +from torch.nn.modules.container import ModuleList +from torch.nn.modules.activation import SiLU +from ultralytics.nn.modules.conv import Concat from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv, - RTDETRDecoder, Segment) -from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load + RTDETRDecoder, Segment, Proto) +from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load, \ + IterableSimpleNamespace from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss from ultralytics.yolo.utils.plotting import feature_visualization @@ -575,7 +582,17 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): """Loads a single model weights.""" - ckpt, weight = torch_safe_load(weight) # load ckpt + with torch.serialization.safe_globals([SegmentationModel, Sequential, Conv, Conv2, + Conv2d, BatchNorm2d, SiLU, C2f, ModuleList, Bottleneck, SPPF, MaxPool2d, + Upsample, Concat, Segment, + DFL, Proto, ConvTranspose2d, AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, + BottleneckCSP, C2f, C3Ghost, C3x, + Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, + DWConvTranspose2d, + Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv, + RTDETRDecoder, Segment, getattr, IterableSimpleNamespace + ]): + ckpt, weight = torch_safe_load(weight) # load ckpt args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model From 836e140560aeffd719225e8d37e763af24b4c780 Mon Sep 17 00:00:00 2001 From: Lukas Hennies Date: Tue, 11 Feb 2025 12:44:01 +0100 Subject: [PATCH 3/7] fix: model load for new pytorch --- fastsam/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fastsam/model.py b/fastsam/model.py index a68f88c..12be2ce 100644 --- a/fastsam/model.py +++ b/fastsam/model.py @@ -8,6 +8,7 @@ model = FastSAM('last.pt') results = model.predict('ultralytics/assets/bus.jpg') """ +import traceback from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.engine.exporter import Exporter @@ -50,6 +51,8 @@ def predict(self, source=None, stream=False, **kwargs): try: return self.predictor(source, stream=stream) except Exception as e: + LOGGER.error("Failed to predict with: %s",e) + LOGGER.error(traceback.format_exc()) return None def train(self, **kwargs): From cfdeb330c32912fa365d7aca37c15480a11e6bf3 Mon Sep 17 00:00:00 2001 From: Lukas Hennies Date: Tue, 11 Feb 2025 12:44:41 +0100 Subject: [PATCH 4/7] fix: limit matplotlib due to tostring_RGB --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8c04939..4316eb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # Base----------------------------------- -matplotlib>=3.2.2 +matplotlib>=3.2.2, <3.10.0 opencv-python>=4.6.0 Pillow>=7.1.2 PyYAML>=5.3.1 From 35bb347275509088bb3f43da39d8586f802be124 Mon Sep 17 00:00:00 2001 From: Lukas Hennies Date: Tue, 11 Feb 2025 12:44:52 +0100 Subject: [PATCH 5/7] chore: ignore new outputs --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 5cb5580..e370921 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ weights build/ *.egg-info/ -gradio_cached_examples \ No newline at end of file +gradio_cached_examples +/output/ From 788cbf8adf57a3ee44ed8bf63e9e4f336e2763f1 Mon Sep 17 00:00:00 2001 From: Lukas Hennies Date: Tue, 11 Feb 2025 13:10:31 +0100 Subject: [PATCH 6/7] feat: add folder processing --- Inference.py | 84 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 48 insertions(+), 36 deletions(-) diff --git a/Inference.py b/Inference.py index 61b70fc..2b08002 100644 --- a/Inference.py +++ b/Inference.py @@ -1,5 +1,7 @@ import argparse -from fastsam import FastSAM, FastSAMPrompt +import pathlib + +from fastsam import FastSAM, FastSAMPrompt import ast import torch from PIL import Image @@ -77,42 +79,52 @@ def main(args): args.point_prompt = ast.literal_eval(args.point_prompt) args.box_prompt = convert_box_xywh_to_xyxy(ast.literal_eval(args.box_prompt)) args.point_label = ast.literal_eval(args.point_label) - input = Image.open(args.img_path) - input = input.convert("RGB") - everything_results = model( - input, - device=args.device, - retina_masks=args.retina, - imgsz=args.imgsz, - conf=args.conf, - iou=args.iou - ) - bboxes = None - points = None - point_label = None - prompt_process = FastSAMPrompt(input, everything_results, device=args.device) - if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0: - ann = prompt_process.box_prompt(bboxes=args.box_prompt) - bboxes = args.box_prompt - elif args.text_prompt != None: - ann = prompt_process.text_prompt(text=args.text_prompt) - elif args.point_prompt[0] != [0, 0]: - ann = prompt_process.point_prompt( - points=args.point_prompt, pointlabel=args.point_label - ) - points = args.point_prompt - point_label = args.point_label + img_path = pathlib.Path(args.img_path) + img_paths = [] + # iterate through entire folder if specified + if img_path.exists() and img_path.is_file(): + img_paths.append(img_path) else: - ann = prompt_process.everything_prompt() - prompt_process.plot( - annotations=ann, - output_path=args.output+args.img_path.split("/")[-1], - bboxes = bboxes, - points = points, - point_label = point_label, - withContours=args.withContours, - better_quality=args.better_quality, - ) + img_paths.extend(img_path.glob("*.jpg")) + img_paths.extend(img_path.glob("*.png")) + img_paths.extend(img_path.glob("*.bmp")) + for img_path in img_paths: + input = Image.open(img_path) + input = input.convert("RGB") + everything_results = model( + input, + device=args.device, + retina_masks=args.retina, + imgsz=args.imgsz, + conf=args.conf, + iou=args.iou + ) + bboxes = None + points = None + point_label = None + prompt_process = FastSAMPrompt(input, everything_results, device=args.device) + if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0: + ann = prompt_process.box_prompt(bboxes=args.box_prompt) + bboxes = args.box_prompt + elif args.text_prompt != None: + ann = prompt_process.text_prompt(text=args.text_prompt) + elif args.point_prompt[0] != [0, 0]: + ann = prompt_process.point_prompt( + points=args.point_prompt, pointlabel=args.point_label + ) + points = args.point_prompt + point_label = args.point_label + else: + ann = prompt_process.everything_prompt() + prompt_process.plot( + annotations=ann, + output_path=args.output+img_path.name, + bboxes = bboxes, + points = points, + point_label = point_label, + withContours=args.withContours, + better_quality=args.better_quality, + ) From 7d5d096ba09aafd68727e953d11e2f5053760ef1 Mon Sep 17 00:00:00 2001 From: Lukas Hennies Date: Tue, 11 Feb 2025 13:52:30 +0100 Subject: [PATCH 7/7] fix: enforce image resizing before Inference img size was not applied correctly before and resulted in ludicrous VRAM usage. --- Inference.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/Inference.py b/Inference.py index 2b08002..0b6d5dc 100644 --- a/Inference.py +++ b/Inference.py @@ -44,7 +44,8 @@ def parse_args(): default="[0]", help="[1,0] 0:background, 1:foreground", ) - parser.add_argument("--box_prompt", type=str, default="[[0,0,0,0]]", help="[[x,y,w,h],[x2,y2,w2,h2]] support multiple boxes") + parser.add_argument("--box_prompt", type=str, default="[[0,0,0,0]]", + help="[[x,y,w,h],[x2,y2,w2,h2]] support multiple boxes") parser.add_argument( "--better_quality", type=str, @@ -85,27 +86,30 @@ def main(args): if img_path.exists() and img_path.is_file(): img_paths.append(img_path) else: - img_paths.extend(img_path.glob("*.jpg")) - img_paths.extend(img_path.glob("*.png")) - img_paths.extend(img_path.glob("*.bmp")) + img_formats = ["*.jpg", "*.png", "*.bmp"] + for img_format in img_formats: + img_paths.extend(img_path.glob(img_format)) + for img_path in img_paths: - input = Image.open(img_path) - input = input.convert("RGB") + input_image = Image.open(img_path) + input_image = input_image.convert("RGB") + input_image = input_image.resize((args.imgsz, args.imgsz)) + everything_results = model( - input, + input_image, device=args.device, retina_masks=args.retina, imgsz=args.imgsz, conf=args.conf, iou=args.iou - ) + ) bboxes = None points = None point_label = None - prompt_process = FastSAMPrompt(input, everything_results, device=args.device) + prompt_process = FastSAMPrompt(input_image, everything_results, device=args.device) if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0: - ann = prompt_process.box_prompt(bboxes=args.box_prompt) - bboxes = args.box_prompt + ann = prompt_process.box_prompt(bboxes=args.box_prompt) + bboxes = args.box_prompt elif args.text_prompt != None: ann = prompt_process.text_prompt(text=args.text_prompt) elif args.point_prompt[0] != [0, 0]: @@ -118,17 +122,15 @@ def main(args): ann = prompt_process.everything_prompt() prompt_process.plot( annotations=ann, - output_path=args.output+img_path.name, - bboxes = bboxes, - points = points, - point_label = point_label, + output_path=args.output + img_path.name, + bboxes=bboxes, + points=points, + point_label=point_label, withContours=args.withContours, better_quality=args.better_quality, ) - - if __name__ == "__main__": args = parse_args() main(args)