Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/edadaltocg/detectors into…
Browse files Browse the repository at this point in the history
… dev
  • Loading branch information
edadaltocg committed Jun 19, 2024
2 parents eb9a811 + 64c9eee commit dac1ec3
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 1 deletion.
8 changes: 8 additions & 0 deletions docs/source/detectors.methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,14 @@ detectors.methods.openmax module
:undoc-members:
:show-inheritance:

detectors.methods.plugin\_bb module
-----------------------------------

.. automodule:: detectors.methods.plugin_bb
:members:
:undoc-members:
:show-inheritance:

detectors.methods.pnml module
-----------------------------

Expand Down
8 changes: 8 additions & 0 deletions docs/source/detectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ detectors.criterions module
:undoc-members:
:show-inheritance:

detectors.ensemble module
-------------------------

.. automodule:: detectors.ensemble
:members:
:undoc-members:
:show-inheritance:

detectors.eval module
---------------------

Expand Down
2 changes: 1 addition & 1 deletion examples/ood_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def main(args):
"igeood_logits",
"max_logits",
"react",
"dice ",
"dice",
"vim",
"gradnorm",
"mahalanobis",
Expand Down
107 changes: 107 additions & 0 deletions examples/sc_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,110 @@ def main(args):
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
_logger.info(json.dumps(args.__dict__, indent=2))
main(args)
import argparse
import json
import logging
import os

import timm
import timm.data
import torch

import detectors
from detectors.config import RESULTS_DIR

_logger = logging.getLogger(__name__)


def main(args):
print(f"Running {args.pipeline} pipeline on {args.model} model")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# create model
model = timm.create_model(args.model, pretrained=True)
model.to(device)
data_config = timm.data.resolve_data_config(model.default_cfg)
test_transform = timm.data.create_transform(**data_config)
_logger.info("Test transform: %s", test_transform)
# create pipeline
pipeline = detectors.create_pipeline(
args.pipeline,
batch_size=args.batch_size,
seed=args.seed,
transform=test_transform,
limit_fit=args.limit_fit,
num_workers=3,
)

if "vit" in args.model and "pooling_op_name" in args.method_kwargs:
args.method_kwargs["pooling_op_name"] = "getitem"

# create detector
method = detectors.create_detector(args.method, model=model, **args.method_kwargs)
# run pipeline
pipeline_results = pipeline.run(method, model)
# print results
print(pipeline.report(pipeline_results["results"]))

if not args.debug:
# save results to file
results = {
"model": args.model,
"method": args.method,
**pipeline_results["results"],
"method_kwargs": args.method_kwargs,
}
filename = os.path.join(RESULTS_DIR, args.pipeline, "results.csv")
detectors.utils.append_results_to_csv_file(results, filename)

scores = pipeline_results["scores"]
labels = pipeline_results["labels"]
idxs = pipeline_results["idxs"]

results = {
"model": args.model,
"in_dataset_name": pipeline.in_dataset_name,
"method": args.method,
"method_kwargs": args.method_kwargs,
"scores": scores.numpy().tolist(),
"labels": labels.numpy().tolist(),
"idx": idxs.numpy().tolist(),
}
filename = os.path.join(RESULTS_DIR, args.pipeline, "scores.csv")
detectors.utils.append_results_to_csv_file(results, filename)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--method", type=str, default="msp")
parser.add_argument("--method_kwargs", type=json.loads, default={}, help='{"temperature":1000, "eps":0.00014}')
parser.add_argument("--pipeline", type=str, default="sc_benchmark_cifar10")
parser.add_argument("--model", type=str, default="resnet18_cifar10")
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--limit_fit", type=float, default=1)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--idx", type=int, default=None)
args = parser.parse_args()
methods = [
"msp",
"energy",
"kl_matching",
"igeood_logits",
"max_logits",
"react",
"dice",
"vim",
"gradnorm",
"mahalanobis",
"relative_mahalanobis",
"knn_euclides",
"maxcosine",
"odin",
"doctor",
]
if args.idx is not None:
args.method = methods[args.idx]

logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
_logger.info(json.dumps(args.__dict__, indent=2))
main(args)
196 changes: 196 additions & 0 deletions src/detectors/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import numpy as np
from scipy import stats


def p_value_fn(test_statistic: np.ndarray, X: np.ndarray):
"""Compute the p-value of a test statistic given a sample X.
Args:
test_statistic (np.ndarray): test statistic (n,m)
X (np.ndarray): sample (N,m)
Returns:
np.ndarray: p-values (n,m)
"""
mult_factor_min = np.where(X.min(0) > 0, np.array(1 / len(X)), np.array(len(X)))
mult_factor_max = np.where(X.max(0) > 0, np.array(len(X)), np.array(1 / len(X)))
lower_bound = X.min(0) * mult_factor_min
upper_bound = X.max(0) * mult_factor_max
X = np.concatenate((lower_bound.reshape(1, -1), X, upper_bound.reshape(1, -1)), axis=0)
X = np.sort(X, axis=0)
y_ecdf = np.concatenate([np.arange(1, X.shape[0] + 1).reshape(-1, 1) / X.shape[0]] * X.shape[1], axis=1)
return np.maximum(
np.concatenate(list(map(lambda xx: np.interp(*xx).reshape(-1, 1), zip(test_statistic.T, X.T, y_ecdf.T))), 1),
1e-8,
)


