Skip to content

Commit ab76a47

Browse files
committed
Refactored metrics store, removed typing extensions
* This commit is in a messy state - tests were not adapted to new metrics.core * MeanAveragePrecision is incomplete
1 parent 85e0a54 commit ab76a47

File tree

6 files changed

+474
-228
lines changed

6 files changed

+474
-228
lines changed

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ tqdm = { version = ">=4.62.3,<=4.66.5", optional = true }
5959
# pandas: picked lowest major version that supports Python 3.8
6060
pandas = { version = ">=2.0.0", optional = true }
6161
pandas-stubs = { version = ">=2.0.0.230412", optional = true }
62-
typing-extensions = "^4.12.2"
6362

6463
[tool.poetry.extras]
6564
desktop = ["opencv-python"]

supervision/metrics/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,11 @@
44
MetricTarget,
55
UnsupportedMetricTargetError,
66
)
7-
from supervision.metrics.intersection_over_union import IntersectionOverUnion
7+
from supervision.metrics.intersection_over_union import (
8+
IntersectionOverUnion,
9+
IntersectionOverUnionResult,
10+
)
11+
from supervision.metrics.mean_average_precision import (
12+
MeanAveragePrecision,
13+
MeanAveragePrecisionResult,
14+
)

supervision/metrics/core.py

Lines changed: 141 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22

33
from abc import ABC, abstractmethod
44
from enum import Enum
5-
from typing import Any, Dict, Iterator, Tuple, Union
5+
from typing import Any, Iterator, Set, Tuple
66

77
import numpy as np
88
import numpy.typing as npt
9-
from typing_extensions import Self
109

1110
from supervision import config
1211
from supervision.detection.core import Detections
13-
from supervision.metrics.utils import len0_like, pad_mask
12+
from supervision.metrics.utils import pad_mask
1413

1514
CLASS_ID_NONE = -1
15+
CONFIDENCE_NONE = -1
1616
"""Used by metrics module as class ID, when none is present"""
1717

1818

@@ -22,7 +22,7 @@ class Metric(ABC):
2222
"""
2323

2424
@abstractmethod
25-
def update(self, *args, **kwargs) -> Self:
25+
def update(self, *args, **kwargs) -> "Metric":
2626
"""
2727
Add data to the metric, without computing the result.
2828
Return the metric itself to allow method chaining.
@@ -78,171 +78,176 @@ def __init__(self, metric: Metric, target: MetricTarget):
7878
super().__init__(f"Metric {metric} does not support target {target}")
7979

8080

81-
class InternalMetricDataStore:
81+
class MetricData:
8282
"""
83-
Stores internal data of IntersectionOverUnion metric:
84-
* Stores the basic data: boxes, masks, or oriented bounding boxes
85-
* Validates data: ensures data types and shape are consistent
86-
* Provides iteration by class
87-
88-
Provides a class-agnostic mode, where all data is treated as a single class.
89-
Warning: numpy inputs are always considered as class-agnostic data.
90-
91-
Data here refers to content of Detections objects: boxes, masks,
92-
or oriented bounding boxes.
83+
A container for detection contents, decouple from Detections.
84+
While a np.ndarray work for xyxy and obb, this approach solves
85+
the mask concatenation problem.
9386
"""
9487

95-
def __init__(self, metric_target: MetricTarget, class_agnostic: bool):
88+
def __init__(self, metric_target: MetricTarget, class_agnostic: bool = False):
9689
self._metric_target = metric_target
9790
self._class_agnostic = class_agnostic
98-
self._data_1: Dict[int, npt.NDArray]
99-
self._data_2: Dict[int, npt.NDArray]
100-
self._mask_shape: Tuple[int, int]
101-
self.reset()
91+
self.confidence = np.array([], dtype=np.float32)
92+
self.class_id = np.array([], dtype=int)
93+
self.data: npt.NDArray = self._get_empty_data()
10294

