diff --git a/CTRAIN/data_loaders/data_loaders.py b/CTRAIN/data_loaders/data_loaders.py index b2ec0cf..6234e9d 100644 --- a/CTRAIN/data_loaders/data_loaders.py +++ b/CTRAIN/data_loaders/data_loaders.py @@ -59,7 +59,7 @@ def load_mnist(batch_size=64, normalise=True, train_transforms=[], val_split=Tru test_dataset = datasets.MNIST(root=data_root, train=False, transform=test_transform) if val_split: train_dataset, val_dataset = random_split(train_dataset, [0.8, 0.2]) - val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) @@ -137,7 +137,7 @@ def load_cifar10(batch_size=64, normalise=True, train_transforms=[transforms.Ran test_dataset = datasets.CIFAR10(root=data_root, train=False, transform=test_transform, download=True) if val_split: train_dataset, val_dataset = random_split(train_dataset, [0.8, 0.2]) - val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2) val_loader.mean, val_loader.std = mean, std @@ -229,7 +229,7 @@ def load_gtsrb(batch_size=64, normalise=True, train_transforms=[transforms.Rando train_dataset = Subset(train_dataset_ori, train_ids) val_dataset = Subset(train_dataset_ori, val_ids) - val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) val_loader.mean, val_loader.std = mean, std else: @@ -371,7 +371,7 @@ def load_tinyimagenet(batch_size=64, normalise=True, train_transforms=[transform test_dataset = datasets.ImageFolder(root=data_root + '/tiny-imagenet-200/val/images', transform=test_transform) if val_split: train_dataset, val_dataset = random_split(train_dataset, [0.8, 0.2]) - val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2) val_loader.mean, val_loader.std = mean, std diff --git a/CTRAIN/eval/eval.py b/CTRAIN/eval/eval.py index a6da1e9..54074c7 100644 --- a/CTRAIN/eval/eval.py +++ b/CTRAIN/eval/eval.py @@ -320,15 +320,18 @@ def eval_adaptive(model, eps, data_loader, n_classes=10, test_samples=np.inf, de certified = torch.tensor([], device=device) total_images = 0 + ibp_data_loader = DataLoader(data_loader.dataset, batch_size=data_loader.batch_size, shuffle=False) + ibp_data_loader.max, ibp_data_loader.min, ibp_data_loader.std = data_loader.max, data_loader.min, data_loader.std + crown_data_loader = DataLoader(data_loader.dataset, batch_size=1, shuffle=False) crown_data_loader.max, crown_data_loader.min, crown_data_loader.std = data_loader.max, data_loader.min, data_loader.std - for batch_idx, (data, targets) in tqdm(enumerate(data_loader)): + for batch_idx, (data, targets) in tqdm(enumerate(ibp_data_loader)): certified_idx = torch.zeros(len(data), device=device, dtype=torch.bool) - ptb = PerturbationLpNorm(eps=eps, norm=np.inf, x_L=torch.clamp(data - eps, data_loader.min, data_loader.max).to(device), x_U=torch.clamp(data + eps, data_loader.min, data_loader.max).to(device)) + ptb = PerturbationLpNorm(eps=eps, norm=np.inf, x_L=torch.clamp(data - eps, ibp_data_loader.min, ibp_data_loader.max).to(device), x_U=torch.clamp(data + eps, ibp_data_loader.min, ibp_data_loader.max).to(device)) data, targets = data.to(device), targets.to(device) - if batch_idx * data_loader.batch_size >= test_samples: + if batch_idx * ibp_data_loader.batch_size >= test_samples: continue total_images += len(targets) @@ -346,7 +349,7 @@ def eval_adaptive(model, eps, data_loader, n_classes=10, test_samples=np.inf, de data = data.to('cpu') certified_idx = certified_idx.to("cpu") - ptb = PerturbationLpNorm(eps=eps, norm=np.inf, x_L=torch.clamp(data[~certified_idx] - eps, data_loader.min, data_loader.max).to(device), x_U=torch.clamp(data[~certified_idx] + eps, data_loader.min, data_loader.max).to(device)) + ptb = PerturbationLpNorm(eps=eps, norm=np.inf, x_L=torch.clamp(data[~certified_idx] - eps, ibp_data_loader.min, ibp_data_loader.max).to(device), x_U=torch.clamp(data[~certified_idx] + eps, ibp_data_loader.min, ibp_data_loader.max).to(device)) data = data.to(device) certified_idx = certified_idx.to(device)