# Fisher
def fisher_method(p_values: np.ndarray):
"""Combine p-values using Fisher's method
Args:
p_values (np.ndarray): p-values (n,m)
Returns:
np.ndarray (n,): combined p-values
"""
tau = -2 * np.sum(np.log(p_values), axis=1).reshape(-1, 1)
group_p_value = p_value_fn(tau, np.random.chisquare(2 * p_values.shape[1], (1000, 1)))
# or
# group_p_value = stats.chi2.cdf(tau, 2 * p_values.shape[1])
return group_p_value


# Fisher
def fisher_tau_method(p_values: np.ndarray):
"""Combine p-values using Fisher's method
Args:
p_values (np.ndarray): p-values (n,m)
Returns:
np.ndarray (n,): combined p-values
"""
tau = -2 * np.sum(np.log(p_values), axis=1)
return tau


# Stouffer
def stouffer_method(p_values: np.ndarray):
"""Combine p-values using Stouffer's method
Args:
p_values (np.ndarray): p-values (n,m)
Returns:
np.ndarray (n,): combined p-values
"""
z = np.sum(stats.norm.ppf(p_values), axis=1).reshape(-1, 1) / np.sqrt(p_values.shape[1])
group_p_value = p_value_fn(z, np.random.normal(size=(1000, 1)))
# or
# group_p_value = stats.norm.cdf(z)
return group_p_value


# Stouffer
def stouffer_tau_method(p_values: np.ndarray):
"""Combine p-values using Stouffer's method
Args:
p_values (np.ndarray): p-values (n,m)
Returns:
np.ndarray (n,): combined p-values
"""
z = np.sum(stats.norm.ppf(p_values), axis=1) / np.sqrt(p_values.shape[1])
# max that is not inf
max_not_inf = np.max(z[np.isfinite(z)])
min_not_inf = np.min(z[np.isfinite(z)])
# replace inf with max or min
z = np.where(np.isposinf(z), max_not_inf, z)
z = np.where(np.isneginf(z), min_not_inf, z)
return z


def tippet_tau_method(p_values: np.ndarray):
"""Combine p-values using Tippet's method
Args:
p_values (np.ndarray): p-values (n,m)
Returns:
np.ndarray (n,): combined p-values
"""
tau = np.min(p_values, axis=1)
return tau


def wilkinson_tau_method(p_values: np.ndarray):
"""Combine p-values using Wilkinson's method
Args:
p_values (np.ndarray): p-values (n,m)
Returns:
np.ndarray (n,): combined p-values
"""
tau = np.max(p_values, axis=1)
return tau


def edgington_tau_method(p_values: np.ndarray):
"""Combine p-values using Edington's method
Args:
p_values (np.ndarray): p-values (n,m)
Returns:
np.ndarray (n,): combined p-values
"""
tau = np.sum(p_values, axis=1)
return tau


def pearson_tau_method(p_values: np.ndarray):
"""Combine p-values using Pearson's method
Args:
p_values (np.ndarray): p-values (n,m)
Returns:
np.ndarray (n,): combined p-values
"""
tau = 2 * np.sum(np.log(1 - p_values + 1e-6), axis=1)
return tau


def simes_tau_method(p_values: np.ndarray):
"""Combine p-values using Simes' method
Args:
p_values (np.ndarray): p-values (n,m)
Returns:
np.ndarray (n,): combined p-values
"""
tau = np.min(np.sort(p_values, axis=1) / np.arange(1, p_values.shape[1] + 1) * p_values.shape[1], 1)
return tau


def geometric_mean_tau_method(p_values: np.ndarray):
"""Combine p-values using geometric mean
Args:
p_values (np.ndarray): p-values (n,m)
Returns:
np.ndarray (n,): combined p-values
"""
tau = np.prod(p_values, axis=1) ** (1 / p_values.shape[1])
return tau


def get_combine_p_values_fn(method_name: str):
method_name = method_name.lower()
if method_name == "fisher":
return fisher_tau_method
elif method_name == "stouffer":
return stouffer_tau_method
elif method_name == "tippet":
return tippet_tau_method
elif method_name == "wilkinson":
return wilkinson_tau_method
elif method_name == "edgington":
return edgington_tau_method
elif method_name == "pearson":
return pearson_tau_method
elif method_name == "simes":
return simes_tau_method
elif method_name == "geometric_mean":
return geometric_mean_tau_method
else:
raise NotImplementedError(f"method {method_name} not implemented")


ensemble_names = ["fisher", "stouffer", "tippet", "wilkinson", "edgington", "pearson", "simes", "geometric_mean"]
Empty file.

0 comments on commit dac1ec3

Please sign in to comment.