103-
def reset(self) -> None:
104-
self._data_1 = {}
105-
self._data_2 = {}
106-
self._mask_shape = (0, 0)
107-
108-
def update(
109-
self,
110-
data_1: Union[npt.NDArray, Detections],
111-
data_2: Union[npt.NDArray, Detections],
112-
) -> None:
113-
"""
114-
Add new data to the store.
95+
def update(self, detections: Detections):
96+
"""Add new detections to the store."""
97+
new_data = self._get_content(detections)
98+
self._validate_shape(new_data)
11599

116-
Use sv.Detections.empty() if only one set of data is available.
117-
"""
118-
content_1 = self._get_content(data_1)
119-
content_2 = self._get_content(data_2)
120-
self._validate_shape(content_1)
121-
self._validate_shape(content_2)
100+
if self._metric_target == MetricTarget.BOXES:
101+
self._append_boxes(new_data)
102+
elif self._metric_target == MetricTarget.MASKS:
103+
self._append_mask(new_data)
104+
elif self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
105+
self.data = np.vstack((self.data, new_data))
122106

123-
class_ids_1 = self._get_class_ids(data_1)
124-
class_ids_2 = self._get_class_ids(data_2)
125-
self._validate_class_ids(class_ids_1, class_ids_2)
107+
confidence = self._get_confidence(detections)
108+
self._append_confidence(confidence)
126109

127-
if self._metric_target == MetricTarget.MASKS:
128-
content_1 = self._expand_mask_shape(content_1)
129-
content_2 = self._expand_mask_shape(content_2)
130-
131-
for class_id in set(class_ids_1):
132-
content_of_class = content_1[class_ids_1 == class_id]
133-
stored_content_of_class = self._data_1.get(class_id, len0_like(content_1))
134-
self._data_1[class_id] = np.vstack(
135-
(stored_content_of_class, content_of_class)
136-
)
110+
class_id = self._get_class_id(detections)
111+
self._append_class_id(class_id)
137112

