From 270425645948c03dac6fb9fb9e8d53de1de17899 Mon Sep 17 00:00:00 2001 From: Dmitriy Jarosh Date: Thu, 28 Jul 2022 21:33:47 +0300 Subject: [PATCH] Fix docs and interface for iou, dice --- src/evops/metrics/DefaultBenchmark.py | 4 ++++ src/evops/metrics/MeanBenchmark.py | 2 +- src/evops/metrics/metrics.py | 29 ++++++++++----------------- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/src/evops/metrics/DefaultBenchmark.py b/src/evops/metrics/DefaultBenchmark.py index 722deec..a1b3269 100644 --- a/src/evops/metrics/DefaultBenchmark.py +++ b/src/evops/metrics/DefaultBenchmark.py @@ -50,4 +50,8 @@ def __fScore( precision = __precision(pred_labels, gt_labels, tp_condition) recall = __recall(pred_labels, gt_labels, tp_condition) + # Prevent division by zero + if precision + recall == 0: + return 0.0 + return 2 * precision * recall / (precision + recall) diff --git a/src/evops/metrics/MeanBenchmark.py b/src/evops/metrics/MeanBenchmark.py index ede3b1f..7e11e42 100644 --- a/src/evops/metrics/MeanBenchmark.py +++ b/src/evops/metrics/MeanBenchmark.py @@ -23,7 +23,7 @@ def __mean( pred_labels: NDArray[Any, np.int32], gt_labels: NDArray[Any, np.int32], metric: Callable[ - [NDArray[(Any, 3), np.float64], NDArray[Any, np.int32], NDArray[Any, np.int32]], + [NDArray[Any, np.int32], NDArray[Any, np.int32]], np.float64, ], ) -> np.float64: diff --git a/src/evops/metrics/metrics.py b/src/evops/metrics/metrics.py index a9b7905..718ec98 100644 --- a/src/evops/metrics/metrics.py +++ b/src/evops/metrics/metrics.py @@ -29,33 +29,31 @@ def iou( - pred_labels: NDArray[Any, np.int32], - gt_labels: NDArray[Any, np.int32], + pred_indices: NDArray[Any, np.int32], + gt_indices: NDArray[Any, np.int32], ) -> np.float64: """ - :param pc_points: source point cloud :param pred_indices: indices of points that belong to one plane obtained as a result of segmentation :param gt_indices: indices of points belonging to the reference plane :return: iou metric value for plane """ - __iou_dice_mean_bechmark_asserts(pred_labels, gt_labels) + __iou_dice_mean_bechmark_asserts(pred_indices, gt_indices) - return __iou(pred_labels, gt_labels) + return __iou(pred_indices, gt_indices) def dice( - pred_labels: NDArray[Any, np.int32], - gt_labels: NDArray[Any, np.int32], + pred_indices: NDArray[Any, np.int32], + gt_indices: NDArray[Any, np.int32], ) -> np.float64: """ - :param pc_points: source point cloud - :param pred_labels: labels of points that belong to one planes obtained as a result of segmentation - :param gt_labels: labels of points belonging to the reference planes + :param pred_indices: labels of points that belong to one plane obtained as a result of segmentation + :param gt_indices: labels of points belonging to the reference plane :return: iou metric value for plane """ - __iou_dice_mean_bechmark_asserts(pred_labels, gt_labels) + __iou_dice_mean_bechmark_asserts(pred_indices, gt_indices) - return __dice(pred_labels, gt_labels) + return __dice(pred_indices, gt_indices) def precision( @@ -64,7 +62,6 @@ def precision( tp_condition: str, ) -> np.float64: """ - :param pc_points: source point cloud :param pred_labels: labels of points that belong to one planes obtained as a result of segmentation :param gt_labels: labels of points belonging to the reference planes :param tp_condition: helper function to calculate statistics: {'iou'} @@ -81,7 +78,6 @@ def recall( tp_condition: str, ) -> np.float64: """ - :param pc_points: source point cloud :param pred_labels: indices of points that belong to one plane obtained as a result of segmentation :param gt_labels: indices of points belonging to the reference plane :param tp_condition: helper function to calculate statistics: {'iou'} @@ -98,7 +94,6 @@ def fScore( tp_condition: str, ) -> np.float64: """ - :param pc_points: source point cloud :param pred_labels: indices of points that belong to one plane obtained as a result of segmentation :param gt_labels: indices of points belonging to the reference plane :param tp_condition: helper function to calculate statistics: {'iou'} @@ -113,12 +108,11 @@ def mean( pred_labels: NDArray[Any, np.int32], gt_labels: NDArray[Any, np.int32], metric: Callable[ - [NDArray[(Any, 3), np.float64], NDArray[Any, np.int32], NDArray[Any, np.int32]], + [NDArray[Any, np.int32], NDArray[Any, np.int32]], np.float64, ], ) -> np.float64: """ - :param pc_points: source point cloud :param pred_labels: labels of points obtained as a result of segmentation :param gt_labels: reference labels of point cloud :param metric: metric function for which you want to get the mean value @@ -135,7 +129,6 @@ def multi_value( overlap_threshold: np.float64 = 0.8, ) -> (np.float64, np.float64, np.float64, np.float64, np.float64, np.float64): """ - :param pc_points: source point cloud :param pred_labels: labels of points obtained as a result of segmentation :param gt_labels: reference labels of point cloud :param overlap_threshold: minimum value at which the planes are considered intersected