Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
weights
build/
*.egg-info/
gradio_cached_examples
gradio_cached_examples
/output/
86 changes: 50 additions & 36 deletions Inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
from fastsam import FastSAM, FastSAMPrompt
import pathlib

from fastsam import FastSAM, FastSAMPrompt
import ast
import torch
from PIL import Image
Expand Down Expand Up @@ -42,7 +44,8 @@ def parse_args():
default="[0]",
help="[1,0] 0:background, 1:foreground",
)
parser.add_argument("--box_prompt", type=str, default="[[0,0,0,0]]", help="[[x,y,w,h],[x2,y2,w2,h2]] support multiple boxes")
parser.add_argument("--box_prompt", type=str, default="[[0,0,0,0]]",
help="[[x,y,w,h],[x2,y2,w2,h2]] support multiple boxes")
parser.add_argument(
"--better_quality",
type=str,
Expand Down Expand Up @@ -77,44 +80,55 @@ def main(args):
args.point_prompt = ast.literal_eval(args.point_prompt)
args.box_prompt = convert_box_xywh_to_xyxy(ast.literal_eval(args.box_prompt))
args.point_label = ast.literal_eval(args.point_label)
input = Image.open(args.img_path)
input = input.convert("RGB")
everything_results = model(
input,
device=args.device,
retina_masks=args.retina,
imgsz=args.imgsz,
conf=args.conf,
iou=args.iou
img_path = pathlib.Path(args.img_path)
img_paths = []
# iterate through entire folder if specified
if img_path.exists() and img_path.is_file():
img_paths.append(img_path)
else:
img_formats = ["*.jpg", "*.png", "*.bmp"]
for img_format in img_formats:
img_paths.extend(img_path.glob(img_format))

for img_path in img_paths:
input_image = Image.open(img_path)
input_image = input_image.convert("RGB")
input_image = input_image.resize((args.imgsz, args.imgsz))

everything_results = model(
input_image,
device=args.device,
retina_masks=args.retina,
imgsz=args.imgsz,
conf=args.conf,
iou=args.iou
)
bboxes = None
points = None
point_label = None
prompt_process = FastSAMPrompt(input, everything_results, device=args.device)
if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0:
bboxes = None
points = None
point_label = None
prompt_process = FastSAMPrompt(input_image, everything_results, device=args.device)
if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0:
ann = prompt_process.box_prompt(bboxes=args.box_prompt)
bboxes = args.box_prompt
elif args.text_prompt != None:
ann = prompt_process.text_prompt(text=args.text_prompt)
elif args.point_prompt[0] != [0, 0]:
ann = prompt_process.point_prompt(
points=args.point_prompt, pointlabel=args.point_label
elif args.text_prompt != None:
ann = prompt_process.text_prompt(text=args.text_prompt)
elif args.point_prompt[0] != [0, 0]:
ann = prompt_process.point_prompt(
points=args.point_prompt, pointlabel=args.point_label
)
points = args.point_prompt
point_label = args.point_label
else:
ann = prompt_process.everything_prompt()
prompt_process.plot(
annotations=ann,
output_path=args.output + img_path.name,
bboxes=bboxes,
points=points,
point_label=point_label,
withContours=args.withContours,
better_quality=args.better_quality,
)
points = args.point_prompt
point_label = args.point_label
else:
ann = prompt_process.everything_prompt()
prompt_process.plot(
annotations=ann,
output_path=args.output+args.img_path.split("/")[-1],
bboxes = bboxes,
points = points,
point_label = point_label,
withContours=args.withContours,
better_quality=args.better_quality,
)




if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions fastsam/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
model = FastSAM('last.pt')
results = model.predict('ultralytics/assets/bus.jpg')
"""
import traceback

from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
Expand Down Expand Up @@ -50,6 +51,8 @@ def predict(self, source=None, stream=False, **kwargs):
try:
return self.predictor(source, stream=stream)
except Exception as e:
LOGGER.error("Failed to predict with: %s",e)
LOGGER.error(traceback.format_exc())
return None

def train(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Base-----------------------------------
matplotlib>=3.2.2
matplotlib>=3.2.2, <3.10.0
opencv-python>=4.6.0
Pillow>=7.1.2
PyYAML>=5.3.1
Expand All @@ -13,7 +13,7 @@ pandas>=1.1.4
seaborn>=0.11.0

gradio==3.35.2

psutil>=6.0.0
# Ultralytics-----------------------------------
# ultralytics == 8.0.120

23 changes: 20 additions & 3 deletions ultralytics/nn/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
import torch
import torch.nn as nn

from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import Sequential, Conv2d, MaxPool2d, Upsample, ConvTranspose2d
from ultralytics.nn.modules.block import C2f, DFL, Bottleneck
from torch.nn.modules.container import ModuleList
from torch.nn.modules.activation import SiLU
from ultralytics.nn.modules.conv import Concat
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
RTDETRDecoder, Segment)
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
RTDETRDecoder, Segment, Proto)
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load, \
IterableSimpleNamespace
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.yolo.utils.plotting import feature_visualization
Expand Down Expand Up @@ -575,7 +582,17 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):

def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
"""Loads a single model weights."""
ckpt, weight = torch_safe_load(weight) # load ckpt
with torch.serialization.safe_globals([SegmentationModel, Sequential, Conv, Conv2,
Conv2d, BatchNorm2d, SiLU, C2f, ModuleList, Bottleneck, SPPF, MaxPool2d,
Upsample, Concat, Segment,
DFL, Proto, ConvTranspose2d, AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck,
BottleneckCSP, C2f, C3Ghost, C3x,
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv,
DWConvTranspose2d,
Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
RTDETRDecoder, Segment, getattr, IterableSimpleNamespace
]):
ckpt, weight = torch_safe_load(weight) # load ckpt
args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model

Expand Down