138-
for class_id in set(class_ids_2):
139-
content_of_class = content_2[class_ids_2 == class_id]
140-
stored_content_of_class = self._data_2.get(class_id, len0_like(content_2))
141-
self._data_2[class_id] = np.vstack(
142-
(stored_content_of_class, content_of_class)
113+
if len(self.class_id) != len(self.confidence) or len(self.class_id) != len(
114+
self.data
115+
):
116+
raise ValueError(
117+
f"Inconsistent data length: class_id={len(class_id)},"
118+
f" confidence={len(confidence)}, data={len(new_data)}"
143119
)
144120

145-
def __getitem__(self, class_id: int) -> Tuple[npt.NDArray, npt.NDArray]:
146-
return (
147-
self._data_1.get(class_id, self._make_empty()),
148-
self._data_2.get(class_id, self._make_empty()),
149-
)
121+
def get_classes(self) -> Set[int]:
122+
"""Return all class IDs."""
123+
return set(self.class_id)
150124

151-
def __iter__(
152-
self,
153-
) -> Iterator[Tuple[int, npt.NDArray, npt.NDArray]]:
154-
class_ids = sorted(set(self._data_1.keys()) | set(self._data_2.keys()))
155-
for class_id in class_ids:
156-
yield (
157-
class_id,
158-
*self[class_id],
159-
)
125+
def get_subset_by_class(self, class_id: int) -> MetricData:
126+
"""Return data, confidence and class_id for a specific class."""
127+
mask = self.class_id == class_id
128+
new_data_obj = MetricData(self._metric_target)
129+
new_data_obj.data = self.data[mask]
130+
new_data_obj.confidence = self.confidence[mask]
131+
new_data_obj.class_id = self.class_id[mask]
132+
return new_data_obj
160133

161-
def _get_content(self, data: Union[npt.NDArray, Detections]) -> npt.NDArray:
162-
"""Return boxes, masks or oriented bounding boxes from the data."""
163-
if not isinstance(data, (Detections, np.ndarray)):
164-
raise ValueError(
165-
f"Invalid data type: {type(data)}."
166-
f" Only Detections or np.ndarray are supported."
167-
)
168-
if isinstance(data, np.ndarray):
169-
return data
134+
def __len__(self) -> int:
135+
return len(self.data)
170136

137+
def _get_content(self, detections: Detections) -> npt.NDArray:
138+
"""Return boxes, masks or oriented bounding boxes from the data."""
171139
if self._metric_target == MetricTarget.BOXES:
172-
return data.xyxy
140+
return detections.xyxy
173141
if self._metric_target == MetricTarget.MASKS:
174142
return (
175-
data.mask if data.mask is not None else np.zeros((0, 0, 0), dtype=bool)
143+
detections.mask
144+
if detections.mask is not None
145+
else self._get_empty_data()
176146
)
177147
if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
178-
obb = data.data.get(
179-
config.ORIENTED_BOX_COORDINATES, np.zeros((0, 8), dtype=np.float32)
148+
obb = detections.data.get(
149+
config.ORIENTED_BOX_COORDINATES, self._get_empty_data()
180150
)
181-
return np.array(obb, dtype=np.float32)
151+
return np.ndarray(obb, dtype=np.float32)
182152
raise ValueError(f"Invalid metric target: {self._metric_target}")
183153

184-
def _get_class_ids(
185-
self, data: Union[npt.NDArray, Detections]
186-
) -> npt.NDArray[np.int_]:
187-
"""
188-
Return an array of class IDs from the data. Guaranteed to
189-
match the length of data.
190-
"""
191-
if (
192-
self._class_agnostic
193-
or isinstance(data, np.ndarray)
194-
or data.class_id is None
195-
):
196-
return np.array([CLASS_ID_NONE] * len(data), dtype=int)
197-
return data.class_id
198-
199-
def _validate_class_ids(
200-
self, class_id_1: npt.NDArray[np.int_], class_id_2: npt.NDArray[np.int_]
201-
) -> None:
202-
class_set = set(class_id_1) | set(class_id_2)
203-
if len(class_set) >= 2 and CLASS_ID_NONE in class_set:
204-
raise ValueError(
205-
"Metrics cannot mix data with class ID and data without class ID."
206-
)
154+
def _get_class_id(self, detections: Detections) -> npt.NDArray[np.int_]:
155+
if self._class_agnostic or detections.class_id is None:
156+
return np.array([CLASS_ID_NONE] * len(detections), dtype=int)
157+
return detections.class_id
158+
159+
def _get_confidence(self, detections: Detections) -> npt.NDArray[np.float32]:
160+
if detections.confidence is None:
161+
return np.full(len(detections), -1, dtype=np.float32)
162+
return detections.confidence
163+
164+
def _append_class_id(self, new_class_id: npt.NDArray[np.int_]) -> None:
165+
self.class_id = np.hstack((self.class_id, new_class_id))
166+
167+
def _append_confidence(self, new_confidence: npt.NDArray[np.float32]) -> None:
168+
self.confidence = np.hstack((self.confidence, new_confidence))
169+
170+
def _append_boxes(self, new_boxes: npt.NDArray[np.float32]) -> None:
171+
"""Stack new xyxy or obb boxes on top of stored boxes."""
172+
if self._metric_target not in [
173+
MetricTarget.BOXES,
174+
MetricTarget.ORIENTED_BOUNDING_BOXES,
175+
]:
176+
raise ValueError("This method is only for box data.")
177+
self.data = np.vstack((self.data, new_boxes))
178+
179+
def _append_mask(self, new_mask: npt.NDArray[np.bool_]) -> None:
180+
"""Stack new mask onto stored masks. Expand the shapes if necessary."""
181+
if self._metric_target != MetricTarget.MASKS:
182+
raise ValueError("This method is only for mask data.")
183+
self._validate_mask_shape(new_mask)
184+
185+
new_width = max(self.data.shape[1], new_mask.shape[1])
186+
new_height = max(self.data.shape[2], new_mask.shape[2])
187+
188+
data = pad_mask(self.data, (new_width, new_height))
189+
new_mask = pad_mask(new_mask, (new_width, new_height))
190+
191+
self.data = np.vstack((data, new_mask))
192+
193+
def _get_empty_data(self) -> npt.NDArray:
194+
if self._metric_target == MetricTarget.BOXES:
195+
return np.empty((0, 4), dtype=np.float32)
196+
if self._metric_target == MetricTarget.MASKS:
197+
return np.empty((0, 0, 0), dtype=bool)
198+
if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
199+
return np.empty((0, 8), dtype=np.float32)
200+
raise ValueError(f"Invalid metric target: {self._metric_target}")
207201

208202
def _validate_shape(self, data: npt.NDArray) -> None:
209-
shape = data.shape
210203
if self._metric_target == MetricTarget.BOXES:
211-
if len(shape) != 2 or shape[1] != 4:
212-
raise ValueError(f"Invalid xyxy shape: {shape}. Expected: (N, 4)")
204+
if len(data.shape) != 2 or data.shape[1] != 4:
205+
raise ValueError(f"Invalid xyxy shape: {data.shape}. Expected: (N, 4)")
213206
elif self._metric_target == MetricTarget.MASKS:
214-
if len(shape) != 3:
215-
raise ValueError(f"Invalid mask shape: {shape}. Expected: (N, H, W)")
207+
if len(data.shape) != 3:
208+
raise ValueError(
209+
f"Invalid mask shape: {data.shape}. Expected: (N, H, W)"
210+
)
216211
elif self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
217-
if len(shape) != 2 or shape[1] != 8:
218-
raise ValueError(f"Invalid obb shape: {shape}. Expected: (N, 8)")
212+
if len(data.shape) != 2 or data.shape[1] != 8:
213+
raise ValueError(f"Invalid obb shape: {data.shape}. Expected: (N, 8)")
219214
else:
220215
raise ValueError(f"Invalid metric target: {self._metric_target}")
221216

222-
def _expand_mask_shape(self, data: npt.NDArray) -> npt.NDArray:
223-
"""Pad the stored and new data to the same shape."""
224-
if self._metric_target != MetricTarget.MASKS:
225-
return data
226217

227-
new_width = max(self._mask_shape[0], data.shape[1])
228-
new_height = max(self._mask_shape[1], data.shape[2])
229-
self._mask_shape = (new_width, new_height)
218+
class InternalMetricDataStore:
219+
"""
220+
Stores internal data for metrics.
230221
231-
data = pad_mask(data, self._mask_shape)
222+
Provides a class-agnostic way to access it.
223+
"""
232224

233-
for class_id, prev_data in self._data_1.items():
234-
self._data_1[class_id] = pad_mask(prev_data, self._mask_shape)
235-
for class_id, prev_data in self._data_2.items():
236-
self._data_2[class_id] = pad_mask(prev_data, self._mask_shape)
225+
def __init__(self, metric_target: MetricTarget, class_agnostic: bool = False):
226+
self._metric_target = metric_target
227+
self._class_agnostic = class_agnostic
228+
self._data_1: MetricData
229+
self._data_2: MetricData
230+
self.reset()
237231

238-
return data
232+
def reset(self) -> None:
233+
self._data_1 = MetricData(self._metric_target, self._class_agnostic)
234+
self._data_2 = MetricData(self._metric_target, self._class_agnostic)
239235

240-
def _make_empty(self) -> npt.NDArray:
241-
"""Create an empty data object with the best-known shape for the target."""
242-
if self._metric_target == MetricTarget.BOXES:
243-
return np.empty((0, 4), dtype=np.float32)
244-
if self._metric_target == MetricTarget.MASKS:
245-
return np.empty((0, *self._mask_shape), dtype=bool)
246-
if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
247-
return np.empty((0, 8), dtype=np.float32)
248-
raise ValueError(f"Invalid metric target: {self._metric_target}")
236+
def update(self, data_1: Detections, data_2: Detections) -> None:
237+
"""
238+
Add new data to the store.
239+
240+
Use sv.Detections.empty() if only one set of data is available.
241+
"""
242+
self._data_1.update(data_1)
243+
self._data_2.update(data_2)
244+
245+
def __getitem__(self, class_id: int) -> Tuple[MetricData, MetricData]:
246+
return (
247+
self._data_1.get_subset_by_class(class_id),
248+
self._data_2.get_subset_by_class(class_id),
249+
)
250+
251+
def __iter__(self) -> Iterator[Tuple[int, MetricData, MetricData]]:
252+
for class_id in self._data_1.get_classes():
253+
yield class_id, *self[class_id]

0 commit comments

Comments
 (0)