Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAML++: BNRS #327

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -10,7 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

* New vision example: MAML++. (@[DubiousCactus](https://github.com/DubiousCactus))
* New BatchNorm layer with per-step running statistics and weights & biases from MAML++. (@[Théo Morales](https://github.com/DubiousCactus))
* New vision example: MAML++. (@[Théo Morales](https://github.com/DubiousCactus))
* Add tutorial: "Demystifying Task Transforms", ([Varad Pimpalkhute](https://github.com/nightlessbaron/))

### Changed
321 changes: 0 additions & 321 deletions examples/vision/mamlpp/cnn4_bnrs.py

This file was deleted.

14 changes: 7 additions & 7 deletions examples/vision/mamlpp/maml++_miniimagenet.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
from typing import Tuple
from tqdm import tqdm

from examples.vision.mamlpp.cnn4_bnrs import CNN4_BNRS
from learn2learn.vision.models.cnn4_metabatchnorm import CNN4_MetaBatchNorm
from examples.vision.mamlpp.MAMLpp import MAMLpp


@@ -72,7 +72,7 @@ def __init__(
)

# Model
self._model = CNN4_BNRS(ways, adaptation_steps=steps)
self._model = CNN4_MetaBatchNorm(ways, steps)
if self._use_cuda:
self._model.cuda()

@@ -147,19 +147,19 @@ def _training_step(
# Adapt the model on the support set
for step in range(self._steps):
# forward + backward + optimize
pred = learner(s_inputs, step)
pred = learner(s_inputs)
support_loss = self._inner_criterion(pred, s_labels)
learner.adapt(support_loss, first_order=not second_order, step=step)
# Multi-Step Loss
if msl:
q_pred = learner(q_inputs, step)
q_pred = learner(q_inputs)
query_loss += self._step_weights[step] * self._inner_criterion(
q_pred, q_labels
)

# Evaluate the adapted model on the query set
if not msl:
q_pred = learner(q_inputs, self._steps-1)
q_pred = learner(q_inputs, inference=True)
query_loss = self._inner_criterion(q_pred, q_labels)
acc = accuracy(q_pred, q_labels).detach()

@@ -180,12 +180,12 @@ def _testing_step(
# Adapt the model on the support set
for step in range(self._steps):
# forward + backward + optimize
pred = learner(s_inputs, step)
pred = learner(s_inputs)
support_loss = self._inner_criterion(pred, s_labels)
learner.adapt(support_loss, step=step)

# Evaluate the adapted model on the query set
q_pred = learner(q_inputs, self._steps-1)
q_pred = learner(q_inputs, inference=True)
query_loss = self._inner_criterion(q_pred, q_labels).detach()
acc = accuracy(q_pred, q_labels)

1 change: 1 addition & 0 deletions learn2learn/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -8,3 +8,4 @@
from .misc import *
from .protonet import PrototypicalClassifier
from .metaoptnet import SVClassifier
from .metabatchnorm import MetaBatchNorm
136 changes: 136 additions & 0 deletions learn2learn/nn/metabatchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#

"""
BatchNorm layer augmented with Per-Step Batch Normalisation Running Statistics and Per-Step Batch
Normalisation Weights and Biases, as proposed in MAML++ by Antobiou et al.
"""

import torch
import torch.nn.functional as F

from copy import deepcopy


class MetaBatchNorm(torch.nn.Module):
"""
[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/nn/metabatchnorm.py)
**Description**
An extension of Pytorch's BatchNorm layer, with the Per-Step Batch Normalisation Running
Statistics and Per-Step Batch Normalisation Weights and Biases improvements proposed in
"How to train your MAML".
It is adapted from the original Pytorch implementation at
https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch,
with heavy refactoring and a bug fix
(https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/42).
**Arguments**
* **num_features** (int) - number of input features.
* **adaptation_steps** (int) - number of inner-loop adaptation steps.
* **eps** (float, *optional*, default=1e-5) - a value added to the denominator for numerical
stability.
* **momentum** (float, *optional*, default=0.1) - the value used for the running_mean and
running_var computation. Can be set to None for cumulative moving average (i.e. simple
average).
* **affine** (bool, *optional*, default=True) - a boolean value that when set to True, this
module has learnable affine parameters.
**References**
1. Antoniou et al. 2019. "How to train your MAML." ICLR.
**Example**
~~~python
batch_norm = MetaBatchNorm(100, 5)
input = torch.randn(20, 100, 35, 45)
for step in range(5):
output = batch_norm(input, step)
~~~
"""

def __init__(
self,
num_features,
adaptation_steps,
eps=1e-5,
momentum=0.1,
affine=True,
):
super(MetaBatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.num_features = num_features
self.running_mean = torch.nn.Parameter(
torch.zeros(adaptation_steps, num_features), requires_grad=False
)
self.running_var = torch.nn.Parameter(
torch.ones(adaptation_steps, num_features), requires_grad=False
)
self.bias = torch.nn.Parameter(
torch.zeros(adaptation_steps, num_features), requires_grad=True
)
self.weight = torch.nn.Parameter(
torch.ones(adaptation_steps, num_features), requires_grad=True
)
self.backup_running_mean = torch.zeros(self.running_mean.shape)
self.backup_running_var = torch.ones(self.running_var.shape)
self.momentum = momentum
self._steps = adaptation_steps
self._current_step = 0

def forward(
self,
input,
inference=False,
):
"""
**Arguments**
* **input** (tensor) - Input data batch, size either can be any.
* **inferencep** (bool, *optional*, default=False) - when set to `True`, uses the final
step's parameters and running statistics. When set to `False`, automatically infers the
current adaptation step.
"""
step = self._current_step if not inference else self._steps - 1
output = F.batch_norm(
input,
self.running_mean[step],
self.running_var[step],
self.weight[step],
self.bias[step],
training=True,
momentum=self.momentum,
eps=self.eps,
)
if not inference:
self._current_step = (
self._current_step + 1 if self._current_step < (self._steps - 1) else 0
)
return output

def backup_stats(self):
self.backup_running_mean.data = deepcopy(self.running_mean.data)
self.backup_running_var.data = deepcopy(self.running_var.data)

def restore_backup_stats(self):
"""
Resets batch statistics to their backup values which are collected after each forward pass.
"""
self.running_mean = torch.nn.Parameter(
self.backup_running_mean, requires_grad=False
)
self.running_var = torch.nn.Parameter(
self.backup_running_var, requires_grad=False
)

def extra_repr(self):
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format(
**self.__dict__
)
13 changes: 13 additions & 0 deletions learn2learn/vision/models/__init__.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,14 @@ def forward(self, x):
CNN4Backbone,
)

from .cnn4_metabatchnorm import (
LinearBlock_MetaBatchNorm,
ConvBlock_MetaBatchNorm,
ConvBase_MetaBatchNorm,
CNN4Backbone_MetaBatchNorm,
CNN4_MetaBatchNorm,
)

from .resnet12 import ResNet12, ResNet12Backbone
from .wrn28 import WRN28, WRN28Backbone

@@ -49,6 +57,11 @@ def forward(self, x):
'ResNet12Backbone',
'WRN28',
'WRN28Backbone',
'LinearBlock_MetaBatchNorm',
'ConvBlock_MetaBatchNorm',
'ConvBase_MetaBatchNorm',
'CNN4Backbone_MetaBatchNorm',
'CNN4_MetaBatchNorm',
]

_BACKBONE_URLS = {
251 changes: 251 additions & 0 deletions learn2learn/vision/models/cnn4_metabatchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#

"""
CNN4 using a MetaBatchNorm layer allowing to accumulate per-step running statistics and use
per-step bias and variance parameters.
"""

import torch

from learn2learn.nn.metabatchnorm import MetaBatchNorm
from learn2learn.vision.models.cnn4 import maml_init_, fc_init_


class LinearBlock_MetaBatchNorm(torch.nn.Module):
def __init__(self, input_size, output_size, adaptation_steps):
super(LinearBlock_MetaBatchNorm, self).__init__()
self.relu = torch.nn.ReLU()
self.normalize = MetaBatchNorm(
output_size,
adaptation_steps,
affine=True,
momentum=0.999,
eps=1e-3,
)
self.linear = torch.nn.Linear(input_size, output_size)
fc_init_(self.linear)

def forward(self, x, inference=False):
x = self.linear(x)
x = self.normalize(x, inference=inference)
x = self.relu(x)
return x


class ConvBlock_MetaBatchNorm(torch.nn.Module):
def __init__(
self,
adaptation_steps,
in_channels,
out_channels,
kernel_size,
max_pool=True,
max_pool_factor=1.0,
):
super(ConvBlock_MetaBatchNorm, self).__init__()
stride = (int(2 * max_pool_factor), int(2 * max_pool_factor))
if max_pool:
self.max_pool = torch.nn.MaxPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=False,
)
stride = (1, 1)
else:
self.max_pool = lambda x: x
self.normalize = MetaBatchNorm(
out_channels,
adaptation_steps,
affine=True,
# eps=1e-3,
# momentum=0.999,
)
torch.nn.init.uniform_(self.normalize.weight)
self.relu = torch.nn.ReLU()

self.conv = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=1,
bias=True,
)
maml_init_(self.conv)

def forward(self, x, inference=False):
x = self.conv(x)
x = self.normalize(x, inference=inference)
x = self.relu(x)
x = self.max_pool(x)
return x


class ConvBase_MetaBatchNorm(torch.nn.Sequential):

# NOTE:
# Omniglot: hidden=64, channels=1, no max_pool
# MiniImagenet: hidden=32, channels=3, max_pool

def __init__(
self,
adaptation_steps,
hidden=64,
channels=1,
max_pool=False,
layers=4,
max_pool_factor=1.0,
):
core = [
ConvBlock_MetaBatchNorm(
adaptation_steps,
channels,
hidden,
(3, 3),
max_pool=max_pool,
max_pool_factor=max_pool_factor,
),
]
for _ in range(layers - 1):
core.append(
ConvBlock_MetaBatchNorm(
adaptation_steps,
hidden,
hidden,
kernel_size=(3, 3),
max_pool=max_pool,
max_pool_factor=max_pool_factor,
)
)
super(ConvBase_MetaBatchNorm, self).__init__(*core)

def forward(self, x, inference=False):
for module in self:
x = module(x, inference=inference)
return x


class CNN4Backbone_MetaBatchNorm(ConvBase_MetaBatchNorm):
def __init__(
self,
adaptation_steps,
hidden_size=64,
layers=4,
channels=3,
max_pool=True,
max_pool_factor=None,
):
if max_pool_factor is None:
max_pool_factor = 4 // layers
super(CNN4Backbone_MetaBatchNorm, self).__init__(
adaptation_steps,
hidden=hidden_size,
layers=layers,
channels=channels,
max_pool=max_pool,
max_pool_factor=max_pool_factor,
)

def forward(self, x, inference=False):
x = super(CNN4Backbone_MetaBatchNorm, self).forward(x, inference=inference)
x = x.reshape(x.size(0), -1)
return x


class CNN4_MetaBatchNorm(torch.nn.Module):
"""
[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models/cnn4.py)
**Description**
The convolutional network commonly used for MiniImagenet, as described by Ravi et Larochelle,
2017, using the MetaBatchNorm layer proposed by Antoniou et al. 2019.
This network assumes inputs of shapes (3, 84, 84).
Instantiate `CNN4Backbone_MetaBatchNorm` if you only need the feature extractor.
**References**
1. Ravi and Larochelle. 2017. “Optimization as a Model for Few-Shot Learning.” ICLR.
2. Antoniou et al. 2019. “How to train your MAML.“ ICLR.
**Arguments**
* **output_size** (int) - The dimensionality of the network's output.
* **adaptation_steps** (int) - Number of inner-loop adaptation steps.
* **hidden_size** (int, *optional*, default=64) - The dimensionality of the hidden
representation.
* **layers** (int, *optional*, default=4) - The number of convolutional layers.
* **channels** (int, *optional*, default=3) - The number of channels in input.
* **max_pool** (bool, *optional*, default=True) - Whether ConvBlocks use max-pooling.
* **embedding_size** (int, *optional*, default=None) - Size of feature embedding.
Defaults to 25 * hidden_size (for mini-Imagenet).
**Example**
~~~python
model = CNN4(output_size=20, adaptation_steps=5, hidden_size=128, layers=3)
~~~
"""

def __init__(
self,
output_size,
adaptation_steps,
hidden_size=64,
layers=4,
channels=3,
max_pool=True,
embedding_size=None,
):
super(CNN4_MetaBatchNorm, self).__init__()
if embedding_size is None:
embedding_size = 25 * hidden_size
self.features = CNN4Backbone_MetaBatchNorm(
adaptation_steps,
hidden_size=hidden_size,
channels=channels,
max_pool=max_pool,
layers=layers,
max_pool_factor=4 // layers,
)
self.classifier = torch.nn.Linear(
embedding_size,
output_size,
bias=True,
)
maml_init_(self.classifier)
self.hidden_size = hidden_size

def backup_stats(self):
"""
Backup stored batch statistics before running a validation epoch.
"""
for layer in self.features.modules():
if type(layer) is MetaBatchNorm:
layer.backup_stats()

def restore_backup_stats(self):
"""
Reset stored batch statistics from the stored backup.
"""
for layer in self.features.modules():
if type(layer) is MetaBatchNorm:
layer.restore_backup_stats()

def forward(self, x, inference=False):
"""
**Arguments**
* **input** (tensor) - Input data batch, size either can be any.
* **inferencep** (bool, *optional*, default=False) - when set to `True`, uses the final
step's parameters and running statistics. When set to `False`, automatically infers the
current adaptation step.
"""
x = self.features(x, inference=inference)
x = self.classifier(x)
return x