diff --git a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py index cb2ba1cd19..5a7a618402 100644 --- a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py +++ b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py @@ -6,6 +6,7 @@ import logging import torch +from torchmetrics.utilities.data import dim_zero_cat from anomalib.metrics.precision_recall_curve import BinaryPrecisionRecallCurve @@ -69,6 +70,23 @@ def compute(self) -> torch.Tensor: ) logging.warning(msg) + self.value = torch.max(dim_zero_cat(self.preds)) + + return self.value + + if not any(0 in batch for batch in self.target): + msg = ( + "The validation set does not contain any normal images. As a result, the adaptive threshold will " + "take the value of the lowest anomaly score observed in the anomalous validation images, which may " + "lead to poor predictions. For a more reliable adaptive threshold computation, please add some normal " + "images to the validation set." + ) + logging.warning(msg) + + self.value = torch.min(dim_zero_cat(self.preds)) + + return self.value + precision, recall, thresholds = super().compute() f1_score = (2 * precision * recall) / (precision + recall + 1e-10) if thresholds.dim() == 0: diff --git a/tests/unit/metrics/test_adaptive_threshold.py b/tests/unit/metrics/test_adaptive_threshold.py index 1eadab4e4d..fdeffb7d24 100644 --- a/tests/unit/metrics/test_adaptive_threshold.py +++ b/tests/unit/metrics/test_adaptive_threshold.py @@ -18,6 +18,8 @@ [ (torch.Tensor([0, 0, 0, 1, 1]), torch.Tensor([2.3, 1.6, 2.6, 7.9, 3.3]), 3.3), # standard case (torch.Tensor([1, 0, 0, 0]), torch.Tensor([4, 3, 2, 1]), 4), # 100% recall for all thresholds + (torch.Tensor([1, 1, 1, 1]), torch.Tensor([4, 3, 2, 1]), 1), # use minimum value when all images are anomalous + (torch.Tensor([0, 0, 0, 0]), torch.Tensor([4, 3, 2, 1]), 4), # use maximum value when all images are normal ], ) def test_adaptive_threshold(labels: torch.Tensor, preds: torch.Tensor, target_threshold: int | float) -> None: