Skip to content

Commit

Permalink
helpful scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
edadaltocg committed Jun 19, 2024
1 parent 20ffb04 commit eb9a811
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 0 deletions.
110 changes: 110 additions & 0 deletions examples/compute_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import argparse
import json
import logging
import os

import timm
import timm.data
import torch
import torch.utils.data
from tqdm import tqdm

import detectors
from detectors.config import RESULTS_DIR

_logger = logging.getLogger(__name__)


def topk_accuracy(preds, labels, k=5):
topk = torch.topk(preds, k=k, dim=1)
topk_preds = topk.indices
topk_labels = labels.unsqueeze(1).expand_as(topk_preds)
return (topk_preds == topk_labels).any(dim=1).float().mean().item()


def main(args):
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cpu":
# try mps
device = "mps"
# create model
model = timm.create_model(args.model, pretrained=True)
model.to(device)
print(model.default_cfg)
data_config = timm.data.resolve_data_config(model.default_cfg)
test_transform = timm.data.create_transform(**data_config)
data_config["is_training"] = True
train_transform = timm.data.create_transform(**data_config, color_jitter=None)

_logger.info("Test transform: %s", test_transform)
_logger.info("Train transform: %s", train_transform)

dataset = detectors.create_dataset(args.dataset, split=args.split, transform=test_transform, download=True)

dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True
)
model.eval()
x = torch.randn(1, 3, 224, 224)
x = x.to(device)
with torch.no_grad():
y = model(x)

num_classes = y.shape[1]
if args.dataset == "imagenet_r":
mask = dataset.imagenet_r_mask
else:
mask = range(num_classes)

all_preds = torch.empty((len(dataset), num_classes), dtype=torch.float32)
all_labels = torch.empty(len(dataset), dtype=torch.long)
_logger.info(f"Shapes: {all_preds.shape}, {all_labels.shape}")
for i, batch in enumerate(tqdm(dataloader, total=len(dataloader))):
inputs, labels = batch
inputs = inputs.to(device)
# print(labels)
with torch.no_grad():
outputs = model(inputs)
# print(outputs)
outputs = torch.softmax(outputs, dim=1)
all_preds[i * args.batch_size : (i + 1) * args.batch_size] = outputs.cpu()
all_labels[i * args.batch_size : (i + 1) * args.batch_size] = labels.cpu()
if args.debug:
_logger.info("Labels: %s", labels)
_logger.info("Predictions: %s", outputs.argmax(1))
break

top1 = topk_accuracy(all_preds[:, mask], all_labels, k=1) * 100
top5 = topk_accuracy(all_preds[:, mask], all_labels, k=5) * 100
_logger.info(torch.sum(torch.argmax(all_preds, dim=1) == all_labels) / len(all_labels))
_logger.info(f"Top-1 accuracy: {top1:.4f}")
_logger.info(f"Top-5 accuracy: {top5:.4f}")

if not args.debug:
# save results to file
results = {
"model": args.model,
"dataset": args.dataset,
"split": args.split,
"top1_acc": top1,
"top5_acc": top5,
}
filename = os.path.join(RESULTS_DIR, "accuracy", "results.csv")
detectors.utils.append_results_to_csv_file(results, filename)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="resnet50.tv_in1k")
parser.add_argument("--dataset", type=str, default="imagenet1k")
parser.add_argument("--split", type=str, default="val")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=3)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()

logging.basicConfig(level=logging.INFO)
_logger.info(json.dumps(args.__dict__, indent=2))

main(args)
1 change: 1 addition & 0 deletions examples/sc_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def main(args):
parser.add_argument("--debug", action="store_true")
parser.add_argument("--idx", type=int, default=None)
args = parser.parse_args()

methods = [
"msp",
"energy",
Expand Down
95 changes: 95 additions & 0 deletions scripts/compute_accuracy_ssl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import argparse
import json
import logging
import time

import accelerate
import numpy as np
import timm
import torch
import torch.utils.data
from sklearn.neighbors import KNeighborsClassifier
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from tqdm import tqdm

import detectors

_logger = logging.getLogger(__name__)


@torch.no_grad()
def main(args):
if "supcon" in args.model or "simclr" in args.model:
args.ssl = True
accelerator = accelerate.Accelerator()

model = timm.create_model(args.model, pretrained=True)
data_config = resolve_data_config(model.default_cfg)
transform = create_transform(**data_config)
_logger.info(transform)

model.eval()
model = accelerator.prepare(model)

dataset = detectors.create_dataset(args.dataset, split=args.split, transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=accelerator.num_processes
)
dataloader = accelerator.prepare(dataloader)

inference_time = []
all_outputs = []
all_labels = []
start_time = time.time()
progress_bar = tqdm(dataloader, desc="Inference", disable=not accelerator.is_local_main_process)
for x, labels in dataloader:
t1 = time.time()
outputs = model(x)
t2 = time.time()

outputs, labels = accelerator.gather_for_metrics((outputs, labels))
all_outputs.append(outputs.cpu())
all_labels.append(labels.cpu())
inference_time.append(t2 - t1)
progress_bar.update()

progress_bar.close()
accelerator.wait_for_everyone()

all_outputs = torch.cat(all_outputs, dim=0)
all_labels = torch.cat(all_labels, dim=0)

if not args.ssl:
_, preds = torch.max(all_outputs, 1)
else:
features = all_outputs.cpu().numpy()
all_labels = all_labels.cpu().numpy()
estimator = KNeighborsClassifier(20, metric="cosine").fit(features, all_labels)
preds = estimator.predict(features)
preds = torch.from_numpy(preds)
all_labels = torch.from_numpy(all_labels)

acc = torch.mean((preds.cpu() == all_labels.cpu()).float()).item()

print(f"Total time: {time.time() - start_time:.2f} seconds")
print(f"Accuracy: {acc}")
print(f"Average inference time: {np.mean(inference_time)}")


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser()

parser.add_argument("--model", type=str, default="densenet121")
parser.add_argument("--dataset", type=str, default="imagenet1k")
parser.add_argument("--split", type=str, default="val")
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--ssl", action="store_true")

args = parser.parse_args()

_logger.info(json.dumps(args.__dict__, indent=2))

main(args)

0 comments on commit eb9a811

Please sign in to comment.