diff --git a/CHANGELOG.md b/CHANGELOG.md
index ca845a18..c7f6407e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,7 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ### Added
 
-* New vision example: MAML++. (@[DubiousCactus](https://github.com/DubiousCactus))
+* New BatchNorm layer with per-step running statistics and weights & biases from MAML++. (@[Théo Morales](https://github.com/DubiousCactus))
+* New vision example: MAML++. (@[Théo Morales](https://github.com/DubiousCactus))
 * Add tutorial: "Demystifying Task Transforms", ([Varad Pimpalkhute](https://github.com/nightlessbaron/))
 
 ### Changed
diff --git a/examples/vision/mamlpp/cnn4_bnrs.py b/examples/vision/mamlpp/cnn4_bnrs.py
deleted file mode 100644
index 28f9666f..00000000
--- a/examples/vision/mamlpp/cnn4_bnrs.py
+++ /dev/null
@@ -1,321 +0,0 @@
-#! /usr/bin/env python3
-# -*- coding: utf-8 -*-
-# vim:fenc=utf-8
-#
-
-"""
-CNN4 extended with Batch-Norm Running Statistics.
-"""
-
-import torch
-import torch.nn.functional as F
-
-from copy import deepcopy
-from learn2learn.vision.models.cnn4 import maml_init_, fc_init_
-
-
-class MetaBatchNormLayer(torch.nn.Module):
-    """
-    An extension of Pytorch's BatchNorm layer, with the Per-Step Batch Normalisation Running
-    Statistics and Per-Step Batch Normalisation Weights and Biases improvements proposed in
-    MAML++ by Antoniou et al. It is adapted from the original Pytorch implementation at
-    https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch,
-    with heavy refactoring and a bug fix
-    (https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/42).
-    """
-
-    def __init__(
-        self,
-        num_features,
-        eps=1e-5,
-        momentum=0.1,
-        affine=True,
-        meta_batch_norm=True,
-        adaptation_steps: int = 1,
-    ):
-        super(MetaBatchNormLayer, self).__init__()
-        self.num_features = num_features
-        self.eps = eps
-        self.affine = affine
-        self.meta_batch_norm = meta_batch_norm
-        self.num_features = num_features
-        self.running_mean = torch.nn.Parameter(
-            torch.zeros(adaptation_steps, num_features), requires_grad=False
-        )
-        self.running_var = torch.nn.Parameter(
-            torch.ones(adaptation_steps, num_features), requires_grad=False
-        )
-        self.bias = torch.nn.Parameter(
-            torch.zeros(adaptation_steps, num_features), requires_grad=True
-        )
-        self.weight = torch.nn.Parameter(
-            torch.ones(adaptation_steps, num_features), requires_grad=True
-        )
-        self.backup_running_mean = torch.zeros(self.running_mean.shape)
-        self.backup_running_var = torch.ones(self.running_var.shape)
-        self.momentum = momentum
-
-    def forward(
-        self,
-        input,
-        step,
-    ):
-        """
-        :param input: input data batch, size either can be any.
-        :param step: The current inner loop step being taken. This is used when to learn per step params and
-         collecting per step batch statistics.
-        :return: The result of the batch norm operation.
-        """
-        assert (
-            step < self.running_mean.shape[0]
-        ), f"Running forward with step={step} when initialised with {self.running_mean.shape[0]} steps!"
-        return F.batch_norm(
-            input,
-            self.running_mean[step],
-            self.running_var[step],
-            self.weight[step],
-            self.bias[step],
-            training=True,
-            momentum=self.momentum,
-            eps=self.eps,
-        )
-
-    def backup_stats(self):
-        self.backup_running_mean.data = deepcopy(self.running_mean.data)
-        self.backup_running_var.data = deepcopy(self.running_var.data)
-
-    def restore_backup_stats(self):
-        """
-        Resets batch statistics to their backup values which are collected after each forward pass.
-        """
-        self.running_mean = torch.nn.Parameter(
-            self.backup_running_mean, requires_grad=False
-        )
-        self.running_var = torch.nn.Parameter(
-            self.backup_running_var, requires_grad=False
-        )
-
-    def extra_repr(self):
-        return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format(
-            **self.__dict__
-        )
-
-
-class LinearBlock_BNRS(torch.nn.Module):
-    def __init__(self, input_size, output_size, adaptation_steps):
-        super(LinearBlock_BNRS, self).__init__()
-        self.relu = torch.nn.ReLU()
-        self.normalize = MetaBatchNormLayer(
-            output_size,
-            affine=True,
-            momentum=0.999,
-            eps=1e-3,
-            adaptation_steps=adaptation_steps,
-        )
-        self.linear = torch.nn.Linear(input_size, output_size)
-        fc_init_(self.linear)
-
-    def forward(self, x, step):
-        x = self.linear(x)
-        x = self.normalize(x, step)
-        x = self.relu(x)
-        return x
-
-
-class ConvBlock_BNRS(torch.nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        kernel_size,
-        max_pool=True,
-        max_pool_factor=1.0,
-        adaptation_steps=1,
-    ):
-        super(ConvBlock_BNRS, self).__init__()
-        stride = (int(2 * max_pool_factor), int(2 * max_pool_factor))
-        if max_pool:
-            self.max_pool = torch.nn.MaxPool2d(
-                kernel_size=stride,
-                stride=stride,
-                ceil_mode=False,
-            )
-            stride = (1, 1)
-        else:
-            self.max_pool = lambda x: x
-        self.normalize = MetaBatchNormLayer(
-            out_channels,
-            affine=True,
-            adaptation_steps=adaptation_steps,
-            # eps=1e-3,
-            # momentum=0.999,
-        )
-        torch.nn.init.uniform_(self.normalize.weight)
-        self.relu = torch.nn.ReLU()
-
-        self.conv = torch.nn.Conv2d(
-            in_channels,
-            out_channels,
-            kernel_size,
-            stride=stride,
-            padding=1,
-            bias=True,
-        )
-        maml_init_(self.conv)
-
-    def forward(self, x, step):
-        x = self.conv(x)
-        x = self.normalize(x, step)
-        x = self.relu(x)
-        x = self.max_pool(x)
-        return x
-
-
-class ConvBase_BNRS(torch.nn.Sequential):
-
-    # NOTE:
-    #     Omniglot: hidden=64, channels=1, no max_pool
-    #     MiniImagenet: hidden=32, channels=3, max_pool
-
-    def __init__(
-        self, hidden=64, channels=1, max_pool=False, layers=4, max_pool_factor=1.0,
-        adaptation_steps=1
-    ):
-        core = [
-            ConvBlock_BNRS(
-                channels,
-                hidden,
-                (3, 3),
-                max_pool=max_pool,
-                max_pool_factor=max_pool_factor,
-                adaptation_steps=adaptation_steps
-            ),
-        ]
-        for _ in range(layers - 1):
-            core.append(
-                ConvBlock_BNRS(
-                    hidden,
-                    hidden,
-                    kernel_size=(3, 3),
-                    max_pool=max_pool,
-                    max_pool_factor=max_pool_factor,
-                    adaptation_steps=adaptation_steps
-                )
-            )
-        super(ConvBase_BNRS, self).__init__(*core)
-
-    def forward(self, x, step):
-        for module in self:
-            x = module(x, step)
-        return x
-
-
-class CNN4Backbone_BNRS(ConvBase_BNRS):
-    def __init__(
-        self,
-        hidden_size=64,
-        layers=4,
-        channels=3,
-        max_pool=True,
-        max_pool_factor=None,
-        adaptation_steps=1,
-    ):
-        if max_pool_factor is None:
-            max_pool_factor = 4 // layers
-        super(CNN4Backbone_BNRS, self).__init__(
-            hidden=hidden_size,
-            layers=layers,
-            channels=channels,
-            max_pool=max_pool,
-            max_pool_factor=max_pool_factor,
-            adaptation_steps=adaptation_steps
-        )
-
-    def forward(self, x, step):
-        x = super(CNN4Backbone_BNRS, self).forward(x, step)
-        x = x.reshape(x.size(0), -1)
-        return x
-
-
-class CNN4_BNRS(torch.nn.Module):
-    """
-
-    [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models/cnn4.py)
-
-    **Description**
-
-    The convolutional network commonly used for MiniImagenet, as described by Ravi et Larochelle, 2017.
-
-    This network assumes inputs of shapes (3, 84, 84).
-
-    Instantiate `CNN4Backbone` if you only need the feature extractor.
-
-    **References**
-
-    1. Ravi and Larochelle. 2017. “Optimization as a Model for Few-Shot Learning.” ICLR.
-
-    **Arguments**
-
-    * **output_size** (int) - The dimensionality of the network's output.
-    * **hidden_size** (int, *optional*, default=64) - The dimensionality of the hidden representation.
-    * **layers** (int, *optional*, default=4) - The number of convolutional layers.
-    * **channels** (int, *optional*, default=3) - The number of channels in input.
-    * **max_pool** (bool, *optional*, default=True) - Whether ConvBlocks use max-pooling.
-    * **embedding_size** (int, *optional*, default=None) - Size of feature embedding.
-        Defaults to 25 * hidden_size (for mini-Imagenet).
-
-    **Example**
-    ~~~python
-    model = CNN4(output_size=20, hidden_size=128, layers=3)
-    ~~~
-    """
-
-    def __init__(
-        self,
-        output_size,
-        hidden_size=64,
-        layers=4,
-        channels=3,
-        max_pool=True,
-        embedding_size=None,
-        adaptation_steps=1,
-    ):
-        super(CNN4_BNRS, self).__init__()
-        if embedding_size is None:
-            embedding_size = 25 * hidden_size
-        self.features = CNN4Backbone_BNRS(
-            hidden_size=hidden_size,
-            channels=channels,
-            max_pool=max_pool,
-            layers=layers,
-            max_pool_factor=4 // layers,
-            adaptation_steps=adaptation_steps,
-        )
-        self.classifier = torch.nn.Linear(
-            embedding_size,
-            output_size,
-            bias=True,
-        )
-        maml_init_(self.classifier)
-        self.hidden_size = hidden_size
-
-    def backup_stats(self):
-        """
-        Backup stored batch statistics before running a validation epoch.
-        """
-        for layer in self.features.modules():
-            if type(layer) is MetaBatchNormLayer:
-                layer.backup_stats()
-
-    def restore_backup_stats(self):
-        """
-        Reset stored batch statistics from the stored backup.
-        """
-        for layer in self.features.modules():
-            if type(layer) is MetaBatchNormLayer:
-                layer.restore_backup_stats()
-
-    def forward(self, x, step):
-        x = self.features(x, step)
-        x = self.classifier(x)
-        return x
diff --git a/examples/vision/mamlpp/maml++_miniimagenet.py b/examples/vision/mamlpp/maml++_miniimagenet.py
index 78085bf9..8e7c9678 100755
--- a/examples/vision/mamlpp/maml++_miniimagenet.py
+++ b/examples/vision/mamlpp/maml++_miniimagenet.py
@@ -20,7 +20,7 @@
 from typing import Tuple
 from tqdm import tqdm
 
-from examples.vision.mamlpp.cnn4_bnrs import CNN4_BNRS
+from learn2learn.vision.models.cnn4_metabatchnorm import CNN4_MetaBatchNorm
 from examples.vision.mamlpp.MAMLpp import MAMLpp
 
 
@@ -72,7 +72,7 @@ def __init__(
         )
 
         # Model
-        self._model = CNN4_BNRS(ways, adaptation_steps=steps)
+        self._model = CNN4_MetaBatchNorm(ways, steps)
         if self._use_cuda:
             self._model.cuda()
 
@@ -147,19 +147,19 @@ def _training_step(
         # Adapt the model on the support set
         for step in range(self._steps):
             # forward + backward + optimize
-            pred = learner(s_inputs, step)
+            pred = learner(s_inputs)
             support_loss = self._inner_criterion(pred, s_labels)
             learner.adapt(support_loss, first_order=not second_order, step=step)
             # Multi-Step Loss
             if msl:
-                q_pred = learner(q_inputs, step)
+                q_pred = learner(q_inputs)
                 query_loss += self._step_weights[step] * self._inner_criterion(
                     q_pred, q_labels
                 )
 
         # Evaluate the adapted model on the query set
         if not msl:
-            q_pred = learner(q_inputs, self._steps-1)
+            q_pred = learner(q_inputs, inference=True)
             query_loss = self._inner_criterion(q_pred, q_labels)
         acc = accuracy(q_pred, q_labels).detach()
 
@@ -180,12 +180,12 @@ def _testing_step(
         # Adapt the model on the support set
         for step in range(self._steps):
             # forward + backward + optimize
-            pred = learner(s_inputs, step)
+            pred = learner(s_inputs)
             support_loss = self._inner_criterion(pred, s_labels)
             learner.adapt(support_loss, step=step)
 
         # Evaluate the adapted model on the query set
-        q_pred = learner(q_inputs, self._steps-1)
+        q_pred = learner(q_inputs, inference=True)
         query_loss = self._inner_criterion(q_pred, q_labels).detach()
         acc = accuracy(q_pred, q_labels)
 
diff --git a/learn2learn/nn/__init__.py b/learn2learn/nn/__init__.py
index 54ea3de3..cf8638b5 100644
--- a/learn2learn/nn/__init__.py
+++ b/learn2learn/nn/__init__.py
@@ -8,3 +8,4 @@
 from .misc import *
 from .protonet import PrototypicalClassifier
 from .metaoptnet import SVClassifier
+from .metabatchnorm import MetaBatchNorm
diff --git a/learn2learn/nn/metabatchnorm.py b/learn2learn/nn/metabatchnorm.py
new file mode 100644
index 00000000..4f4967c5
--- /dev/null
+++ b/learn2learn/nn/metabatchnorm.py
@@ -0,0 +1,136 @@
+#! /usr/bin/env python3
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+#
+
+"""
+BatchNorm layer augmented with Per-Step Batch Normalisation Running Statistics and Per-Step Batch
+Normalisation Weights and Biases, as proposed in MAML++ by Antobiou et al.
+"""
+
+import torch
+import torch.nn.functional as F
+
+from copy import deepcopy
+
+
+class MetaBatchNorm(torch.nn.Module):
+    """
+    [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/nn/metabatchnorm.py)
+
+    **Description**
+
+    An extension of Pytorch's BatchNorm layer, with the Per-Step Batch Normalisation Running
+    Statistics and Per-Step Batch Normalisation Weights and Biases improvements proposed in
+    "How to train your MAML".
+    It is adapted from the original Pytorch implementation at
+    https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch,
+    with heavy refactoring and a bug fix
+    (https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/42).
+
+    **Arguments**
+
+    * **num_features** (int) - number of input features.
+    * **adaptation_steps** (int) - number of inner-loop adaptation steps.
+    * **eps** (float, *optional*, default=1e-5) - a value added to the denominator for numerical
+    stability.
+    * **momentum** (float, *optional*, default=0.1) - the value used for the running_mean and
+    running_var computation. Can be set to None for cumulative moving average (i.e. simple
+    average).
+    * **affine** (bool, *optional*, default=True) - a boolean value that when set to True, this
+    module has learnable affine parameters.
+
+    **References**
+
+    1. Antoniou et al. 2019. "How to train your MAML." ICLR.
+
+    **Example**
+
+    ~~~python
+    batch_norm = MetaBatchNorm(100, 5)
+    input = torch.randn(20, 100, 35, 45)
+    for step in range(5):
+        output = batch_norm(input, step)
+    ~~~
+    """
+
+    def __init__(
+        self,
+        num_features,
+        adaptation_steps,
+        eps=1e-5,
+        momentum=0.1,
+        affine=True,
+    ):
+        super(MetaBatchNorm, self).__init__()
+        self.num_features = num_features
+        self.eps = eps
+        self.affine = affine
+        self.num_features = num_features
+        self.running_mean = torch.nn.Parameter(
+            torch.zeros(adaptation_steps, num_features), requires_grad=False
+        )
+        self.running_var = torch.nn.Parameter(
+            torch.ones(adaptation_steps, num_features), requires_grad=False
+        )
+        self.bias = torch.nn.Parameter(
+            torch.zeros(adaptation_steps, num_features), requires_grad=True
+        )
+        self.weight = torch.nn.Parameter(
+            torch.ones(adaptation_steps, num_features), requires_grad=True
+        )
+        self.backup_running_mean = torch.zeros(self.running_mean.shape)
+        self.backup_running_var = torch.ones(self.running_var.shape)
+        self.momentum = momentum
+        self._steps = adaptation_steps
+        self._current_step = 0
+
+    def forward(
+        self,
+        input,
+        inference=False,
+    ):
+        """
+        **Arguments**
+
+        * **input** (tensor) - Input data batch, size either can be any.
+        * **inferencep** (bool, *optional*, default=False) - when set to `True`, uses the final
+        step's parameters and running statistics. When set to `False`, automatically infers the
+        current adaptation step.
+        """
+        step = self._current_step if not inference else self._steps - 1
+        output = F.batch_norm(
+            input,
+            self.running_mean[step],
+            self.running_var[step],
+            self.weight[step],
+            self.bias[step],
+            training=True,
+            momentum=self.momentum,
+            eps=self.eps,
+        )
+        if not inference:
+            self._current_step = (
+                self._current_step + 1 if self._current_step < (self._steps - 1) else 0
+            )
+        return output
+
+    def backup_stats(self):
+        self.backup_running_mean.data = deepcopy(self.running_mean.data)
+        self.backup_running_var.data = deepcopy(self.running_var.data)
+
+    def restore_backup_stats(self):
+        """
+        Resets batch statistics to their backup values which are collected after each forward pass.
+        """
+        self.running_mean = torch.nn.Parameter(
+            self.backup_running_mean, requires_grad=False
+        )
+        self.running_var = torch.nn.Parameter(
+            self.backup_running_var, requires_grad=False
+        )
+
+    def extra_repr(self):
+        return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format(
+            **self.__dict__
+        )
diff --git a/learn2learn/vision/models/__init__.py b/learn2learn/vision/models/__init__.py
index 54cdeec4..7c3edb35 100644
--- a/learn2learn/vision/models/__init__.py
+++ b/learn2learn/vision/models/__init__.py
@@ -31,6 +31,14 @@ def forward(self, x):
     CNN4Backbone,
 )
 
+from .cnn4_metabatchnorm import (
+    LinearBlock_MetaBatchNorm,
+    ConvBlock_MetaBatchNorm,
+    ConvBase_MetaBatchNorm,
+    CNN4Backbone_MetaBatchNorm,
+    CNN4_MetaBatchNorm,
+)
+
 from .resnet12 import ResNet12, ResNet12Backbone
 from .wrn28 import WRN28, WRN28Backbone
 
@@ -49,6 +57,11 @@ def forward(self, x):
     'ResNet12Backbone',
     'WRN28',
     'WRN28Backbone',
+    'LinearBlock_MetaBatchNorm',
+    'ConvBlock_MetaBatchNorm',
+    'ConvBase_MetaBatchNorm',
+    'CNN4Backbone_MetaBatchNorm',
+    'CNN4_MetaBatchNorm',
 ]
 
 _BACKBONE_URLS = {
diff --git a/learn2learn/vision/models/cnn4_metabatchnorm.py b/learn2learn/vision/models/cnn4_metabatchnorm.py
new file mode 100644
index 00000000..1ea6f923
--- /dev/null
+++ b/learn2learn/vision/models/cnn4_metabatchnorm.py
@@ -0,0 +1,251 @@
+#! /usr/bin/env python3
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+#
+
+"""
+CNN4 using a MetaBatchNorm layer allowing to accumulate per-step running statistics and use
+per-step bias and variance parameters.
+"""
+
+import torch
+
+from learn2learn.nn.metabatchnorm import MetaBatchNorm
+from learn2learn.vision.models.cnn4 import maml_init_, fc_init_
+
+
+class LinearBlock_MetaBatchNorm(torch.nn.Module):
+    def __init__(self, input_size, output_size, adaptation_steps):
+        super(LinearBlock_MetaBatchNorm, self).__init__()
+        self.relu = torch.nn.ReLU()
+        self.normalize = MetaBatchNorm(
+            output_size,
+            adaptation_steps,
+            affine=True,
+            momentum=0.999,
+            eps=1e-3,
+        )
+        self.linear = torch.nn.Linear(input_size, output_size)
+        fc_init_(self.linear)
+
+    def forward(self, x, inference=False):
+        x = self.linear(x)
+        x = self.normalize(x, inference=inference)
+        x = self.relu(x)
+        return x
+
+
+class ConvBlock_MetaBatchNorm(torch.nn.Module):
+    def __init__(
+        self,
+        adaptation_steps,
+        in_channels,
+        out_channels,
+        kernel_size,
+        max_pool=True,
+        max_pool_factor=1.0,
+    ):
+        super(ConvBlock_MetaBatchNorm, self).__init__()
+        stride = (int(2 * max_pool_factor), int(2 * max_pool_factor))
+        if max_pool:
+            self.max_pool = torch.nn.MaxPool2d(
+                kernel_size=stride,
+                stride=stride,
+                ceil_mode=False,
+            )
+            stride = (1, 1)
+        else:
+            self.max_pool = lambda x: x
+        self.normalize = MetaBatchNorm(
+            out_channels,
+            adaptation_steps,
+            affine=True,
+            # eps=1e-3,
+            # momentum=0.999,
+        )
+        torch.nn.init.uniform_(self.normalize.weight)
+        self.relu = torch.nn.ReLU()
+
+        self.conv = torch.nn.Conv2d(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=1,
+            bias=True,
+        )
+        maml_init_(self.conv)
+
+    def forward(self, x, inference=False):
+        x = self.conv(x)
+        x = self.normalize(x, inference=inference)
+        x = self.relu(x)
+        x = self.max_pool(x)
+        return x
+
+
+class ConvBase_MetaBatchNorm(torch.nn.Sequential):
+
+    # NOTE:
+    #     Omniglot: hidden=64, channels=1, no max_pool
+    #     MiniImagenet: hidden=32, channels=3, max_pool
+
+    def __init__(
+        self,
+        adaptation_steps,
+        hidden=64,
+        channels=1,
+        max_pool=False,
+        layers=4,
+        max_pool_factor=1.0,
+    ):
+        core = [
+            ConvBlock_MetaBatchNorm(
+                adaptation_steps,
+                channels,
+                hidden,
+                (3, 3),
+                max_pool=max_pool,
+                max_pool_factor=max_pool_factor,
+            ),
+        ]
+        for _ in range(layers - 1):
+            core.append(
+                ConvBlock_MetaBatchNorm(
+                    adaptation_steps,
+                    hidden,
+                    hidden,
+                    kernel_size=(3, 3),
+                    max_pool=max_pool,
+                    max_pool_factor=max_pool_factor,
+                )
+            )
+        super(ConvBase_MetaBatchNorm, self).__init__(*core)
+
+    def forward(self, x, inference=False):
+        for module in self:
+            x = module(x, inference=inference)
+        return x
+
+
+class CNN4Backbone_MetaBatchNorm(ConvBase_MetaBatchNorm):
+    def __init__(
+        self,
+        adaptation_steps,
+        hidden_size=64,
+        layers=4,
+        channels=3,
+        max_pool=True,
+        max_pool_factor=None,
+    ):
+        if max_pool_factor is None:
+            max_pool_factor = 4 // layers
+        super(CNN4Backbone_MetaBatchNorm, self).__init__(
+            adaptation_steps,
+            hidden=hidden_size,
+            layers=layers,
+            channels=channels,
+            max_pool=max_pool,
+            max_pool_factor=max_pool_factor,
+        )
+
+    def forward(self, x, inference=False):
+        x = super(CNN4Backbone_MetaBatchNorm, self).forward(x, inference=inference)
+        x = x.reshape(x.size(0), -1)
+        return x
+
+
+class CNN4_MetaBatchNorm(torch.nn.Module):
+    """
+
+    [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models/cnn4.py)
+
+    **Description**
+
+    The convolutional network commonly used for MiniImagenet, as described by Ravi et Larochelle,
+    2017, using the MetaBatchNorm layer proposed by Antoniou et al. 2019.
+
+    This network assumes inputs of shapes (3, 84, 84).
+
+    Instantiate `CNN4Backbone_MetaBatchNorm` if you only need the feature extractor.
+
+    **References**
+
+    1. Ravi and Larochelle. 2017. “Optimization as a Model for Few-Shot Learning.” ICLR.
+    2. Antoniou et al. 2019. “How to train your MAML.“ ICLR.
+
+    **Arguments**
+
+    * **output_size** (int) - The dimensionality of the network's output.
+    * **adaptation_steps** (int) - Number of inner-loop adaptation steps.
+    * **hidden_size** (int, *optional*, default=64) - The dimensionality of the hidden
+    representation.
+    * **layers** (int, *optional*, default=4) - The number of convolutional layers.
+    * **channels** (int, *optional*, default=3) - The number of channels in input.
+    * **max_pool** (bool, *optional*, default=True) - Whether ConvBlocks use max-pooling.
+    * **embedding_size** (int, *optional*, default=None) - Size of feature embedding.
+        Defaults to 25 * hidden_size (for mini-Imagenet).
+
+    **Example**
+    ~~~python
+    model = CNN4(output_size=20, adaptation_steps=5, hidden_size=128, layers=3)
+    ~~~
+    """
+
+    def __init__(
+        self,
+        output_size,
+        adaptation_steps,
+        hidden_size=64,
+        layers=4,
+        channels=3,
+        max_pool=True,
+        embedding_size=None,
+    ):
+        super(CNN4_MetaBatchNorm, self).__init__()
+        if embedding_size is None:
+            embedding_size = 25 * hidden_size
+        self.features = CNN4Backbone_MetaBatchNorm(
+            adaptation_steps,
+            hidden_size=hidden_size,
+            channels=channels,
+            max_pool=max_pool,
+            layers=layers,
+            max_pool_factor=4 // layers,
+        )
+        self.classifier = torch.nn.Linear(
+            embedding_size,
+            output_size,
+            bias=True,
+        )
+        maml_init_(self.classifier)
+        self.hidden_size = hidden_size
+
+    def backup_stats(self):
+        """
+        Backup stored batch statistics before running a validation epoch.
+        """
+        for layer in self.features.modules():
+            if type(layer) is MetaBatchNorm:
+                layer.backup_stats()
+
+    def restore_backup_stats(self):
+        """
+        Reset stored batch statistics from the stored backup.
+        """
+        for layer in self.features.modules():
+            if type(layer) is MetaBatchNorm:
+                layer.restore_backup_stats()
+
+    def forward(self, x, inference=False):
+        """
+        **Arguments**
+
+        * **input** (tensor) - Input data batch, size either can be any.
+        * **inferencep** (bool, *optional*, default=False) - when set to `True`, uses the final
+        step's parameters and running statistics. When set to `False`, automatically infers the
+        current adaptation step.
+        """
+        x = self.features(x, inference=inference)
+        x = self.classifier(x)
+        return x