Skip to content

Multi Label Precision numerical stability issues with MPS #2955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
dmnkf opened this issue Feb 15, 2025 · 3 comments
Open

Multi Label Precision numerical stability issues with MPS #2955

dmnkf opened this issue Feb 15, 2025 · 3 comments
Labels
bug / fix Something isn't working duplicate This issue or pull request already exists help wanted Extra attention is needed v1.6.x

Comments

@dmnkf
Copy link

dmnkf commented Feb 15, 2025

🐛 Bug

Currently trying to track metrics of a rather complicated ECG multi-label classifier (20 labels) and noticed that training with MPS breaks the MultiLabelPrecision in an odd way, so that the returned values are clearly outside of any possible scope e.g. -1.49272e+25. First I thought it is a numerical issue on the logits or some weird input shape issue I am not handling properly, but after doing some further digging I noticed that the inconsistency happens with device MPS specifically and can be reproduced. I couldn't get to the bottom of what leads to this as it is inconsistent, but most likely numerical.

Would love some guidance on this!

To Reproduce

Below code snippet with manually calculated expected outputs. I hope my math isn't wrong, but nonetheless, if I set device to MPS and run the code a few times (may work on some runs) I will eventually run into some very odd numerical issues on test 3 and 4:

Code sample
import torch
import torchmetrics
import logging

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

def test_label_scenarios(device='cpu'):
    num_labels = 20

    # Test scenarios specifically for 20-label case
    test_cases = [
        {
            "name": "All labels correct",
            "preds": torch.full((100, 20), 0.7),
            "target": torch.ones((100, 20)),
            "expected": 1.0,
        },
        {
            "name": "All labels wrong",
            "preds": torch.full((100, 20), 0.7),
            "target": torch.zeros((100, 20)),
            "expected": 0.0,
        },
        {
            "name": "Exactly 5 labels correct per sample",
            "preds": torch.cat(
                [
                    torch.full((100, 5), 0.9),  # 5 correct predictions
                    torch.full((100, 15), 0.1),  # 15 never predicted positive
                ],
                dim=1,
            ),
            "target": torch.cat(
                [
                    torch.ones((100, 5)),
                    torch.zeros((100, 15))
                ],
                dim=1
            ),
            # For macro, 5 labels have precision=1, 15 labels have precision=0 => avg=0.25
            "expected": 0.25,
        },
        {
            "name": "5 correct + 5 false positives",
            "preds": torch.cat(
                [
                    torch.full((100, 5), 0.9),  # 5 always correct
                    torch.full((100, 5), 0.9),  # 5 always false
                    torch.full((100, 10), 0.1),
                ],
                dim=1,
            ),
            "target": torch.cat(
                [
                    torch.ones((100, 5)),   # True positives for first 5 labels
                    torch.zeros((100, 15))  # False (and predicted + negative)
                ],
                dim=1
            ),
            # For macro, first 5 => precision=1, next 5 => precision=0, last 10 => 0 => avg=0.25
            "expected": 0.25,
        },
        {
            "name": "All False Predictions",
            "preds": torch.ones((1000, 20)),  # All predictions=1
            "target": torch.zeros((1000, 20)),  # All labels=0
            "expected": 0.0,
            "checks": ["no_nan", "exact_zero"],
        },
        {
            "name": "Numerical instability test",
            "preds": torch.randn(5000, 20).sigmoid(),
            "target": torch.randint(0, 2, (5000, 20)),
            "expected": None,
        },
    ]

    for case in test_cases:
        print(f"\n=== {case['name']} ===")

        preds = case["preds"].to(device)
        target = case["target"].to(device).long()

        precision = torchmetrics.Precision(
            task="multilabel",
            num_labels=num_labels,
            average="macro",
            threshold=0.5
        ).to(device)

        try:
            # Simulate batch updates
            for _ in range(10):
                precision.update(preds, target)

            result = precision.compute()
            logger.debug(f"Computed precision: {result}")

            if case["expected"] is not None:
                assert torch.isclose(
                    result, torch.tensor(case["expected"], device=device),
                    atol=1e-7
                ), f"Expected {case['expected']}, got {result:.4f}"

            # Check for numerical stability
            assert not torch.isnan(result).any(), "NaN detected in precision"
            assert result.abs().max() < 1e6, f"Precision value exploded: {result}"

            print(f"✓ PASSED - Final Precision: {result:.4f}")
        except Exception as e:
            print(f"✗ FAILED: {str(e)}")
        finally:
            precision.reset()


if __name__ == "__main__":
    test_label_scenarios('mps')

Expected behavior

Test 3 and 4 should return this using MPS:

=== Exactly 5 labels correct per sample ===
DEBUG:main:Computed precision: 0.25
✓ PASSED - Final Precision: 0.2500

=== 5 correct + 5 false positives ===
DEBUG:main:Computed precision: 0.25
✓ PASSED - Final Precision: 0.2500

instead of something like:

=== Exactly 5 labels correct per sample ===
DEBUG:main:Computed precision: 16985032228864.0
✗ FAILED: Expected 0.25, got 16985032228864.0000

=== 5 correct + 5 false positives ===
DEBUG:main:Computed precision: -1.4927448875054142e+25
✗ FAILED: Expected 0.25, got -14927448875054142008066048.0000

Environment

  • TorchMetrics version (if build from source, add commit SHA): 1.6.1
  • Python & PyTorch Version (e.g., 1.0): 3.12.5
  • Any other relevant information such as OS (e.g., Linux): MacOS Sonoma 14.7.4, M3 Max, but should be reproducible on all Apple Silicon)
@dmnkf dmnkf added bug / fix Something isn't working help wanted Extra attention is needed labels Feb 15, 2025
Copy link

Hi! Thanks for your contribution! Great first issue!

@Borda Borda added the v1.6.x label Feb 28, 2025
@SkafteNicki
Copy link
Member

Most likely related to #1727

@Borda Borda added the duplicate This issue or pull request already exists label Mar 21, 2025
@Borda
Copy link
Member

Borda commented Mar 21, 2025

maybe worth raising some concerns with the Torch team?
cc: @lantiga

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working duplicate This issue or pull request already exists help wanted Extra attention is needed v1.6.x
Projects
None yet
Development

No branches or pull requests

3 participants