Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major Development #15

Merged
merged 25 commits into from
Jan 2, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
implement Beyond AUROC \& Co.
edadaltocg committed Jun 7, 2023
commit 99a03323b5001e4f4e8ac48db693aad0b5def96a
49 changes: 49 additions & 0 deletions src/detectors/eval.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,27 @@ def fpr_at_fixed_tpr(fprs: np.ndarray, tprs: np.ndarray, thresholds: np.ndarray,
return float(fprs[idx]), float(tprs[idx]), float(thresholds[idx])


def fnr_at_fixed_tnr(fprs: np.ndarray, tprs: np.ndarray, thresholds: np.ndarray, tnr_level: float = 0.95):
"""Return the FNR at a fixed TNR level.

Args:
fprs (np.ndarray): False positive rates.
tprs (np.ndarray): True positive rates.
thresholds (np.ndarray): Thresholds.
tnr_level (float, optional): TNR level. Defaults to 0.95.

Returns:
Tuple[float, float, float]: FNR, TNR, threshold."""
tnrs = 1 - fprs
fnrs = 1 - tprs

if all(tnrs < tnr_level):
raise ValueError(f"No threshold allows for TNR at least {tnr_level}.")
idxs = [i for i, x in enumerate(tnrs) if x >= tnr_level]
idx = min(idxs)
return float(fnrs[idx]), float(tnrs[idx]), float(thresholds[idx])


def compute_detection_error(fpr: float, tpr: float, pos_ratio: float):
"""Compute the detection error.

@@ -65,6 +86,29 @@ def minimum_detection_error(fprs: np.ndarray, tprs: np.ndarray, pos_ratio: float
return detection_errors[idx]


def aufnr_aufpr_autc(fprs: np.ndarray, tprs: np.ndarray, thresholds: np.ndarray):
"""Compute the AUFNR, AUFPR, and AUTC metrics.

Args:
fprs (np.ndarray): False positive rates.
tprs (np.ndarray): True positive rates.
thresholds (np.ndarray): Thresholds.


Returns:
Tuple[float, float, float]: AUFNR, AUFPR, AUTC.

References:
[1] Humblot-Renaux et. al. Beyond AUROC \& Co. for Evaluating Out-of-Distribution Detection Performance. 2023.
"""
fnrs = 1 - tprs
sorted_idx = np.argsort(thresholds)
aufnr = sklearn.metrics.auc(thresholds[sorted_idx], fnrs[sorted_idx])
aufpr = sklearn.metrics.auc(thresholds[sorted_idx], fprs[sorted_idx])
autc = (aufnr + aufpr) / 2
return float(aufnr), float(aufpr), float(autc)


def get_ood_results(in_scores: Tensor, ood_scores: Tensor) -> Dict[str, float]:
"""Compute OOD detection metrics.

@@ -97,13 +141,18 @@ def get_ood_results(in_scores: Tensor, ood_scores: Tensor) -> Dict[str, float]:
pos_ratio = np.mean(_test_labels == 1)
detection_error = minimum_detection_error(fprs, tprs, pos_ratio)

aufnr, aufpr, autc = aufnr_aufpr_autc(fprs, tprs, thrs)

results = {
"fpr_at_0.95_tpr": fpr,
"tnr_at_0.95_tpr": 1 - fpr,
"detection_error": detection_error,
"auroc": auroc,
"aupr_in": aupr_in,
"aupr_out": aupr_out,
"aufnr": aufnr,
"aufpr": aufpr,
"autc": autc,
"thr": thr,
}
return results