diff --git a/learn2learn/vision/benchmarks/fc100_benchmark.py b/learn2learn/vision/benchmarks/fc100_benchmark.py index 5ac5a714..22fbbff5 100644 --- a/learn2learn/vision/benchmarks/fc100_benchmark.py +++ b/learn2learn/vision/benchmarks/fc100_benchmark.py @@ -5,6 +5,7 @@ from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels +from torchvision.transforms import Compose def fc100_tasksets( train_ways=5, @@ -12,21 +13,68 @@ def fc100_tasksets( test_ways=5, test_samples=10, root='~/data', + data_augmentation=None, device=None, **kwargs, ): """Tasksets for FC100 benchmarks.""" - data_transform = tv.transforms.ToTensor() + if data_augmentation is None: + train_data_transforms = tv.transforms.ToTensor() + test_data_transforms = tv.transforms.ToTensor() + elif data_augmentation == 'normalize': + train_data_transforms = Compose([ + lambda x: x / 255.0, + ]) + test_data_transforms = train_data_transforms + elif data_augmentation == 'rfs2020': + """ + # original + if augment: + transform = transforms.Compose([ + lambda x: Image.fromarray(x), + transforms.RandomCrop(32, padding=4), + transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), + transforms.RandomHorizontalFlip(), + lambda x: np.asarray(x), + transforms.ToTensor(), + normalize_cifar100 + ]) + else: + transform = transforms.Compose([ + lambda x: Image.fromarray(x), + transforms.ToTensor(), + normalize_cifar100 + ]) + return transform + """ + mean = [0.5071, 0.4867, 0.4408] + std = [0.2675, 0.2565, 0.2761] + normalize = tv.transforms.Normalize(mean=mean, std=std) + train_data_transforms = Compose([ + ToPILImage(), + RandomCrop(32, padding=4), + ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ]) + test_data_transforms = Compose([ + ToTensor(), + normalize, + ]) + else: + raise('Invalid data_augmentation argument.') + train_dataset = l2l.vision.datasets.FC100(root=root, - transform=data_transform, + transform=train_data_transforms, mode='train', download=True) valid_dataset = l2l.vision.datasets.FC100(root=root, - transform=data_transform, + transform=train_data_transforms, mode='validation', download=True) test_dataset = l2l.vision.datasets.FC100(root=root, - transform=data_transform, + transform=test_data_transforms, mode='test', download=True) if device is not None: