From ab76a47eb840d87c95ab5d6f6ed9264d9eb577e4 Mon Sep 17 00:00:00 2001 From: LinasKo Date: Fri, 16 Aug 2024 09:58:49 +0300 Subject: [PATCH] Refactored metrics store, removed typing extensions * This commit is in a messy state - tests were not adapted to new metrics.core * MeanAveragePrecision is incomplete --- poetry.lock | 2 +- pyproject.toml | 1 - supervision/metrics/__init__.py | 9 +- supervision/metrics/core.py | 277 +++++++++--------- .../metrics/intersection_over_union.py | 176 ++++++----- supervision/metrics/mean_average_precision.py | 237 +++++++++++++++ 6 files changed, 474 insertions(+), 228 deletions(-) create mode 100644 supervision/metrics/mean_average_precision.py diff --git a/poetry.lock b/poetry.lock index f4e2f2844..ab9e05918 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4460,4 +4460,4 @@ metrics = ["pandas", "pandas-stubs"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "54f19bfad31db0e19784721c25480219901fae412dde440ab4d82d86f37243dd" +content-hash = "6caeff23222cc70a3e443b0b93a7b88cc8c7449fe497b5fd86ce92d513e99af6" diff --git a/pyproject.toml b/pyproject.toml index cc19b24b7..1145d8b3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ tqdm = { version = ">=4.62.3,<=4.66.5", optional = true } # pandas: picked lowest major version that supports Python 3.8 pandas = { version = ">=2.0.0", optional = true } pandas-stubs = { version = ">=2.0.0.230412", optional = true } -typing-extensions = "^4.12.2" [tool.poetry.extras] desktop = ["opencv-python"] diff --git a/supervision/metrics/__init__.py b/supervision/metrics/__init__.py index 76d40729e..3369dc22a 100644 --- a/supervision/metrics/__init__.py +++ b/supervision/metrics/__init__.py @@ -4,4 +4,11 @@ MetricTarget, UnsupportedMetricTargetError, ) -from supervision.metrics.intersection_over_union import IntersectionOverUnion +from supervision.metrics.intersection_over_union import ( + IntersectionOverUnion, + IntersectionOverUnionResult, +) +from supervision.metrics.mean_average_precision import ( + MeanAveragePrecision, + MeanAveragePrecisionResult, +) diff --git a/supervision/metrics/core.py b/supervision/metrics/core.py index 67bf5fb5f..005387f99 100644 --- a/supervision/metrics/core.py +++ b/supervision/metrics/core.py @@ -2,17 +2,17 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, Iterator, Tuple, Union +from typing import Any, Iterator, Set, Tuple import numpy as np import numpy.typing as npt -from typing_extensions import Self from supervision import config from supervision.detection.core import Detections -from supervision.metrics.utils import len0_like, pad_mask +from supervision.metrics.utils import pad_mask CLASS_ID_NONE = -1 +CONFIDENCE_NONE = -1 """Used by metrics module as class ID, when none is present""" @@ -22,7 +22,7 @@ class Metric(ABC): """ @abstractmethod - def update(self, *args, **kwargs) -> Self: + def update(self, *args, **kwargs) -> "Metric": """ Add data to the metric, without computing the result. Return the metric itself to allow method chaining. @@ -78,171 +78,176 @@ def __init__(self, metric: Metric, target: MetricTarget): super().__init__(f"Metric {metric} does not support target {target}") -class InternalMetricDataStore: +class MetricData: """ - Stores internal data of IntersectionOverUnion metric: - * Stores the basic data: boxes, masks, or oriented bounding boxes - * Validates data: ensures data types and shape are consistent - * Provides iteration by class - - Provides a class-agnostic mode, where all data is treated as a single class. - Warning: numpy inputs are always considered as class-agnostic data. - - Data here refers to content of Detections objects: boxes, masks, - or oriented bounding boxes. + A container for detection contents, decouple from Detections. + While a np.ndarray work for xyxy and obb, this approach solves + the mask concatenation problem. """ - def __init__(self, metric_target: MetricTarget, class_agnostic: bool): + def __init__(self, metric_target: MetricTarget, class_agnostic: bool = False): self._metric_target = metric_target self._class_agnostic = class_agnostic - self._data_1: Dict[int, npt.NDArray] - self._data_2: Dict[int, npt.NDArray] - self._mask_shape: Tuple[int, int] - self.reset() + self.confidence = np.array([], dtype=np.float32) + self.class_id = np.array([], dtype=int) + self.data: npt.NDArray = self._get_empty_data() - def reset(self) -> None: - self._data_1 = {} - self._data_2 = {} - self._mask_shape = (0, 0) - - def update( - self, - data_1: Union[npt.NDArray, Detections], - data_2: Union[npt.NDArray, Detections], - ) -> None: - """ - Add new data to the store. + def update(self, detections: Detections): + """Add new detections to the store.""" + new_data = self._get_content(detections) + self._validate_shape(new_data) - Use sv.Detections.empty() if only one set of data is available. - """ - content_1 = self._get_content(data_1) - content_2 = self._get_content(data_2) - self._validate_shape(content_1) - self._validate_shape(content_2) + if self._metric_target == MetricTarget.BOXES: + self._append_boxes(new_data) + elif self._metric_target == MetricTarget.MASKS: + self._append_mask(new_data) + elif self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: + self.data = np.vstack((self.data, new_data)) - class_ids_1 = self._get_class_ids(data_1) - class_ids_2 = self._get_class_ids(data_2) - self._validate_class_ids(class_ids_1, class_ids_2) + confidence = self._get_confidence(detections) + self._append_confidence(confidence) - if self._metric_target == MetricTarget.MASKS: - content_1 = self._expand_mask_shape(content_1) - content_2 = self._expand_mask_shape(content_2) - - for class_id in set(class_ids_1): - content_of_class = content_1[class_ids_1 == class_id] - stored_content_of_class = self._data_1.get(class_id, len0_like(content_1)) - self._data_1[class_id] = np.vstack( - (stored_content_of_class, content_of_class) - ) + class_id = self._get_class_id(detections) + self._append_class_id(class_id) - for class_id in set(class_ids_2): - content_of_class = content_2[class_ids_2 == class_id] - stored_content_of_class = self._data_2.get(class_id, len0_like(content_2)) - self._data_2[class_id] = np.vstack( - (stored_content_of_class, content_of_class) + if len(self.class_id) != len(self.confidence) or len(self.class_id) != len( + self.data + ): + raise ValueError( + f"Inconsistent data length: class_id={len(class_id)}," + f" confidence={len(confidence)}, data={len(new_data)}" ) - def __getitem__(self, class_id: int) -> Tuple[npt.NDArray, npt.NDArray]: - return ( - self._data_1.get(class_id, self._make_empty()), - self._data_2.get(class_id, self._make_empty()), - ) + def get_classes(self) -> Set[int]: + """Return all class IDs.""" + return set(self.class_id) - def __iter__( - self, - ) -> Iterator[Tuple[int, npt.NDArray, npt.NDArray]]: - class_ids = sorted(set(self._data_1.keys()) | set(self._data_2.keys())) - for class_id in class_ids: - yield ( - class_id, - *self[class_id], - ) + def get_subset_by_class(self, class_id: int) -> MetricData: + """Return data, confidence and class_id for a specific class.""" + mask = self.class_id == class_id + new_data_obj = MetricData(self._metric_target) + new_data_obj.data = self.data[mask] + new_data_obj.confidence = self.confidence[mask] + new_data_obj.class_id = self.class_id[mask] + return new_data_obj - def _get_content(self, data: Union[npt.NDArray, Detections]) -> npt.NDArray: - """Return boxes, masks or oriented bounding boxes from the data.""" - if not isinstance(data, (Detections, np.ndarray)): - raise ValueError( - f"Invalid data type: {type(data)}." - f" Only Detections or np.ndarray are supported." - ) - if isinstance(data, np.ndarray): - return data + def __len__(self) -> int: + return len(self.data) + def _get_content(self, detections: Detections) -> npt.NDArray: + """Return boxes, masks or oriented bounding boxes from the data.""" if self._metric_target == MetricTarget.BOXES: - return data.xyxy + return detections.xyxy if self._metric_target == MetricTarget.MASKS: return ( - data.mask if data.mask is not None else np.zeros((0, 0, 0), dtype=bool) + detections.mask + if detections.mask is not None + else self._get_empty_data() ) if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: - obb = data.data.get( - config.ORIENTED_BOX_COORDINATES, np.zeros((0, 8), dtype=np.float32) + obb = detections.data.get( + config.ORIENTED_BOX_COORDINATES, self._get_empty_data() ) - return np.array(obb, dtype=np.float32) + return np.ndarray(obb, dtype=np.float32) raise ValueError(f"Invalid metric target: {self._metric_target}") - def _get_class_ids( - self, data: Union[npt.NDArray, Detections] - ) -> npt.NDArray[np.int_]: - """ - Return an array of class IDs from the data. Guaranteed to - match the length of data. - """ - if ( - self._class_agnostic - or isinstance(data, np.ndarray) - or data.class_id is None - ): - return np.array([CLASS_ID_NONE] * len(data), dtype=int) - return data.class_id - - def _validate_class_ids( - self, class_id_1: npt.NDArray[np.int_], class_id_2: npt.NDArray[np.int_] - ) -> None: - class_set = set(class_id_1) | set(class_id_2) - if len(class_set) >= 2 and CLASS_ID_NONE in class_set: - raise ValueError( - "Metrics cannot mix data with class ID and data without class ID." - ) + def _get_class_id(self, detections: Detections) -> npt.NDArray[np.int_]: + if self._class_agnostic or detections.class_id is None: + return np.array([CLASS_ID_NONE] * len(detections), dtype=int) + return detections.class_id + + def _get_confidence(self, detections: Detections) -> npt.NDArray[np.float32]: + if detections.confidence is None: + return np.full(len(detections), -1, dtype=np.float32) + return detections.confidence + + def _append_class_id(self, new_class_id: npt.NDArray[np.int_]) -> None: + self.class_id = np.hstack((self.class_id, new_class_id)) + + def _append_confidence(self, new_confidence: npt.NDArray[np.float32]) -> None: + self.confidence = np.hstack((self.confidence, new_confidence)) + + def _append_boxes(self, new_boxes: npt.NDArray[np.float32]) -> None: + """Stack new xyxy or obb boxes on top of stored boxes.""" + if self._metric_target not in [ + MetricTarget.BOXES, + MetricTarget.ORIENTED_BOUNDING_BOXES, + ]: + raise ValueError("This method is only for box data.") + self.data = np.vstack((self.data, new_boxes)) + + def _append_mask(self, new_mask: npt.NDArray[np.bool_]) -> None: + """Stack new mask onto stored masks. Expand the shapes if necessary.""" + if self._metric_target != MetricTarget.MASKS: + raise ValueError("This method is only for mask data.") + self._validate_mask_shape(new_mask) + + new_width = max(self.data.shape[1], new_mask.shape[1]) + new_height = max(self.data.shape[2], new_mask.shape[2]) + + data = pad_mask(self.data, (new_width, new_height)) + new_mask = pad_mask(new_mask, (new_width, new_height)) + + self.data = np.vstack((data, new_mask)) + + def _get_empty_data(self) -> npt.NDArray: + if self._metric_target == MetricTarget.BOXES: + return np.empty((0, 4), dtype=np.float32) + if self._metric_target == MetricTarget.MASKS: + return np.empty((0, 0, 0), dtype=bool) + if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: + return np.empty((0, 8), dtype=np.float32) + raise ValueError(f"Invalid metric target: {self._metric_target}") def _validate_shape(self, data: npt.NDArray) -> None: - shape = data.shape if self._metric_target == MetricTarget.BOXES: - if len(shape) != 2 or shape[1] != 4: - raise ValueError(f"Invalid xyxy shape: {shape}. Expected: (N, 4)") + if len(data.shape) != 2 or data.shape[1] != 4: + raise ValueError(f"Invalid xyxy shape: {data.shape}. Expected: (N, 4)") elif self._metric_target == MetricTarget.MASKS: - if len(shape) != 3: - raise ValueError(f"Invalid mask shape: {shape}. Expected: (N, H, W)") + if len(data.shape) != 3: + raise ValueError( + f"Invalid mask shape: {data.shape}. Expected: (N, H, W)" + ) elif self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: - if len(shape) != 2 or shape[1] != 8: - raise ValueError(f"Invalid obb shape: {shape}. Expected: (N, 8)") + if len(data.shape) != 2 or data.shape[1] != 8: + raise ValueError(f"Invalid obb shape: {data.shape}. Expected: (N, 8)") else: raise ValueError(f"Invalid metric target: {self._metric_target}") - def _expand_mask_shape(self, data: npt.NDArray) -> npt.NDArray: - """Pad the stored and new data to the same shape.""" - if self._metric_target != MetricTarget.MASKS: - return data - new_width = max(self._mask_shape[0], data.shape[1]) - new_height = max(self._mask_shape[1], data.shape[2]) - self._mask_shape = (new_width, new_height) +class InternalMetricDataStore: + """ + Stores internal data for metrics. - data = pad_mask(data, self._mask_shape) + Provides a class-agnostic way to access it. + """ - for class_id, prev_data in self._data_1.items(): - self._data_1[class_id] = pad_mask(prev_data, self._mask_shape) - for class_id, prev_data in self._data_2.items(): - self._data_2[class_id] = pad_mask(prev_data, self._mask_shape) + def __init__(self, metric_target: MetricTarget, class_agnostic: bool = False): + self._metric_target = metric_target + self._class_agnostic = class_agnostic + self._data_1: MetricData + self._data_2: MetricData + self.reset() - return data + def reset(self) -> None: + self._data_1 = MetricData(self._metric_target, self._class_agnostic) + self._data_2 = MetricData(self._metric_target, self._class_agnostic) - def _make_empty(self) -> npt.NDArray: - """Create an empty data object with the best-known shape for the target.""" - if self._metric_target == MetricTarget.BOXES: - return np.empty((0, 4), dtype=np.float32) - if self._metric_target == MetricTarget.MASKS: - return np.empty((0, *self._mask_shape), dtype=bool) - if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: - return np.empty((0, 8), dtype=np.float32) - raise ValueError(f"Invalid metric target: {self._metric_target}") + def update(self, data_1: Detections, data_2: Detections) -> None: + """ + Add new data to the store. + + Use sv.Detections.empty() if only one set of data is available. + """ + self._data_1.update(data_1) + self._data_2.update(data_2) + + def __getitem__(self, class_id: int) -> Tuple[MetricData, MetricData]: + return ( + self._data_1.get_subset_by_class(class_id), + self._data_2.get_subset_by_class(class_id), + ) + + def __iter__(self) -> Iterator[Tuple[int, MetricData, MetricData]]: + for class_id in self._data_1.get_classes(): + yield class_id, *self[class_id] diff --git a/supervision/metrics/intersection_over_union.py b/supervision/metrics/intersection_over_union.py index f1b29e800..6d64f223a 100644 --- a/supervision/metrics/intersection_over_union.py +++ b/supervision/metrics/intersection_over_union.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from itertools import zip_longest from typing import TYPE_CHECKING, Dict, List, Optional, Union @@ -5,7 +7,6 @@ import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt -from typing_extensions import Self from supervision.detection.core import Detections from supervision.detection.utils import box_iou_batch @@ -16,77 +17,6 @@ import pandas as pd -@dataclass -class IntersectionOverUnionResult: - ious: Dict[int, npt.NDArray[np.float32]] - metric_target: MetricTarget - - @property - def class_ids(self) -> List[int]: - return list(self.ious.keys()) - - def __getitem__(self, class_id: int) -> npt.NDArray[np.float32]: - return self.ious[class_id] - - def __iter__(self): - return iter(self.ious.items()) - - def to_pandas(self) -> Dict[int, "pd.DataFrame"]: - ensure_pandas_installed() - return {class_id: pd.DataFrame(iou) for class_id, iou in self.ious.items()} - - def plot(self, class_id=None): - """ - Visualize the IoU results. - - Args: - class_id (Optional[int]): The class ID to visualize. If not - provided, all classes will be visualized. - """ - if class_id: - self._plot_class(class_id) - else: - for cls in self.ious: - self._plot_class(cls) - - def _plot_class(self, class_id): - """ - Helper function to plot a single class IoU matrix or show - zero-sized information. - - Args: - class_id (int): The class ID to plot. - """ - iou_matrix = self.ious[class_id] - - if iou_matrix.size == 0: - print( - f"No data for class {class_id}, with result shape" - f" {iou_matrix.shape}. Skipping plot." - ) - else: - plt.figure(figsize=(6, 6)) - plt.matshow(iou_matrix, cmap="viridis", fignum=1) - plt.title(f"Class {class_id} IoU Matrix", pad=20) - plt.gca().xaxis.set_ticks_position("bottom") - plt.xlabel("Target Bounding Boxes") - plt.ylabel("Predicted Bounding Boxes") - plt.colorbar() - - for (i, j), val in np.ndenumerate(iou_matrix): - plt.text( - j, - i, - f"{val:.2f}", - ha="center", - va="center", - fontsize=8, - color="white" if val < 0.5 else "black", - ) - - plt.show() - - class IntersectionOverUnion(Metric): def __init__( self, @@ -101,12 +31,12 @@ def __init__( metric_target (MetricTarget): The type of detection data to use. class_agnostic (bool): Whether to treat all data as a single class. Defaults to `False`. - shared_data_store (Optional[InternalMetricDataStore]): If you are composing - multiple metrics, you can pass a data store here to share data. Objects - UP in the hierarchy are responsible for resetting the store and updating - data. + shared_data_store (Optional[InternalMetricDataStore]): If you have + a hierarchy of metrics, you can pass a data store to share it + between them, saving memory. The responsibility of updating + the store falls on the parent metric (that contain this one). """ - if metric_target in [MetricTarget.MASKS, MetricTarget.ORIENTED_BOUNDING_BOXES]: + if metric_target != MetricTarget.BOXES: raise NotImplementedError( f"Intersection over union is not implemented for {metric_target}." ) @@ -130,9 +60,10 @@ def update( self, data_1: Union[Detections, List[Detections]], data_2: Union[Detections, List[Detections]], - ) -> Self: + ) -> IntersectionOverUnion: """ Add data to the metric, without computing the result. + Should call all update methods of the shared data store. Args: data_1 (Union[Detection, List[Detections]]): The first set of data. @@ -143,6 +74,7 @@ def update( by calling the `compute` method. """ if self._is_store_shared: + # Should be updated by the parent metric return self if not isinstance(data_1, list): @@ -159,10 +91,9 @@ def _update( self, data_1: Detections, data_2: Detections, - ) -> Self: + ) -> None: assert not self._is_store_shared self._store.update(data_1, data_2) - return self def compute(self) -> IntersectionOverUnionResult: """ @@ -171,16 +102,83 @@ def compute(self) -> IntersectionOverUnionResult: Returns: Dict[int, npt.NDArray[np.float32]]: A dictionary with class IDs as keys. - If no class ID is provided, the key is the value CLASS_ID_NONE. + If no class ID is provided, the key is the value CLASS_ID_NONE. The values + are (N, M) arrays where N is the number of predictions and M is the number + of targets. """ ious = {} - for class_id, array_1, array_2 in self._store: - if self._metric_target == MetricTarget.BOXES: - iou = box_iou_batch(array_1, array_2) - else: - raise NotImplementedError( - f"IoU is not implemented for {self._metric_target}." - ) + for class_id, data_1, data_2 in self._store: + iou = box_iou_batch(data_1.data, data_2.data).transpose() ious[class_id] = iou - return IntersectionOverUnionResult(ious, self._metric_target) + + +@dataclass +class IntersectionOverUnionResult: + ious: Dict[int, npt.NDArray[np.float32]] + metric_target: MetricTarget + + @property + def class_ids(self) -> List[int]: + return list(self.ious.keys()) + + def __getitem__(self, class_id: int) -> npt.NDArray[np.float32]: + return self.ious[class_id] + + def __iter__(self): + return iter(self.ious.items()) + + def to_pandas(self) -> Dict[int, "pd.DataFrame"]: + ensure_pandas_installed() + return {class_id: pd.DataFrame(iou) for class_id, iou in self.ious.items()} + + def plot(self, class_id=None): + """ + Visualize the IoU results. + + Args: + class_id (Optional[int]): The class ID to visualize. If not + provided, all classes will be visualized. + """ + if class_id: + self._plot_class(class_id) + else: + for cls in self.ious: + self._plot_class(cls) + + def _plot_class(self, class_id): + """ + Helper function to plot a single class IoU matrix or show + zero-sized information. + + Args: + class_id (int): The class ID to plot. + """ + iou_matrix = self.ious[class_id] + + if iou_matrix.size == 0: + print( + f"No data for class {class_id}, with result shape" + f" {iou_matrix.shape}. Skipping plot." + ) + else: + plt.figure(figsize=(6, 6)) + plt.matshow(iou_matrix, cmap="viridis", fignum=1) + plt.title(f"Class {class_id} IoU Matrix", pad=20) + plt.gca().xaxis.set_ticks_position("bottom") + plt.xlabel("Target Bounding Boxes") + plt.ylabel("Predicted Bounding Boxes") + plt.colorbar() + + for (i, j), val in np.ndenumerate(iou_matrix): + plt.text( + j, + i, + f"{val:.2f}", + ha="center", + va="center", + fontsize=8, + color="white" if val < 0.5 else "black", + ) + + plt.show() diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py new file mode 100644 index 000000000..ed486241c --- /dev/null +++ b/supervision/metrics/mean_average_precision.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +from dataclasses import dataclass +from itertools import zip_longest +from typing import Dict, List, Union + +import numpy as np +import numpy.typing as npt + +from supervision.detection.core import Detections +from supervision.metrics.core import ( + InternalMetricDataStore, + Metric, + MetricData, + MetricTarget, +) +from supervision.metrics.intersection_over_union import IntersectionOverUnion + + +class MeanAveragePrecision(Metric): + def __init__( + self, + metric_target: MetricTarget = MetricTarget.BOXES, + class_agnostic: bool = False, + iou_threshold: float = 0.25, + ): + self._metric_target = metric_target + if self._metric_target != MetricTarget.BOXES: + raise NotImplementedError( + f"mAP is not implemented for {self._metric_target}." + ) + + self._class_agnostic = class_agnostic + self._iou_threshold = iou_threshold + + self._store = InternalMetricDataStore(metric_target, class_agnostic) + self._iou_metric = IntersectionOverUnion(metric_target, class_agnostic) + + self.reset() + + def reset(self) -> None: + self._iou_metric.reset() + self._store.reset() + + def update( + self, + predictions: Union[Detections, List[Detections]], + targets: Union[Detections, List[Detections]], + ) -> MeanAveragePrecision: + if not isinstance(predictions, list): + predictions = [predictions] + if not isinstance(targets, list): + targets = [targets] + + for d1, d2 in zip_longest(predictions, targets, fillvalue=Detections.empty()): + self._update(d1, d2) + + return self + + def _update(self, predictions: Detections, targets: Detections) -> None: + self._store.update(predictions, targets) + self._iou_metric.update(predictions, targets) + + def compute(self) -> MeanAveragePrecisionResult: + ious = self._iou_metric.compute() + iou_thresholds = np.linspace(0.5, 0.95, 10) + + average_precisions: Dict[int, npt.NDArray] = {} + for class_id, prediction_data, target_data in self._store: + if len(target_data) == 0: + continue + + if len(prediction_data) == 0: + stats = ( + np.zeros((0, iou_thresholds.size), dtype=bool), + np.array([], dtype=np.float32), + np.array([], dtype=np.float32), + target_data.class_id, + ) + else: + ious_of_class = ious[class_id] + matches = self._match_predictions_to_targets( + prediction_data, target_data, ious_of_class, iou_thresholds + ) + stats = ( + matches, + prediction_data.confidence, + prediction_data.class_id, + target_data.class_id, + ) + + to_concat = [np.expand_dims(item, 0) for item in stats] + for x in to_concat: + print(x.shape) + # print(to_concat) + + # TODO: class_id size mismatch + # (1, 9, 10) + # (1, 9) + # (1, 9) + # (1, 12) + + concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)] + average_precisions[class_id] = self._average_precisions_per_class( + *concatenated_stats + ) + + return MeanAveragePrecisionResult(average_precisions) + + def _match_predictions_to_targets( + self, + prediction_data: MetricData, + target_data: MetricData, + predictions_iou: npt.NDArray[np.float32], + iou_thresholds: npt.NDArray[np.float32], + ) -> npt.NDArray[np.bool_]: + """ + Match predictions to targets based on IoU. + + Given N predictions, M targets and T IoU thresholds, + returns a boolean array (N, T), where each element is True + if the prediction is a true positive at the given IoU threshold. + + Assumes that predictions were already filtered by class. + """ + if set(prediction_data.class_id) != set(target_data.class_id): + raise ValueError( + f"Class IDs of predictions and targets" + f" do not match: {prediction_data.class_id}, {target_data.class_id}" + ) + + correct = np.zeros((len(prediction_data), len(iou_thresholds)), dtype=bool) + for i, iou_level in enumerate(iou_thresholds): + matched_indices = np.where((predictions_iou >= iou_level)) + + if matched_indices[0].shape[0]: + combined_indices = np.stack(matched_indices, axis=1) + iou_values = predictions_iou[matched_indices][:, None] + matches = np.hstack([combined_indices, iou_values]) + + if matched_indices[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + + _, unique_pred_idx = np.unique(matches[:, 1], return_index=True) + matches = matches[unique_pred_idx] + _, unique_target_idx = np.unique(matches[:, 0], return_index=True) + matches = matches[unique_target_idx] + + correct[matches[:, 1].astype(int), i] = True + + return correct + + def _average_precisions_per_class( + self, + matches: np.ndarray, + prediction_confidence: np.ndarray, + prediction_class_ids: np.ndarray, + true_class_ids: np.ndarray, + eps: float = 1e-16, + ) -> np.ndarray: + """ + Compute the average precision, given the recall and precision curves. + Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. + + Args: + matches (np.ndarray): True positives. + prediction_confidence (np.ndarray): Objectness value from 0-1. + prediction_class_ids (np.ndarray): Predicted object classes. + true_class_ids (np.ndarray): True object classes. + eps (float, optional): Small value to prevent division by zero. + + Returns: + np.ndarray: Average precision for different IoU levels. + """ + sorted_indices = np.argsort(-prediction_confidence) + matches = matches[sorted_indices] + prediction_class_ids = prediction_class_ids[sorted_indices] + + unique_classes, class_counts = np.unique(true_class_ids, return_counts=True) + num_classes = unique_classes.shape[0] + + average_precisions = np.zeros((num_classes, matches.shape[1])) + + for class_idx, class_id in enumerate(unique_classes): + is_class = prediction_class_ids == class_id + total_true = class_counts[class_idx] + total_prediction = is_class.sum() + + if total_prediction == 0 or total_true == 0: + continue + + false_positives = (1 - matches[is_class]).cumsum(0) + true_positives = matches[is_class].cumsum(0) + recall = true_positives / (total_true + eps) + precision = true_positives / (true_positives + false_positives) + + for iou_level_idx in range(matches.shape[1]): + average_precisions[class_idx, iou_level_idx] = ( + self._compute_average_precision( + recall[:, iou_level_idx], precision[:, iou_level_idx] + ) + ) + + return average_precisions + + @staticmethod + def _compute_average_precision(recall: np.ndarray, precision: np.ndarray) -> float: + """ + Compute the average precision using 101-point interpolation (COCO), given + the recall and precision curves. + + Args: + recall (np.ndarray): The recall curve. + precision (np.ndarray): The precision curve. + + Returns: + float: Average precision. + """ + assert len(recall) == len(precision) + + extended_recall = np.concatenate(([0.0], recall, [1.0])) + extended_precision = np.concatenate(([1.0], precision, [0.0])) + max_accumulated_precision = np.flip( + np.maximum.accumulate(np.flip(extended_precision)) + ) + interpolated_recall_levels = np.linspace(0, 1, 101) + interpolated_precision = np.interp( + interpolated_recall_levels, extended_recall, max_accumulated_precision + ) + average_precision = np.trapz(interpolated_precision, interpolated_recall_levels) + return average_precision + + +@dataclass +class MeanAveragePrecisionResult: + # TODO: continue here + average_precisions: Dict[int, npt.NDArray]