From 1368042fa8737f4d29605f25ae1f69d44d4acb76 Mon Sep 17 00:00:00 2001 From: Eduardo Dadalto Date: Wed, 19 Jun 2024 23:00:37 +0200 Subject: [PATCH] Dev (#20) * download imagenet models * Improvements to pipelines * Code + docs cleaning * Combine and conquer score aggregation --- Makefile | 2 +- docs/source/detectors.aggregations.rst | 9 + docs/source/detectors.pipelines.rst | 8 - examples/compute_accuracy.py | 110 +++++++++ examples/list_resources.py | 1 + examples/sc_benchmark.py | 108 +++++++++ scripts/compute_accuracy_ssl.py | 95 ++++++++ scripts/download_models.py | 20 ++ scripts/parse_arxiv.py | 1 + scripts/push_model_to_hf_hub.py | 1 + setup.py | 2 +- .../aggregations/combine_and_conquer.py | 227 ++++++++++++++++++ src/detectors/config.py | 1 + src/detectors/data/300k_random_images.py | 4 + src/detectors/data/__init__.py | 4 +- src/detectors/data/imagenetlt.py | 1 + src/detectors/eval.py | 1 + src/detectors/methods/__init__.py | 1 + src/detectors/methods/ash.py | 0 src/detectors/methods/igeood_features.py | 137 +++++++++++ src/detectors/methods/masf.py | 0 src/detectors/methods/pnml.py | 0 src/detectors/methods/residual.py | 0 src/detectors/methods/she.py | 0 src/detectors/methods/templates.py | 1 + src/detectors/models/densenet.py | 1 + src/detectors/models/dino.py | 1 + src/detectors/models/resnet.py | 1 + src/detectors/models/vgg.py | 1 + src/detectors/models/vit.py | 1 + src/detectors/pipelines/__init__.py | 1 + src/detectors/pipelines/base.py | 1 + src/detectors/pipelines/covariate_drift.py | 107 +++++++-- src/detectors/pipelines/ood.py | 42 +--- src/detectors/pipelines/sc.py | 7 +- src/detectors/pipelines/scod.py | 24 ++ src/detectors/scod_eval.py | 7 - src/detectors/trainer.py | 1 + src/detectors/utils.py | 10 +- 39 files changed, 856 insertions(+), 83 deletions(-) create mode 100644 examples/compute_accuracy.py create mode 100644 scripts/compute_accuracy_ssl.py create mode 100644 scripts/download_models.py create mode 100644 src/detectors/aggregations/combine_and_conquer.py create mode 100644 src/detectors/data/300k_random_images.py create mode 100644 src/detectors/methods/ash.py create mode 100644 src/detectors/methods/igeood_features.py create mode 100644 src/detectors/methods/masf.py create mode 100644 src/detectors/methods/pnml.py create mode 100644 src/detectors/methods/residual.py create mode 100644 src/detectors/methods/she.py diff --git a/Makefile b/Makefile index e16a823..d130616 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ doc: cd docs && sphinx-apidoc -o ./source -f ../src/detectors -d 4 && make clean && make html test: - pytest -vv tests/models.py tests/methods.py tests/docstrings.py -s --cov \ + pytest -vv tests/methods.py -s --cov \ --cov-config=.coveragerc \ --cov-report xml \ --cov-report term-missing:skip-covered diff --git a/docs/source/detectors.aggregations.rst b/docs/source/detectors.aggregations.rst index 811199d..1ea272b 100644 --- a/docs/source/detectors.aggregations.rst +++ b/docs/source/detectors.aggregations.rst @@ -20,6 +20,15 @@ detectors.aggregations.basics module :undoc-members: :show-inheritance: + +detectors.aggregations.combine\_and\_conquer module +--------------------------------------------------- + +.. automodule:: detectors.aggregations.combine_and_conquer + :members: + :undoc-members: + :show-inheritance: + detectors.aggregations.cosine module ------------------------------------ diff --git a/docs/source/detectors.pipelines.rst b/docs/source/detectors.pipelines.rst index 9bba854..fec5ba2 100644 --- a/docs/source/detectors.pipelines.rst +++ b/docs/source/detectors.pipelines.rst @@ -28,14 +28,6 @@ detectors.pipelines.drift module :undoc-members: :show-inheritance: -detectors.pipelines.misclassif module -------------------------------------- - -.. automodule:: detectors.pipelines.misclassif - :members: - :undoc-members: - :show-inheritance: - detectors.pipelines.ood module ------------------------------ diff --git a/examples/compute_accuracy.py b/examples/compute_accuracy.py new file mode 100644 index 0000000..ab01cef --- /dev/null +++ b/examples/compute_accuracy.py @@ -0,0 +1,110 @@ +import argparse +import json +import logging +import os + +import timm +import timm.data +import torch +import torch.utils.data +from tqdm import tqdm + +import detectors +from detectors.config import RESULTS_DIR + +_logger = logging.getLogger(__name__) + + +def topk_accuracy(preds, labels, k=5): + topk = torch.topk(preds, k=k, dim=1) + topk_preds = topk.indices + topk_labels = labels.unsqueeze(1).expand_as(topk_preds) + return (topk_preds == topk_labels).any(dim=1).float().mean().item() + + +def main(args): + torch.manual_seed(42) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device == "cpu": + # try mps + device = "mps" + # create model + model = timm.create_model(args.model, pretrained=True) + model.to(device) + print(model.default_cfg) + data_config = timm.data.resolve_data_config(model.default_cfg) + test_transform = timm.data.create_transform(**data_config) + data_config["is_training"] = True + train_transform = timm.data.create_transform(**data_config, color_jitter=None) + + _logger.info("Test transform: %s", test_transform) + _logger.info("Train transform: %s", train_transform) + + dataset = detectors.create_dataset(args.dataset, split=args.split, transform=test_transform, download=True) + + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True + ) + model.eval() + x = torch.randn(1, 3, 224, 224) + x = x.to(device) + with torch.no_grad(): + y = model(x) + + num_classes = y.shape[1] + if args.dataset == "imagenet_r": + mask = dataset.imagenet_r_mask + else: + mask = range(num_classes) + + all_preds = torch.empty((len(dataset), num_classes), dtype=torch.float32) + all_labels = torch.empty(len(dataset), dtype=torch.long) + _logger.info(f"Shapes: {all_preds.shape}, {all_labels.shape}") + for i, batch in enumerate(tqdm(dataloader, total=len(dataloader))): + inputs, labels = batch + inputs = inputs.to(device) + # print(labels) + with torch.no_grad(): + outputs = model(inputs) + # print(outputs) + outputs = torch.softmax(outputs, dim=1) + all_preds[i * args.batch_size : (i + 1) * args.batch_size] = outputs.cpu() + all_labels[i * args.batch_size : (i + 1) * args.batch_size] = labels.cpu() + if args.debug: + _logger.info("Labels: %s", labels) + _logger.info("Predictions: %s", outputs.argmax(1)) + break + + top1 = topk_accuracy(all_preds[:, mask], all_labels, k=1) * 100 + top5 = topk_accuracy(all_preds[:, mask], all_labels, k=5) * 100 + _logger.info(torch.sum(torch.argmax(all_preds, dim=1) == all_labels) / len(all_labels)) + _logger.info(f"Top-1 accuracy: {top1:.4f}") + _logger.info(f"Top-5 accuracy: {top5:.4f}") + + if not args.debug: + # save results to file + results = { + "model": args.model, + "dataset": args.dataset, + "split": args.split, + "top1_acc": top1, + "top5_acc": top5, + } + filename = os.path.join(RESULTS_DIR, "accuracy", "results.csv") + detectors.utils.append_results_to_csv_file(results, filename) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="resnet50.tv_in1k") + parser.add_argument("--dataset", type=str, default="imagenet1k") + parser.add_argument("--split", type=str, default="val") + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_workers", type=int, default=3) + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + _logger.info(json.dumps(args.__dict__, indent=2)) + + main(args) diff --git a/examples/list_resources.py b/examples/list_resources.py index 3fe26f6..771fa21 100644 --- a/examples/list_resources.py +++ b/examples/list_resources.py @@ -1,4 +1,5 @@ """Example on how to list all resources available in `detectors` package.""" + import detectors if __name__ == "__main__": diff --git a/examples/sc_benchmark.py b/examples/sc_benchmark.py index d25331e..3b00ce2 100644 --- a/examples/sc_benchmark.py +++ b/examples/sc_benchmark.py @@ -13,6 +13,114 @@ _logger = logging.getLogger(__name__) +def main(args): + print(f"Running {args.pipeline} pipeline on {args.model} model") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # create model + model = timm.create_model(args.model, pretrained=True) + model.to(device) + data_config = timm.data.resolve_data_config(model.default_cfg) + test_transform = timm.data.create_transform(**data_config) + _logger.info("Test transform: %s", test_transform) + # create pipeline + pipeline = detectors.create_pipeline( + args.pipeline, + batch_size=args.batch_size, + seed=args.seed, + transform=test_transform, + limit_fit=args.limit_fit, + num_workers=3, + ) + + if "vit" in args.model and "pooling_op_name" in args.method_kwargs: + args.method_kwargs["pooling_op_name"] = "getitem" + + # create detector + method = detectors.create_detector(args.method, model=model, **args.method_kwargs) + # run pipeline + pipeline_results = pipeline.run(method, model) + # print results + print(pipeline.report(pipeline_results["results"])) + + if not args.debug: + # save results to file + results = { + "model": args.model, + "method": args.method, + **pipeline_results["results"], + "method_kwargs": args.method_kwargs, + } + filename = os.path.join(RESULTS_DIR, args.pipeline, "results.csv") + detectors.utils.append_results_to_csv_file(results, filename) + + scores = pipeline_results["scores"] + labels = pipeline_results["labels"] + idxs = pipeline_results["idxs"] + + results = { + "model": args.model, + "in_dataset_name": pipeline.in_dataset_name, + "method": args.method, + "method_kwargs": args.method_kwargs, + "scores": scores.numpy().tolist(), + "labels": labels.numpy().tolist(), + "idx": idxs.numpy().tolist(), + } + filename = os.path.join(RESULTS_DIR, args.pipeline, "scores.csv") + detectors.utils.append_results_to_csv_file(results, filename) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--method", type=str, default="msp") + parser.add_argument("--method_kwargs", type=json.loads, default={}, help='{"temperature":1000, "eps":0.00014}') + parser.add_argument("--pipeline", type=str, default="sc_benchmark_cifar10") + parser.add_argument("--model", type=str, default="resnet18_cifar10") + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--limit_fit", type=float, default=1) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--debug", action="store_true") + parser.add_argument("--idx", type=int, default=None) + args = parser.parse_args() + + methods = [ + "msp", + "energy", + "kl_matching", + "igeood_logits", + "max_logits", + "react", + "dice", + "vim", + "gradnorm", + "mahalanobis", + "relative_mahalanobis", + "knn_euclides", + "maxcosine", + "odin", + "doctor", + ] + if args.idx is not None: + args.method = methods[args.idx] + + logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO) + _logger.info(json.dumps(args.__dict__, indent=2)) + main(args) +import argparse +import json +import logging +import os + +import timm +import timm.data +import torch + +import detectors +from detectors.config import RESULTS_DIR + +_logger = logging.getLogger(__name__) + + def main(args): print(f"Running {args.pipeline} pipeline on {args.model} model") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/scripts/compute_accuracy_ssl.py b/scripts/compute_accuracy_ssl.py new file mode 100644 index 0000000..8336ce3 --- /dev/null +++ b/scripts/compute_accuracy_ssl.py @@ -0,0 +1,95 @@ +import argparse +import json +import logging +import time + +import accelerate +import numpy as np +import timm +import torch +import torch.utils.data +from sklearn.neighbors import KNeighborsClassifier +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform +from tqdm import tqdm + +import detectors + +_logger = logging.getLogger(__name__) + + +@torch.no_grad() +def main(args): + if "supcon" in args.model or "simclr" in args.model: + args.ssl = True + accelerator = accelerate.Accelerator() + + model = timm.create_model(args.model, pretrained=True) + data_config = resolve_data_config(model.default_cfg) + transform = create_transform(**data_config) + _logger.info(transform) + + model.eval() + model = accelerator.prepare(model) + + dataset = detectors.create_dataset(args.dataset, split=args.split, transform=transform) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=accelerator.num_processes + ) + dataloader = accelerator.prepare(dataloader) + + inference_time = [] + all_outputs = [] + all_labels = [] + start_time = time.time() + progress_bar = tqdm(dataloader, desc="Inference", disable=not accelerator.is_local_main_process) + for x, labels in dataloader: + t1 = time.time() + outputs = model(x) + t2 = time.time() + + outputs, labels = accelerator.gather_for_metrics((outputs, labels)) + all_outputs.append(outputs.cpu()) + all_labels.append(labels.cpu()) + inference_time.append(t2 - t1) + progress_bar.update() + + progress_bar.close() + accelerator.wait_for_everyone() + + all_outputs = torch.cat(all_outputs, dim=0) + all_labels = torch.cat(all_labels, dim=0) + + if not args.ssl: + _, preds = torch.max(all_outputs, 1) + else: + features = all_outputs.cpu().numpy() + all_labels = all_labels.cpu().numpy() + estimator = KNeighborsClassifier(20, metric="cosine").fit(features, all_labels) + preds = estimator.predict(features) + preds = torch.from_numpy(preds) + all_labels = torch.from_numpy(all_labels) + + acc = torch.mean((preds.cpu() == all_labels.cpu()).float()).item() + + print(f"Total time: {time.time() - start_time:.2f} seconds") + print(f"Accuracy: {acc}") + print(f"Average inference time: {np.mean(inference_time)}") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser() + + parser.add_argument("--model", type=str, default="densenet121") + parser.add_argument("--dataset", type=str, default="imagenet1k") + parser.add_argument("--split", type=str, default="val") + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--ssl", action="store_true") + + args = parser.parse_args() + + _logger.info(json.dumps(args.__dict__, indent=2)) + + main(args) diff --git a/scripts/download_models.py b/scripts/download_models.py new file mode 100644 index 0000000..84d6a20 --- /dev/null +++ b/scripts/download_models.py @@ -0,0 +1,20 @@ +import timm + +models = [ + "resnet34.tv_in1k", + "resnet50.tv_in1k", + "resnet101.tv_in1k", + "vit_tiny_patch16_224.augreg_in21k_ft_in1k", + "vit_small_patch16_224.augreg_in21k_ft_in1k", + "vit_base_patch16_224.augreg_in21k_ft_in1k", + "vit_large_patch16_224.augreg_in21k_ft_in1k", + "densenet121.tv_in1k", + "vgg16.tv_in1k", + "mobilenetv3_small_100.lamb_in1k", + "mobilenetv3_large_100.ra_in1k", + "mobilenetv3_large_100.miil_in21k_ft_in1k", +] + +for model_name in models: + model = timm.create_model(model_name, pretrained=True) + print(f"Downloaded {model_name}") diff --git a/scripts/parse_arxiv.py b/scripts/parse_arxiv.py index 4fd9a71..0d4e920 100644 --- a/scripts/parse_arxiv.py +++ b/scripts/parse_arxiv.py @@ -1,6 +1,7 @@ """Requirements - feedparser installed: pip install feedparser """ + import argparse import json import logging diff --git a/scripts/push_model_to_hf_hub.py b/scripts/push_model_to_hf_hub.py index 3ef4448..dc6725c 100644 --- a/scripts/push_model_to_hf_hub.py +++ b/scripts/push_model_to_hf_hub.py @@ -4,6 +4,7 @@ - Jinja2 installed - Git LFS installed """ + import argparse import json import logging diff --git a/setup.py b/setup.py index a41cd4c..823670d 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ description="Detectors: a python package to benchmark generalized out-of-distribution detection methods.", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", - keywords="vision deep learning pytorch OOD", + keywords="computer vision, deep learning, pytorch, out-of-distribution detection, OOD", license="APACHE 2.0", url="https://github.com/edadaltocg/detectors", package_dir={"": "src"}, diff --git a/src/detectors/aggregations/combine_and_conquer.py b/src/detectors/aggregations/combine_and_conquer.py new file mode 100644 index 0000000..17172be --- /dev/null +++ b/src/detectors/aggregations/combine_and_conquer.py @@ -0,0 +1,227 @@ +import numpy as np +from scipy import stats + + +def p_value_fn(test_statistic: np.ndarray, X: np.ndarray, w=None): + """Compute the p-value of a test statistic given a sample X. + + Args: + test_statistic (np.ndarray): test statistic (n,m) + X (np.ndarray): sample (N,m) + + Returns: + np.ndarray: p-values (n,m) + """ + if len(X.shape) == 1: + X = X.reshape(-1, 1) + if len(test_statistic.shape) == 1: + test_statistic = test_statistic.reshape(-1, 1) + mult_factor_min = np.where(X.min(0) > 0, np.array(1 / len(X)), np.array(len(X))) + mult_factor_max = np.where(X.max(0) > 0, np.array(len(X)), np.array(1 / len(X))) + lower_bound = X.min(0) * mult_factor_min + upper_bound = X.max(0) * mult_factor_max + X = np.concatenate((lower_bound.reshape(1, -1), X, upper_bound.reshape(1, -1)), axis=0) + X = np.sort(X, axis=0) + y_ecdf = np.concatenate([np.arange(1, X.shape[0] + 1).reshape(-1, 1) / X.shape[0]] * X.shape[1], axis=1) + if w is not None: + y_ecdf = y_ecdf * w.reshape(1, -1) + return np.concatenate( + list( + map( + lambda xx: np.interp(*xx).reshape(-1, 1), + zip(test_statistic.T, X.T, y_ecdf.T), + ) + ), + 1, + ) + + +def fisher_method(p_values: np.ndarray): + """Combine p-values using Fisher's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = -2 * np.sum(np.log(p_values), axis=1).reshape(-1, 1) + group_p_value = p_value_fn(tau, np.random.chisquare(2 * p_values.shape[1], (1000, 1))) + # or + # group_p_value = stats.chi2.cdf(tau, 2 * p_values.shape[1]) + return group_p_value + + +def fisher_tau_method(p_values: np.ndarray): + """Combine p-values using Fisher's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = -2 * np.sum(np.log(p_values), axis=1) + return tau + + +def stouffer_method(p_values: np.ndarray): + """Combine p-values using Stouffer's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + z = np.sum(stats.norm.ppf(p_values), axis=1).reshape(-1, 1) / np.sqrt(p_values.shape[1]) + group_p_value = p_value_fn(z, np.random.normal(size=(1000, 1))) + # or + # group_p_value = stats.norm.cdf(z) + return group_p_value + + +def stouffer_tau_method(p_values: np.ndarray): + """Combine p-values using Stouffer's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + z = np.sum(stats.norm.ppf(p_values), axis=1) / np.sqrt(p_values.shape[1]) + # max that is not inf + max_not_inf = np.max(z[np.isfinite(z)]) + min_not_inf = np.min(z[np.isfinite(z)]) + # replace inf with max or min + z = np.where(np.isposinf(z), max_not_inf, z) + z = np.where(np.isneginf(z), min_not_inf, z) + return z + + +def tippet_tau_method(p_values: np.ndarray): + """Combine p-values using Tippet's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = np.min(p_values, axis=1) + return tau + + +def wilkinson_tau_method(p_values: np.ndarray): + """Combine p-values using Wilkinson's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = np.max(p_values, axis=1) + return tau + + +def edgington_tau_method(p_values: np.ndarray): + """Combine p-values using Edington's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = np.sum(p_values, axis=1) + return tau + + +def pearson_tau_method(p_values: np.ndarray): + """Combine p-values using Pearson's method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = 2 * np.sum(np.log(1 - p_values + 1e-6), axis=1) + return tau + + +def simes_tau_method(p_values: np.ndarray): + """Combine p-values using Simes' method + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = np.min( + np.sort(p_values, axis=1) / np.arange(1, p_values.shape[1] + 1) * p_values.shape[1], + 1, + ) + return tau + + +def geometric_mean_tau_method(p_values: np.ndarray): + """Combine p-values using geometric mean + + Args: + p_values (np.ndarray): p-values (n,m) + + Returns: + np.ndarray (n,): combined p-values + """ + tau = np.prod(p_values, axis=1) ** (1 / p_values.shape[1]) + return tau + + +def rho(p_values): + k = p_values.shape[1] + phi = stats.norm.ppf(p_values) + return 1 - (1 / (k - 1)) * np.sum((phi - np.mean(phi, axis=1, keepdims=True)) ** 2, axis=1) + + +def hartung(p_values, r): + k = p_values.shape[1] + t = stats.norm.ppf(p_values) + return np.sum(t, axis=1) / np.sqrt((1 - r) * k + r * k**2) + + +def get_combine_p_values_fn(method_name: str): + method_name = method_name.lower() + if method_name == "fisher": + return fisher_tau_method + elif method_name == "stouffer": + return stouffer_tau_method + elif method_name == "tippet": + return tippet_tau_method + elif method_name == "wilkinson": + return wilkinson_tau_method + elif method_name == "edgington": + return edgington_tau_method + elif method_name == "pearson": + return pearson_tau_method + elif method_name == "simes": + return simes_tau_method + elif method_name == "geometric_mean": + return geometric_mean_tau_method + else: + raise NotImplementedError(f"method {method_name} not implemented") + + +ensemble_names = [ + "fisher", + "stouffer", + "tippet", + "wilkinson", + "edgington", + "pearson", + "simes", + "geometric_mean", +] diff --git a/src/detectors/config.py b/src/detectors/config.py index bc2e4fb..823df86 100644 --- a/src/detectors/config.py +++ b/src/detectors/config.py @@ -10,6 +10,7 @@ - `CHECKPOINTS_DIR`: The directory where the checkpoints are stored. - `RESULTS_DIR`: The directory where the results are stored. """ + import os HOME = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) diff --git a/src/detectors/data/300k_random_images.py b/src/detectors/data/300k_random_images.py new file mode 100644 index 0000000..8e7733e --- /dev/null +++ b/src/detectors/data/300k_random_images.py @@ -0,0 +1,4 @@ +""" +https://people.eecs.berkeley.edu/~hendrycks/300K_random_images.npy +https://github.com/hendrycks/outlier-exposure +""" diff --git a/src/detectors/data/__init__.py b/src/detectors/data/__init__.py index db69ba9..a8c9772 100644 --- a/src/detectors/data/__init__.py +++ b/src/detectors/data/__init__.py @@ -1,6 +1,7 @@ """ Datasets module. """ + import logging import os from enum import Enum @@ -89,6 +90,7 @@ "cifar10_c": CIFAR10_C, "cifar100_c": CIFAR100_C, "imagenet": ImageNet, + "imagenet_train": ImageNet, "imagenet1k": ImageNet, "ilsvrc2012": ImageNet, "imagenet_c": ImageNetCnpz, @@ -173,7 +175,7 @@ def create_dataset( Dataset: Dataset object. """ try: - if dataset_name in ["imagenet", "imagenet1k", "ilsvrc2012", "imagenet1k_lt", "imagenet_lt"]: + if dataset_name in ["imagenet", "imagenet1k", "ilsvrc2012", "imagenet1k_lt", "imagenet_lt", "imagenet_train"]: return datasets_registry[dataset_name](root=IMAGENET_ROOT, split=split, transform=transform, **kwargs) return datasets_registry[dataset_name](root=root, split=split, transform=transform, download=download, **kwargs) except KeyError as e: diff --git a/src/detectors/data/imagenetlt.py b/src/detectors/data/imagenetlt.py index 300552c..25d67d4 100644 --- a/src/detectors/data/imagenetlt.py +++ b/src/detectors/data/imagenetlt.py @@ -1,4 +1,5 @@ """From https://github.com/KaihuaTang/Long-Tailed-Recognition.pytorch/blob/master/classification/data/dataloader.py""" + import os import numpy as np diff --git a/src/detectors/eval.py b/src/detectors/eval.py index 70cf6bb..42333cc 100644 --- a/src/detectors/eval.py +++ b/src/detectors/eval.py @@ -1,6 +1,7 @@ """ Module containing evaluation metrics. """ + from typing import Dict, Union import numpy as np diff --git a/src/detectors/methods/__init__.py b/src/detectors/methods/__init__.py index b22f41a..0c80b49 100644 --- a/src/detectors/methods/__init__.py +++ b/src/detectors/methods/__init__.py @@ -1,6 +1,7 @@ """ Detection methods. """ + import logging import types from enum import Enum diff --git a/src/detectors/methods/ash.py b/src/detectors/methods/ash.py new file mode 100644 index 0000000..e69de29 diff --git a/src/detectors/methods/igeood_features.py b/src/detectors/methods/igeood_features.py new file mode 100644 index 0000000..f7d07b7 --- /dev/null +++ b/src/detectors/methods/igeood_features.py @@ -0,0 +1,137 @@ +from typing import Callable, List, Optional + +import torch +from torch import Tensor, nn + +from detectors.methods.templates import DetectorWithFeatureExtraction + + +def fr_distance_univariate_gaussian( + mu_1: torch.Tensor, sig_1: torch.Tensor, mu_2: torch.Tensor, sig_2: torch.Tensor +) -> torch.Tensor: + """Calculates the Fisher-Rao distance between univariate gaussian distributions in prallel. + + Args: + mu_1 (torch.Tensor): Tensor of dimension (N,*) containing the means of N different univariate gaussians + sig_1 (torch.Tensor): Standard deviations of univariate gaussian distributions + mu_2 (torch.Tensor): Means of the second univariate gaussian distributions + sig_2 (torch.Tensor): Standard deviation of the second univariate gaussian distributions + Returns: + torch.Tensor: Distance tensor of size (N,*) + """ + dim = len(mu_1.shape) + mu_1, mu_2 = mu_1.reshape(*mu_1.shape, 1), mu_2.reshape(*mu_2.shape, 1) + sig_1, sig_2 = sig_1.reshape(*sig_1.shape, 1), sig_2.reshape(*sig_2.shape, 1) + + sqrt_2 = torch.sqrt(torch.tensor(2.0, device=mu_1.device)) + a = torch.norm( + torch.cat((mu_1 / sqrt_2, sig_1), dim=dim) - torch.cat((mu_2 / sqrt_2, -1 * sig_2), dim=dim), p=2, dim=dim + ) + b = torch.norm( + torch.cat((mu_1 / sqrt_2, sig_1), dim=dim) - torch.cat((mu_2 / sqrt_2, sig_2), dim=dim), p=2, dim=dim + ) + + num = a + b + 1e-12 + den = a - b + 1e-12 + return sqrt_2 * torch.log(num / den) + + +def fr_distance_multivariate_gaussian( + x: torch.Tensor, y: torch.Tensor, cov_x: torch.Tensor, cov_y: torch.Tensor +) -> torch.Tensor: + num_examples_x = x.shape[0] + num_examples_y = y.shape[0] + # Replicate std dev. matrix to match the batch size + sig_x = torch.vstack([torch.sqrt(torch.diag(cov_x)).reshape(1, -1)] * num_examples_x) + sig_y = torch.vstack([torch.sqrt(torch.diag(cov_y)).reshape(1, -1)] * num_examples_y) + return torch.sqrt(torch.sum(fr_distance_univariate_gaussian(x, sig_x, y, sig_y) ** 2, dim=1)).reshape(-1, 1) + + +def _igeood_layer_score(x, mus, cov_x, cov_mus): + if type(mus) == dict: + mus = torch.vstack([mu.reshape(1, -1) for mu in mus.values()]) + else: + mus = [mus.reshape(1, -1)] + + stack = torch.hstack( + [fr_distance_multivariate_gaussian(x, mu.reshape(1, -1), cov_x, cov_mus).reshape(-1, 1) for mu in mus] + ) + return stack + + +def igeood_layer_score_min(x, mus, cov_x, cov_mus): + stack = _igeood_layer_score(x, mus, cov_x, cov_mus) + return torch.min(stack, dim=1)[0] + + +def class_cond_mus_cov_matrix(x: Tensor, targets: Tensor, device=torch.device("cpu")): + class_cond_mean = {} + centered_data_per_class = {} + unique_classes = sorted(torch.unique(targets.detach().cpu()).numpy().tolist()) + for c in unique_classes: + filt = targets == c + temp = x[filt].to(device) + class_cond_mean[c] = temp.mean(0, keepdim=True) + centered_data_per_class[c] = temp - class_cond_mean[c] + + centered_data_per_class = torch.vstack(list(centered_data_per_class.values())) + mus = torch.vstack(list(class_cond_mean.values())) + + cov_mat = torch.matmul(centered_data_per_class.T, centered_data_per_class) / centered_data_per_class.shape[0] + cov_mat = torch.diag(torch.diag(cov_mat)) + return mus, cov_mat + + +HYPERPARAMETERS = dict() + + +class IgeoodFeatures(DetectorWithFeatureExtraction): + """Igeood OOD detector. + + Args: + model (nn.Module): Model to be used to extract features + features_nodes (Optional[List[str]]): List of strings that represent the feature nodes. + Defaults to None. + all_blocks (bool, optional): If True, use all blocks of the model. Defaults to False. + last_layer (bool, optional): If True, use also the last layer of the model. Defaults to False. + pooling_op_name (str, optional): Pooling operation to be applied to the features. + Can be one of `max`, `avg`, `flatten`, `getitem`, `avg_or_getitem`, `max_or_getitem`, `none`. Defaults to `avg`. + aggregation_method_name (str, optional): Aggregation method to be applied to the features. Defaults to None. + mu_cov_est_fn (Callable, optional): Function to estimate the mean and covariance matrix of the features. + + References: + [1] https://arxiv.org/abs/2203.07798 + """ + + def __init__( + self, + model: nn.Module, + features_nodes: Optional[List[str]] = None, + all_blocks: bool = False, + last_layer: bool = False, + pooling_op_name: str = "avg_or_getitem", + aggregation_method_name: Optional[str] = "mean", + mu_cov_est_fn: Callable = class_cond_mus_cov_matrix, + **kwargs, + ): + super().__init__( + model, features_nodes, all_blocks, last_layer, pooling_op_name, aggregation_method_name, **kwargs + ) + self.mu_cov_est_fn = mu_cov_est_fn + + def _layer_score(self, x: Tensor, layer_name: Optional[str] = None, index: Optional[int] = None): + return igeood_layer_score_min( + x, + self.mus[layer_name].to(x.device), + self.cov_mats[layer_name].to(x.device), + self.cov_mats[layer_name].to(x.device), + ) + + def _fit_params(self) -> None: + self.mus = {} + self.cov_mats = {} + device = next(self.model.parameters()).device + for layer_name, layer_features in self.train_features.items(): + self.mus[layer_name], self.cov_mats[layer_name] = self.mu_cov_est_fn( + layer_features, self.train_targets, device=device + ) diff --git a/src/detectors/methods/masf.py b/src/detectors/methods/masf.py new file mode 100644 index 0000000..e69de29 diff --git a/src/detectors/methods/pnml.py b/src/detectors/methods/pnml.py new file mode 100644 index 0000000..e69de29 diff --git a/src/detectors/methods/residual.py b/src/detectors/methods/residual.py new file mode 100644 index 0000000..e69de29 diff --git a/src/detectors/methods/she.py b/src/detectors/methods/she.py new file mode 100644 index 0000000..e69de29 diff --git a/src/detectors/methods/templates.py b/src/detectors/methods/templates.py index 93acfcb..f598bfa 100644 --- a/src/detectors/methods/templates.py +++ b/src/detectors/methods/templates.py @@ -1,6 +1,7 @@ """ Generalized detection methods templates. """ + import logging from abc import ABC, abstractmethod from typing import Dict, List, Optional diff --git a/src/detectors/models/densenet.py b/src/detectors/models/densenet.py index 1539bf0..2968e21 100644 --- a/src/detectors/models/densenet.py +++ b/src/detectors/models/densenet.py @@ -1,4 +1,5 @@ """Densenet models for CIFAR10, CIFAR100 and SVHN datasets.""" + import timm import timm.models import torch diff --git a/src/detectors/models/dino.py b/src/detectors/models/dino.py index 6efec5d..06ace8c 100644 --- a/src/detectors/models/dino.py +++ b/src/detectors/models/dino.py @@ -3,6 +3,7 @@ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py https://github.com/facebookresearch/dino/ """ + import math from functools import partial diff --git a/src/detectors/models/resnet.py b/src/detectors/models/resnet.py index 80234a9..2ef3e5c 100644 --- a/src/detectors/models/resnet.py +++ b/src/detectors/models/resnet.py @@ -1,4 +1,5 @@ """ResNet models for CIFAR10, CIFAR100, and SVHN datasets.""" + import logging import timm diff --git a/src/detectors/models/vgg.py b/src/detectors/models/vgg.py index 6410364..a47621c 100644 --- a/src/detectors/models/vgg.py +++ b/src/detectors/models/vgg.py @@ -1,4 +1,5 @@ """VGG models for CIFAR10, CIFAR100 and SVHN datasets.""" + import timm import timm.models import torch diff --git a/src/detectors/models/vit.py b/src/detectors/models/vit.py index d8170f7..024dffd 100644 --- a/src/detectors/models/vit.py +++ b/src/detectors/models/vit.py @@ -1,4 +1,5 @@ """Finetuned ViT models for CIFAR10, CIFAR100, and SVHN datasets.""" + import timm import timm.models import torch diff --git a/src/detectors/pipelines/__init__.py b/src/detectors/pipelines/__init__.py index 6192a93..16a3be8 100644 --- a/src/detectors/pipelines/__init__.py +++ b/src/detectors/pipelines/__init__.py @@ -1,6 +1,7 @@ """ Pipeline module. """ + from enum import Enum from typing import Any, List, Optional, Tuple diff --git a/src/detectors/pipelines/base.py b/src/detectors/pipelines/base.py index 1ba4eee..713c616 100644 --- a/src/detectors/pipelines/base.py +++ b/src/detectors/pipelines/base.py @@ -1,4 +1,5 @@ """Base abstract pipeline class.""" + import logging from typing import Any, Dict diff --git a/src/detectors/pipelines/covariate_drift.py b/src/detectors/pipelines/covariate_drift.py index a7456bd..4012a72 100644 --- a/src/detectors/pipelines/covariate_drift.py +++ b/src/detectors/pipelines/covariate_drift.py @@ -78,7 +78,12 @@ def __init__( _indices = indices[torch.arange(self.splits[i + j + 1].item(), self.splits[i + j + 2].item())] out_datasets[corruption].append( torch.utils.data.Subset( - create_dataset(dataset_name + "_c", split=corruption, intensity=intensity, transform=transform), + create_dataset( + dataset_name + "_c", + split=corruption, + intensity=intensity, + transform=transform, + ), _indices, ) ) @@ -111,7 +116,7 @@ def __init__( self.setup() - def setup(self): + def setup(self, *args): test_dataset = torch.utils.data.ConcatDataset([self.in_dataset, self.out_dataset]) test_labels = torch.utils.data.TensorDataset( torch.cat([torch.zeros(len(self.in_dataset))] + [torch.ones(len(self.out_dataset))]).long() # type: ignore @@ -149,7 +154,9 @@ def preprocess(self, method: DetectorWrapper) -> DetectorWrapper: method.detector.model = self.accelerator.prepare(method.detector.model) progress_bar = tqdm( - range(len(self.fit_dataloader)), desc="Fitting", disable=not self.accelerator.is_local_main_process + range(len(self.fit_dataloader)), + desc="Fitting", + disable=not self.accelerator.is_local_main_process, ) method.start() for x, y in self.fit_dataloader: @@ -169,7 +176,9 @@ def run(self, method, model, **kwargs): test_preds = torch.empty(len(self.test_dataset), dtype=torch.long) idx = 0 progress_bar = tqdm( - range(len(self.test_dataloader)), desc="Inference", disable=not self.accelerator.is_local_main_process + range(len(self.test_dataloader)), + desc="Inference", + disable=not self.accelerator.is_local_main_process, ) for x, y, labels in self.test_dataloader: scores = method(x) @@ -326,12 +335,37 @@ def report(results: Dict[str, Any], subsample=10, warmup_size=2000, **kwargs): linewidth=3, label="begin test set", ) - ax2.plot(drift_labels.numpy(), alpha=0.5, c=mpl_colors[2], linestyle="--", label="drift", linewidth=3) - ax2.scatter(range(len(mistakes)), mistakes.numpy(), alpha=0.5, marker="*", c=mpl_colors[3], label="mistakes") - ax2.plot(moving_accuracy.numpy(), alpha=0.5, c=mpl_colors[3], label="moving accuracy", linewidth=2) + ax2.plot( + drift_labels.numpy(), + alpha=0.5, + c=mpl_colors[2], + linestyle="--", + label="drift", + linewidth=3, + ) + ax2.scatter( + range(len(mistakes)), + mistakes.numpy(), + alpha=0.5, + marker="*", + c=mpl_colors[3], + label="mistakes", + ) + ax2.plot( + moving_accuracy.numpy(), + alpha=0.5, + c=mpl_colors[3], + label="moving accuracy", + linewidth=2, + ) # plot reference accuracy ax2.axhline( - results["ref_accuracy"], linestyle=":", color="black", alpha=0.5, linewidth=3, label="drift accuracy ref" + results["ref_accuracy"], + linestyle=":", + color="black", + alpha=0.5, + linewidth=3, + label="drift accuracy ref", ) ax1.set_xlabel("Sample index") ax1.set_ylabel("Scores") @@ -346,17 +380,62 @@ def report(results: Dict[str, Any], subsample=10, warmup_size=2000, **kwargs): @register_pipeline("covariate_drift_cifar10") class OneCorruptionCovariateDriftCifar10Pipeline(CovariateDriftPipeline): - def __init__(self, transform, corruption: str, intensities: List[int], batch_size: int = 128, **kwargs) -> None: - super().__init__("cifar10", ["train", "test"], transform, [corruption], intensities, batch_size, **kwargs) + def __init__( + self, + transform, + corruption: str, + intensities: List[int], + batch_size: int = 128, + **kwargs, + ) -> None: + super().__init__( + "cifar10", + ["train", "test"], + transform, + [corruption], + intensities, + batch_size, + **kwargs, + ) @register_pipeline("covariate_drift_cifar100") class OneCorruptionCovariateDriftCifar100Pipeline(CovariateDriftPipeline): - def __init__(self, transform, corruption: str, intensities: List[int], batch_size: int = 128, **kwargs) -> None: - super().__init__("cifar100", ["train", "test"], transform, [corruption], intensities, batch_size, **kwargs) + def __init__( + self, + transform, + corruption: str, + intensities: List[int], + batch_size: int = 128, + **kwargs, + ) -> None: + super().__init__( + "cifar100", + ["train", "test"], + transform, + [corruption], + intensities, + batch_size, + **kwargs, + ) @register_pipeline("covariate_drift_imagenet") class OneCorruptionCovariateDriftImagenetPipeline(CovariateDriftPipeline): - def __init__(self, transform, corruption: str, intensities: List[int], batch_size: int = 128, **kwargs) -> None: - super().__init__("imagenet", ["train", "val"], transform, [corruption], intensities, batch_size, **kwargs) + def __init__( + self, + transform, + corruption: str, + intensities: List[int], + batch_size: int = 128, + **kwargs, + ) -> None: + super().__init__( + "imagenet", + ["train", "val"], + transform, + [corruption], + intensities, + batch_size, + **kwargs, + ) diff --git a/src/detectors/pipelines/ood.py b/src/detectors/pipelines/ood.py index 1f3bf97..d072cbb 100644 --- a/src/detectors/pipelines/ood.py +++ b/src/detectors/pipelines/ood.py @@ -1,6 +1,7 @@ """ OOD Pipelines. """ + import logging import time from typing import Any, Callable, Dict, List, Literal, Tuple, Union @@ -278,10 +279,6 @@ def _setup_datasets(self): self.out_dataset = torch.utils.data.ConcatDataset(list(self.out_datasets.values())) -# @register_pipeline("ood_benchmark_cifar10_scood") # TODO -# @register_pipeline("ood_benchmark_cifar10_mood") # TODO - - @register_pipeline("ood_benchmark_cifar100") class OODCifar100BenchmarkPipeline(OODBenchmarkPipeline): def __init__(self, transform: Callable, limit_fit=1.0, limit_run=1.0, batch_size=128, seed=42, **kwargs) -> None: @@ -423,41 +420,6 @@ def _setup_datasets(self): self.out_dataset = torch.utils.data.ConcatDataset(list(self.out_datasets.values())) -@register_pipeline("ood_benchmark_imagenet_all_2") -class OODImageNetBenchmarkPipelineAll(OODBenchmarkPipeline): - def __init__(self, transform: Callable, limit_fit=1.0, limit_run=1.0, batch_size=64, seed=42, **kwargs) -> None: - super().__init__( - "ilsvrc2012", - { - "inaturalist_clean": None, - "species_clean": None, - "places_clean": None, - "openimage_o_clean": None, - "ssb_easy": None, - "textures_clean": None, - "ninco": None, - "ssb_hard": None, - }, - limit_fit=limit_fit, - limit_run=limit_run, - transform=transform, - batch_size=batch_size, - seed=seed, - ) - - def _setup_datasets(self): - _logger.info("Loading In-distribution dataset...") - self.fit_dataset = create_dataset(self.in_dataset_name, split="train", transform=self.transform) - self.in_dataset = create_dataset(self.in_dataset_name, split="val", transform=self.transform) - - _logger.info("Loading OOD datasets...") - self.out_datasets = { - ds: create_dataset(ds, split=split, transform=self.transform, download=True) - for ds, split in self.out_datasets_names_splits.items() - } - self.out_dataset = torch.utils.data.ConcatDataset(list(self.out_datasets.values())) - - @register_pipeline("ood_benchmark_imagenet_far") class OODImageNetBenchmarkPipelineEasy(OODBenchmarkPipeline): def __init__(self, transform: Callable, limit_fit=1.0, limit_run=1.0, batch_size=64, seed=42, **kwargs) -> None: @@ -533,8 +495,6 @@ class OODValidationPipeline(OODBenchmarkPipeline): n_trials (int, optional): The number of trials to run. Defaults to 20. """ - # TODO: include prevent refit flag. - def run( self, method: DetectorWrapper, diff --git a/src/detectors/pipelines/sc.py b/src/detectors/pipelines/sc.py index 2f893c4..d80e3e5 100644 --- a/src/detectors/pipelines/sc.py +++ b/src/detectors/pipelines/sc.py @@ -73,7 +73,12 @@ def _setup_datasets(self): ) else: self.fit_dataset = None - self.in_dataset = create_dataset(self.in_dataset_name, split="test", transform=self.transform, download=True) + try: + self.in_dataset = create_dataset( + self.in_dataset_name, split="test", transform=self.transform, download=True + ) + except ValueError: + self.in_dataset = create_dataset(self.in_dataset_name, split="val", transform=self.transform, download=True) def _setup_dataloaders(self): if self.in_dataset is None: diff --git a/src/detectors/pipelines/scod.py b/src/detectors/pipelines/scod.py index e46c1f7..d65ce2a 100644 --- a/src/detectors/pipelines/scod.py +++ b/src/detectors/pipelines/scod.py @@ -281,3 +281,27 @@ def __init__(self, transform: Callable, limit_fit=0.0, limit_run=1.0, batch_size seed=seed, **kwargs ) + + +@register_pipeline("scod_benchmark_imagenet") +class SCODImagenetBenchmarkPipeline(SCODPipeline): + def __init__(self, transform: Callable, limit_fit=0.0, limit_run=1.0, batch_size=128, seed=42, **kwargs) -> None: + super().__init__( + "ilsvrc2012", + { + "inaturalist_clean": None, + "species_clean": None, + "places_clean": None, + "openimage_o_clean": None, + "ssb_easy": None, + "textures_clean": None, + "ninco": None, + "ssb_hard": None, + }, + transform=transform, + batch_size=batch_size, + limit_fit=limit_fit, + limit_run=limit_run, + seed=seed, + **kwargs + ) diff --git a/src/detectors/scod_eval.py b/src/detectors/scod_eval.py index ff3c441..2d1e0d6 100644 --- a/src/detectors/scod_eval.py +++ b/src/detectors/scod_eval.py @@ -41,12 +41,6 @@ def plugin_bb( r_bb == 1 => reject (misclassified or OOD) """ - # build rejector: need to compute c_in and c_out - # c_in = lbd * pi - # c_out = c_fn - lbd * (1 - pi) - # thr = 1 - 2 * c_in - c_out - # pi is asy to estimate - # lbd need to solve the lagragnian or solve Pr(r(x)=1) <= b_rej def r_bb_fn(c_in, c_out, thr): return (1 - c_in - c_out) * scores_sc + c_out * invert(scores_ood) < thr @@ -183,5 +177,4 @@ def timefn(fn, *args, **kwargs): if __name__ == "__main__": - # benchmark benchmark() diff --git a/src/detectors/trainer.py b/src/detectors/trainer.py index 41e40aa..d37c943 100644 --- a/src/detectors/trainer.py +++ b/src/detectors/trainer.py @@ -1,4 +1,5 @@ """Trainer for classification models.""" + import json import logging import os diff --git a/src/detectors/utils.py b/src/detectors/utils.py index 526f971..d644888 100644 --- a/src/detectors/utils.py +++ b/src/detectors/utils.py @@ -66,15 +66,7 @@ def sync_tensor_across_gpus(t: torch.Tensor) -> torch.Tensor: if group_size == 1: return t gather_t_tensor = [torch.zeros_like(t) for _ in range(group_size)] - dist.all_gather(gather_t_tensor, t) # this works with nccl backend when tensors need to be on gpu. - # for gloo and mpi backends, tensors need to be on cpu. also this works single machine with - # multiple gpus. for multiple nodes, you should use dist.all_gather_multigpu. both have the - # same definition... see [here](https://pytorch.org/docs/stable/distributed.html). - # somewhere in the same page, it was mentioned that dist.all_gather_multigpu is more for - # multi-nodes. still dont see the benefit of all_gather_multigpu. the provided working case in - # the doc is vague... - # move tensors to cpu - # gather_t_tensor = [t.cpu() for t in gather_t_tensor] + dist.all_gather(gather_t_tensor, t) gather_t_tensor = torch.cat(gather_t_tensor, dim=0) return gather_t_tensor