Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/evops/metrics/DefaultBenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,8 @@ def __fScore(
precision = __precision(pred_labels, gt_labels, tp_condition)
recall = __recall(pred_labels, gt_labels, tp_condition)

# Prevent division by zero
if precision + recall == 0:
return 0.0

return 2 * precision * recall / (precision + recall)
2 changes: 1 addition & 1 deletion src/evops/metrics/MeanBenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __mean(
pred_labels: NDArray[Any, np.int32],
gt_labels: NDArray[Any, np.int32],
metric: Callable[
[NDArray[(Any, 3), np.float64], NDArray[Any, np.int32], NDArray[Any, np.int32]],
[NDArray[Any, np.int32], NDArray[Any, np.int32]],
np.float64,
],
) -> np.float64:
Expand Down
29 changes: 11 additions & 18 deletions src/evops/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,31 @@


def iou(
pred_labels: NDArray[Any, np.int32],
gt_labels: NDArray[Any, np.int32],
pred_indices: NDArray[Any, np.int32],
gt_indices: NDArray[Any, np.int32],
) -> np.float64:
"""
:param pc_points: source point cloud
:param pred_indices: indices of points that belong to one plane obtained as a result of segmentation
:param gt_indices: indices of points belonging to the reference plane
:return: iou metric value for plane
"""
__iou_dice_mean_bechmark_asserts(pred_labels, gt_labels)
__iou_dice_mean_bechmark_asserts(pred_indices, gt_indices)

return __iou(pred_labels, gt_labels)
return __iou(pred_indices, gt_indices)


def dice(
pred_labels: NDArray[Any, np.int32],
gt_labels: NDArray[Any, np.int32],
pred_indices: NDArray[Any, np.int32],
gt_indices: NDArray[Any, np.int32],
) -> np.float64:
"""
:param pc_points: source point cloud
:param pred_labels: labels of points that belong to one planes obtained as a result of segmentation
:param gt_labels: labels of points belonging to the reference planes
:param pred_indices: labels of points that belong to one plane obtained as a result of segmentation
:param gt_indices: labels of points belonging to the reference plane
:return: iou metric value for plane
"""
__iou_dice_mean_bechmark_asserts(pred_labels, gt_labels)
__iou_dice_mean_bechmark_asserts(pred_indices, gt_indices)

return __dice(pred_labels, gt_labels)
return __dice(pred_indices, gt_indices)


def precision(
Expand All @@ -64,7 +62,6 @@ def precision(
tp_condition: str,
) -> np.float64:
"""
:param pc_points: source point cloud
:param pred_labels: labels of points that belong to one planes obtained as a result of segmentation
:param gt_labels: labels of points belonging to the reference planes
:param tp_condition: helper function to calculate statistics: {'iou'}
Expand All @@ -81,7 +78,6 @@ def recall(
tp_condition: str,
) -> np.float64:
"""
:param pc_points: source point cloud
:param pred_labels: indices of points that belong to one plane obtained as a result of segmentation
:param gt_labels: indices of points belonging to the reference plane
:param tp_condition: helper function to calculate statistics: {'iou'}
Expand All @@ -98,7 +94,6 @@ def fScore(
tp_condition: str,
) -> np.float64:
"""
:param pc_points: source point cloud
:param pred_labels: indices of points that belong to one plane obtained as a result of segmentation
:param gt_labels: indices of points belonging to the reference plane
:param tp_condition: helper function to calculate statistics: {'iou'}
Expand All @@ -113,12 +108,11 @@ def mean(
pred_labels: NDArray[Any, np.int32],
gt_labels: NDArray[Any, np.int32],
metric: Callable[
[NDArray[(Any, 3), np.float64], NDArray[Any, np.int32], NDArray[Any, np.int32]],
[NDArray[Any, np.int32], NDArray[Any, np.int32]],
np.float64,
],
) -> np.float64:
"""
:param pc_points: source point cloud
:param pred_labels: labels of points obtained as a result of segmentation
:param gt_labels: reference labels of point cloud
:param metric: metric function for which you want to get the mean value
Expand All @@ -135,7 +129,6 @@ def multi_value(
overlap_threshold: np.float64 = 0.8,
) -> (np.float64, np.float64, np.float64, np.float64, np.float64, np.float64):
"""
:param pc_points: source point cloud
:param pred_labels: labels of points obtained as a result of segmentation
:param gt_labels: reference labels of point cloud
:param overlap_threshold: minimum value at which the planes are considered intersected
Expand Down