diff --git a/CHANGELOG.md b/CHANGELOG.md index ca845a18..2bae22a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -* New vision example: MAML++. (@[DubiousCactus](https://github.com/DubiousCactus)) +* New GBML transforms: Per-Step & Per-Layer Per-Step learning rates from MAML++. ([Theo Morales](https://github.com/DubiousCactus)) +* New vision example: MAML++. ([Theo Morales](https://github.com/DubiousCactus)) * Add tutorial: "Demystifying Task Transforms", ([Varad Pimpalkhute](https://github.com/nightlessbaron/)) ### Changed diff --git a/examples/vision/mamlpp/MAMLpp.py b/examples/vision/mamlpp/MAMLpp.py deleted file mode 100644 index 533193c9..00000000 --- a/examples/vision/mamlpp/MAMLpp.py +++ /dev/null @@ -1,305 +0,0 @@ -#! /usr/bin/env python3 -# -*- coding: utf-8 -*- -# vim:fenc=utf-8 -# - -""" -MAML++ wrapper. -""" - -import torch -import traceback - -from torch.autograd import grad - -from learn2learn.algorithms.base_learner import BaseLearner -from learn2learn.utils import clone_module, update_module, clone_named_parameters - - -def maml_pp_update(model, step=None, lrs=None, grads=None): - """ - - **Description** - - Performs a MAML++ update on model using grads and lrs. - The function re-routes the Python object, thus avoiding in-place - operations. - - NOTE: The model itself is updated in-place (no deepcopy), but the - parameters' tensors are not. - - **Arguments** - - * **model** (Module) - The model to update. - * **lrs** (list) - The meta-learned learning rates used to update the model. - * **grads** (list, *optional*, default=None) - A list of gradients for each layer - of the model. If None, will use the gradients in .grad attributes. - - **Example** - ~~~python - maml_pp = l2l.algorithms.MAMLpp(Model(), lr=1.0) - lslr = torch.nn.ParameterDict() - for layer_name, layer in model.named_modules(): - # If the layer has learnable parameters - if ( - len( - [ - name - for name, param in layer.named_parameters(recurse=False) - if param.requires_grad - ] - ) - > 0 - ): - lslr[layer_name.replace(".", "-")] = torch.nn.Parameter( - data=torch.ones(adaptation_steps) * init_lr, - requires_grad=True, - ) - model = maml_pp.clone() # The next two lines essentially implement model.adapt(loss) - for inner_step in range(5): - loss = criterion(model(x), y) - grads = autograd.grad(loss, model.parameters(), create_graph=True) - maml_pp_update(model, inner_step, lrs=lslr, grads=grads) - ~~~ - """ - if grads is not None and lrs is not None: - params = list(model.parameters()) - if not len(grads) == len(list(params)): - msg = "WARNING:maml_update(): Parameters and gradients have different length. (" - msg += str(len(params)) + " vs " + str(len(grads)) + ")" - print(msg) - # TODO: Why doesn't this work?? I can't assign p.grad when zipping like this... Is this - # because I'm using a tuple? - # for named_param, g in zip( - # [(k, v) for k, l in model.named_parameters() for v in l], grads - # ): - # p_name, p = named_param - it = 0 - for name, p in model.named_parameters(): - if grads[it] is not None: - lr = None - layer_name = name[: name.rfind(".")].replace( - ".", "-" - ) # Extract the layer name from the named parameter - lr = lrs[layer_name][step] - assert ( - lr is not None - ), f"Parameter {name} does not have a learning rate in LSLR dict!" - p.grad = grads[it] - p._lr = lr - it += 1 - - # Update the params - for param_key in model._parameters: - p = model._parameters[param_key] - if p is not None and p.grad is not None: - model._parameters[param_key] = p - p._lr * p.grad - p.grad = None - p._lr = None - - # Second, handle the buffers if necessary - for buffer_key in model._buffers: - buff = model._buffers[buffer_key] - if buff is not None and buff.grad is not None and buff._lr is not None: - model._buffers[buffer_key] = buff - buff._lr * buff.grad - buff.grad = None - buff._lr = None - - # Then, recurse for each submodule - for module_key in model._modules: - model._modules[module_key] = maml_pp_update(model._modules[module_key]) - return model - - -class MAMLpp(BaseLearner): - """ - [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/maml.py) - - **Description** - - High-level implementation of *Model-Agnostic Meta-Learning*. - - This class wraps an arbitrary nn.Module and augments it with `clone()` and `adapt()` - methods. - - For the first-order version of MAML (i.e. FOMAML), set the `first_order` flag to `True` - upon initialization. - - **Arguments** - - * **model** (Module) - Module to be wrapped. - * **lr** (float) - Fast adaptation learning rate. - * **lslr** (bool) - Whether to use Per-Layer Per-Step Learning Rates and Gradient Directions - (LSLR) or not. - * **lrs** (list of Parameters, *optional*, default=None) - If not None, overrides `lr`, and uses the list - as learning rates for fast-adaptation. - * **first_order** (bool, *optional*, default=False) - Whether to use the first-order - approximation of MAML. (FOMAML) - * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation - of unused parameters. Defaults to `allow_nograd`. - * **allow_nograd** (bool, *optional*, default=False) - Whether to allow adaptation with - parameters that have `requires_grad = False`. - - **References** - - 1. Finn et al. 2017. "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks." - - **Example** - - ~~~python - linear = l2l.algorithms.MAML(nn.Linear(20, 10), lr=0.01) - clone = linear.clone() - error = loss(clone(X), y) - clone.adapt(error) - error = loss(clone(X), y) - error.backward() - ~~~ - """ - - def __init__( - self, - model, - lr, - lrs=None, - adaptation_steps=1, - first_order=False, - allow_unused=None, - allow_nograd=False, - ): - super().__init__() - self.module = model - self.lr = lr - if lrs is None: - lrs = self._init_lslr_parameters(model, adaptation_steps, lr) - self.lrs = lrs - self.first_order = first_order - self.allow_nograd = allow_nograd - if allow_unused is None: - allow_unused = allow_nograd - self.allow_unused = allow_unused - - def _init_lslr_parameters( - self, model: torch.nn.Module, adaptation_steps: int, init_lr: float - ) -> torch.nn.ParameterDict: - lslr = torch.nn.ParameterDict() - for layer_name, layer in model.named_modules(): - # If the layer has learnable parameters - if ( - len( - [ - name - for name, param in layer.named_parameters(recurse=False) - if param.requires_grad - ] - ) - > 0 - ): - lslr[layer_name.replace(".", "-")] = torch.nn.Parameter( - data=torch.ones(adaptation_steps) * init_lr, - requires_grad=True, - ) - return lslr - - def forward(self, *args, **kwargs): - return self.module(*args, **kwargs) - - def adapt(self, loss, step=None, first_order=None, allow_unused=None, allow_nograd=None): - """ - **Description** - - Takes a gradient step on the loss and updates the cloned parameters in place. - - **Arguments** - - * **loss** (Tensor) - Loss to minimize upon update. - * **step** (int) - Current inner loop step. Used to fetch the corresponding learning rate. - * **first_order** (bool, *optional*, default=None) - Whether to use first- or - second-order updates. Defaults to self.first_order. - * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation - of unused parameters. Defaults to self.allow_unused. - * **allow_nograd** (bool, *optional*, default=None) - Whether to allow adaptation with - parameters that have `requires_grad = False`. Defaults to self.allow_nograd. - """ - if first_order is None: - first_order = self.first_order - if allow_unused is None: - allow_unused = self.allow_unused - if allow_nograd is None: - allow_nograd = self.allow_nograd - second_order = not first_order - - gradients = [] - if allow_nograd: - # Compute relevant gradients - diff_params = [p for p in self.module.parameters() if p.requires_grad] - grad_params = grad( - loss, - diff_params, - retain_graph=second_order, - create_graph=second_order, - allow_unused=allow_unused, - ) - grad_counter = 0 - - # Handles gradients for non-differentiable parameters - for param in self.module.parameters(): - if param.requires_grad: - gradient = grad_params[grad_counter] - grad_counter += 1 - else: - gradient = None - gradients.append(gradient) - else: - try: - gradients = grad( - loss, - self.module.parameters(), - retain_graph=second_order, - create_graph=second_order, - allow_unused=allow_unused, - ) - except RuntimeError: - traceback.print_exc() - print( - "learn2learn: Maybe try with allow_nograd=True and/or allow_unused=True ?" - ) - - # Update the module - assert step is not None, "step cannot be None when using LSLR!" - self.module = maml_pp_update(self.module, step, lrs=self.lrs, grads=gradients) - - def clone(self, first_order=None, allow_unused=None, allow_nograd=None): - """ - **Description** - - Returns a `MAMLpp`-wrapped copy of the module whose parameters and buffers - are `torch.clone`d from the original module. - - This implies that back-propagating losses on the cloned module will - populate the buffers of the original module. - For more information, refer to learn2learn.clone_module(). - - **Arguments** - - * **first_order** (bool, *optional*, default=None) - Whether the clone uses first- - or second-order updates. Defaults to self.first_order. - * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation - of unused parameters. Defaults to self.allow_unused. - * **allow_nograd** (bool, *optional*, default=False) - Whether to allow adaptation with - parameters that have `requires_grad = False`. Defaults to self.allow_nograd. - - """ - if first_order is None: - first_order = self.first_order - if allow_unused is None: - allow_unused = self.allow_unused - if allow_nograd is None: - allow_nograd = self.allow_nograd - return MAMLpp( - clone_module(self.module), - lr=self.lr, - lrs=clone_named_parameters(self.lrs), - first_order=first_order, - allow_unused=allow_unused, - allow_nograd=allow_nograd, - ) diff --git a/examples/vision/mamlpp/maml++_miniimagenet.py b/examples/vision/mamlpp/maml++_miniimagenet.py index 78085bf9..a84f2445 100755 --- a/examples/vision/mamlpp/maml++_miniimagenet.py +++ b/examples/vision/mamlpp/maml++_miniimagenet.py @@ -19,9 +19,9 @@ from collections import namedtuple from typing import Tuple from tqdm import tqdm +from learn2learn.optim.transforms.layer_step_lr_transform import PerLayerPerStepLRTransform -from examples.vision.mamlpp.cnn4_bnrs import CNN4_BNRS -from examples.vision.mamlpp.MAMLpp import MAMLpp +from learn2learn.vision.models.cnn4 import CNN4 MetaBatch = namedtuple("MetaBatch", "support query") @@ -38,9 +38,9 @@ def accuracy(predictions, targets): class MAMLppTrainer: def __init__( self, - ways=5, - k_shots=10, - n_queries=30, + ways=20, + k_shots=5, + n_queries=5, steps=5, msl_epochs=25, DA_epochs=50, @@ -52,6 +52,7 @@ def __init__( if self._use_cuda and torch.cuda.device_count(): torch.cuda.manual_seed(seed) self._device = torch.device("cuda") + print(f"[*] Using device: {self._device}") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) @@ -64,15 +65,16 @@ def __init__( self._test_tasks, ) = l2l.vision.benchmarks.get_tasksets( "mini-imagenet", - train_samples=k_shots, + train_samples=k_shots+n_queries, train_ways=ways, - test_samples=n_queries, + test_samples=k_shots+n_queries, test_ways=ways, root="~/data", ) + print("[*] Done.") # Model - self._model = CNN4_BNRS(ways, adaptation_steps=steps) + self._model = CNN4(ways) # TODO: Change config for miniImageNet (32 filters ?) if self._use_cuda: self._model.cuda() @@ -109,9 +111,9 @@ def _split_batch(self, batch: tuple) -> MetaBatch: Separate data batch into adaptation/evalutation sets. """ images, labels = batch - batch_size = self._k_shots + self._n_queries - assert batch_size <= images.shape[0], "K+N are greater than the batch size!" - indices = torch.randperm(batch_size) + task_size = self._k_shots + self._n_queries + assert task_size <= images.shape[0], "K+N are smaller than the batch size!" + indices = torch.randperm(task_size) support_indices = indices[: self._k_shots] query_indices = indices[self._k_shots :] return MetaBatch( @@ -125,7 +127,7 @@ def _split_batch(self, batch: tuple) -> MetaBatch: def _training_step( self, batch: MetaBatch, - learner: MAMLpp, + learner: torch.nn.Module, msl: bool = True, epoch: int = 0, ) -> Tuple[torch.Tensor, float]: @@ -147,26 +149,26 @@ def _training_step( # Adapt the model on the support set for step in range(self._steps): # forward + backward + optimize - pred = learner(s_inputs, step) + pred = learner(s_inputs) support_loss = self._inner_criterion(pred, s_labels) - learner.adapt(support_loss, first_order=not second_order, step=step) + learner.adapt(support_loss, first_order=not second_order) # Multi-Step Loss if msl: - q_pred = learner(q_inputs, step) + q_pred = learner(q_inputs) query_loss += self._step_weights[step] * self._inner_criterion( q_pred, q_labels ) # Evaluate the adapted model on the query set if not msl: - q_pred = learner(q_inputs, self._steps-1) + q_pred = learner(q_inputs, inference=True) query_loss = self._inner_criterion(q_pred, q_labels) acc = accuracy(q_pred, q_labels).detach() return query_loss, acc def _testing_step( - self, batch: MetaBatch, learner: MAMLpp + self, batch: MetaBatch, learner: torch.nn.Module ) -> Tuple[torch.Tensor, float]: s_inputs, s_labels = batch.support q_inputs, q_labels = batch.query @@ -180,12 +182,12 @@ def _testing_step( # Adapt the model on the support set for step in range(self._steps): # forward + backward + optimize - pred = learner(s_inputs, step) + pred = learner(s_inputs) support_loss = self._inner_criterion(pred, s_labels) - learner.adapt(support_loss, step=step) + learner.adapt(support_loss) # Evaluate the adapted model on the query set - q_pred = learner(q_inputs, self._steps-1) + q_pred = learner(q_inputs, inference=True) query_loss = self._inner_criterion(q_pred, q_labels).detach() acc = accuracy(q_pred, q_labels) @@ -195,23 +197,29 @@ def train( self, meta_lr=0.001, fast_lr=0.01, - meta_bsz=5, - epochs=100, + meta_bsz=16, + epochs=1, val_interval=1, ): print("[*] Training...") - maml = MAMLpp( + transform = PerLayerPerStepLRTransform(fast_lr, self._steps, self._model, ["conv"]) + # Setting adapt_transform=True means that the transform will be updated in + # the *adapt* function, which is not what we want. We want it to compute gradients during + # eval_loss.backward() only, so that it's updated in opt.step(). + mamlpp = l2l.algorithms.GBML( self._model, - lr=fast_lr, # Initialisation LR for all layers and steps - adaptation_steps=self._steps, # For LSLR - first_order=False, - allow_nograd=True, # For the parameters of the MetaBatchNorm layers + transform, + lr=1.0, + allow_nograd=True, + adapt_transform=False, + pass_param_names=True, ) - opt = torch.optim.AdamW(maml.parameters(), meta_lr, betas=(0, 0.999)) + opt = torch.optim.AdamW(mamlpp.parameters(), meta_lr, betas=(0.9, 0.99)) iter_per_epoch = ( train_samples // (meta_bsz * (self._k_shots + self._n_queries)) ) + 1 + print(f"[*] Training with {iter_per_epoch} iterations/epoch with {train_samples} total training samples") scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( opt, T_max=epochs * iter_per_epoch, @@ -228,7 +236,7 @@ def train( meta_batch = self._split_batch(self._train_tasks.sample()) meta_loss, meta_acc = self._training_step( meta_batch, - maml.clone(), + mamlpp.clone(), msl=(epoch < self._msl_epochs), epoch=epoch, ) @@ -241,7 +249,7 @@ def train( # Average the accumulated gradients and optimize with torch.no_grad(): - for p in maml.parameters(): + for p in mamlpp.parameters(): # Remember the MetaBatchNorm layer has parameters that don't require grad! if p.requires_grad: p.grad.data.mul_(1.0 / meta_bsz) @@ -259,49 +267,50 @@ def train( # ======= Validation ======== if (epoch + 1) % val_interval == 0: - # Backup the BatchNorm layers' running statistics - maml.backup_stats() - # Compute the meta-validation loss # TODO: Go through the entire validation set, which shouldn't be shuffled, and # which tasks should not be continuously resampled from! + # This may be done in the get_tasksets() method actually... meta_val_losses, meta_val_accs = [], [] for _ in tqdm(range(val_samples // tasks)): meta_batch = self._split_batch(self._valid_tasks.sample()) - loss, acc = self._testing_step(meta_batch, maml.clone()) + loss, acc = self._testing_step(meta_batch, mamlpp.clone()) meta_val_losses.append(loss) meta_val_accs.append(acc) meta_val_loss = float(torch.Tensor(meta_val_losses).mean().item()) meta_val_acc = float(torch.Tensor(meta_val_accs).mean().item()) print(f"Meta-validation Loss: {meta_val_loss:.6f}") print(f"Meta-validation Accuracy: {meta_val_acc:.6f}") - # Restore the BatchNorm layers' running statistics - maml.restore_backup_stats() print("============================================") - return self._model.state_dict() + return self._model.state_dict(), transform.state_dict() def test( self, model_state_dict, + trasnform_state_dict, meta_lr=0.001, fast_lr=0.01, meta_bsz=5, ): self._model.load_state_dict(model_state_dict) - maml = MAMLpp( + transform = PerLayerPerStepLRTransform(fast_lr, self._steps, self._model, ["conv"]) + transform.load_state_dict(trasnform_state_dict) + # Setting adapt_transform=True means that the transform will be updated in + # the *adapt* function, which is not what we want. We want it to compute gradients during + # eval_loss.backward() only, so that it's updated in opt.step(). + mamlpp = l2l.algorithms.GBML( self._model, - lr=fast_lr, - adaptation_steps=self._steps, - first_order=False, - allow_nograd=True, + transform, + lr=1.0, + adapt_transform=False, + pass_param_names=True, ) - opt = torch.optim.AdamW(maml.parameters(), meta_lr, betas=(0, 0.999)) meta_losses, meta_accs = [], [] for _ in tqdm(range(test_samples // tasks)): meta_batch = self._split_batch(self._test_tasks.sample()) - loss, acc = self._testing_step(meta_batch, maml.clone()) + loss, acc = self._testing_step(meta_batch, mamlpp.clone()) meta_losses.append(loss) meta_accs.append(acc) loss = float(torch.Tensor(meta_losses).mean().item()) @@ -312,5 +321,6 @@ def test( if __name__ == "__main__": mamlPlusPlus = MAMLppTrainer() - model = mamlPlusPlus.train() - mamlPlusPlus.test(model) + model_state_dict, transform_state_dict = mamlPlusPlus.train() + mamlPlusPlus.test(model_state_dict, transform_state_dict) + diff --git a/learn2learn/algorithms/gbml.py b/learn2learn/algorithms/gbml.py index 24d230f1..a1bf61bb 100644 --- a/learn2learn/algorithms/gbml.py +++ b/learn2learn/algorithms/gbml.py @@ -27,6 +27,8 @@ class GBML(torch.nn.Module): * **lr** (float) - Fast adaptation learning rate. * **adapt_transform** (bool, *optional*, default=False) - Whether to update the transform's parameters during fast-adaptation. + * **pass_param_names** (bool, *optional*, default=False) - Whether to pass the parameters' + names to the transform. * **first_order** (bool, *optional*, default=False) - Whether to use the first-order approximation. * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation @@ -73,6 +75,7 @@ def __init__( transform, lr=1.0, adapt_transform=False, + pass_param_names=False, first_order=False, allow_unused=False, allow_nograd=False, @@ -90,8 +93,9 @@ def __init__( self.compute_update = kwargs.get('compute_update') else: self.compute_update = l2l.optim.ParameterUpdate( - parameters=self.module.parameters(), + parameters=self.module.named_parameters(), transform=transform, + pass_param_names=pass_param_names, ) self.diff_sgd = l2l.optim.DifferentiableSGD(lr=self.lr) # Whether the module params have already been updated with the diff --git a/learn2learn/algorithms/maml.py b/learn2learn/algorithms/maml.py index 08fa2638..70876011 100644 --- a/learn2learn/algorithms/maml.py +++ b/learn2learn/algorithms/maml.py @@ -102,6 +102,7 @@ def __init__(self, if allow_unused is None: allow_unused = allow_nograd self.allow_unused = allow_unused + self.update_func = maml_update def forward(self, *args, **kwargs): return self.module(*args, **kwargs) @@ -166,7 +167,7 @@ def adapt(self, print('learn2learn: Maybe try with allow_nograd=True and/or allow_unused=True ?') # Update the module - self.module = maml_update(self.module, self.lr, gradients) + self.module = self.update_func(self.module, self.lr, gradients) def clone(self, first_order=None, allow_unused=None, allow_nograd=None): """ diff --git a/learn2learn/algorithms/meta_sgd.py b/learn2learn/algorithms/meta_sgd.py index bf0002e5..9c2373ea 100644 --- a/learn2learn/algorithms/meta_sgd.py +++ b/learn2learn/algorithms/meta_sgd.py @@ -109,6 +109,7 @@ def __init__(self, model, lr=1.0, first_order=False, lrs=None): lrs = nn.ParameterList([nn.Parameter(lr) for lr in lrs]) self.lrs = lrs self.first_order = first_order + self.update_func = meta_sgd_update def forward(self, *args, **kwargs): return self.module(*args, **kwargs) @@ -138,7 +139,7 @@ def adapt(self, loss, first_order=None): self.module.parameters(), retain_graph=second_order, create_graph=second_order) - self.module = meta_sgd_update(self.module, self.lrs, gradients) + self.module = self.update_func(self.module, self.lrs, gradients) if __name__ == '__main__': diff --git a/learn2learn/optim/parameter_update.py b/learn2learn/optim/parameter_update.py index 3ad46e83..bf642577 100644 --- a/learn2learn/optim/parameter_update.py +++ b/learn2learn/optim/parameter_update.py @@ -26,6 +26,8 @@ class ParameterUpdate(torch.nn.Module): * **parameters** (list) - Parameters of the model to update. * **transform** (callable) - A callable that returns an instantiated transform given a parameter. + * **pass_param_names** (bool, *optional*, default=False) - Whether to pass the parameters' + names to the transform. **Example** ~~~python @@ -47,13 +49,13 @@ class ParameterUpdate(torch.nn.Module): ~~~ """ - def __init__(self, parameters, transform): + def __init__(self, parameters, transform, pass_param_names=False): super(ParameterUpdate, self).__init__() transforms_indices = [] transform_modules = [] module_counter = 0 - for param in parameters: - t = transform(param) + for name, param in parameters: + t = transform(param) if not pass_param_names else transform(name, param) if t is None: idx = None elif isinstance(t, torch.nn.Module): diff --git a/learn2learn/optim/transforms/__init__.py b/learn2learn/optim/transforms/__init__.py index 8d62b45d..1a6c52ca 100644 --- a/learn2learn/optim/transforms/__init__.py +++ b/learn2learn/optim/transforms/__init__.py @@ -7,6 +7,7 @@ gradient descent, allow you to learn optimization functions from data. """ +from .layer_step_lr_transform import PerStepLRTransform, PerLayerPerStepLRTransform from .module_transform import ModuleTransform, ReshapedTransform from .kronecker_transform import KroneckerTransform from .transform_dictionary import TransformDictionary diff --git a/learn2learn/optim/transforms/layer_step_lr_transform.py b/learn2learn/optim/transforms/layer_step_lr_transform.py new file mode 100644 index 00000000..6d450e82 --- /dev/null +++ b/learn2learn/optim/transforms/layer_step_lr_transform.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 + +""" +Per-Layer and Per-Layer Per-Step Learning Rate transforms for the GBML algorithm. +""" + +from typing import Any, Dict, Optional +import learn2learn as l2l +import numpy as np +import random +import torch + + +class PerStepLR(torch.nn.Module): + def __init__(self, init_lr: float, steps: int): + super().__init__() + self.lrs = torch.nn.Parameter( + data=torch.ones(steps) * init_lr, + requires_grad=True, + ) + self._current_step = 0 + self._steps = steps + + def forward(self, grad): + # The update is positive because it is applied as `grad.mul(-self.lr)` in + # DifferentiableSGD of the GBML, where lr=1. + updates = self.lrs[self._current_step] * grad + self._current_step = ( + self._current_step + 1 if self._current_step < (self._steps - 1) else 0 + ) # avoids overflow + return updates + + def __str__(self): + return str(self.lrs) + + +class PerLayerPerStepLRTransform: + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/optim/transforms/layer_step_lr_transform.py) + + **Description** + + The PerLayerPerStepLRTransform creates a per-step transform for each layer of a given module. + + This can be used with the GBML algorithm to reproduce the *LSLR* improvement of MAML++ proposed by + Antoniou et al. + + **Arguments** + + * **init_lr** (float) - The initial learning rate for each adaptation step and layer. + * **steps** (int) - The number of adaptation steps. + * **model** (torch.nn.Module) - The module being updated with the learning rates. This is + needed to define the learning rates for each layer. + * **layer_names** (List[str], *optional*, default=None) - If not None, only layers named with + one of the list elements will have a per-step learning rate. Otherwise, all layers will have + one. It may be more efficient to specify the layer names as to avoid redundant layers + introducing extra parameters, such as a "BatchNorm" layer followed by a "Conv" layer. + + **Example** + ~~~python + model = torch.nn.Sequential( + torch.nn.Linear(128, 24), torch.nn.Linear(24, 16), torch.nn.Linear(16, 10) + ) + transform = PerLayerPerStepLRTransform(1e-3, N_STEPS, model, ["conv", "linear"]) + metamodel = l2l.algorithms.GBML( + model, + transform, + allow_nograd=True, + lr=0.001, + adapt_transform=False, + pass_param_names=True, # This is needed for this transform to find the module's layers + ) + opt = torch.optim.Adam(metamodel.parameters(), lr=1.0) + ~~~ + """ + + def __init__(self, init_lr, steps, model, layer_names=None): + self._lslr = {} + for layer_name, layer in model.named_modules(): + # If the layer has learnable parameters + if ( + len( + [ + name + for name, param in layer.named_parameters(recurse=False) + if param.requires_grad + ] + ) + > 0 + ): + if layer_names is None or layer_name.lower().split(".")[-1] in [ + name.lower() for name in layer_names + ]: + self._lslr[layer_name] = PerStepLR(init_lr, steps) + else: + self._lslr[layer_name] = None + + def load_state_dict(self, lr_state_dicts: Dict[str, Dict[str, Any]]): + assert ( + type(lr_state_dicts) is dict + ), "Argument lr_state_dicts must be a dictionary!" + for layer_name, state_dict in lr_state_dicts.items(): + if self._lslr[layer_name] is not None: + self._lslr[layer_name].load_state_dict(state_dict) + + def state_dict(self) -> Dict[str, Optional[Dict[str, Any]]]: + return { + layer_name: (pslr.state_dict() if pslr is not None else None) for layer_name, pslr in self._lslr.items() + } + + def __call__(self, name, param): + name = name[ + : name.rfind(".") + ] # Extract the layer name from the named parameter + assert name in self._lslr, "No matching LR found for layer." + return self._lslr[name] + + def __str__(self): + string = "" + for layer, lslr in self._lslr.items(): + string += f"Layer {layer}: {lslr}\n" + return string + + +class PerStepLRTransform: + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/optim/transforms/layer_step_lr_transform.py) + + **Description** + + The PerStepLRTransform creates a per-step transform for inner-loop-based algorithms. + + This can be used with the GBML algorithm to reproduce the *LSLR* improvement of MAML++ proposed by + Antoniou et al, with the same learning rates for all layers. + + **Arguments** + + * **init_lr** (float) - The initial learning rate for each adaptation step. + * **steps** (int) - The number of adaptation steps. + + **Example** + ~~~python + model = torch.nn.Linear(128, 10) + transform = PerStepLRTransform(1e-3, N_STEPS) + metamodel = l2l.algorithms.GBML( + model, + transform, + allow_nograd=True, + lr=0.001, + adapt_transform=False, + ) + opt = torch.optim.Adam(metamodel.parameters(), lr=1.0) + ~~~ + """ + + def __init__(self, init_lr, steps): + self._obj = PerStepLR(init_lr, steps) + + def __call__(self, param): + return self._obj + + def __str__(self): + return str(self._obj) + + def parameters(self): + return self._obj.parameters() + + def load_state_dict(self, state_dict: Dict[str, Any]): + self._obj.load_state_dict(state_dict) + + def state_dict(self) -> Dict[str, Any]: + return self._obj.state_dict() + + +if __name__ == "__main__": + + random.seed(1234) + np.random.seed(1234) + torch.manual_seed(1234) + + N_DIMS = 32 + N_SAMPLES = 128 + N_STEPS = 5 + device = torch.device("cpu") + + print("[*] Testing per-step LR with one linear layer") + model = torch.nn.Linear(N_DIMS, 10) + transform = PerStepLRTransform(1e-3, N_STEPS) + # Setting adapt_transform=True means that the transform will be updated in + # the *adapt* function, which is not what we want. We want it to compute gradients during + # eval_loss.backward() only, so that it's updated in opt.step(). + metamodel = l2l.algorithms.GBML( + model, + transform, + lr=1.0, + adapt_transform=False, + allow_nograd=True, + ) + opt = torch.optim.Adam(metamodel.parameters(), lr=1.0) + print("\nPre-learning") + print("Transform parameters: ", transform) + for name, p in metamodel.named_parameters(): + print(name, ":", p.norm()) + + for task in range(10): + opt.zero_grad() + learner = metamodel.clone() + X = torch.randn(N_SAMPLES, N_DIMS) + + # fast adapt + for step in range(N_STEPS): + adapt_loss = learner(X).norm(2) + learner.adapt(adapt_loss) + + # meta-learn + eval_loss = learner(X).norm(2) + eval_loss.backward() + opt.step() + + print("\nPost-learning") + print("Transform parameters: ", transform) + for name, p in metamodel.named_parameters(): + print(name, ":", p.norm()) + + print("Transform state_dict: ", transform.state_dict()) + + print("\n\n--------------------------") + print("[*] Testing per-layer per-step LR with three linear layers") + model = torch.nn.Sequential( + torch.nn.Linear(N_DIMS, 24), torch.nn.Linear(24, 16), torch.nn.Linear(16, 10) + ) + transform = PerLayerPerStepLRTransform(1e-3, N_STEPS, model) + # Setting adapt_transform=True means that the transform will be updated in + # the *adapt* function, which is not what we want. We want it to compute gradients during + # eval_loss.backward() only, so that it's updated in opt.step(). + metamodel = l2l.algorithms.GBML( + model, + transform, + lr=1.0, + adapt_transform=False, + pass_param_names=True, + allow_nograd=True, + ) + opt = torch.optim.Adam(metamodel.parameters(), lr=1.0) + print("\nPre-learning") + print("Transform parameters: ", transform) + for name, p in metamodel.named_parameters(): + print(name, ":", p.norm()) + + for task in range(10): + opt.zero_grad() + learner = metamodel.clone() + X = torch.randn(N_SAMPLES, N_DIMS) + + # fast adapt + for step in range(N_STEPS): + adapt_loss = learner(X).norm(2) + learner.adapt(adapt_loss) + + # meta-learn + eval_loss = learner(X).norm(2) + eval_loss.backward() + opt.step() + + print("\nPost-learning") + print("Transform parameters: ", transform) + for name, p in metamodel.named_parameters(): + print(name, ":", p.norm()) + + print("Transform state_dict: ", transform.state_dict())