Skip to content

Commit 22b970e

Browse files
committed
ready for test
1 parent e367e30 commit 22b970e

File tree

2 files changed

+101
-36
lines changed

2 files changed

+101
-36
lines changed

supervision/detection/overlap_filter.py

+22-13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from enum import Enum
24
from typing import List, Union
35

@@ -248,16 +250,23 @@ class OverlapFilter(Enum):
248250
NON_MAX_SUPPRESSION = "non_max_suppression"
249251
NON_MAX_MERGE = "non_max_merge"
250252

251-
252-
def validate_overlap_filter(
253-
strategy: Union[OverlapFilter, str],
254-
) -> OverlapFilter:
255-
if isinstance(strategy, str):
256-
try:
257-
strategy = OverlapFilter(strategy.lower())
258-
except ValueError:
259-
raise ValueError(
260-
f"Invalid strategy value: {strategy}. Must be one of "
261-
f"{[e.value for e in OverlapFilter]}"
262-
)
263-
return strategy
253+
@classmethod
254+
def list(cls):
255+
return list(map(lambda c: c.value, cls))
256+
257+
@classmethod
258+
def from_value(cls, value: Union[OverlapFilter, str]) -> OverlapFilter:
259+
if isinstance(value, cls):
260+
return value
261+
if isinstance(value, str):
262+
value = value.lower()
263+
try:
264+
return cls(value)
265+
except ValueError:
266+
raise ValueError(
267+
f"Invalid value: {value}. Must be one of {cls.list()}"
268+
)
269+
raise ValueError(
270+
f"Invalid value type: {type(value)}. Must be an instance of "
271+
f"{cls.__name__} or str."
272+
)

supervision/detection/tools/inference_slicer.py

+79-23
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
from supervision.config import ORIENTED_BOX_COORDINATES
88
from supervision.detection.core import Detections
9-
from supervision.detection.overlap_filter import OverlapFilter, validate_overlap_filter
9+
from supervision.detection.overlap_filter import OverlapFilter
1010
from supervision.detection.utils import move_boxes, move_masks, move_oriented_boxes
1111
from supervision.utils.image import crop_image
12-
from supervision.utils.internal import SupervisionWarnings
12+
from supervision.utils.internal import SupervisionWarnings, warn_deprecated, \
13+
deprecated_parameter
1314

1415

