Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytorch_toolbelt/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .wing_loss import *
from .logcosh import *
from .quality_focal_loss import *
from .mcc import *
125 changes: 125 additions & 0 deletions pytorch_toolbelt/losses/mcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import Optional

import torch
from torch import Tensor
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F

__all__ = ["MCCLoss"]


class MCCLoss(_Loss):
"""
Implementation of Matthews Correlation Coefficient (MCC) loss for image
segmentation task. It supports binary cases.
Reference: https://github.com/kakumarabhishek/MCC-Loss
Paper: https://doi.org/10.1109/ISBI48211.2021.9433782
"""

def __init__(
self,
from_logits: bool = False,
reduction: str = "batch",
eps: Optional[float] = 1e-7,
):
"""
Initializes the MCCLoss class.

:param from_logits: Flag to convert logits to probabilities.
Default: False.
If True, y_pred is assumed to be logits and will
be converted to probabilities using logsigmoid.
:param reduction: Specifies the reduction to apply to the output:
- 'sample': compute loss for each sample in the
batch
- 'batch': compute loss for the whole batch
Default: 'batch'.
:param eps: Small epsilon for numerical stability.
"""
super().__init__()
assert reduction in {"sample", "batch"}, (
"reduction must be 'sample' or 'batch'"
)
self.eps = eps
self.from_logits = from_logits
self.reduction = reduction

def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
"""
Computes the Matthews Correlation Coefficient (MCC) loss.

MCC = (TP.TN - FP.FN) / sqrt((TP+FP) . (TP+FN) . (TN+FP) . (TN+FN)),
where TP, TN, FP, and FN are elements in the confusion matrix.

:param y_pred: Predicted logits or probabilities.
Shape: (N, 1, H, W) or (N, H, W).
If `from_logits=True`, logits are converted to
probabilities.
:param y_true: Ground truth labels.
Shape: (N, 1, H, W) or (N, H, W).
Values should be 0 or 1.
:return: Computed MCC loss
"""
# Input validation.
assert y_pred.shape == y_true.shape, (
f"y_pred and y_true must have the same shape, "
f"but got {y_pred.shape} and {y_true.shape}"
)

if not self.from_logits:
assert torch.all(y_pred >= 0) and torch.all(y_pred <= 1), (
"y_pred must be in [0, 1] range when from_logits=False"
)

# Ensure inputs are 4D: (N, 1, H, W)
if y_pred.ndim == 3:
y_pred = y_pred.unsqueeze(1)
if y_true.ndim == 3:
y_true = y_true.unsqueeze(1)

y_true = y_true.float()
y_pred = y_pred.float()

# Obtain the batch size.
batch_size = y_true.shape[0]

# Flatten spatial dimensions
y_true = y_true.view(batch_size, -1)
y_pred = y_pred.view(batch_size, -1)

# Convert logits to probabilities if needed.
if self.from_logits:
# Use logsigmoid to avoid numerical instability.
y_pred = F.logsigmoid(y_pred).exp()

# Compute the terms of the confusion matrix.
if self.reduction == "sample":
# Sum over the sample dimension for sample-wise reduction.
dim = 1
else:
# Flatten all dimensions and sum once for batch-wise reduction.
y_true = y_true.view(-1)
y_pred = y_pred.view(-1)
dim = 0

tp = (y_pred * y_true).sum(dim=dim)
tn = ((1 - y_pred) * (1 - y_true)).sum(dim=dim)
fp = (y_pred * (1 - y_true)).sum(dim=dim)
fn = ((1 - y_pred) * y_true).sum(dim=dim)

# Special case: perfect predictions.
# In this case, FP and FN are zero, so we return 0 since the MCC is 1.
if (fp == 0).all() and (fn == 0).all():
return tp.new_tensor(0.0)

# Compute the numerator and denominator of the MCC expression.
numerator = tp * tn - fp * fn
denominator = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
# Add epsilon to avoid division by zero.
denominator = torch.sqrt(denominator) + self.eps

# Compute the MCC loss.
mcc = numerator / denominator
loss = 1 - mcc

return loss.mean()
58 changes: 58 additions & 0 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,61 @@ def test_bbce():
y = torch.tensor([0, 1, 1, 1, 1]).float()
loss = L.balanced_binary_cross_entropy_with_logits(x, y, gamma=1, reduction="none")
print(loss)


@torch.no_grad()
def test_mcc_loss():
eps = 1e-5

# Ideal case - perfect predictions, all ones. Sample-wise reduction.
criterion = L.MCCLoss(from_logits=False, reduction="sample")
y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, 1, -1)
y_true = torch.tensor([1, 1, 1]).view(1, 1, 1, -1)
loss = criterion(y_pred, y_true)
assert float(loss) == pytest.approx(0.0, abs=eps)

# Ideal case - perfect predictions, all zeros. Batch-wise reduction.
criterion = L.MCCLoss(from_logits=False, reduction="batch")
y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, 1, -1)
y_true = torch.tensor([0, 0, 0]).view(1, 1, 1, -1)
loss = criterion(y_pred, y_true)
assert float(loss) == pytest.approx(0.0, abs=eps)

# Ideal case - perfect predictions with mixed values. Sample-wise reduction.
criterion = L.MCCLoss(from_logits=False, reduction="sample")
y_pred = torch.tensor([[1.0, 1.0], [0.0, 0.0]]).view(2, 1, 1, -1)
y_true = torch.tensor([[1, 1], [0, 0]]).view(2, 1, 1, -1)
loss = criterion(y_pred, y_true)
assert float(loss) == pytest.approx(0.0, abs=eps)

# Ideal case - perfect predictions with mixed values. Batch-wise reduction.
criterion = L.MCCLoss(from_logits=False, reduction="batch")
y_pred = torch.tensor([[1.0, 1.0], [0.0, 0.0]]).view(2, 1, 1, -1)
y_true = torch.tensor([[1, 1], [0, 0]]).view(2, 1, 1, -1)
loss = criterion(y_pred, y_true)
assert float(loss) == pytest.approx(0.0, abs=eps)

# Ideal case - perfect predictions with logits.
criterion_logits = L.MCCLoss(from_logits=True)
y_pred = torch.tensor([10.0, -10.0, 10.0]).view(1, 1, 1, -1)
y_true = torch.tensor([1, 0, 1]).view(1, 1, 1, -1)
loss = criterion_logits(y_pred, y_true)
assert float(loss) == pytest.approx(0.0, abs=eps)

# Random case - mixed predictions. Sample-wise reduction.
criterion = L.MCCLoss(from_logits=False, reduction="sample")
shape = (4, 3, 5, 5)
y_pred = torch.bernoulli(torch.rand(shape))
y_true = torch.bernoulli(torch.rand(shape))
loss = criterion(y_pred, y_true)
# Check that the loss is between 0 and 2.
assert 0.0 <= float(loss) <= 2.0

# Random case - mixed predictions. Batch-wise reduction.
criterion = L.MCCLoss(from_logits=False, reduction="batch")
shape = (4, 3, 5, 5)
y_pred = torch.bernoulli(torch.rand(shape))
y_true = torch.bernoulli(torch.rand(shape))
loss = criterion(y_pred, y_true)
# Check that the loss is between 0 and 2.
assert 0.0 <= float(loss) <= 2.0