diff --git a/ava_evaluation/metrics.py b/ava_evaluation/metrics.py index 47116fac..de95b644 100644 --- a/ava_evaluation/metrics.py +++ b/ava_evaluation/metrics.py @@ -38,7 +38,7 @@ def compute_precision_recall(scores, labels, num_gt): """ if ( not isinstance(labels, np.ndarray) - or labels.dtype != np.bool + or labels.dtype != bool or len(labels.shape) != 1 ): raise ValueError("labels must be single dimension bool numpy array")