From 20516595c4eec95563df52df3632fabd6847c77f Mon Sep 17 00:00:00 2001 From: tanemaki Date: Mon, 25 Nov 2024 23:36:19 +0900 Subject: [PATCH 1/2] * Fix the bug in F1AdaptiveThreshold which occurs only when there are no anomalous images in a validation set Signed-off-by: tanemaki --- src/anomalib/metrics/threshold/f1_adaptive_threshold.py | 5 +++++ tests/unit/metrics/test_adaptive_threshold.py | 2 ++ 2 files changed, 7 insertions(+) diff --git a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py index cb2ba1cd19..408d5956d6 100644 --- a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py +++ b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py @@ -75,6 +75,11 @@ def compute(self) -> torch.Tensor: # special case where recall is 1.0 even for the highest threshold. # In this case 'thresholds' will be scalar. self.value = thresholds + elif not any(1 in batch for batch in self.target): + # another special case where there are no anomalous image in the validation set. + # In this case, the adaptive threshold will take the value of the highest anomaly score observed in the + # normal validation images. + self.value = torch.max(thresholds) else: self.value = thresholds[torch.argmax(f1_score)] return self.value diff --git a/tests/unit/metrics/test_adaptive_threshold.py b/tests/unit/metrics/test_adaptive_threshold.py index 1eadab4e4d..fdeffb7d24 100644 --- a/tests/unit/metrics/test_adaptive_threshold.py +++ b/tests/unit/metrics/test_adaptive_threshold.py @@ -18,6 +18,8 @@ [ (torch.Tensor([0, 0, 0, 1, 1]), torch.Tensor([2.3, 1.6, 2.6, 7.9, 3.3]), 3.3), # standard case (torch.Tensor([1, 0, 0, 0]), torch.Tensor([4, 3, 2, 1]), 4), # 100% recall for all thresholds + (torch.Tensor([1, 1, 1, 1]), torch.Tensor([4, 3, 2, 1]), 1), # use minimum value when all images are anomalous + (torch.Tensor([0, 0, 0, 0]), torch.Tensor([4, 3, 2, 1]), 4), # use maximum value when all images are normal ], ) def test_adaptive_threshold(labels: torch.Tensor, preds: torch.Tensor, target_threshold: int | float) -> None: From a80ab46de4f50dd2013fbc258ce4a8d83075f675 Mon Sep 17 00:00:00 2001 From: tanemaki Date: Thu, 28 Nov 2024 11:41:42 +0900 Subject: [PATCH 2/2] * Avoid unnecessary computation of the PR curve when all targets are either normal or anomalous. * Add an explicit check in the metric implementation for the case where all images are anomalous. Signed-off-by: tanemaki --- .../threshold/f1_adaptive_threshold.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py index 408d5956d6..5a7a618402 100644 --- a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py +++ b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py @@ -6,6 +6,7 @@ import logging import torch +from torchmetrics.utilities.data import dim_zero_cat from anomalib.metrics.precision_recall_curve import BinaryPrecisionRecallCurve @@ -69,17 +70,29 @@ def compute(self) -> torch.Tensor: ) logging.warning(msg) + self.value = torch.max(dim_zero_cat(self.preds)) + + return self.value + + if not any(0 in batch for batch in self.target): + msg = ( + "The validation set does not contain any normal images. As a result, the adaptive threshold will " + "take the value of the lowest anomaly score observed in the anomalous validation images, which may " + "lead to poor predictions. For a more reliable adaptive threshold computation, please add some normal " + "images to the validation set." + ) + logging.warning(msg) + + self.value = torch.min(dim_zero_cat(self.preds)) + + return self.value + precision, recall, thresholds = super().compute() f1_score = (2 * precision * recall) / (precision + recall + 1e-10) if thresholds.dim() == 0: # special case where recall is 1.0 even for the highest threshold. # In this case 'thresholds' will be scalar. self.value = thresholds - elif not any(1 in batch for batch in self.target): - # another special case where there are no anomalous image in the validation set. - # In this case, the adaptive threshold will take the value of the highest anomaly score observed in the - # normal validation images. - self.value = torch.max(thresholds) else: self.value = thresholds[torch.argmax(f1_score)] return self.value