diff --git a/CHANGELOG.md b/CHANGELOG.md index ca845a18..c7f6407e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -* New vision example: MAML++. (@[DubiousCactus](https://github.com/DubiousCactus)) +* New BatchNorm layer with per-step running statistics and weights & biases from MAML++. (@[Théo Morales](https://github.com/DubiousCactus)) +* New vision example: MAML++. (@[Théo Morales](https://github.com/DubiousCactus)) * Add tutorial: "Demystifying Task Transforms", ([Varad Pimpalkhute](https://github.com/nightlessbaron/)) ### Changed diff --git a/examples/vision/mamlpp/cnn4_bnrs.py b/examples/vision/mamlpp/cnn4_bnrs.py deleted file mode 100644 index 28f9666f..00000000 --- a/examples/vision/mamlpp/cnn4_bnrs.py +++ /dev/null @@ -1,321 +0,0 @@ -#! /usr/bin/env python3 -# -*- coding: utf-8 -*- -# vim:fenc=utf-8 -# - -""" -CNN4 extended with Batch-Norm Running Statistics. -""" - -import torch -import torch.nn.functional as F - -from copy import deepcopy -from learn2learn.vision.models.cnn4 import maml_init_, fc_init_ - - -class MetaBatchNormLayer(torch.nn.Module): - """ - An extension of Pytorch's BatchNorm layer, with the Per-Step Batch Normalisation Running - Statistics and Per-Step Batch Normalisation Weights and Biases improvements proposed in - MAML++ by Antoniou et al. It is adapted from the original Pytorch implementation at - https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch, - with heavy refactoring and a bug fix - (https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/42). - """ - - def __init__( - self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - meta_batch_norm=True, - adaptation_steps: int = 1, - ): - super(MetaBatchNormLayer, self).__init__() - self.num_features = num_features - self.eps = eps - self.affine = affine - self.meta_batch_norm = meta_batch_norm - self.num_features = num_features - self.running_mean = torch.nn.Parameter( - torch.zeros(adaptation_steps, num_features), requires_grad=False - ) - self.running_var = torch.nn.Parameter( - torch.ones(adaptation_steps, num_features), requires_grad=False - ) - self.bias = torch.nn.Parameter( - torch.zeros(adaptation_steps, num_features), requires_grad=True - ) - self.weight = torch.nn.Parameter( - torch.ones(adaptation_steps, num_features), requires_grad=True - ) - self.backup_running_mean = torch.zeros(self.running_mean.shape) - self.backup_running_var = torch.ones(self.running_var.shape) - self.momentum = momentum - - def forward( - self, - input, - step, - ): - """ - :param input: input data batch, size either can be any. - :param step: The current inner loop step being taken. This is used when to learn per step params and - collecting per step batch statistics. - :return: The result of the batch norm operation. - """ - assert ( - step < self.running_mean.shape[0] - ), f"Running forward with step={step} when initialised with {self.running_mean.shape[0]} steps!" - return F.batch_norm( - input, - self.running_mean[step], - self.running_var[step], - self.weight[step], - self.bias[step], - training=True, - momentum=self.momentum, - eps=self.eps, - ) - - def backup_stats(self): - self.backup_running_mean.data = deepcopy(self.running_mean.data) - self.backup_running_var.data = deepcopy(self.running_var.data) - - def restore_backup_stats(self): - """ - Resets batch statistics to their backup values which are collected after each forward pass. - """ - self.running_mean = torch.nn.Parameter( - self.backup_running_mean, requires_grad=False - ) - self.running_var = torch.nn.Parameter( - self.backup_running_var, requires_grad=False - ) - - def extra_repr(self): - return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format( - **self.__dict__ - ) - - -class LinearBlock_BNRS(torch.nn.Module): - def __init__(self, input_size, output_size, adaptation_steps): - super(LinearBlock_BNRS, self).__init__() - self.relu = torch.nn.ReLU() - self.normalize = MetaBatchNormLayer( - output_size, - affine=True, - momentum=0.999, - eps=1e-3, - adaptation_steps=adaptation_steps, - ) - self.linear = torch.nn.Linear(input_size, output_size) - fc_init_(self.linear) - - def forward(self, x, step): - x = self.linear(x) - x = self.normalize(x, step) - x = self.relu(x) - return x - - -class ConvBlock_BNRS(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - max_pool=True, - max_pool_factor=1.0, - adaptation_steps=1, - ): - super(ConvBlock_BNRS, self).__init__() - stride = (int(2 * max_pool_factor), int(2 * max_pool_factor)) - if max_pool: - self.max_pool = torch.nn.MaxPool2d( - kernel_size=stride, - stride=stride, - ceil_mode=False, - ) - stride = (1, 1) - else: - self.max_pool = lambda x: x - self.normalize = MetaBatchNormLayer( - out_channels, - affine=True, - adaptation_steps=adaptation_steps, - # eps=1e-3, - # momentum=0.999, - ) - torch.nn.init.uniform_(self.normalize.weight) - self.relu = torch.nn.ReLU() - - self.conv = torch.nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=1, - bias=True, - ) - maml_init_(self.conv) - - def forward(self, x, step): - x = self.conv(x) - x = self.normalize(x, step) - x = self.relu(x) - x = self.max_pool(x) - return x - - -class ConvBase_BNRS(torch.nn.Sequential): - - # NOTE: - # Omniglot: hidden=64, channels=1, no max_pool - # MiniImagenet: hidden=32, channels=3, max_pool - - def __init__( - self, hidden=64, channels=1, max_pool=False, layers=4, max_pool_factor=1.0, - adaptation_steps=1 - ): - core = [ - ConvBlock_BNRS( - channels, - hidden, - (3, 3), - max_pool=max_pool, - max_pool_factor=max_pool_factor, - adaptation_steps=adaptation_steps - ), - ] - for _ in range(layers - 1): - core.append( - ConvBlock_BNRS( - hidden, - hidden, - kernel_size=(3, 3), - max_pool=max_pool, - max_pool_factor=max_pool_factor, - adaptation_steps=adaptation_steps - ) - ) - super(ConvBase_BNRS, self).__init__(*core) - - def forward(self, x, step): - for module in self: - x = module(x, step) - return x - - -class CNN4Backbone_BNRS(ConvBase_BNRS): - def __init__( - self, - hidden_size=64, - layers=4, - channels=3, - max_pool=True, - max_pool_factor=None, - adaptation_steps=1, - ): - if max_pool_factor is None: - max_pool_factor = 4 // layers - super(CNN4Backbone_BNRS, self).__init__( - hidden=hidden_size, - layers=layers, - channels=channels, - max_pool=max_pool, - max_pool_factor=max_pool_factor, - adaptation_steps=adaptation_steps - ) - - def forward(self, x, step): - x = super(CNN4Backbone_BNRS, self).forward(x, step) - x = x.reshape(x.size(0), -1) - return x - - -class CNN4_BNRS(torch.nn.Module): - """ - - [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models/cnn4.py) - - **Description** - - The convolutional network commonly used for MiniImagenet, as described by Ravi et Larochelle, 2017. - - This network assumes inputs of shapes (3, 84, 84). - - Instantiate `CNN4Backbone` if you only need the feature extractor. - - **References** - - 1. Ravi and Larochelle. 2017. “Optimization as a Model for Few-Shot Learning.” ICLR. - - **Arguments** - - * **output_size** (int) - The dimensionality of the network's output. - * **hidden_size** (int, *optional*, default=64) - The dimensionality of the hidden representation. - * **layers** (int, *optional*, default=4) - The number of convolutional layers. - * **channels** (int, *optional*, default=3) - The number of channels in input. - * **max_pool** (bool, *optional*, default=True) - Whether ConvBlocks use max-pooling. - * **embedding_size** (int, *optional*, default=None) - Size of feature embedding. - Defaults to 25 * hidden_size (for mini-Imagenet). - - **Example** - ~~~python - model = CNN4(output_size=20, hidden_size=128, layers=3) - ~~~ - """ - - def __init__( - self, - output_size, - hidden_size=64, - layers=4, - channels=3, - max_pool=True, - embedding_size=None, - adaptation_steps=1, - ): - super(CNN4_BNRS, self).__init__() - if embedding_size is None: - embedding_size = 25 * hidden_size - self.features = CNN4Backbone_BNRS( - hidden_size=hidden_size, - channels=channels, - max_pool=max_pool, - layers=layers, - max_pool_factor=4 // layers, - adaptation_steps=adaptation_steps, - ) - self.classifier = torch.nn.Linear( - embedding_size, - output_size, - bias=True, - ) - maml_init_(self.classifier) - self.hidden_size = hidden_size - - def backup_stats(self): - """ - Backup stored batch statistics before running a validation epoch. - """ - for layer in self.features.modules(): - if type(layer) is MetaBatchNormLayer: - layer.backup_stats() - - def restore_backup_stats(self): - """ - Reset stored batch statistics from the stored backup. - """ - for layer in self.features.modules(): - if type(layer) is MetaBatchNormLayer: - layer.restore_backup_stats() - - def forward(self, x, step): - x = self.features(x, step) - x = self.classifier(x) - return x diff --git a/examples/vision/mamlpp/maml++_miniimagenet.py b/examples/vision/mamlpp/maml++_miniimagenet.py index 78085bf9..8e7c9678 100755 --- a/examples/vision/mamlpp/maml++_miniimagenet.py +++ b/examples/vision/mamlpp/maml++_miniimagenet.py @@ -20,7 +20,7 @@ from typing import Tuple from tqdm import tqdm -from examples.vision.mamlpp.cnn4_bnrs import CNN4_BNRS +from learn2learn.vision.models.cnn4_metabatchnorm import CNN4_MetaBatchNorm from examples.vision.mamlpp.MAMLpp import MAMLpp @@ -72,7 +72,7 @@ def __init__( ) # Model - self._model = CNN4_BNRS(ways, adaptation_steps=steps) + self._model = CNN4_MetaBatchNorm(ways, steps) if self._use_cuda: self._model.cuda() @@ -147,19 +147,19 @@ def _training_step( # Adapt the model on the support set for step in range(self._steps): # forward + backward + optimize - pred = learner(s_inputs, step) + pred = learner(s_inputs) support_loss = self._inner_criterion(pred, s_labels) learner.adapt(support_loss, first_order=not second_order, step=step) # Multi-Step Loss if msl: - q_pred = learner(q_inputs, step) + q_pred = learner(q_inputs) query_loss += self._step_weights[step] * self._inner_criterion( q_pred, q_labels ) # Evaluate the adapted model on the query set if not msl: - q_pred = learner(q_inputs, self._steps-1) + q_pred = learner(q_inputs, inference=True) query_loss = self._inner_criterion(q_pred, q_labels) acc = accuracy(q_pred, q_labels).detach() @@ -180,12 +180,12 @@ def _testing_step( # Adapt the model on the support set for step in range(self._steps): # forward + backward + optimize - pred = learner(s_inputs, step) + pred = learner(s_inputs) support_loss = self._inner_criterion(pred, s_labels) learner.adapt(support_loss, step=step) # Evaluate the adapted model on the query set - q_pred = learner(q_inputs, self._steps-1) + q_pred = learner(q_inputs, inference=True) query_loss = self._inner_criterion(q_pred, q_labels).detach() acc = accuracy(q_pred, q_labels) diff --git a/learn2learn/nn/__init__.py b/learn2learn/nn/__init__.py index 54ea3de3..cf8638b5 100644 --- a/learn2learn/nn/__init__.py +++ b/learn2learn/nn/__init__.py @@ -8,3 +8,4 @@ from .misc import * from .protonet import PrototypicalClassifier from .metaoptnet import SVClassifier +from .metabatchnorm import MetaBatchNorm diff --git a/learn2learn/nn/metabatchnorm.py b/learn2learn/nn/metabatchnorm.py new file mode 100644 index 00000000..4f4967c5 --- /dev/null +++ b/learn2learn/nn/metabatchnorm.py @@ -0,0 +1,136 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# + +""" +BatchNorm layer augmented with Per-Step Batch Normalisation Running Statistics and Per-Step Batch +Normalisation Weights and Biases, as proposed in MAML++ by Antobiou et al. +""" + +import torch +import torch.nn.functional as F + +from copy import deepcopy + + +class MetaBatchNorm(torch.nn.Module): + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/nn/metabatchnorm.py) + + **Description** + + An extension of Pytorch's BatchNorm layer, with the Per-Step Batch Normalisation Running + Statistics and Per-Step Batch Normalisation Weights and Biases improvements proposed in + "How to train your MAML". + It is adapted from the original Pytorch implementation at + https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch, + with heavy refactoring and a bug fix + (https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/42). + + **Arguments** + + * **num_features** (int) - number of input features. + * **adaptation_steps** (int) - number of inner-loop adaptation steps. + * **eps** (float, *optional*, default=1e-5) - a value added to the denominator for numerical + stability. + * **momentum** (float, *optional*, default=0.1) - the value used for the running_mean and + running_var computation. Can be set to None for cumulative moving average (i.e. simple + average). + * **affine** (bool, *optional*, default=True) - a boolean value that when set to True, this + module has learnable affine parameters. + + **References** + + 1. Antoniou et al. 2019. "How to train your MAML." ICLR. + + **Example** + + ~~~python + batch_norm = MetaBatchNorm(100, 5) + input = torch.randn(20, 100, 35, 45) + for step in range(5): + output = batch_norm(input, step) + ~~~ + """ + + def __init__( + self, + num_features, + adaptation_steps, + eps=1e-5, + momentum=0.1, + affine=True, + ): + super(MetaBatchNorm, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + self.num_features = num_features + self.running_mean = torch.nn.Parameter( + torch.zeros(adaptation_steps, num_features), requires_grad=False + ) + self.running_var = torch.nn.Parameter( + torch.ones(adaptation_steps, num_features), requires_grad=False + ) + self.bias = torch.nn.Parameter( + torch.zeros(adaptation_steps, num_features), requires_grad=True + ) + self.weight = torch.nn.Parameter( + torch.ones(adaptation_steps, num_features), requires_grad=True + ) + self.backup_running_mean = torch.zeros(self.running_mean.shape) + self.backup_running_var = torch.ones(self.running_var.shape) + self.momentum = momentum + self._steps = adaptation_steps + self._current_step = 0 + + def forward( + self, + input, + inference=False, + ): + """ + **Arguments** + + * **input** (tensor) - Input data batch, size either can be any. + * **inferencep** (bool, *optional*, default=False) - when set to `True`, uses the final + step's parameters and running statistics. When set to `False`, automatically infers the + current adaptation step. + """ + step = self._current_step if not inference else self._steps - 1 + output = F.batch_norm( + input, + self.running_mean[step], + self.running_var[step], + self.weight[step], + self.bias[step], + training=True, + momentum=self.momentum, + eps=self.eps, + ) + if not inference: + self._current_step = ( + self._current_step + 1 if self._current_step < (self._steps - 1) else 0 + ) + return output + + def backup_stats(self): + self.backup_running_mean.data = deepcopy(self.running_mean.data) + self.backup_running_var.data = deepcopy(self.running_var.data) + + def restore_backup_stats(self): + """ + Resets batch statistics to their backup values which are collected after each forward pass. + """ + self.running_mean = torch.nn.Parameter( + self.backup_running_mean, requires_grad=False + ) + self.running_var = torch.nn.Parameter( + self.backup_running_var, requires_grad=False + ) + + def extra_repr(self): + return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format( + **self.__dict__ + ) diff --git a/learn2learn/vision/models/__init__.py b/learn2learn/vision/models/__init__.py index 54cdeec4..7c3edb35 100644 --- a/learn2learn/vision/models/__init__.py +++ b/learn2learn/vision/models/__init__.py @@ -31,6 +31,14 @@ def forward(self, x): CNN4Backbone, ) +from .cnn4_metabatchnorm import ( + LinearBlock_MetaBatchNorm, + ConvBlock_MetaBatchNorm, + ConvBase_MetaBatchNorm, + CNN4Backbone_MetaBatchNorm, + CNN4_MetaBatchNorm, +) + from .resnet12 import ResNet12, ResNet12Backbone from .wrn28 import WRN28, WRN28Backbone @@ -49,6 +57,11 @@ def forward(self, x): 'ResNet12Backbone', 'WRN28', 'WRN28Backbone', + 'LinearBlock_MetaBatchNorm', + 'ConvBlock_MetaBatchNorm', + 'ConvBase_MetaBatchNorm', + 'CNN4Backbone_MetaBatchNorm', + 'CNN4_MetaBatchNorm', ] _BACKBONE_URLS = { diff --git a/learn2learn/vision/models/cnn4_metabatchnorm.py b/learn2learn/vision/models/cnn4_metabatchnorm.py new file mode 100644 index 00000000..1ea6f923 --- /dev/null +++ b/learn2learn/vision/models/cnn4_metabatchnorm.py @@ -0,0 +1,251 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# + +""" +CNN4 using a MetaBatchNorm layer allowing to accumulate per-step running statistics and use +per-step bias and variance parameters. +""" + +import torch + +from learn2learn.nn.metabatchnorm import MetaBatchNorm +from learn2learn.vision.models.cnn4 import maml_init_, fc_init_ + + +class LinearBlock_MetaBatchNorm(torch.nn.Module): + def __init__(self, input_size, output_size, adaptation_steps): + super(LinearBlock_MetaBatchNorm, self).__init__() + self.relu = torch.nn.ReLU() + self.normalize = MetaBatchNorm( + output_size, + adaptation_steps, + affine=True, + momentum=0.999, + eps=1e-3, + ) + self.linear = torch.nn.Linear(input_size, output_size) + fc_init_(self.linear) + + def forward(self, x, inference=False): + x = self.linear(x) + x = self.normalize(x, inference=inference) + x = self.relu(x) + return x + + +class ConvBlock_MetaBatchNorm(torch.nn.Module): + def __init__( + self, + adaptation_steps, + in_channels, + out_channels, + kernel_size, + max_pool=True, + max_pool_factor=1.0, + ): + super(ConvBlock_MetaBatchNorm, self).__init__() + stride = (int(2 * max_pool_factor), int(2 * max_pool_factor)) + if max_pool: + self.max_pool = torch.nn.MaxPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=False, + ) + stride = (1, 1) + else: + self.max_pool = lambda x: x + self.normalize = MetaBatchNorm( + out_channels, + adaptation_steps, + affine=True, + # eps=1e-3, + # momentum=0.999, + ) + torch.nn.init.uniform_(self.normalize.weight) + self.relu = torch.nn.ReLU() + + self.conv = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=1, + bias=True, + ) + maml_init_(self.conv) + + def forward(self, x, inference=False): + x = self.conv(x) + x = self.normalize(x, inference=inference) + x = self.relu(x) + x = self.max_pool(x) + return x + + +class ConvBase_MetaBatchNorm(torch.nn.Sequential): + + # NOTE: + # Omniglot: hidden=64, channels=1, no max_pool + # MiniImagenet: hidden=32, channels=3, max_pool + + def __init__( + self, + adaptation_steps, + hidden=64, + channels=1, + max_pool=False, + layers=4, + max_pool_factor=1.0, + ): + core = [ + ConvBlock_MetaBatchNorm( + adaptation_steps, + channels, + hidden, + (3, 3), + max_pool=max_pool, + max_pool_factor=max_pool_factor, + ), + ] + for _ in range(layers - 1): + core.append( + ConvBlock_MetaBatchNorm( + adaptation_steps, + hidden, + hidden, + kernel_size=(3, 3), + max_pool=max_pool, + max_pool_factor=max_pool_factor, + ) + ) + super(ConvBase_MetaBatchNorm, self).__init__(*core) + + def forward(self, x, inference=False): + for module in self: + x = module(x, inference=inference) + return x + + +class CNN4Backbone_MetaBatchNorm(ConvBase_MetaBatchNorm): + def __init__( + self, + adaptation_steps, + hidden_size=64, + layers=4, + channels=3, + max_pool=True, + max_pool_factor=None, + ): + if max_pool_factor is None: + max_pool_factor = 4 // layers + super(CNN4Backbone_MetaBatchNorm, self).__init__( + adaptation_steps, + hidden=hidden_size, + layers=layers, + channels=channels, + max_pool=max_pool, + max_pool_factor=max_pool_factor, + ) + + def forward(self, x, inference=False): + x = super(CNN4Backbone_MetaBatchNorm, self).forward(x, inference=inference) + x = x.reshape(x.size(0), -1) + return x + + +class CNN4_MetaBatchNorm(torch.nn.Module): + """ + + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models/cnn4.py) + + **Description** + + The convolutional network commonly used for MiniImagenet, as described by Ravi et Larochelle, + 2017, using the MetaBatchNorm layer proposed by Antoniou et al. 2019. + + This network assumes inputs of shapes (3, 84, 84). + + Instantiate `CNN4Backbone_MetaBatchNorm` if you only need the feature extractor. + + **References** + + 1. Ravi and Larochelle. 2017. “Optimization as a Model for Few-Shot Learning.” ICLR. + 2. Antoniou et al. 2019. “How to train your MAML.“ ICLR. + + **Arguments** + + * **output_size** (int) - The dimensionality of the network's output. + * **adaptation_steps** (int) - Number of inner-loop adaptation steps. + * **hidden_size** (int, *optional*, default=64) - The dimensionality of the hidden + representation. + * **layers** (int, *optional*, default=4) - The number of convolutional layers. + * **channels** (int, *optional*, default=3) - The number of channels in input. + * **max_pool** (bool, *optional*, default=True) - Whether ConvBlocks use max-pooling. + * **embedding_size** (int, *optional*, default=None) - Size of feature embedding. + Defaults to 25 * hidden_size (for mini-Imagenet). + + **Example** + ~~~python + model = CNN4(output_size=20, adaptation_steps=5, hidden_size=128, layers=3) + ~~~ + """ + + def __init__( + self, + output_size, + adaptation_steps, + hidden_size=64, + layers=4, + channels=3, + max_pool=True, + embedding_size=None, + ): + super(CNN4_MetaBatchNorm, self).__init__() + if embedding_size is None: + embedding_size = 25 * hidden_size + self.features = CNN4Backbone_MetaBatchNorm( + adaptation_steps, + hidden_size=hidden_size, + channels=channels, + max_pool=max_pool, + layers=layers, + max_pool_factor=4 // layers, + ) + self.classifier = torch.nn.Linear( + embedding_size, + output_size, + bias=True, + ) + maml_init_(self.classifier) + self.hidden_size = hidden_size + + def backup_stats(self): + """ + Backup stored batch statistics before running a validation epoch. + """ + for layer in self.features.modules(): + if type(layer) is MetaBatchNorm: + layer.backup_stats() + + def restore_backup_stats(self): + """ + Reset stored batch statistics from the stored backup. + """ + for layer in self.features.modules(): + if type(layer) is MetaBatchNorm: + layer.restore_backup_stats() + + def forward(self, x, inference=False): + """ + **Arguments** + + * **input** (tensor) - Input data batch, size either can be any. + * **inferencep** (bool, *optional*, default=False) - when set to `True`, uses the final + step's parameters and running statistics. When set to `False`, automatically infers the + current adaptation step. + """ + x = self.features(x, inference=inference) + x = self.classifier(x) + return x