Skip to content

Commit

Permalink
Extract IoU aout of matcher
Browse files Browse the repository at this point in the history
  • Loading branch information
LinasKo committed Aug 23, 2024
1 parent 1b9b3f2 commit 8905768
Showing 1 changed file with 13 additions and 60 deletions.
73 changes: 13 additions & 60 deletions supervision/metrics/mean_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,6 @@ def _compute(
targets: np.ndarray,
target_classes: np.ndarray,
) -> MeanAveragePrecisionResult:
predictions = np.hstack(
[predictions, prediction_classes[:, None], prediction_confidence[:, None]]
)
targets = np.hstack([targets, target_classes[:, None]])

self._validate_input_tensors([predictions], [targets])
iou_thresholds = np.linspace(0.5, 0.95, 10)
stats = []

Expand All @@ -201,20 +195,21 @@ def _compute(
np.zeros((0, iou_thresholds.size), dtype=bool),
np.zeros((0,), dtype=np.float32),
np.zeros((0,), dtype=int),
targets[:, 4],
target_classes,
)
)

else:
iou = box_iou_batch(targets, predictions)
matches = self._match_detection_batch(
predictions, targets, iou_thresholds
prediction_classes, target_classes, iou, iou_thresholds
)
stats.append(
(
matches,
prediction_confidence,
prediction_classes,
targets[:, 4],
target_classes,
)
)

Expand Down Expand Up @@ -263,29 +258,17 @@ def compute_average_precision(recall: np.ndarray, precision: np.ndarray) -> floa

@staticmethod
def _match_detection_batch(
predictions: np.ndarray, targets: np.ndarray, iou_thresholds: np.ndarray
predictions_classes: np.ndarray,
target_classes: np.ndarray,
iou: np.ndarray,
iou_thresholds: np.ndarray,
) -> np.ndarray:
"""
Match predictions with target labels based on IoU levels.
Args:
predictions (np.ndarray): Batch prediction. Describes a single image and
has `shape = (M, 6)` where `M` is the number of detected objects.
Each row is expected to be in
`(x_min, y_min, x_max, y_max, class, conf)` format.
targets (np.ndarray): Batch target labels. Describes a single image and
has `shape = (N, 5)` where `N` is the number of ground-truth objects.
Each row is expected to be in
`(x_min, y_min, x_max, y_max, class)` format.
iou_thresholds (np.ndarray): Array contains different IoU thresholds.
Returns:
np.ndarray: Matched prediction with target labels result.
"""
num_predictions, num_iou_levels = predictions.shape[0], iou_thresholds.shape[0]
num_predictions, num_iou_levels = (
predictions_classes.shape[0],
iou_thresholds.shape[0],
)
correct = np.zeros((num_predictions, num_iou_levels), dtype=bool)
iou = box_iou_batch(targets[:, :4], predictions[:, :4])
correct_class = targets[:, 4:5] == predictions[:, 4]
correct_class = target_classes[:, None] == predictions_classes

for i, iou_level in enumerate(iou_thresholds):
matched_indices = np.where((iou >= iou_level) & correct_class)
Expand Down Expand Up @@ -360,36 +343,6 @@ def _average_precisions_per_class(

return average_precisions

@staticmethod
def _validate_input_tensors(
predictions: List[np.ndarray], targets: List[np.ndarray]
):
"""
Checks for shape consistency of input tensors.
"""
if len(predictions) != len(targets):
raise ValueError(
f"Number of predictions ({len(predictions)}) and"
f"targets ({len(targets)}) must be equal."
)
if len(predictions) > 0:
if not isinstance(predictions[0], np.ndarray) or not isinstance(
targets[0], np.ndarray
):
raise ValueError(
f"Predictions and targets must be lists of numpy arrays."
f"Got {type(predictions[0])} and {type(targets[0])} instead."
)
if predictions[0].shape[1] != 6:
raise ValueError(
f"Predictions must have shape (N, 6)."
f"Got {predictions[0].shape} instead."
)
if targets[0].shape[1] != 5:
raise ValueError(
f"Targets must have shape (N, 5). Got {targets[0].shape} instead."
)


@dataclass
class MeanAveragePrecisionResult:
Expand Down

0 comments on commit 8905768

Please sign in to comment.