Skip to content

Commit

Permalink
Fix TFLite Segment inference (#13488)
Browse files Browse the repository at this point in the history
* Fix TFLite Segment inference

* Auto-format by https://ultralytics.com

---------

Co-authored-by: UltralyticsAssistant <[email protected]>
  • Loading branch information
Y-T-G and UltralyticsAssistant authored Jan 12, 2025
1 parent 86fd1ab commit 6420a1d
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 12 deletions.
3 changes: 3 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,9 @@ def forward(self, im, augment=False, visualize=False):
scale, zero_point = output["quantization"]
x = (x.astype(np.float32) - zero_point) * scale # re-scale
y.append(x)
if len(y) == 2: # segment with (det, proto) output order reversed
if len(y[1].shape) != 4:
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels

Expand Down
6 changes: 3 additions & 3 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,9 @@ def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vi
self._new_video(videos[0]) # new video
else:
self.cap = None
assert self.nf > 0, (
f"No images or videos found in {p}. Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
)
assert (
self.nf > 0
), f"No images or videos found in {p}. Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"

def __iter__(self):
"""Initializes iterator by resetting count and returns the iterator object itself."""
Expand Down
6 changes: 3 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,9 @@ def check_file(file, suffix=""):
assert Path(file).exists() and Path(file).stat().st_size > 0, f"File download failed: {url}" # check
return file
elif file.startswith("clearml://"): # ClearML Dataset ID
assert "clearml" in sys.modules, (
"ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
)
assert (
"clearml" in sys.modules
), "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
return file
else: # search
files = []
Expand Down
8 changes: 5 additions & 3 deletions utils/loggers/clearml/clearml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def construct_dataset(clearml_info_string):
with open(yaml_filenames[0]) as f:
dataset_definition = yaml.safe_load(f)

assert set(dataset_definition.keys()).issuperset({"train", "test", "val", "nc", "names"}), (
"The right keys were not found in the yaml file, make sure it at least has the following keys: ('train', 'test', 'val', 'nc', 'names')"
)
assert set(
dataset_definition.keys()
).issuperset(
{"train", "test", "val", "nc", "names"}
), "The right keys were not found in the yaml file, make sure it at least has the following keys: ('train', 'test', 'val', 'nc', 'names')"

data_dict = {
"train": (
Expand Down
6 changes: 3 additions & 3 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def select_device(device="", batch_size=0, newline=True):
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(",", "")), (
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
)
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(
device.replace(",", "")
), f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"

if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
Expand Down

0 comments on commit 6420a1d

Please sign in to comment.