2
2
3
3
from abc import ABC , abstractmethod
4
4
from enum import Enum
5
- from typing import Any , Dict , Iterator , Tuple , Union
5
+ from typing import Any , Iterator , Set , Tuple
6
6
7
7
import numpy as np
8
8
import numpy .typing as npt
9
- from typing_extensions import Self
10
9
11
10
from supervision import config
12
11
from supervision .detection .core import Detections
13
- from supervision .metrics .utils import len0_like , pad_mask
12
+ from supervision .metrics .utils import pad_mask
14
13
15
14
CLASS_ID_NONE = - 1
15
+ CONFIDENCE_NONE = - 1
16
16
"""Used by metrics module as class ID, when none is present"""
17
17
18
18
@@ -22,7 +22,7 @@ class Metric(ABC):
22
22
"""
23
23
24
24
@abstractmethod
25
- def update (self , * args , ** kwargs ) -> Self :
25
+ def update (self , * args , ** kwargs ) -> "Metric" :
26
26
"""
27
27
Add data to the metric, without computing the result.
28
28
Return the metric itself to allow method chaining.
@@ -78,171 +78,176 @@ def __init__(self, metric: Metric, target: MetricTarget):
78
78
super ().__init__ (f"Metric { metric } does not support target { target } " )
79
79
80
80
81
- class InternalMetricDataStore :
81
+ class MetricData :
82
82
"""
83
- Stores internal data of IntersectionOverUnion metric:
84
- * Stores the basic data: boxes, masks, or oriented bounding boxes
85
- * Validates data: ensures data types and shape are consistent
86
- * Provides iteration by class
87
-
88
- Provides a class-agnostic mode, where all data is treated as a single class.
89
- Warning: numpy inputs are always considered as class-agnostic data.
90
-
91
- Data here refers to content of Detections objects: boxes, masks,
92
- or oriented bounding boxes.
83
+ A container for detection contents, decouple from Detections.
84
+ While a np.ndarray work for xyxy and obb, this approach solves
85
+ the mask concatenation problem.
93
86
"""
94
87
95
- def __init__ (self , metric_target : MetricTarget , class_agnostic : bool ):
88
+ def __init__ (self , metric_target : MetricTarget , class_agnostic : bool = False ):
96
89
self ._metric_target = metric_target
97
90
self ._class_agnostic = class_agnostic
98
- self ._data_1 : Dict [int , npt .NDArray ]
99
- self ._data_2 : Dict [int , npt .NDArray ]
100
- self ._mask_shape : Tuple [int , int ]
101
- self .reset ()
91
+ self .confidence = np .array ([], dtype = np .float32 )
92
+ self .class_id = np .array ([], dtype = int )
93
+ self .data : npt .NDArray = self ._get_empty_data ()
102
94
103
- def reset (self ) -> None :
104
- self ._data_1 = {}
105
- self ._data_2 = {}
106
- self ._mask_shape = (0 , 0 )
107
-
108
- def update (
109
- self ,
110
- data_1 : Union [npt .NDArray , Detections ],
111
- data_2 : Union [npt .NDArray , Detections ],
112
- ) -> None :
113
- """
114
- Add new data to the store.
95
+ def update (self , detections : Detections ):
96
+ """Add new detections to the store."""
97
+ new_data = self ._get_content (detections )
98
+ self ._validate_shape (new_data )
115
99
116
- Use sv.Detections.empty() if only one set of data is available.
117
- """
118
- content_1 = self . _get_content ( data_1 )
119
- content_2 = self ._get_content ( data_2 )
120
- self ._validate_shape ( content_1 )
121
- self ._validate_shape ( content_2 )
100
+ if self . _metric_target == MetricTarget . BOXES :
101
+ self . _append_boxes ( new_data )
102
+ elif self . _metric_target == MetricTarget . MASKS :
103
+ self ._append_mask ( new_data )
104
+ elif self ._metric_target == MetricTarget . ORIENTED_BOUNDING_BOXES :
105
+ self .data = np . vstack (( self . data , new_data ) )
122
106
123
- class_ids_1 = self ._get_class_ids (data_1 )
124
- class_ids_2 = self ._get_class_ids (data_2 )
125
- self ._validate_class_ids (class_ids_1 , class_ids_2 )
107
+ confidence = self ._get_confidence (detections )
108
+ self ._append_confidence (confidence )
126
109
127
- if self ._metric_target == MetricTarget .MASKS :
128
- content_1 = self ._expand_mask_shape (content_1 )
129
- content_2 = self ._expand_mask_shape (content_2 )
130
-
131
- for class_id in set (class_ids_1 ):
132
- content_of_class = content_1 [class_ids_1 == class_id ]
133
- stored_content_of_class = self ._data_1 .get (class_id , len0_like (content_1 ))
134
- self ._data_1 [class_id ] = np .vstack (
135
- (stored_content_of_class , content_of_class )
136
- )
110
+ class_id = self ._get_class_id (detections )
111
+ self ._append_class_id (class_id )
137
112
138
- for class_id in set (class_ids_2 ):
139
- content_of_class = content_2 [class_ids_2 == class_id ]
140
- stored_content_of_class = self ._data_2 .get (class_id , len0_like (content_2 ))
141
- self ._data_2 [class_id ] = np .vstack (
142
- (stored_content_of_class , content_of_class )
113
+ if len (self .class_id ) != len (self .confidence ) or len (self .class_id ) != len (
114
+ self .data
115
+ ):
116
+ raise ValueError (
117
+ f"Inconsistent data length: class_id={ len (class_id )} ,"
118
+ f" confidence={ len (confidence )} , data={ len (new_data )} "
143
119
)
144
120
145
- def __getitem__ (self , class_id : int ) -> Tuple [npt .NDArray , npt .NDArray ]:
146
- return (
147
- self ._data_1 .get (class_id , self ._make_empty ()),
148
- self ._data_2 .get (class_id , self ._make_empty ()),
149
- )
121
+ def get_classes (self ) -> Set [int ]:
122
+ """Return all class IDs."""
123
+ return set (self .class_id )
150
124
151
- def __iter__ (
152
- self ,
153
- ) -> Iterator [Tuple [int , npt .NDArray , npt .NDArray ]]:
154
- class_ids = sorted (set (self ._data_1 .keys ()) | set (self ._data_2 .keys ()))
155
- for class_id in class_ids :
156
- yield (
157
- class_id ,
158
- * self [class_id ],
159
- )
125
+ def get_subset_by_class (self , class_id : int ) -> MetricData :
126
+ """Return data, confidence and class_id for a specific class."""
127
+ mask = self .class_id == class_id
128
+ new_data_obj = MetricData (self ._metric_target )
129
+ new_data_obj .data = self .data [mask ]
130
+ new_data_obj .confidence = self .confidence [mask ]
131
+ new_data_obj .class_id = self .class_id [mask ]
132
+ return new_data_obj
160
133
161
- def _get_content (self , data : Union [npt .NDArray , Detections ]) -> npt .NDArray :
162
- """Return boxes, masks or oriented bounding boxes from the data."""
163
- if not isinstance (data , (Detections , np .ndarray )):
164
- raise ValueError (
165
- f"Invalid data type: { type (data )} ."
166
- f" Only Detections or np.ndarray are supported."
167
- )
168
- if isinstance (data , np .ndarray ):
169
- return data
134
+ def __len__ (self ) -> int :
135
+ return len (self .data )
170
136
137
+ def _get_content (self , detections : Detections ) -> npt .NDArray :
138
+ """Return boxes, masks or oriented bounding boxes from the data."""
171
139
if self ._metric_target == MetricTarget .BOXES :
172
- return data .xyxy
140
+ return detections .xyxy
173
141
if self ._metric_target == MetricTarget .MASKS :
174
142
return (
175
- data .mask if data .mask is not None else np .zeros ((0 , 0 , 0 ), dtype = bool )
143
+ detections .mask
144
+ if detections .mask is not None
145
+ else self ._get_empty_data ()
176
146
)
177
147
if self ._metric_target == MetricTarget .ORIENTED_BOUNDING_BOXES :
178
- obb = data .data .get (
179
- config .ORIENTED_BOX_COORDINATES , np . zeros (( 0 , 8 ), dtype = np . float32 )
148
+ obb = detections .data .get (
149
+ config .ORIENTED_BOX_COORDINATES , self . _get_empty_data ( )
180
150
)
181
- return np .array (obb , dtype = np .float32 )
151
+ return np .ndarray (obb , dtype = np .float32 )
182
152
raise ValueError (f"Invalid metric target: { self ._metric_target } " )
183
153
184
- def _get_class_ids (
185
- self , data : Union [npt .NDArray , Detections ]
186
- ) -> npt .NDArray [np .int_ ]:
187
- """
188
- Return an array of class IDs from the data. Guaranteed to
189
- match the length of data.
190
- """
191
- if (
192
- self ._class_agnostic
193
- or isinstance (data , np .ndarray )
194
- or data .class_id is None
195
- ):
196
- return np .array ([CLASS_ID_NONE ] * len (data ), dtype = int )
197
- return data .class_id
198
-
199
- def _validate_class_ids (
200
- self , class_id_1 : npt .NDArray [np .int_ ], class_id_2 : npt .NDArray [np .int_ ]
201
- ) -> None :
202
- class_set = set (class_id_1 ) | set (class_id_2 )
203
- if len (class_set ) >= 2 and CLASS_ID_NONE in class_set :
204
- raise ValueError (
205
- "Metrics cannot mix data with class ID and data without class ID."
206
- )
154
+ def _get_class_id (self , detections : Detections ) -> npt .NDArray [np .int_ ]:
155
+ if self ._class_agnostic or detections .class_id is None :
156
+ return np .array ([CLASS_ID_NONE ] * len (detections ), dtype = int )
157
+ return detections .class_id
158
+
159
+ def _get_confidence (self , detections : Detections ) -> npt .NDArray [np .float32 ]:
160
+ if detections .confidence is None :
161
+ return np .full (len (detections ), - 1 , dtype = np .float32 )
162
+ return detections .confidence
163
+
164
+ def _append_class_id (self , new_class_id : npt .NDArray [np .int_ ]) -> None :
165
+ self .class_id = np .hstack ((self .class_id , new_class_id ))
166
+
167
+ def _append_confidence (self , new_confidence : npt .NDArray [np .float32 ]) -> None :
168
+ self .confidence = np .hstack ((self .confidence , new_confidence ))
169
+
170
+ def _append_boxes (self , new_boxes : npt .NDArray [np .float32 ]) -> None :
171
+ """Stack new xyxy or obb boxes on top of stored boxes."""
172
+ if self ._metric_target not in [
173
+ MetricTarget .BOXES ,
174
+ MetricTarget .ORIENTED_BOUNDING_BOXES ,
175
+ ]:
176
+ raise ValueError ("This method is only for box data." )
177
+ self .data = np .vstack ((self .data , new_boxes ))
178
+
179
+ def _append_mask (self , new_mask : npt .NDArray [np .bool_ ]) -> None :
180
+ """Stack new mask onto stored masks. Expand the shapes if necessary."""
181
+ if self ._metric_target != MetricTarget .MASKS :
182
+ raise ValueError ("This method is only for mask data." )
183
+ self ._validate_mask_shape (new_mask )
184
+
185
+ new_width = max (self .data .shape [1 ], new_mask .shape [1 ])
186
+ new_height = max (self .data .shape [2 ], new_mask .shape [2 ])
187
+
188
+ data = pad_mask (self .data , (new_width , new_height ))
189
+ new_mask = pad_mask (new_mask , (new_width , new_height ))
190
+
191
+ self .data = np .vstack ((data , new_mask ))
192
+
193
+ def _get_empty_data (self ) -> npt .NDArray :
194
+ if self ._metric_target == MetricTarget .BOXES :
195
+ return np .empty ((0 , 4 ), dtype = np .float32 )
196
+ if self ._metric_target == MetricTarget .MASKS :
197
+ return np .empty ((0 , 0 , 0 ), dtype = bool )
198
+ if self ._metric_target == MetricTarget .ORIENTED_BOUNDING_BOXES :
199
+ return np .empty ((0 , 8 ), dtype = np .float32 )
200
+ raise ValueError (f"Invalid metric target: { self ._metric_target } " )
207
201
208
202
def _validate_shape (self , data : npt .NDArray ) -> None :
209
- shape = data .shape
210
203
if self ._metric_target == MetricTarget .BOXES :
211
- if len (shape ) != 2 or shape [1 ] != 4 :
212
- raise ValueError (f"Invalid xyxy shape: { shape } . Expected: (N, 4)" )
204
+ if len (data . shape ) != 2 or data . shape [1 ] != 4 :
205
+ raise ValueError (f"Invalid xyxy shape: { data . shape } . Expected: (N, 4)" )
213
206
elif self ._metric_target == MetricTarget .MASKS :
214
- if len (shape ) != 3 :
215
- raise ValueError (f"Invalid mask shape: { shape } . Expected: (N, H, W)" )
207
+ if len (data .shape ) != 3 :
208
+ raise ValueError (
209
+ f"Invalid mask shape: { data .shape } . Expected: (N, H, W)"
210
+ )
216
211
elif self ._metric_target == MetricTarget .ORIENTED_BOUNDING_BOXES :
217
- if len (shape ) != 2 or shape [1 ] != 8 :
218
- raise ValueError (f"Invalid obb shape: { shape } . Expected: (N, 8)" )
212
+ if len (data . shape ) != 2 or data . shape [1 ] != 8 :
213
+ raise ValueError (f"Invalid obb shape: { data . shape } . Expected: (N, 8)" )
219
214
else :
220
215
raise ValueError (f"Invalid metric target: { self ._metric_target } " )
221
216
222
- def _expand_mask_shape (self , data : npt .NDArray ) -> npt .NDArray :
223
- """Pad the stored and new data to the same shape."""
224
- if self ._metric_target != MetricTarget .MASKS :
225
- return data
226
217
227
- new_width = max ( self . _mask_shape [ 0 ], data . shape [ 1 ])
228
- new_height = max ( self . _mask_shape [ 1 ], data . shape [ 2 ])
229
- self . _mask_shape = ( new_width , new_height )
218
+ class InternalMetricDataStore :
219
+ """
220
+ Stores internal data for metrics.
230
221
231
- data = pad_mask (data , self ._mask_shape )
222
+ Provides a class-agnostic way to access it.
223
+ """
232
224
233
- for class_id , prev_data in self ._data_1 .items ():
234
- self ._data_1 [class_id ] = pad_mask (prev_data , self ._mask_shape )
235
- for class_id , prev_data in self ._data_2 .items ():
236
- self ._data_2 [class_id ] = pad_mask (prev_data , self ._mask_shape )
225
+ def __init__ (self , metric_target : MetricTarget , class_agnostic : bool = False ):
226
+ self ._metric_target = metric_target
227
+ self ._class_agnostic = class_agnostic
228
+ self ._data_1 : MetricData
229
+ self ._data_2 : MetricData
230
+ self .reset ()
237
231
238
- return data
232
+ def reset (self ) -> None :
233
+ self ._data_1 = MetricData (self ._metric_target , self ._class_agnostic )
234
+ self ._data_2 = MetricData (self ._metric_target , self ._class_agnostic )
239
235
240
- def _make_empty (self ) -> npt .NDArray :
241
- """Create an empty data object with the best-known shape for the target."""
242
- if self ._metric_target == MetricTarget .BOXES :
243
- return np .empty ((0 , 4 ), dtype = np .float32 )
244
- if self ._metric_target == MetricTarget .MASKS :
245
- return np .empty ((0 , * self ._mask_shape ), dtype = bool )
246
- if self ._metric_target == MetricTarget .ORIENTED_BOUNDING_BOXES :
247
- return np .empty ((0 , 8 ), dtype = np .float32 )
248
- raise ValueError (f"Invalid metric target: { self ._metric_target } " )
236
+ def update (self , data_1 : Detections , data_2 : Detections ) -> None :
237
+ """
238
+ Add new data to the store.
239
+
240
+ Use sv.Detections.empty() if only one set of data is available.
241
+ """
242
+ self ._data_1 .update (data_1 )
243
+ self ._data_2 .update (data_2 )
244
+
245
+ def __getitem__ (self , class_id : int ) -> Tuple [MetricData , MetricData ]:
246
+ return (
247
+ self ._data_1 .get_subset_by_class (class_id ),
248
+ self ._data_2 .get_subset_by_class (class_id ),
249
+ )
250
+
251
+ def __iter__ (self ) -> Iterator [Tuple [int , MetricData , MetricData ]]:
252
+ for class_id in self ._data_1 .get_classes ():
253
+ yield class_id , * self [class_id ]
0 commit comments