diff --git a/CHANGELOG.md b/CHANGELOG.md
index ca845a18..2bae22a4 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,7 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ### Added
 
-* New vision example: MAML++. (@[DubiousCactus](https://github.com/DubiousCactus))
+* New GBML transforms: Per-Step & Per-Layer Per-Step learning rates from MAML++. ([Theo Morales](https://github.com/DubiousCactus))
+* New vision example: MAML++. ([Theo Morales](https://github.com/DubiousCactus))
 * Add tutorial: "Demystifying Task Transforms", ([Varad Pimpalkhute](https://github.com/nightlessbaron/))
 
 ### Changed
diff --git a/examples/vision/mamlpp/MAMLpp.py b/examples/vision/mamlpp/MAMLpp.py
deleted file mode 100644
index 533193c9..00000000
--- a/examples/vision/mamlpp/MAMLpp.py
+++ /dev/null
@@ -1,305 +0,0 @@
-#! /usr/bin/env python3
-# -*- coding: utf-8 -*-
-# vim:fenc=utf-8
-#
-
-"""
-MAML++ wrapper.
-"""
-
-import torch
-import traceback
-
-from torch.autograd import grad
-
-from learn2learn.algorithms.base_learner import BaseLearner
-from learn2learn.utils import clone_module, update_module, clone_named_parameters
-
-
-def maml_pp_update(model, step=None, lrs=None, grads=None):
-    """
-
-    **Description**
-
-    Performs a MAML++ update on model using grads and lrs.
-    The function re-routes the Python object, thus avoiding in-place
-    operations.
-
-    NOTE: The model itself is updated in-place (no deepcopy), but the
-          parameters' tensors are not.
-
-    **Arguments**
-
-    * **model** (Module) - The model to update.
-    * **lrs** (list) - The meta-learned learning rates used to update the model.
-    * **grads** (list, *optional*, default=None) - A list of gradients for each layer
-        of the model. If None, will use the gradients in .grad attributes.
-
-    **Example**
-    ~~~python
-    maml_pp = l2l.algorithms.MAMLpp(Model(), lr=1.0)
-    lslr = torch.nn.ParameterDict()
-    for layer_name, layer in model.named_modules():
-        # If the layer has learnable parameters
-        if (
-            len(
-                [
-                    name
-                    for name, param in layer.named_parameters(recurse=False)
-                    if param.requires_grad
-                ]
-            )
-            > 0
-        ):
-            lslr[layer_name.replace(".", "-")] = torch.nn.Parameter(
-                data=torch.ones(adaptation_steps) * init_lr,
-                requires_grad=True,
-            )
-    model = maml_pp.clone() # The next two lines essentially implement model.adapt(loss)
-    for inner_step in range(5):
-        loss = criterion(model(x), y)
-        grads = autograd.grad(loss, model.parameters(), create_graph=True)
-        maml_pp_update(model, inner_step, lrs=lslr, grads=grads)
-    ~~~
-    """
-    if grads is not None and lrs is not None:
-        params = list(model.parameters())
-        if not len(grads) == len(list(params)):
-            msg = "WARNING:maml_update(): Parameters and gradients have different length. ("
-            msg += str(len(params)) + " vs " + str(len(grads)) + ")"
-            print(msg)
-        # TODO: Why doesn't this work?? I can't assign p.grad when zipping like this... Is this
-        # because I'm using a tuple?
-        # for named_param, g in zip(
-            # [(k, v) for k, l in model.named_parameters() for v in l], grads
-        # ):
-            # p_name, p = named_param
-        it = 0
-        for name, p in model.named_parameters():
-            if grads[it] is not None:
-                lr = None
-                layer_name = name[: name.rfind(".")].replace(
-                    ".", "-"
-                )  # Extract the layer name from the named parameter
-                lr = lrs[layer_name][step]
-                assert (
-                    lr is not None
-                ), f"Parameter {name} does not have a learning rate in LSLR dict!"
-                p.grad = grads[it]
-                p._lr = lr
-            it += 1
-
-    # Update the params
-    for param_key in model._parameters:
-        p = model._parameters[param_key]
-        if p is not None and p.grad is not None:
-            model._parameters[param_key] = p - p._lr * p.grad
-            p.grad = None
-            p._lr = None
-
-    # Second, handle the buffers if necessary
-    for buffer_key in model._buffers:
-        buff = model._buffers[buffer_key]
-        if buff is not None and buff.grad is not None and buff._lr is not None:
-            model._buffers[buffer_key] = buff - buff._lr * buff.grad
-            buff.grad = None
-            buff._lr = None
-
-    # Then, recurse for each submodule
-    for module_key in model._modules:
-        model._modules[module_key] = maml_pp_update(model._modules[module_key])
-    return model
-
-
-class MAMLpp(BaseLearner):
-    """
-    [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/maml.py)
-
-    **Description**
-
-    High-level implementation of *Model-Agnostic Meta-Learning*.
-
-    This class wraps an arbitrary nn.Module and augments it with `clone()` and `adapt()`
-    methods.
-
-    For the first-order version of MAML (i.e. FOMAML), set the `first_order` flag to `True`
-    upon initialization.
-
-    **Arguments**
-
-    * **model** (Module) - Module to be wrapped.
-    * **lr** (float) - Fast adaptation learning rate.
-    * **lslr** (bool) - Whether to use Per-Layer Per-Step Learning Rates and Gradient Directions
-        (LSLR) or not.
-    * **lrs** (list of Parameters, *optional*, default=None) - If not None, overrides `lr`, and uses the list
-        as learning rates for fast-adaptation.
-    * **first_order** (bool, *optional*, default=False) - Whether to use the first-order
-        approximation of MAML. (FOMAML)
-    * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation
-        of unused parameters. Defaults to `allow_nograd`.
-    * **allow_nograd** (bool, *optional*, default=False) - Whether to allow adaptation with
-        parameters that have `requires_grad = False`.
-
-    **References**
-
-    1. Finn et al. 2017. "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks."
-
-    **Example**
-
-    ~~~python
-    linear = l2l.algorithms.MAML(nn.Linear(20, 10), lr=0.01)
-    clone = linear.clone()
-    error = loss(clone(X), y)
-    clone.adapt(error)
-    error = loss(clone(X), y)
-    error.backward()
-    ~~~
-    """
-
-    def __init__(
-        self,
-        model,
-        lr,
-        lrs=None,
-        adaptation_steps=1,
-        first_order=False,
-        allow_unused=None,
-        allow_nograd=False,
-    ):
-        super().__init__()
-        self.module = model
-        self.lr = lr
-        if lrs is None:
-            lrs = self._init_lslr_parameters(model, adaptation_steps, lr)
-        self.lrs = lrs
-        self.first_order = first_order
-        self.allow_nograd = allow_nograd
-        if allow_unused is None:
-            allow_unused = allow_nograd
-        self.allow_unused = allow_unused
-
-    def _init_lslr_parameters(
-        self, model: torch.nn.Module, adaptation_steps: int, init_lr: float
-    ) -> torch.nn.ParameterDict:
-        lslr = torch.nn.ParameterDict()
-        for layer_name, layer in model.named_modules():
-            # If the layer has learnable parameters
-            if (
-                len(
-                    [
-                        name
-                        for name, param in layer.named_parameters(recurse=False)
-                        if param.requires_grad
-                    ]
-                )
-                > 0
-            ):
-                lslr[layer_name.replace(".", "-")] = torch.nn.Parameter(
-                    data=torch.ones(adaptation_steps) * init_lr,
-                    requires_grad=True,
-                )
-        return lslr
-
-    def forward(self, *args, **kwargs):
-        return self.module(*args, **kwargs)
-
-    def adapt(self, loss, step=None, first_order=None, allow_unused=None, allow_nograd=None):
-        """
-        **Description**
-
-        Takes a gradient step on the loss and updates the cloned parameters in place.
-
-        **Arguments**
-
-        * **loss** (Tensor) - Loss to minimize upon update.
-        * **step** (int) - Current inner loop step. Used to fetch the corresponding learning rate.
-        * **first_order** (bool, *optional*, default=None) - Whether to use first- or
-            second-order updates. Defaults to self.first_order.
-        * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation
-            of unused parameters. Defaults to self.allow_unused.
-        * **allow_nograd** (bool, *optional*, default=None) - Whether to allow adaptation with
-            parameters that have `requires_grad = False`. Defaults to self.allow_nograd.
-        """
-        if first_order is None:
-            first_order = self.first_order
-        if allow_unused is None:
-            allow_unused = self.allow_unused
-        if allow_nograd is None:
-            allow_nograd = self.allow_nograd
-        second_order = not first_order
-
-        gradients = []
-        if allow_nograd:
-            # Compute relevant gradients
-            diff_params = [p for p in self.module.parameters() if p.requires_grad]
-            grad_params = grad(
-                loss,
-                diff_params,
-                retain_graph=second_order,
-                create_graph=second_order,
-                allow_unused=allow_unused,
-            )
-            grad_counter = 0
-
-            # Handles gradients for non-differentiable parameters
-            for param in self.module.parameters():
-                if param.requires_grad:
-                    gradient = grad_params[grad_counter]
-                    grad_counter += 1
-                else:
-                    gradient = None
-                gradients.append(gradient)
-        else:
-            try:
-                gradients = grad(
-                    loss,
-                    self.module.parameters(),
-                    retain_graph=second_order,
-                    create_graph=second_order,
-                    allow_unused=allow_unused,
-                )
-            except RuntimeError:
-                traceback.print_exc()
-                print(
-                    "learn2learn: Maybe try with allow_nograd=True and/or allow_unused=True ?"
-                )
-
-        # Update the module
-        assert step is not None, "step cannot be None when using LSLR!"
-        self.module = maml_pp_update(self.module, step, lrs=self.lrs, grads=gradients)
-
-    def clone(self, first_order=None, allow_unused=None, allow_nograd=None):
-        """
-        **Description**
-
-        Returns a `MAMLpp`-wrapped copy of the module whose parameters and buffers
-        are `torch.clone`d from the original module.
-
-        This implies that back-propagating losses on the cloned module will
-        populate the buffers of the original module.
-        For more information, refer to learn2learn.clone_module().
-
-        **Arguments**
-
-        * **first_order** (bool, *optional*, default=None) - Whether the clone uses first-
-            or second-order updates. Defaults to self.first_order.
-        * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation
-        of unused parameters. Defaults to self.allow_unused.
-        * **allow_nograd** (bool, *optional*, default=False) - Whether to allow adaptation with
-            parameters that have `requires_grad = False`. Defaults to self.allow_nograd.
-
-        """
-        if first_order is None:
-            first_order = self.first_order
-        if allow_unused is None:
-            allow_unused = self.allow_unused
-        if allow_nograd is None:
-            allow_nograd = self.allow_nograd
-        return MAMLpp(
-            clone_module(self.module),
-            lr=self.lr,
-            lrs=clone_named_parameters(self.lrs),
-            first_order=first_order,
-            allow_unused=allow_unused,
-            allow_nograd=allow_nograd,
-        )
diff --git a/examples/vision/mamlpp/maml++_miniimagenet.py b/examples/vision/mamlpp/maml++_miniimagenet.py
index 78085bf9..a84f2445 100755
--- a/examples/vision/mamlpp/maml++_miniimagenet.py
+++ b/examples/vision/mamlpp/maml++_miniimagenet.py
@@ -19,9 +19,9 @@
 from collections import namedtuple
 from typing import Tuple
 from tqdm import tqdm
+from learn2learn.optim.transforms.layer_step_lr_transform import PerLayerPerStepLRTransform
 
-from examples.vision.mamlpp.cnn4_bnrs import CNN4_BNRS
-from examples.vision.mamlpp.MAMLpp import MAMLpp
+from learn2learn.vision.models.cnn4 import CNN4
 
 
 MetaBatch = namedtuple("MetaBatch", "support query")
@@ -38,9 +38,9 @@ def accuracy(predictions, targets):
 class MAMLppTrainer:
     def __init__(
         self,
-        ways=5,
-        k_shots=10,
-        n_queries=30,
+        ways=20,
+        k_shots=5,
+        n_queries=5,
         steps=5,
         msl_epochs=25,
         DA_epochs=50,
@@ -52,6 +52,7 @@ def __init__(
         if self._use_cuda and torch.cuda.device_count():
             torch.cuda.manual_seed(seed)
             self._device = torch.device("cuda")
+        print(f"[*] Using device: {self._device}")
         random.seed(seed)
         np.random.seed(seed)
         torch.manual_seed(seed)
@@ -64,15 +65,16 @@ def __init__(
             self._test_tasks,
         ) = l2l.vision.benchmarks.get_tasksets(
             "mini-imagenet",
-            train_samples=k_shots,
+            train_samples=k_shots+n_queries,
             train_ways=ways,
-            test_samples=n_queries,
+            test_samples=k_shots+n_queries,
             test_ways=ways,
             root="~/data",
         )
+        print("[*] Done.")
 
         # Model
-        self._model = CNN4_BNRS(ways, adaptation_steps=steps)
+        self._model = CNN4(ways) # TODO: Change config for miniImageNet (32 filters ?)
         if self._use_cuda:
             self._model.cuda()
 
@@ -109,9 +111,9 @@ def _split_batch(self, batch: tuple) -> MetaBatch:
         Separate data batch into adaptation/evalutation sets.
         """
         images, labels = batch
-        batch_size = self._k_shots + self._n_queries
-        assert batch_size <= images.shape[0], "K+N are greater than the batch size!"
-        indices = torch.randperm(batch_size)
+        task_size = self._k_shots + self._n_queries
+        assert task_size <= images.shape[0], "K+N are smaller than the batch size!"
+        indices = torch.randperm(task_size)
         support_indices = indices[: self._k_shots]
         query_indices = indices[self._k_shots :]
         return MetaBatch(
@@ -125,7 +127,7 @@ def _split_batch(self, batch: tuple) -> MetaBatch:
     def _training_step(
         self,
         batch: MetaBatch,
-        learner: MAMLpp,
+        learner: torch.nn.Module,
         msl: bool = True,
         epoch: int = 0,
     ) -> Tuple[torch.Tensor, float]:
@@ -147,26 +149,26 @@ def _training_step(
         # Adapt the model on the support set
         for step in range(self._steps):
             # forward + backward + optimize
-            pred = learner(s_inputs, step)
+            pred = learner(s_inputs)
             support_loss = self._inner_criterion(pred, s_labels)
-            learner.adapt(support_loss, first_order=not second_order, step=step)
+            learner.adapt(support_loss, first_order=not second_order)
             # Multi-Step Loss
             if msl:
-                q_pred = learner(q_inputs, step)
+                q_pred = learner(q_inputs)
                 query_loss += self._step_weights[step] * self._inner_criterion(
                     q_pred, q_labels
                 )
 
         # Evaluate the adapted model on the query set
         if not msl:
-            q_pred = learner(q_inputs, self._steps-1)
+            q_pred = learner(q_inputs, inference=True)
             query_loss = self._inner_criterion(q_pred, q_labels)
         acc = accuracy(q_pred, q_labels).detach()
 
         return query_loss, acc
 
     def _testing_step(
-        self, batch: MetaBatch, learner: MAMLpp
+        self, batch: MetaBatch, learner: torch.nn.Module
     ) -> Tuple[torch.Tensor, float]:
         s_inputs, s_labels = batch.support
         q_inputs, q_labels = batch.query
@@ -180,12 +182,12 @@ def _testing_step(
         # Adapt the model on the support set
         for step in range(self._steps):
             # forward + backward + optimize
-            pred = learner(s_inputs, step)
+            pred = learner(s_inputs)
             support_loss = self._inner_criterion(pred, s_labels)
-            learner.adapt(support_loss, step=step)
+            learner.adapt(support_loss)
 
         # Evaluate the adapted model on the query set
-        q_pred = learner(q_inputs, self._steps-1)
+        q_pred = learner(q_inputs, inference=True)
         query_loss = self._inner_criterion(q_pred, q_labels).detach()
         acc = accuracy(q_pred, q_labels)
 
@@ -195,23 +197,29 @@ def train(
         self,
         meta_lr=0.001,
         fast_lr=0.01,
-        meta_bsz=5,
-        epochs=100,
+        meta_bsz=16,
+        epochs=1,
         val_interval=1,
     ):
         print("[*] Training...")
-        maml = MAMLpp(
+        transform = PerLayerPerStepLRTransform(fast_lr, self._steps, self._model, ["conv"])
+        # Setting adapt_transform=True means that the transform will be updated in
+        # the *adapt* function, which is not what we want. We want it to compute gradients during
+        # eval_loss.backward() only, so that it's updated in opt.step().
+        mamlpp = l2l.algorithms.GBML(
             self._model,
-            lr=fast_lr, # Initialisation LR for all layers and steps
-            adaptation_steps=self._steps, # For LSLR
-            first_order=False,
-            allow_nograd=True, # For the parameters of the MetaBatchNorm layers
+            transform,
+            lr=1.0,
+            allow_nograd=True,
+            adapt_transform=False,
+            pass_param_names=True,
         )
-        opt = torch.optim.AdamW(maml.parameters(), meta_lr, betas=(0, 0.999))
+        opt = torch.optim.AdamW(mamlpp.parameters(), meta_lr, betas=(0.9, 0.99))
 
         iter_per_epoch = (
             train_samples // (meta_bsz * (self._k_shots + self._n_queries))
         ) + 1
+        print(f"[*] Training with {iter_per_epoch} iterations/epoch with {train_samples} total training samples")
         scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
             opt,
             T_max=epochs * iter_per_epoch,
@@ -228,7 +236,7 @@ def train(
                     meta_batch = self._split_batch(self._train_tasks.sample())
                     meta_loss, meta_acc = self._training_step(
                         meta_batch,
-                        maml.clone(),
+                        mamlpp.clone(),
                         msl=(epoch < self._msl_epochs),
                         epoch=epoch,
                     )
@@ -241,7 +249,7 @@ def train(
 
                 # Average the accumulated gradients and optimize
                 with torch.no_grad():
-                    for p in maml.parameters():
+                    for p in mamlpp.parameters():
                         # Remember the MetaBatchNorm layer has parameters that don't require grad!
                         if p.requires_grad:
                             p.grad.data.mul_(1.0 / meta_bsz)
@@ -259,49 +267,50 @@ def train(
 
             # ======= Validation ========
             if (epoch + 1) % val_interval == 0:
-                # Backup the BatchNorm layers' running statistics
-                maml.backup_stats()
-
                 # Compute the meta-validation loss
                 # TODO: Go through the entire validation set, which shouldn't be shuffled, and
                 # which tasks should not be continuously resampled from!
+                # This may be done in the get_tasksets() method actually...
                 meta_val_losses, meta_val_accs = [], []
                 for _ in tqdm(range(val_samples // tasks)):
                     meta_batch = self._split_batch(self._valid_tasks.sample())
-                    loss, acc = self._testing_step(meta_batch, maml.clone())
+                    loss, acc = self._testing_step(meta_batch, mamlpp.clone())
                     meta_val_losses.append(loss)
                     meta_val_accs.append(acc)
                 meta_val_loss = float(torch.Tensor(meta_val_losses).mean().item())
                 meta_val_acc = float(torch.Tensor(meta_val_accs).mean().item())
                 print(f"Meta-validation Loss: {meta_val_loss:.6f}")
                 print(f"Meta-validation Accuracy: {meta_val_acc:.6f}")
-                # Restore the BatchNorm layers' running statistics
-                maml.restore_backup_stats()
             print("============================================")
 
-        return self._model.state_dict()
+        return self._model.state_dict(), transform.state_dict()
 
     def test(
         self,
         model_state_dict,
+        trasnform_state_dict,
         meta_lr=0.001,
         fast_lr=0.01,
         meta_bsz=5,
     ):
         self._model.load_state_dict(model_state_dict)
-        maml = MAMLpp(
+        transform = PerLayerPerStepLRTransform(fast_lr, self._steps, self._model, ["conv"])
+        transform.load_state_dict(trasnform_state_dict)
+        # Setting adapt_transform=True means that the transform will be updated in
+        # the *adapt* function, which is not what we want. We want it to compute gradients during
+        # eval_loss.backward() only, so that it's updated in opt.step().
+        mamlpp = l2l.algorithms.GBML(
             self._model,
-            lr=fast_lr,
-            adaptation_steps=self._steps,
-            first_order=False,
-            allow_nograd=True,
+            transform,
+            lr=1.0,
+            adapt_transform=False,
+            pass_param_names=True,
         )
-        opt = torch.optim.AdamW(maml.parameters(), meta_lr, betas=(0, 0.999))
 
         meta_losses, meta_accs = [], []
         for _ in tqdm(range(test_samples // tasks)):
             meta_batch = self._split_batch(self._test_tasks.sample())
-            loss, acc = self._testing_step(meta_batch, maml.clone())
+            loss, acc = self._testing_step(meta_batch, mamlpp.clone())
             meta_losses.append(loss)
             meta_accs.append(acc)
         loss = float(torch.Tensor(meta_losses).mean().item())
@@ -312,5 +321,6 @@ def test(
 
 if __name__ == "__main__":
     mamlPlusPlus = MAMLppTrainer()
-    model = mamlPlusPlus.train()
-    mamlPlusPlus.test(model)
+    model_state_dict, transform_state_dict = mamlPlusPlus.train()
+    mamlPlusPlus.test(model_state_dict, transform_state_dict)
+
diff --git a/learn2learn/algorithms/gbml.py b/learn2learn/algorithms/gbml.py
index 24d230f1..a1bf61bb 100644
--- a/learn2learn/algorithms/gbml.py
+++ b/learn2learn/algorithms/gbml.py
@@ -27,6 +27,8 @@ class GBML(torch.nn.Module):
     * **lr** (float) - Fast adaptation learning rate.
     * **adapt_transform** (bool, *optional*, default=False) - Whether to update the transform's
         parameters during fast-adaptation.
+    * **pass_param_names** (bool, *optional*, default=False) - Whether to pass the parameters'
+        names to the transform.
     * **first_order** (bool, *optional*, default=False) - Whether to use the first-order
         approximation.
     * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation
@@ -73,6 +75,7 @@ def __init__(
             transform,
             lr=1.0,
             adapt_transform=False,
+            pass_param_names=False,
             first_order=False,
             allow_unused=False,
             allow_nograd=False,
@@ -90,8 +93,9 @@ def __init__(
             self.compute_update = kwargs.get('compute_update')
         else:
             self.compute_update = l2l.optim.ParameterUpdate(
-                    parameters=self.module.parameters(),
+                    parameters=self.module.named_parameters(),
                     transform=transform,
+                    pass_param_names=pass_param_names,
             )
         self.diff_sgd = l2l.optim.DifferentiableSGD(lr=self.lr)
         # Whether the module params have already been updated with the
diff --git a/learn2learn/algorithms/maml.py b/learn2learn/algorithms/maml.py
index 08fa2638..70876011 100644
--- a/learn2learn/algorithms/maml.py
+++ b/learn2learn/algorithms/maml.py
@@ -102,6 +102,7 @@ def __init__(self,
         if allow_unused is None:
             allow_unused = allow_nograd
         self.allow_unused = allow_unused
+        self.update_func = maml_update
 
     def forward(self, *args, **kwargs):
         return self.module(*args, **kwargs)
@@ -166,7 +167,7 @@ def adapt(self,
                 print('learn2learn: Maybe try with allow_nograd=True and/or allow_unused=True ?')
 
         # Update the module
-        self.module = maml_update(self.module, self.lr, gradients)
+        self.module = self.update_func(self.module, self.lr, gradients)
 
     def clone(self, first_order=None, allow_unused=None, allow_nograd=None):
         """
diff --git a/learn2learn/algorithms/meta_sgd.py b/learn2learn/algorithms/meta_sgd.py
index bf0002e5..9c2373ea 100644
--- a/learn2learn/algorithms/meta_sgd.py
+++ b/learn2learn/algorithms/meta_sgd.py
@@ -109,6 +109,7 @@ def __init__(self, model, lr=1.0, first_order=False, lrs=None):
             lrs = nn.ParameterList([nn.Parameter(lr) for lr in lrs])
         self.lrs = lrs
         self.first_order = first_order
+        self.update_func = meta_sgd_update
 
     def forward(self, *args, **kwargs):
         return self.module(*args, **kwargs)
@@ -138,7 +139,7 @@ def adapt(self, loss, first_order=None):
                          self.module.parameters(),
                          retain_graph=second_order,
                          create_graph=second_order)
-        self.module = meta_sgd_update(self.module, self.lrs, gradients)
+        self.module = self.update_func(self.module, self.lrs, gradients)
 
 
 if __name__ == '__main__':
diff --git a/learn2learn/optim/parameter_update.py b/learn2learn/optim/parameter_update.py
index 3ad46e83..bf642577 100644
--- a/learn2learn/optim/parameter_update.py
+++ b/learn2learn/optim/parameter_update.py
@@ -26,6 +26,8 @@ class ParameterUpdate(torch.nn.Module):
     * **parameters** (list) - Parameters of the model to update.
     * **transform** (callable) - A callable that returns an instantiated
         transform given a parameter.
+    * **pass_param_names** (bool, *optional*, default=False) - Whether to pass the parameters'
+        names to the transform.
 
     **Example**
     ~~~python
@@ -47,13 +49,13 @@ class ParameterUpdate(torch.nn.Module):
     ~~~
     """
 
-    def __init__(self, parameters, transform):
+    def __init__(self, parameters, transform, pass_param_names=False):
         super(ParameterUpdate, self).__init__()
         transforms_indices = []
         transform_modules = []
         module_counter = 0
-        for param in parameters:
-            t = transform(param)
+        for name, param in parameters:
+            t = transform(param) if not pass_param_names else transform(name, param)
             if t is None:
                 idx = None
             elif isinstance(t, torch.nn.Module):
diff --git a/learn2learn/optim/transforms/__init__.py b/learn2learn/optim/transforms/__init__.py
index 8d62b45d..1a6c52ca 100644
--- a/learn2learn/optim/transforms/__init__.py
+++ b/learn2learn/optim/transforms/__init__.py
@@ -7,6 +7,7 @@
 gradient descent, allow you to learn optimization functions from data.
 """
 
+from .layer_step_lr_transform import PerStepLRTransform, PerLayerPerStepLRTransform
 from .module_transform import ModuleTransform, ReshapedTransform
 from .kronecker_transform import KroneckerTransform
 from .transform_dictionary import TransformDictionary
diff --git a/learn2learn/optim/transforms/layer_step_lr_transform.py b/learn2learn/optim/transforms/layer_step_lr_transform.py
new file mode 100644
index 00000000..6d450e82
--- /dev/null
+++ b/learn2learn/optim/transforms/layer_step_lr_transform.py
@@ -0,0 +1,271 @@
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+
+"""
+Per-Layer and Per-Layer Per-Step Learning Rate transforms for the GBML algorithm.
+"""
+
+from typing import Any, Dict, Optional
+import learn2learn as l2l
+import numpy as np
+import random
+import torch
+
+
+class PerStepLR(torch.nn.Module):
+    def __init__(self, init_lr: float, steps: int):
+        super().__init__()
+        self.lrs = torch.nn.Parameter(
+            data=torch.ones(steps) * init_lr,
+            requires_grad=True,
+        )
+        self._current_step = 0
+        self._steps = steps
+
+    def forward(self, grad):
+        # The update is positive because it is applied as `grad.mul(-self.lr)` in
+        # DifferentiableSGD of the GBML, where lr=1.
+        updates = self.lrs[self._current_step] * grad
+        self._current_step = (
+            self._current_step + 1 if self._current_step < (self._steps - 1) else 0
+        )  # avoids overflow
+        return updates
+
+    def __str__(self):
+        return str(self.lrs)
+
+
+class PerLayerPerStepLRTransform:
+    """
+    [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/optim/transforms/layer_step_lr_transform.py)
+
+    **Description**
+
+    The PerLayerPerStepLRTransform creates a per-step transform for each layer of a given module.
+
+    This can be used with the GBML algorithm to reproduce the *LSLR* improvement of MAML++ proposed by
+    Antoniou et al.
+
+    **Arguments**
+
+    * **init_lr** (float) - The initial learning rate for each adaptation step and layer.
+    * **steps** (int) - The number of adaptation steps.
+    * **model** (torch.nn.Module) - The module being updated with the learning rates. This is
+        needed to define the learning rates for each layer.
+    * **layer_names** (List[str], *optional*, default=None) - If not None, only layers named with
+    one of the list elements will have a per-step learning rate. Otherwise, all layers will have
+    one. It may be more efficient to specify the layer names as to avoid redundant layers
+    introducing extra parameters, such as a "BatchNorm" layer followed by a "Conv" layer.
+
+    **Example**
+    ~~~python
+    model = torch.nn.Sequential(
+        torch.nn.Linear(128, 24), torch.nn.Linear(24, 16), torch.nn.Linear(16, 10)
+    )
+    transform = PerLayerPerStepLRTransform(1e-3, N_STEPS, model, ["conv", "linear"])
+    metamodel = l2l.algorithms.GBML(
+        model,
+        transform,
+        allow_nograd=True,
+        lr=0.001,
+        adapt_transform=False,
+        pass_param_names=True,  # This is needed for this transform to find the module's layers
+    )
+    opt = torch.optim.Adam(metamodel.parameters(), lr=1.0)
+    ~~~
+    """
+
+    def __init__(self, init_lr, steps, model, layer_names=None):
+        self._lslr = {}
+        for layer_name, layer in model.named_modules():
+            # If the layer has learnable parameters
+            if (
+                len(
+                    [
+                        name
+                        for name, param in layer.named_parameters(recurse=False)
+                        if param.requires_grad
+                    ]
+                )
+                > 0
+            ):
+                if layer_names is None or layer_name.lower().split(".")[-1] in [
+                    name.lower() for name in layer_names
+                ]:
+                    self._lslr[layer_name] = PerStepLR(init_lr, steps)
+                else:
+                    self._lslr[layer_name] = None
+
+    def load_state_dict(self, lr_state_dicts: Dict[str, Dict[str, Any]]):
+        assert (
+            type(lr_state_dicts) is dict
+        ), "Argument lr_state_dicts must be a dictionary!"
+        for layer_name, state_dict in lr_state_dicts.items():
+            if self._lslr[layer_name] is not None:
+                self._lslr[layer_name].load_state_dict(state_dict)
+
+    def state_dict(self) -> Dict[str, Optional[Dict[str, Any]]]:
+        return {
+            layer_name: (pslr.state_dict() if pslr is not None else None) for layer_name, pslr in self._lslr.items()
+        }
+
+    def __call__(self, name, param):
+        name = name[
+            : name.rfind(".")
+        ]  # Extract the layer name from the named parameter
+        assert name in self._lslr, "No matching LR found for layer."
+        return self._lslr[name]
+
+    def __str__(self):
+        string = ""
+        for layer, lslr in self._lslr.items():
+            string += f"Layer {layer}: {lslr}\n"
+        return string
+
+
+class PerStepLRTransform:
+    """
+    [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/optim/transforms/layer_step_lr_transform.py)
+
+    **Description**
+
+    The PerStepLRTransform creates a per-step transform for inner-loop-based algorithms.
+
+    This can be used with the GBML algorithm to reproduce the *LSLR* improvement of MAML++ proposed by
+    Antoniou et al, with the same learning rates for all layers.
+
+    **Arguments**
+
+    * **init_lr** (float) - The initial learning rate for each adaptation step.
+    * **steps** (int) - The number of adaptation steps.
+
+    **Example**
+    ~~~python
+    model = torch.nn.Linear(128, 10)
+    transform = PerStepLRTransform(1e-3, N_STEPS)
+    metamodel = l2l.algorithms.GBML(
+        model,
+        transform,
+        allow_nograd=True,
+        lr=0.001,
+        adapt_transform=False,
+    )
+    opt = torch.optim.Adam(metamodel.parameters(), lr=1.0)
+    ~~~
+    """
+
+    def __init__(self, init_lr, steps):
+        self._obj = PerStepLR(init_lr, steps)
+
+    def __call__(self, param):
+        return self._obj
+
+    def __str__(self):
+        return str(self._obj)
+
+    def parameters(self):
+        return self._obj.parameters()
+
+    def load_state_dict(self, state_dict: Dict[str, Any]):
+        self._obj.load_state_dict(state_dict)
+
+    def state_dict(self) -> Dict[str, Any]:
+        return self._obj.state_dict()
+
+
+if __name__ == "__main__":
+
+    random.seed(1234)
+    np.random.seed(1234)
+    torch.manual_seed(1234)
+
+    N_DIMS = 32
+    N_SAMPLES = 128
+    N_STEPS = 5
+    device = torch.device("cpu")
+
+    print("[*] Testing per-step LR with one linear layer")
+    model = torch.nn.Linear(N_DIMS, 10)
+    transform = PerStepLRTransform(1e-3, N_STEPS)
+    # Setting adapt_transform=True means that the transform will be updated in
+    # the *adapt* function, which is not what we want. We want it to compute gradients during
+    # eval_loss.backward() only, so that it's updated in opt.step().
+    metamodel = l2l.algorithms.GBML(
+        model,
+        transform,
+        lr=1.0,
+        adapt_transform=False,
+        allow_nograd=True,
+    )
+    opt = torch.optim.Adam(metamodel.parameters(), lr=1.0)
+    print("\nPre-learning")
+    print("Transform parameters: ", transform)
+    for name, p in metamodel.named_parameters():
+        print(name, ":", p.norm())
+
+    for task in range(10):
+        opt.zero_grad()
+        learner = metamodel.clone()
+        X = torch.randn(N_SAMPLES, N_DIMS)
+
+        # fast adapt
+        for step in range(N_STEPS):
+            adapt_loss = learner(X).norm(2)
+            learner.adapt(adapt_loss)
+
+        # meta-learn
+        eval_loss = learner(X).norm(2)
+        eval_loss.backward()
+        opt.step()
+
+    print("\nPost-learning")
+    print("Transform parameters: ", transform)
+    for name, p in metamodel.named_parameters():
+        print(name, ":", p.norm())
+
+    print("Transform state_dict: ", transform.state_dict())
+
+    print("\n\n--------------------------")
+    print("[*] Testing per-layer per-step LR with three linear layers")
+    model = torch.nn.Sequential(
+        torch.nn.Linear(N_DIMS, 24), torch.nn.Linear(24, 16), torch.nn.Linear(16, 10)
+    )
+    transform = PerLayerPerStepLRTransform(1e-3, N_STEPS, model)
+    # Setting adapt_transform=True means that the transform will be updated in
+    # the *adapt* function, which is not what we want. We want it to compute gradients during
+    # eval_loss.backward() only, so that it's updated in opt.step().
+    metamodel = l2l.algorithms.GBML(
+        model,
+        transform,
+        lr=1.0,
+        adapt_transform=False,
+        pass_param_names=True,
+        allow_nograd=True,
+    )
+    opt = torch.optim.Adam(metamodel.parameters(), lr=1.0)
+    print("\nPre-learning")
+    print("Transform parameters: ", transform)
+    for name, p in metamodel.named_parameters():
+        print(name, ":", p.norm())
+
+    for task in range(10):
+        opt.zero_grad()
+        learner = metamodel.clone()
+        X = torch.randn(N_SAMPLES, N_DIMS)
+
+        # fast adapt
+        for step in range(N_STEPS):
+            adapt_loss = learner(X).norm(2)
+            learner.adapt(adapt_loss)
+
+        # meta-learn
+        eval_loss = learner(X).norm(2)
+        eval_loss.backward()
+        opt.step()
+
+    print("\nPost-learning")
+    print("Transform parameters: ", transform)
+    for name, p in metamodel.named_parameters():
+        print(name, ":", p.norm())
+
+    print("Transform state_dict: ", transform.state_dict())