From 501ed9d4612ee30586d7e9155c5a081e57208a6a Mon Sep 17 00:00:00 2001 From: Kumar Abhishek Date: Fri, 9 May 2025 20:46:33 -0700 Subject: [PATCH 1/2] Add MCCLoss implementation for binary image segmentation --- pytorch_toolbelt/losses/__init__.py | 1 + pytorch_toolbelt/losses/mcc.py | 59 +++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 pytorch_toolbelt/losses/mcc.py diff --git a/pytorch_toolbelt/losses/__init__.py b/pytorch_toolbelt/losses/__init__.py index 50af25fad..87163f2b3 100644 --- a/pytorch_toolbelt/losses/__init__.py +++ b/pytorch_toolbelt/losses/__init__.py @@ -14,3 +14,4 @@ from .wing_loss import * from .logcosh import * from .quality_focal_loss import * +from .mcc import * diff --git a/pytorch_toolbelt/losses/mcc.py b/pytorch_toolbelt/losses/mcc.py new file mode 100644 index 000000000..0f0321f97 --- /dev/null +++ b/pytorch_toolbelt/losses/mcc.py @@ -0,0 +1,59 @@ +from typing import Optional + +import torch +from torch import Tensor +from torch.nn.modules.loss import _Loss + +__all__ = ["MCCLoss"] + + +class MCCLoss(_Loss): + """ + Implementation of Matthews Correlation Coefficient (MCC) loss for image segmentation task. + It supports binary cases. + Reference: https://github.com/kakumarabhishek/MCC-Loss + Paper: https://doi.org/10.1109/ISBI48211.2021.9433782 + """ + + def __init__(self, eps: Optional[float] = 1e-7): + """ + Initializes the MCCLoss class. + + :param eps: Small epsilon for numerical stability + """ + super().__init__() + self.eps = eps + + def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: + """ + Computes the Matthews Correlation Coefficient (MCC) loss. + MCC = (TP.TN - FP.FN) / sqrt((TP+FP) . (TP+FN) . (TN+FP) . (TN+FN)) + where TP, TN, FP, and FN are elements in the confusion matrix. + + :param y_pred: Predicted probabilities (logits) of shape (N, 1, H, W) + :param y_true: Ground truth labels of shape (N, 1, H, W) + :return: Computed MCC loss + """ + + batch_size = y_true.shape[0] + + y_true = y_true.view(batch_size, 1, -1) + y_pred = y_pred.view(batch_size, 1, -1) + + tp = torch.sum(torch.mul(y_pred, y_true)) + self.eps + tn = torch.sum(torch.mul((1 - y_pred), (1 - y_true))) + self.eps + fp = torch.sum(torch.mul(y_pred, (1 - y_true))) + self.eps + fn = torch.sum(torch.mul((1 - y_pred), y_true)) + self.eps + + numerator = torch.mul(tp, tn) - torch.mul(fp, fn) + denominator = torch.sqrt( + torch.add(tp, fp) + * torch.add(tp, fn) + * torch.add(tn, fp) + * torch.add(tn, fn) + ) + + mcc = torch.div(numerator.sum(), denominator.sum()) + loss = 1 - mcc + + return loss \ No newline at end of file From 0cc5c4d0370aca49383e5c5f5fdfcc45e8180b72 Mon Sep 17 00:00:00 2001 From: Kumar Abhishek Date: Sat, 17 May 2025 22:37:30 -0700 Subject: [PATCH 2/2] Update `mcc.py` and `test_losses.py` Add option to calculate loss from either logits or predictions. Add option to perform sample-wise or batch-wise reduction. Add tests for `MCCLoss`. --- pytorch_toolbelt/losses/mcc.py | 110 ++++++++++++++++++++++++++------- tests/test_losses.py | 58 +++++++++++++++++ 2 files changed, 146 insertions(+), 22 deletions(-) diff --git a/pytorch_toolbelt/losses/mcc.py b/pytorch_toolbelt/losses/mcc.py index 0f0321f97..e5b33d475 100644 --- a/pytorch_toolbelt/losses/mcc.py +++ b/pytorch_toolbelt/losses/mcc.py @@ -3,57 +3,123 @@ import torch from torch import Tensor from torch.nn.modules.loss import _Loss +import torch.nn.functional as F __all__ = ["MCCLoss"] class MCCLoss(_Loss): """ - Implementation of Matthews Correlation Coefficient (MCC) loss for image segmentation task. - It supports binary cases. + Implementation of Matthews Correlation Coefficient (MCC) loss for image + segmentation task. It supports binary cases. Reference: https://github.com/kakumarabhishek/MCC-Loss Paper: https://doi.org/10.1109/ISBI48211.2021.9433782 """ - def __init__(self, eps: Optional[float] = 1e-7): + def __init__( + self, + from_logits: bool = False, + reduction: str = "batch", + eps: Optional[float] = 1e-7, + ): """ Initializes the MCCLoss class. - :param eps: Small epsilon for numerical stability + :param from_logits: Flag to convert logits to probabilities. + Default: False. + If True, y_pred is assumed to be logits and will + be converted to probabilities using logsigmoid. + :param reduction: Specifies the reduction to apply to the output: + - 'sample': compute loss for each sample in the + batch + - 'batch': compute loss for the whole batch + Default: 'batch'. + :param eps: Small epsilon for numerical stability. """ super().__init__() + assert reduction in {"sample", "batch"}, ( + "reduction must be 'sample' or 'batch'" + ) self.eps = eps + self.from_logits = from_logits + self.reduction = reduction def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: """ Computes the Matthews Correlation Coefficient (MCC) loss. - MCC = (TP.TN - FP.FN) / sqrt((TP+FP) . (TP+FN) . (TN+FP) . (TN+FN)) + + MCC = (TP.TN - FP.FN) / sqrt((TP+FP) . (TP+FN) . (TN+FP) . (TN+FN)), where TP, TN, FP, and FN are elements in the confusion matrix. - :param y_pred: Predicted probabilities (logits) of shape (N, 1, H, W) - :param y_true: Ground truth labels of shape (N, 1, H, W) + :param y_pred: Predicted logits or probabilities. + Shape: (N, 1, H, W) or (N, H, W). + If `from_logits=True`, logits are converted to + probabilities. + :param y_true: Ground truth labels. + Shape: (N, 1, H, W) or (N, H, W). + Values should be 0 or 1. :return: Computed MCC loss """ + # Input validation. + assert y_pred.shape == y_true.shape, ( + f"y_pred and y_true must have the same shape, " + f"but got {y_pred.shape} and {y_true.shape}" + ) + + if not self.from_logits: + assert torch.all(y_pred >= 0) and torch.all(y_pred <= 1), ( + "y_pred must be in [0, 1] range when from_logits=False" + ) + + # Ensure inputs are 4D: (N, 1, H, W) + if y_pred.ndim == 3: + y_pred = y_pred.unsqueeze(1) + if y_true.ndim == 3: + y_true = y_true.unsqueeze(1) + y_true = y_true.float() + y_pred = y_pred.float() + + # Obtain the batch size. batch_size = y_true.shape[0] - y_true = y_true.view(batch_size, 1, -1) - y_pred = y_pred.view(batch_size, 1, -1) + # Flatten spatial dimensions + y_true = y_true.view(batch_size, -1) + y_pred = y_pred.view(batch_size, -1) - tp = torch.sum(torch.mul(y_pred, y_true)) + self.eps - tn = torch.sum(torch.mul((1 - y_pred), (1 - y_true))) + self.eps - fp = torch.sum(torch.mul(y_pred, (1 - y_true))) + self.eps - fn = torch.sum(torch.mul((1 - y_pred), y_true)) + self.eps + # Convert logits to probabilities if needed. + if self.from_logits: + # Use logsigmoid to avoid numerical instability. + y_pred = F.logsigmoid(y_pred).exp() - numerator = torch.mul(tp, tn) - torch.mul(fp, fn) - denominator = torch.sqrt( - torch.add(tp, fp) - * torch.add(tp, fn) - * torch.add(tn, fp) - * torch.add(tn, fn) - ) + # Compute the terms of the confusion matrix. + if self.reduction == "sample": + # Sum over the sample dimension for sample-wise reduction. + dim = 1 + else: + # Flatten all dimensions and sum once for batch-wise reduction. + y_true = y_true.view(-1) + y_pred = y_pred.view(-1) + dim = 0 + + tp = (y_pred * y_true).sum(dim=dim) + tn = ((1 - y_pred) * (1 - y_true)).sum(dim=dim) + fp = (y_pred * (1 - y_true)).sum(dim=dim) + fn = ((1 - y_pred) * y_true).sum(dim=dim) + + # Special case: perfect predictions. + # In this case, FP and FN are zero, so we return 0 since the MCC is 1. + if (fp == 0).all() and (fn == 0).all(): + return tp.new_tensor(0.0) + + # Compute the numerator and denominator of the MCC expression. + numerator = tp * tn - fp * fn + denominator = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + # Add epsilon to avoid division by zero. + denominator = torch.sqrt(denominator) + self.eps - mcc = torch.div(numerator.sum(), denominator.sum()) + # Compute the MCC loss. + mcc = numerator / denominator loss = 1 - mcc - return loss \ No newline at end of file + return loss.mean() diff --git a/tests/test_losses.py b/tests/test_losses.py index be20b3370..bc0279ef3 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -269,3 +269,61 @@ def test_bbce(): y = torch.tensor([0, 1, 1, 1, 1]).float() loss = L.balanced_binary_cross_entropy_with_logits(x, y, gamma=1, reduction="none") print(loss) + + +@torch.no_grad() +def test_mcc_loss(): + eps = 1e-5 + + # Ideal case - perfect predictions, all ones. Sample-wise reduction. + criterion = L.MCCLoss(from_logits=False, reduction="sample") + y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, 1, -1) + y_true = torch.tensor([1, 1, 1]).view(1, 1, 1, -1) + loss = criterion(y_pred, y_true) + assert float(loss) == pytest.approx(0.0, abs=eps) + + # Ideal case - perfect predictions, all zeros. Batch-wise reduction. + criterion = L.MCCLoss(from_logits=False, reduction="batch") + y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, 1, -1) + y_true = torch.tensor([0, 0, 0]).view(1, 1, 1, -1) + loss = criterion(y_pred, y_true) + assert float(loss) == pytest.approx(0.0, abs=eps) + + # Ideal case - perfect predictions with mixed values. Sample-wise reduction. + criterion = L.MCCLoss(from_logits=False, reduction="sample") + y_pred = torch.tensor([[1.0, 1.0], [0.0, 0.0]]).view(2, 1, 1, -1) + y_true = torch.tensor([[1, 1], [0, 0]]).view(2, 1, 1, -1) + loss = criterion(y_pred, y_true) + assert float(loss) == pytest.approx(0.0, abs=eps) + + # Ideal case - perfect predictions with mixed values. Batch-wise reduction. + criterion = L.MCCLoss(from_logits=False, reduction="batch") + y_pred = torch.tensor([[1.0, 1.0], [0.0, 0.0]]).view(2, 1, 1, -1) + y_true = torch.tensor([[1, 1], [0, 0]]).view(2, 1, 1, -1) + loss = criterion(y_pred, y_true) + assert float(loss) == pytest.approx(0.0, abs=eps) + + # Ideal case - perfect predictions with logits. + criterion_logits = L.MCCLoss(from_logits=True) + y_pred = torch.tensor([10.0, -10.0, 10.0]).view(1, 1, 1, -1) + y_true = torch.tensor([1, 0, 1]).view(1, 1, 1, -1) + loss = criterion_logits(y_pred, y_true) + assert float(loss) == pytest.approx(0.0, abs=eps) + + # Random case - mixed predictions. Sample-wise reduction. + criterion = L.MCCLoss(from_logits=False, reduction="sample") + shape = (4, 3, 5, 5) + y_pred = torch.bernoulli(torch.rand(shape)) + y_true = torch.bernoulli(torch.rand(shape)) + loss = criterion(y_pred, y_true) + # Check that the loss is between 0 and 2. + assert 0.0 <= float(loss) <= 2.0 + + # Random case - mixed predictions. Batch-wise reduction. + criterion = L.MCCLoss(from_logits=False, reduction="batch") + shape = (4, 3, 5, 5) + y_pred = torch.bernoulli(torch.rand(shape)) + y_true = torch.bernoulli(torch.rand(shape)) + loss = criterion(y_pred, y_true) + # Check that the loss is between 0 and 2. + assert 0.0 <= float(loss) <= 2.0 \ No newline at end of file