diff --git a/docs/source/detectors.methods.rst b/docs/source/detectors.methods.rst index 7584a3a..01e4c5e 100644 --- a/docs/source/detectors.methods.rst +++ b/docs/source/detectors.methods.rst @@ -260,6 +260,14 @@ detectors.methods.openmax module :undoc-members: :show-inheritance: +detectors.methods.plugin\_bb module +----------------------------------- + +.. automodule:: detectors.methods.plugin_bb + :members: + :undoc-members: + :show-inheritance: + detectors.methods.pnml module ----------------------------- diff --git a/docs/source/detectors.rst b/docs/source/detectors.rst index 4a29e99..3618b6b 100644 --- a/docs/source/detectors.rst +++ b/docs/source/detectors.rst @@ -33,6 +33,14 @@ detectors.criterions module :undoc-members: :show-inheritance: +detectors.ensemble module +------------------------- + +.. automodule:: detectors.ensemble + :members: + :undoc-members: + :show-inheritance: + detectors.eval module --------------------- diff --git a/examples/ood_benchmark.py b/examples/ood_benchmark.py index fd92222..2f87e37 100644 --- a/examples/ood_benchmark.py +++ b/examples/ood_benchmark.py @@ -94,7 +94,7 @@ def main(args): "igeood_logits", "max_logits", "react", - "dice ", + "dice", "vim", "gradnorm", "mahalanobis", diff --git a/examples/sc_benchmark.py b/examples/sc_benchmark.py index e9755e2..3b00ce2 100644 --- a/examples/sc_benchmark.py +++ b/examples/sc_benchmark.py @@ -106,3 +106,110 @@ def main(args): logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO) _logger.info(json.dumps(args.__dict__, indent=2)) main(args) +import argparse +import json +import logging +import os + +import timm +import timm.data +import torch + +import detectors +from detectors.config import RESULTS_DIR + +_logger = logging.getLogger(__name__) + + +def main(args): + print(f"Running {args.pipeline} pipeline on {args.model} model") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # create model + model = timm.create_model(args.model, pretrained=True) + model.to(device) + data_config = timm.data.resolve_data_config(model.default_cfg) + test_transform = timm.data.create_transform(**data_config) + _logger.info("Test transform: %s", test_transform) + # create pipeline + pipeline = detectors.create_pipeline( + args.pipeline, + batch_size=args.batch_size, + seed=args.seed, + transform=test_transform, + limit_fit=args.limit_fit, + num_workers=3, + ) + + if "vit" in args.model and "pooling_op_name" in args.method_kwargs: + args.method_kwargs["pooling_op_name"] = "getitem" + + # create detector + method = detectors.create_detector(args.method, model=model, **args.method_kwargs) + # run pipeline + pipeline_results = pipeline.run(method, model) + # print results + print(pipeline.report(pipeline_results["results"])) + + if not args.debug: + # save results to file + results = { + "model": args.model, + "method": args.method, + **pipeline_results["results"], + "method_kwargs": args.method_kwargs, + } + filename = os.path.join(RESULTS_DIR, args.pipeline, "results.csv") + detectors.utils.append_results_to_csv_file(results, filename) + + scores = pipeline_results["scores"] + labels = pipeline_results["labels"] + idxs = pipeline_results["idxs"] + + results = { + "model": args.model, + "in_dataset_name": pipeline.in_dataset_name, + "method": args.method, + "method_kwargs": args.method_kwargs, + "scores": scores.numpy().tolist(), + "labels": labels.numpy().tolist(), + "idx": idxs.numpy().tolist(), + } + filename = os.path.join(RESULTS_DIR, args.pipeline, "scores.csv") + detectors.utils.append_results_to_csv_file(results, filename) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--method", type=str, default="msp") + parser.add_argument("--method_kwargs", type=json.loads, default={}, help='{"temperature":1000, "eps":0.00014}') + parser.add_argument("--pipeline", type=str, default="sc_benchmark_cifar10") + parser.add_argument("--model", type=str, default="resnet18_cifar10") + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--limit_fit", type=float, default=1) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--debug", action="store_true") + parser.add_argument("--idx", type=int, default=None) + args = parser.parse_args() + methods = [ + "msp", + "energy", + "kl_matching", + "igeood_logits", + "max_logits", + "react", + "dice", + "vim", + "gradnorm", + "mahalanobis", + "relative_mahalanobis", + "knn_euclides", + "maxcosine", + "odin", + "doctor", + ] + if args.idx is not None: + args.method = methods[args.idx] + + logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO) + _logger.info(json.dumps(args.__dict__, indent=2)) + main(args) diff --git a/src/detectors/ensemble.py b/src/detectors/ensemble.py new file mode 100644 index 0000000..7728826 --- /dev/null +++ b/src/detectors/ensemble.py @@ -0,0 +1,196 @@ +import numpy as np +from scipy import stats + + +def p_value_fn(test_statistic: np.ndarray, X: np.ndarray): + """Compute the p-value of a test statistic given a sample X. + + Args: + test_statistic (np.ndarray): test statistic (n,m) + X (np.ndarray): sample (N,m) + + Returns: + np.ndarray: p-values (n,m) + """ + mult_factor_min = np.where(X.min(0) > 0, np.array(1 / len(X)), np.array(len(X))) + mult_factor_max = np.where(X.max(0) > 0, np.array(len(X)), np.array(1 / len(X))) + lower_bound = X.min(0) * mult_factor_min + upper_bound = X.max(0) * mult_factor_max + X = np.concatenate((lower_bound.reshape(1, -1), X, upper_bound.reshape(1, -1)), axis=0) + X = np.sort(X, axis=0) + y_ecdf = np.concatenate([np.arange(1, X.shape[0] + 1).reshape(-1, 1) / X.shape[0]] * X.shape[1], axis=1) + return np.maximum( + np.concatenate(list(map(lambda xx: np.interp(*xx).reshape(-1, 1), zip(test_statistic.T, X.T, y_ecdf.T))), 1), + 1e-8, + ) + + +# Fisher +def fisher_method(p_values: np.ndarray): + """Combine p-values using Fisher's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = -2 * np.sum(np.log(p_values), axis=1).reshape(-1, 1) + group_p_value = p_value_fn(tau, np.random.chisquare(2 * p_values.shape[1], (1000, 1))) + # or + # group_p_value = stats.chi2.cdf(tau, 2 * p_values.shape[1]) + return group_p_value + + +# Fisher +def fisher_tau_method(p_values: np.ndarray): + """Combine p-values using Fisher's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = -2 * np.sum(np.log(p_values), axis=1) + return tau + + +# Stouffer +def stouffer_method(p_values: np.ndarray): + """Combine p-values using Stouffer's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + z = np.sum(stats.norm.ppf(p_values), axis=1).reshape(-1, 1) / np.sqrt(p_values.shape[1]) + group_p_value = p_value_fn(z, np.random.normal(size=(1000, 1))) + # or + # group_p_value = stats.norm.cdf(z) + return group_p_value + + +# Stouffer +def stouffer_tau_method(p_values: np.ndarray): + """Combine p-values using Stouffer's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + z = np.sum(stats.norm.ppf(p_values), axis=1) / np.sqrt(p_values.shape[1]) + # max that is not inf + max_not_inf = np.max(z[np.isfinite(z)]) + min_not_inf = np.min(z[np.isfinite(z)]) + # replace inf with max or min + z = np.where(np.isposinf(z), max_not_inf, z) + z = np.where(np.isneginf(z), min_not_inf, z) + return z + + +def tippet_tau_method(p_values: np.ndarray): + """Combine p-values using Tippet's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = np.min(p_values, axis=1) + return tau + + +def wilkinson_tau_method(p_values: np.ndarray): + """Combine p-values using Wilkinson's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = np.max(p_values, axis=1) + return tau + + +def edgington_tau_method(p_values: np.ndarray): + """Combine p-values using Edington's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = np.sum(p_values, axis=1) + return tau + + +def pearson_tau_method(p_values: np.ndarray): + """Combine p-values using Pearson's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = 2 * np.sum(np.log(1 - p_values + 1e-6), axis=1) + return tau + + +def simes_tau_method(p_values: np.ndarray): + """Combine p-values using Simes' method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = np.min(np.sort(p_values, axis=1) / np.arange(1, p_values.shape[1] + 1) * p_values.shape[1], 1) + return tau + + +def geometric_mean_tau_method(p_values: np.ndarray): + """Combine p-values using geometric mean + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = np.prod(p_values, axis=1) ** (1 / p_values.shape[1]) + return tau + + +def get_combine_p_values_fn(method_name: str): + method_name = method_name.lower() + if method_name == "fisher": + return fisher_tau_method + elif method_name == "stouffer": + return stouffer_tau_method + elif method_name == "tippet": + return tippet_tau_method + elif method_name == "wilkinson": + return wilkinson_tau_method + elif method_name == "edgington": + return edgington_tau_method + elif method_name == "pearson": + return pearson_tau_method + elif method_name == "simes": + return simes_tau_method + elif method_name == "geometric_mean": + return geometric_mean_tau_method + else: + raise NotImplementedError(f"method {method_name} not implemented") + + +ensemble_names = ["fisher", "stouffer", "tippet", "wilkinson", "edgington", "pearson", "simes", "geometric_mean"] diff --git a/src/detectors/pipelines/misclassif.py b/src/detectors/pipelines/misclassif.py new file mode 100644 index 0000000..e69de29