6
6
7
7
from supervision .config import ORIENTED_BOX_COORDINATES
8
8
from supervision .detection .core import Detections
9
- from supervision .detection .overlap_filter import OverlapFilter , validate_overlap_filter
9
+ from supervision .detection .overlap_filter import OverlapFilter
10
10
from supervision .detection .utils import move_boxes , move_masks , move_oriented_boxes
11
11
from supervision .utils .image import crop_image
12
- from supervision .utils .internal import SupervisionWarnings
12
+ from supervision .utils .internal import SupervisionWarnings , warn_deprecated , \
13
+ deprecated_parameter
13
14
14
15
15
16
def move_detections (
@@ -56,9 +57,14 @@ class InferenceSlicer:
56
57
Args:
57
58
slice_wh (Tuple[int, int]): Dimensions of each slice in the format
58
59
`(width, height)`.
59
- overlap_ratio_wh (Tuple[float, float]): Overlap ratio between consecutive
60
- slices in the format `(width_ratio, height_ratio)`.
61
- overlap_filter_strategy (Union[OverlapFilter, str]): Strategy for
60
+ overlap_ratio_wh (Optional[Tuple[float, float]]): A tuple representing the
61
+ desired overlap ratio for width and height between consecutive slices.
62
+ Each value should be in the range [0, 1), where 0 means no overlap and
63
+ a value close to 1 means high overlap.
64
+ overlap_wh (Optional[Tuple[int, int]]): A tuple representing the desired
65
+ overlap for width and height between consecutive slices. Each value
66
+ should be greater than 0.
67
+ overlap_filter (Union[OverlapFilter, str]): Strategy for
62
68
filtering or merging overlapping detections in slices.
63
69
iou_threshold (float): Intersection over Union (IoU) threshold
64
70
used when filtering by overlap.
@@ -73,23 +79,39 @@ class InferenceSlicer:
73
79
not a multiple of the slice's width or height minus the overlap.
74
80
"""
75
81
82
+ @deprecated_parameter (
83
+ old_parameter = "overlap_filter_strategy" ,
84
+ new_parameter = "overlap_filter" ,
85
+ map_function = lambda x : x ,
86
+ warning_message = "`{old_parameter}` in `{function_name}` is deprecated and will "
87
+ "be remove in `supervision-0.27.0`. Use '{new_parameter}' "
88
+ "instead." ,
89
+ )
76
90
def __init__ (
77
91
self ,
78
92
callback : Callable [[np .ndarray ], Detections ],
79
93
slice_wh : Tuple [int , int ] = (320 , 320 ),
80
- overlap_ratio_wh : Tuple [float , float ] = (0.2 , 0.2 ),
81
- overlap_filter_strategy : Union [
94
+ overlap_ratio_wh : Optional [Tuple [float , float ]] = (0.2 , 0.2 ),
95
+ overlap_wh : Optional [Tuple [int , int ]] = None ,
96
+ overlap_filter : Union [
82
97
OverlapFilter , str
83
98
] = OverlapFilter .NON_MAX_SUPPRESSION ,
84
99
iou_threshold : float = 0.5 ,
85
100
thread_workers : int = 1 ,
86
101
):
87
- overlap_filter_strategy = validate_overlap_filter (overlap_filter_strategy )
102
+ if overlap_ratio_wh is None :
103
+ warn_deprecated (
104
+ "`overlap_ratio_wh` in `InferenceSlicer.__init__` is deprecated and "
105
+ "will be remove in `supervision-0.27.0`. Use `overlap_wh` instead."
106
+ )
88
107
89
- self .slice_wh = slice_wh
108
+ self ._validate_overlap ( overlap_ratio_wh , overlap_wh )
90
109
self .overlap_ratio_wh = overlap_ratio_wh
110
+ self .overlap_wh = overlap_wh
111
+
112
+ self .slice_wh = slice_wh
91
113
self .iou_threshold = iou_threshold
92
- self .overlap_filter_strategy = overlap_filter_strategy
114
+ self .overlap_filter = OverlapFilter . from_value ( overlap_filter )
93
115
self .callback = callback
94
116
self .thread_workers = thread_workers
95
117
@@ -144,15 +166,15 @@ def callback(image_slice: np.ndarray) -> sv.Detections:
144
166
detections_list .append (future .result ())
145
167
146
168
merged = Detections .merge (detections_list = detections_list )
147
- if self .overlap_filter_strategy == OverlapFilter .NONE :
169
+ if self .overlap_filter == OverlapFilter .NONE :
148
170
return merged
149
- elif self .overlap_filter_strategy == OverlapFilter .NON_MAX_SUPPRESSION :
171
+ elif self .overlap_filter == OverlapFilter .NON_MAX_SUPPRESSION :
150
172
return merged .with_nms (threshold = self .iou_threshold )
151
- elif self .overlap_filter_strategy == OverlapFilter .NON_MAX_MERGE :
173
+ elif self .overlap_filter == OverlapFilter .NON_MAX_MERGE :
152
174
return merged .with_nmm (threshold = self .iou_threshold )
153
175
else :
154
176
warnings .warn (
155
- f"Invalid overlap filter strategy: { self .overlap_filter_strategy } " ,
177
+ f"Invalid overlap filter strategy: { self .overlap_filter } " ,
156
178
category = SupervisionWarnings ,
157
179
)
158
180
return merged
@@ -182,7 +204,8 @@ def _run_callback(self, image, offset) -> Detections:
182
204
def _generate_offset (
183
205
resolution_wh : Tuple [int , int ],
184
206
slice_wh : Tuple [int , int ],
185
- overlap_ratio_wh : Tuple [float , float ],
207
+ overlap_ratio_wh : Optional [Tuple [float , float ]],
208
+ overlap_wh : Optional [Tuple [int , int ]]
186
209
) -> np .ndarray :
187
210
"""
188
211
Generate offset coordinates for slicing an image based on the given resolution,
@@ -193,10 +216,13 @@ def _generate_offset(
193
216
of the image to be sliced.
194
217
slice_wh (Tuple[int, int]): A tuple representing the desired width and
195
218
height of each slice.
196
- overlap_ratio_wh (Tuple[float, float]): A tuple representing the desired
197
- overlap ratio for width and height between consecutive slices. Each
198
- value should be in the range [0, 1), where 0 means no overlap and a
199
- value close to 1 means high overlap.
219
+ overlap_ratio_wh (Optional[Tuple[float, float]]): A tuple representing the
220
+ desired overlap ratio for width and height between consecutive slices.
221
+ Each value should be in the range [0, 1), where 0 means no overlap and
222
+ a value close to 1 means high overlap.
223
+ overlap_wh (Optional[Tuple[int, int]]): A tuple representing the desired
224
+ overlap for width and height between consecutive slices. Each value
225
+ should be greater than 0.
200
226
201
227
Returns:
202
228
np.ndarray: An array of shape `(n, 4)` containing coordinates for each
@@ -211,10 +237,17 @@ def _generate_offset(
211
237
"""
212
238
slice_width , slice_height = slice_wh
213
239
image_width , image_height = resolution_wh
214
- overlap_ratio_width , overlap_ratio_height = overlap_ratio_wh
215
-
216
- width_stride = slice_width - int (overlap_ratio_width * slice_width )
217
- height_stride = slice_height - int (overlap_ratio_height * slice_height )
240
+ overlap_width = (
241
+ overlap_wh [0 ]
242
+ if overlap_wh is not None
243
+ else int (overlap_ratio_wh [0 ] * slice_width ))
244
+ overlap_height = (
245
+ overlap_wh [1 ]
246
+ if overlap_wh is not None
247
+ else int (overlap_ratio_wh [1 ] * slice_height ))
248
+
249
+ width_stride = slice_width - overlap_width
250
+ height_stride = slice_height - overlap_height
218
251
219
252
ws = np .arange (0 , image_width , width_stride )
220
253
hs = np .arange (0 , image_height , height_stride )
@@ -226,3 +259,26 @@ def _generate_offset(
226
259
offsets = np .stack ([xmin , ymin , xmax , ymax ], axis = - 1 ).reshape (- 1 , 4 )
227
260
228
261
return offsets
262
+
263
+ @staticmethod
264
+ def _validate_overlap (
265
+ overlap_ratio_wh : Optional [Tuple [float , float ]],
266
+ overlap_wh : Optional [Tuple [int , int ]]
267
+ ) -> None :
268
+ if overlap_ratio_wh is not None and overlap_wh is not None :
269
+ raise ValueError (
270
+ "Both `overlap_ratio_wh` and `overlap_wh` cannot be provided. "
271
+ "Please provide only one of them."
272
+ )
273
+ if overlap_ratio_wh is not None :
274
+ if not (0 <= overlap_ratio_wh [0 ] < 1 and 0 <= overlap_ratio_wh [1 ] < 1 ):
275
+ raise ValueError (
276
+ "Overlap ratios must be in the range [0, 1). "
277
+ f"Received: { overlap_ratio_wh } "
278
+ )
279
+ if overlap_wh is not None :
280
+ if not (overlap_wh [0 ] > 0 and overlap_wh [1 ] > 0 ):
281
+ raise ValueError (
282
+ "Overlap values must be greater than 0. "
283
+ f"Received: { overlap_wh } "
284
+ )
0 commit comments