Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Report nan in PR curve above max threshold #2873

Open
ytang137 opened this issue Dec 17, 2024 · 1 comment
Open

Report nan in PR curve above max threshold #2873

ytang137 opened this issue Dec 17, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@ytang137
Copy link

ytang137 commented Dec 17, 2024

🚀 Feature

Currently the PR curve metrics report (0,0) at threshold greater than the maximum pred confidence. For example:

from torchmetrics.classification import BinaryPrecisionRecallCurve
import torch

preds = torch.tensor([0.3,0.4,0.5,0.6,0.7,0.8])
targets = torch.tensor([0,1,1,0,1,0])
pr_curve = BinaryPrecisionRecallCurve(thresholds=10)
pr_curve(preds, targets)

This code returns the following result:

(tensor([0.5000, 0.5000, 0.5000, 0.6000, 0.5000, 0.3333, 0.5000, 0.0000, 0.0000, 0.0000, 1.0000]),
  tensor([1.0000, 1.0000, 1.0000, 1.0000, 0.6667, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000]),
  tensor([0.0000, 0.1111, 0.2222, 0.3333, 0.4444, 0.5556, 0.6667, 0.7778, 0.8889, 1.0000]))

Since there are no pred confidence>=0.889, I think it makes more sense to return nans for precision and recall. This makes it possible to filter out (0,0) points artificially introduced by the fix threshold grid.

Motivation

Make it possible to filter out (0,0) points artificially introduced by the fix threshold grid in PR curve metrics.

Pitch

Return nans for precision and recall for thresholds greater than the max pred confidence.

Alternatives

Additional context

@ytang137 ytang137 added the enhancement New feature or request label Dec 17, 2024
@VDFaller
Copy link

VDFaller commented Feb 27, 2025

If this happened would that also prevent thresholds >1 for PrecisionAtFixedRecall?

import torch
from torchmetrics.classification import BinaryPrecisionAtFixedRecall
from pytorch_lightning import seed_everything
seed_everything(42)
target = torch.randint(0, 1, (1000,))
preds = torch.rand(1000)
metric = BinaryPrecisionAtFixedRecall(min_recall=.1)
precision, threshold = metric(preds, target)
print(threshold)  # tensor(1000000.)

I'm unable, so far, to make it happen in a test case without all zeros for target; but in reality it's happening in my dataset with multiple targets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants