From 1b48224603de0df6a3677cae41316156101943e8 Mon Sep 17 00:00:00 2001 From: brando90 Date: Sat, 5 Feb 2022 12:12:29 -0600 Subject: [PATCH 1/3] Update fc100_benchmark.py data augmentation for f100 ala rfs 2020 --- .../vision/benchmarks/fc100_benchmark.py | 54 +++++++++++++++++-- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/learn2learn/vision/benchmarks/fc100_benchmark.py b/learn2learn/vision/benchmarks/fc100_benchmark.py index 5ac5a714..2b8e9d6d 100644 --- a/learn2learn/vision/benchmarks/fc100_benchmark.py +++ b/learn2learn/vision/benchmarks/fc100_benchmark.py @@ -16,17 +16,63 @@ def fc100_tasksets( **kwargs, ): """Tasksets for FC100 benchmarks.""" - data_transform = tv.transforms.ToTensor() + if data_augmentation is None: + train_data_transforms = tv.transforms.ToTensor() + test_data_transforms = tv.transforms.ToTensor() + elif data_augmentation == 'normalize': + train_data_transforms = Compose([ + lambda x: x / 255.0, + ]) + test_data_transforms = train_data_transforms + elif data_augmentation == 'rfs2020': + """ + # original + if augment: + transform = transforms.Compose([ + lambda x: Image.fromarray(x), + transforms.RandomCrop(32, padding=4), + transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), + transforms.RandomHorizontalFlip(), + lambda x: np.asarray(x), + transforms.ToTensor(), + normalize_cifar100 + ]) + else: + transform = transforms.Compose([ + lambda x: Image.fromarray(x), + transforms.ToTensor(), + normalize_cifar100 + ]) + return transform + """ + mean = [0.5071, 0.4867, 0.4408] + std = [0.2675, 0.2565, 0.2761] + normalize = tv.transforms.Normalize(mean=mean, std=std) + train_data_transforms = Compose([ + ToPILImage(), + RandomCrop(32, padding=4), + ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ]) + test_data_transforms = Compose([ + ToTensor(), + normalize, + ]) + else: + raise('Invalid data_augmentation argument.') + train_dataset = l2l.vision.datasets.FC100(root=root, - transform=data_transform, + transform=train_data_transforms, mode='train', download=True) valid_dataset = l2l.vision.datasets.FC100(root=root, - transform=data_transform, + transform=train_data_transforms, mode='validation', download=True) test_dataset = l2l.vision.datasets.FC100(root=root, - transform=data_transform, + transform=test_data_transforms, mode='test', download=True) if device is not None: From 12159f277ccbdd6c2ee748a484162d29927c043e Mon Sep 17 00:00:00 2001 From: brando90 Date: Sat, 5 Feb 2022 12:44:43 -0600 Subject: [PATCH 2/3] Update fc100_benchmark.py misisng data_augmentation=None, --- learn2learn/vision/benchmarks/fc100_benchmark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/learn2learn/vision/benchmarks/fc100_benchmark.py b/learn2learn/vision/benchmarks/fc100_benchmark.py index 2b8e9d6d..6b9bcca2 100644 --- a/learn2learn/vision/benchmarks/fc100_benchmark.py +++ b/learn2learn/vision/benchmarks/fc100_benchmark.py @@ -12,6 +12,7 @@ def fc100_tasksets( test_ways=5, test_samples=10, root='~/data', + data_augmentation=None, device=None, **kwargs, ): From dcc2c0d83ca1b2c304cf2ae69a0b91edfe2098d9 Mon Sep 17 00:00:00 2001 From: brando90 Date: Sat, 5 Feb 2022 12:46:47 -0600 Subject: [PATCH 3/3] forgot from torchvision.transforms import Compose --- learn2learn/vision/benchmarks/fc100_benchmark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/learn2learn/vision/benchmarks/fc100_benchmark.py b/learn2learn/vision/benchmarks/fc100_benchmark.py index 6b9bcca2..22fbbff5 100644 --- a/learn2learn/vision/benchmarks/fc100_benchmark.py +++ b/learn2learn/vision/benchmarks/fc100_benchmark.py @@ -5,6 +5,7 @@ from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels +from torchvision.transforms import Compose def fc100_tasksets( train_ways=5,