1516
def move_detections(
@@ -56,9 +57,14 @@ class InferenceSlicer:
5657
Args:
5758
slice_wh (Tuple[int, int]): Dimensions of each slice in the format
5859
`(width, height)`.
59-
overlap_ratio_wh (Tuple[float, float]): Overlap ratio between consecutive
60-
slices in the format `(width_ratio, height_ratio)`.
61-
overlap_filter_strategy (Union[OverlapFilter, str]): Strategy for
60+
overlap_ratio_wh (Optional[Tuple[float, float]]): A tuple representing the
61+
desired overlap ratio for width and height between consecutive slices.
62+
Each value should be in the range [0, 1), where 0 means no overlap and
63+
a value close to 1 means high overlap.
64+
overlap_wh (Optional[Tuple[int, int]]): A tuple representing the desired
65+
overlap for width and height between consecutive slices. Each value
66+
should be greater than 0.
67+
overlap_filter (Union[OverlapFilter, str]): Strategy for
6268
filtering or merging overlapping detections in slices.
6369
iou_threshold (float): Intersection over Union (IoU) threshold
6470
used when filtering by overlap.
@@ -73,23 +79,39 @@ class InferenceSlicer:
7379
not a multiple of the slice's width or height minus the overlap.
7480
"""
7581

82+
@deprecated_parameter(
83+
old_parameter="overlap_filter_strategy",
84+
new_parameter="overlap_filter",
85+
map_function=lambda x: x,
86+
warning_message="`{old_parameter}` in `{function_name}` is deprecated and will "
87+
"be remove in `supervision-0.27.0`. Use '{new_parameter}' "
88+
"instead.",
89+
)
7690
def __init__(
7791
self,
7892
callback: Callable[[np.ndarray], Detections],
7993
slice_wh: Tuple[int, int] = (320, 320),
80-
overlap_ratio_wh: Tuple[float, float] = (0.2, 0.2),
81-
overlap_filter_strategy: Union[
94+
overlap_ratio_wh: Optional[Tuple[float, float]] = (0.2, 0.2),
95+
overlap_wh: Optional[Tuple[int, int]] = None,
96+
overlap_filter: Union[
8297
OverlapFilter, str
8398
] = OverlapFilter.NON_MAX_SUPPRESSION,
8499
iou_threshold: float = 0.5,
85100
thread_workers: int = 1,
86101
):
87-
overlap_filter_strategy = validate_overlap_filter(overlap_filter_strategy)
102+
if overlap_ratio_wh is None:
103+
warn_deprecated(
104+
"`overlap_ratio_wh` in `InferenceSlicer.__init__` is deprecated and "
105+
"will be remove in `supervision-0.27.0`. Use `overlap_wh` instead."
106+
)
88107

89-
self.slice_wh = slice_wh
108+
self._validate_overlap(overlap_ratio_wh, overlap_wh)
90109
self.overlap_ratio_wh = overlap_ratio_wh
110+
self.overlap_wh = overlap_wh
111+
112+
self.slice_wh = slice_wh
91113
self.iou_threshold = iou_threshold
92-
self.overlap_filter_strategy = overlap_filter_strategy
114+
self.overlap_filter = OverlapFilter.from_value(overlap_filter)
93115
self.callback = callback
94116
self.thread_workers = thread_workers
95117

@@ -144,15 +166,15 @@ def callback(image_slice: np.ndarray) -> sv.Detections:
144166
detections_list.append(future.result())
145167

146168
merged = Detections.merge(detections_list=detections_list)
147-
if self.overlap_filter_strategy == OverlapFilter.NONE:
169+
if self.overlap_filter == OverlapFilter.NONE:
148170
return merged
149-
elif self.overlap_filter_strategy == OverlapFilter.NON_MAX_SUPPRESSION:
171+
elif self.overlap_filter == OverlapFilter.NON_MAX_SUPPRESSION:
150172
return merged.with_nms(threshold=self.iou_threshold)
151-
elif self.overlap_filter_strategy == OverlapFilter.NON_MAX_MERGE:
173+
elif self.overlap_filter == OverlapFilter.NON_MAX_MERGE:
152174
return merged.with_nmm(threshold=self.iou_threshold)
153175
else:
154176
warnings.warn(
155-
f"Invalid overlap filter strategy: {self.overlap_filter_strategy}",
177+
f"Invalid overlap filter strategy: {self.overlap_filter}",
156178
category=SupervisionWarnings,
157179
)
158180
return merged
@@ -182,7 +204,8 @@ def _run_callback(self, image, offset) -> Detections:
182204
def _generate_offset(
183205
resolution_wh: Tuple[int, int],
184206
slice_wh: Tuple[int, int],
185-
overlap_ratio_wh: Tuple[float, float],
207+
overlap_ratio_wh: Optional[Tuple[float, float]],
208+
overlap_wh: Optional[Tuple[int, int]]
186209
) -> np.ndarray:
187210
"""
188211
Generate offset coordinates for slicing an image based on the given resolution,
@@ -193,10 +216,13 @@ def _generate_offset(
193216
of the image to be sliced.
194217
slice_wh (Tuple[int, int]): A tuple representing the desired width and
195218
height of each slice.
196-
overlap_ratio_wh (Tuple[float, float]): A tuple representing the desired
197-
overlap ratio for width and height between consecutive slices. Each
198-
value should be in the range [0, 1), where 0 means no overlap and a
199-
value close to 1 means high overlap.
219+
overlap_ratio_wh (Optional[Tuple[float, float]]): A tuple representing the
220+
desired overlap ratio for width and height between consecutive slices.
221+
Each value should be in the range [0, 1), where 0 means no overlap and
222+
a value close to 1 means high overlap.
223+
overlap_wh (Optional[Tuple[int, int]]): A tuple representing the desired
224+
overlap for width and height between consecutive slices. Each value
225+
should be greater than 0.
200226
201227
Returns:
202228
np.ndarray: An array of shape `(n, 4)` containing coordinates for each
@@ -211,10 +237,17 @@ def _generate_offset(
211237
"""
212238
slice_width, slice_height = slice_wh
213239
image_width, image_height = resolution_wh
214-
overlap_ratio_width, overlap_ratio_height = overlap_ratio_wh
215-
216-
width_stride = slice_width - int(overlap_ratio_width * slice_width)
217-
height_stride = slice_height - int(overlap_ratio_height * slice_height)
240+
overlap_width = (
241+
overlap_wh[0]
242+
if overlap_wh is not None
243+
else int(overlap_ratio_wh[0] * slice_width))
244+
overlap_height = (
245+
overlap_wh[1]
246+
if overlap_wh is not None
247+
else int(overlap_ratio_wh[1] * slice_height))
248+
249+
width_stride = slice_width - overlap_width
250+
height_stride = slice_height - overlap_height
218251

219252
ws = np.arange(0, image_width, width_stride)
220253
hs = np.arange(0, image_height, height_stride)
@@ -226,3 +259,26 @@ def _generate_offset(
226259
offsets = np.stack([xmin, ymin, xmax, ymax], axis=-1).reshape(-1, 4)
227260

228261
return offsets
262+
263+
@staticmethod
264+
def _validate_overlap(
265+
overlap_ratio_wh: Optional[Tuple[float, float]],
266+
overlap_wh: Optional[Tuple[int, int]]
267+
) -> None:
268+
if overlap_ratio_wh is not None and overlap_wh is not None:
269+
raise ValueError(
270+
"Both `overlap_ratio_wh` and `overlap_wh` cannot be provided. "
271+
"Please provide only one of them."
272+
)
273+
if overlap_ratio_wh is not None:
274+
if not (0 <= overlap_ratio_wh[0] < 1 and 0 <= overlap_ratio_wh[1] < 1):
275+
raise ValueError(
276+
"Overlap ratios must be in the range [0, 1). "
277+
f"Received: {overlap_ratio_wh}"
278+
)
279+
if overlap_wh is not None:
280+
if not (overlap_wh[0] > 0 and overlap_wh[1] > 0):
281+
raise ValueError(
282+
"Overlap values must be greater than 0. "
283+
f"Received: {overlap_wh}"
284+
)

0 commit comments

Comments
 (0)