diff --git a/avalanche/models/dynamic_optimizers.py b/avalanche/models/dynamic_optimizers.py index e48241ee9..33d6eb72d 100644 --- a/avalanche/models/dynamic_optimizers.py +++ b/avalanche/models/dynamic_optimizers.py @@ -16,14 +16,220 @@ """ from collections import defaultdict +import numpy as np -def compare_keys(old_dict, new_dict): - not_in_new = list(set(old_dict.keys()) - set(new_dict.keys())) - in_both = list(set(old_dict.keys()) & set(new_dict.keys())) - not_in_old = list(set(new_dict.keys()) - set(old_dict.keys())) - return not_in_new, in_both, not_in_old +from avalanche._annotations import deprecated +colors = { + "END": "\033[0m", + 0: "\033[32m", + 1: "\033[33m", + 2: "\033[34m", + 3: "\033[35m", + 4: "\033[36m", +} +colors[None] = colors["END"] + +def _map_optimized_params(optimizer, parameters, old_params=None): + """ + Establishes a mapping between a list of named parameters and the parameters + that are in the optimizer, additionally, + returns the lists of: + + returns: + new_parameters: Names of new parameters in the provided "parameters" argument, + that are not in the old parameters + changed_parameters: Names and indexes of parameters that have changed (grown, shrink) + not_found_in_parameters: List of indexes of optimizer parameters + that are not found in the provided parameters + """ + + if old_params is None: + old_params = {} + + group_mapping = defaultdict(dict) + new_parameters = [] + + found_indexes = [] + changed_parameters = [] + for group in optimizer.param_groups: + params = group["params"] + found_indexes.append(np.zeros(len(params))) + + for n, p in parameters.items(): + gidx = None + pidx = None + + # Find param in optimizer + found = False + + if n in old_params: + search_id = id(old_params[n]) + else: + search_id = id(p) + + for group_idx, group in enumerate(optimizer.param_groups): + params = group["params"] + for param_idx, po in enumerate(params): + if id(po) == search_id: + gidx = group_idx + pidx = param_idx + found = True + # Update found indexes + assert found_indexes[group_idx][param_idx] == 0 + found_indexes[group_idx][param_idx] = 1 + break + if found: + break + + if not found: + new_parameters.append(n) + + if search_id != id(p): + if found: + changed_parameters.append((n, gidx, pidx)) + + if len(optimizer.param_groups) > 1: + group_mapping[n] = gidx + else: + group_mapping[n] = 0 + + not_found_in_parameters = [np.where(arr == 0)[0] for arr in found_indexes] + + return ( + group_mapping, + changed_parameters, + new_parameters, + not_found_in_parameters, + ) + + +def _build_tree_from_name_groups(name_groups): + root = _TreeNode("") # Root node + node_mapping = {} + + # Iterate through each string in the list + for name, group in name_groups.items(): + components = name.split(".") + current_node = root + + # Traverse the tree and construct nodes for each component + for component in components: + if component not in current_node.children: + current_node.children[component] = _TreeNode( + component, parent=current_node + ) + current_node = current_node.children[component] + + # Update the groups for the leaf node + if group is not None: + current_node.groups |= set([group]) + current_node.update_groups_upwards() # Inform parent about the groups + + # Update leaf node mapping dict + node_mapping[name] = current_node + + # This will resolve nodes without group + root.update_groups_downwards() + return root, node_mapping + + +def _print_group_information(node, prefix=""): + # Print the groups for the current node + + if len(node.groups) == 1: + pstring = ( + colors[list(node.groups)[0]] + + f"{prefix}{node.global_name()}: {node.groups}" + + colors["END"] + ) + print(pstring) + else: + print(f"{prefix}{node.global_name()}: {node.groups}") + + # Recursively print group information for children nodes + for child_name, child_node in node.children.items(): + _print_group_information(child_node, prefix + " ") + + +class _ParameterGroupStructure: + """ + Structure used for the resolution of unknown parameter groups, + stores parameters as a tree and propagates parameter groups from leaves of + the same hierarchical level + """ + + def __init__(self, name_groups, verbose=False): + # Here we rebuild the tree + self.root, self.node_mapping = _build_tree_from_name_groups(name_groups) + if verbose: + _print_group_information(self.root) + + def __getitem__(self, name): + return self.node_mapping[name] + + +class _TreeNode: + def __init__(self, name, parent=None): + self.name = name + self.children = {} + self.groups = set() # Set of groups (represented by index) this node belongs to + self.parent = parent # Reference to the parent node + if parent: + # Inform the parent about the new child node + parent.add_child(self) + + def add_child(self, child): + self.children[child.name] = child + + def update_groups_upwards(self): + if self.parent: + if self.groups != {None}: + self.parent.groups |= ( + self.groups + ) # Update parent's groups with the child's groups + self.parent.update_groups_upwards() # Propagate the group update to the parent + + def update_groups_downwards(self, new_groups=None): + # If you are a node with no groups, absorb + if len(self.groups) == 0 and new_groups is not None: + self.groups = self.groups.union(new_groups) + + # Then transmit + if len(self.groups) > 0: + for key, children in self.children.items(): + children.update_groups_downwards(self.groups) + + def global_name(self, initial_name=None): + """ + Returns global node name + """ + if initial_name is None: + initial_name = self.name + elif self.name != "": + initial_name = ".".join([self.name, initial_name]) + + if self.parent: + return self.parent.global_name(initial_name) + else: + return initial_name + + @property + def single_group(self): + if len(self.groups) == 0: + raise AttributeError( + f"Could not identify group for this node {self.global_name()}" + ) + elif len(self.groups) > 1: + raise AttributeError( + f"No unique group found for this node {self.global_name()}" + ) + else: + return list(self.groups)[0] + + +@deprecated(0.6, "update_optimizer with optimized_params=None is now used instead") def reset_optimizer(optimizer, model): """Reset the optimizer to update the list of learnable parameters. @@ -53,7 +259,14 @@ def reset_optimizer(optimizer, model): return optimized_param_id -def update_optimizer(optimizer, new_params, optimized_params, reset_state=False): +def update_optimizer( + optimizer, + new_params, + optimized_params=None, + reset_state=False, + remove_params=False, + verbose=False, +): """Update the optimizer by adding new parameters, removing removed parameters, and adding new parameters to the optimizer, for instance after model has been adapted @@ -64,72 +277,69 @@ def update_optimizer(optimizer, new_params, optimized_params, reset_state=False) :param new_params: Dict (name, param) of new parameters :param optimized_params: Dict (name, param) of - currently optimized parameters (returned by reset_optimizer) - :param reset_state: Wheter to reset the optimizer's state (i.e momentum). - Defaults to False. + currently optimized parameters + :param reset_state: Whether to reset the optimizer's state (i.e momentum). + Defaults to False. + :param remove_params: Whether to remove parameters that were in the optimizer + but are not found in new parameters. For safety reasons, + defaults to False. + :param verbose: If True, prints information about inferred + parameter groups for new params + :return: Dict (name, param) of optimized parameters """ - not_in_new, in_both, not_in_old = compare_keys(optimized_params, new_params) + ( + group_mapping, + changed_parameters, + new_parameters, + not_found_in_parameters, + ) = _map_optimized_params(optimizer, new_params, old_params=optimized_params) + # Change reference to already existing parameters # i.e growing IncrementalClassifier - for key in in_both: - old_p_hash = optimized_params[key] - new_p = new_params[key] + for name, group_idx, param_idx in changed_parameters: + group = optimizer.param_groups[group_idx] + old_p = optimized_params[name] + new_p = new_params[name] # Look for old parameter id in current optimizer - found = False - for group in optimizer.param_groups: - for i, curr_p in enumerate(group["params"]): - if id(curr_p) == id(old_p_hash): - found = True - if id(curr_p) != id(new_p): - group["params"][i] = new_p - optimized_params[key] = new_p - optimizer.state[new_p] = {} - break - if not found: - raise Exception( - f"Parameter {key} expected but " "not found in the optimizer" - ) + group["params"][param_idx] = new_p + if old_p in optimizer.state: + optimizer.state.pop(old_p) + optimizer.state[new_p] = {} # Remove parameters that are not here anymore # This should not happend in most use case - keys_to_remove = [] - for key in not_in_new: - old_p_hash = optimized_params[key] - found = False - for i, group in enumerate(optimizer.param_groups): - keys_to_remove.append([]) - for j, curr_p in enumerate(group["params"]): - if id(curr_p) == id(old_p_hash): - found = True - keys_to_remove[i].append((j, curr_p)) - optimized_params.pop(key) - break - if not found: - raise Exception( - f"Parameter {key} expected but " "not found in the optimizer" - ) - - for i, idx_list in enumerate(keys_to_remove): - for j, p in sorted(idx_list, key=lambda x: x[0], reverse=True): - del optimizer.param_groups[i]["params"][j] - if p in optimizer.state: - optimizer.state.pop(p) + if remove_params: + for group_idx, idx_list in enumerate(not_found_in_parameters): + for j in sorted(idx_list, key=lambda x: x, reverse=True): + p = optimizer.param_groups[group_idx]["params"][j] + optimizer.param_groups[group_idx]["params"].pop(j) + if p in optimizer.state: + optimizer.state.pop(p) + del p # Add newly added parameters (i.e Multitask, PNN) - # by default, add to param groups 0 - for key in not_in_old: + + param_structure = _ParameterGroupStructure(group_mapping, verbose=verbose) + + # New parameters + for key in new_parameters: new_p = new_params[key] - optimizer.param_groups[0]["params"].append(new_p) + group = param_structure[key].single_group + optimizer.param_groups[group]["params"].append(new_p) optimized_params[key] = new_p optimizer.state[new_p] = {} if reset_state: optimizer.state = defaultdict(dict) - return optimized_params + return new_params +@deprecated( + 0.6, + "parameters have to be added manually to the optimizer in an existing or a new parameter group", +) def add_new_params_to_optimizer(optimizer, new_params): """Add new parameters to the trainable parameters. diff --git a/avalanche/training/templates/observation_type/batch_observation.py b/avalanche/training/templates/observation_type/batch_observation.py index 374d71c45..a663e2d42 100644 --- a/avalanche/training/templates/observation_type/batch_observation.py +++ b/avalanche/training/templates/observation_type/batch_observation.py @@ -36,32 +36,42 @@ def model_adaptation(self, model=None): return model.to(self.device) - def make_optimizer(self, reset_optimizer_state=False, **kwargs): + def make_optimizer( + self, + reset_optimizer_state=False, + remove_params=False, + verbose_optimizer=False, + **kwargs + ): """Optimizer initialization. Called before each training experience to configure the optimizer. :param reset_optimizer_state: bool, whether to reset the state of the optimizer, defaults to False + :param remove_params: bool, whether to remove parameters that + are in the optimizer but not found in the current model + :param verbose_optimizer: bool, print optimized parameters + along with their parameter group Warnings: - The first time this function is called for a given strategy it will reset the optimizer to gather the (name, param) correspondance of the optimized parameters - all the model parameters will be put in the + all of the model parameters will be put in the optimizer, regardless of what parameters are initially put in the optimizer. + """ - if self.optimized_param_id is None: - self.optimized_param_id = reset_optimizer(self.optimizer, self.model) - else: - self.optimized_param_id = update_optimizer( - self.optimizer, - dict(self.model.named_parameters()), - self.optimized_param_id, - reset_state=reset_optimizer_state, - ) + self.optimized_param_id = update_optimizer( + self.optimizer, + dict(self.model.named_parameters()), + self.optimized_param_id, + reset_state=reset_optimizer_state, + remove_params=remove_params, + verbose=verbose_optimizer, + ) def check_model_and_optimizer(self, reset_optimizer_state=False, **kwargs): # If strategy has access to the task boundaries, and the current diff --git a/examples/optimizer_param_groups.py b/examples/optimizer_param_groups.py new file mode 100644 index 000000000..cc152e643 --- /dev/null +++ b/examples/optimizer_param_groups.py @@ -0,0 +1,97 @@ +################################################################################ +# Copyright (c) 2021 ContinualAI. # +# Copyrights licensed under the MIT License. # +# See the accompanying LICENSE file for terms. # +# # +# Date: 25-03-2024 # +# Author(s): Albin Soutif # +# E-mail: contact@continualai.org # +# Website: avalanche.continualai.org # +################################################################################ + +""" +This example trains a Multi-head model on Split MNIST with Elastic Weight +Consolidation. Each experience has a different task label, which is used at test +time to select the appropriate head. Additionally, it assigns different parameter groups +to the classifier and the backbone, assigning lower learning rate to +the backbone than to the classifier. When the multihead classifier grows, +new parameters are automatically assigned to the corresponding parameter group +""" + +import argparse +import torch +from torch.nn import CrossEntropyLoss +from torch.optim import SGD + +from avalanche.benchmarks.classic import SplitMNIST +from avalanche.models import MTSimpleMLP +from avalanche.training.supervised import EWC +from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics +from avalanche.logging import InteractiveLogger +from avalanche.training.plugins import EvaluationPlugin + + +def main(args): + # Config + device = torch.device( + f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu" + ) + # model + model = MTSimpleMLP() + + # CL Benchmark Creation + benchmark = SplitMNIST(n_experiences=5, return_task_id=True) + train_stream = benchmark.train_stream + test_stream = benchmark.test_stream + + # Prepare for training & testing + g1 = {"params": [], "lr": 0.1} + g2 = {"params": [], "lr": 0.01} + + for n, p in model.named_parameters(): + if "classifier" in n: + g1["params"].append(p) + else: + g2["params"].append(p) + + optimizer = SGD([g1, g2]) + criterion = CrossEntropyLoss() + + # choose some metrics and evaluation method + interactive_logger = InteractiveLogger() + + eval_plugin = EvaluationPlugin( + accuracy_metrics(minibatch=False, epoch=True, experience=True, stream=True), + forgetting_metrics(experience=True), + loggers=[interactive_logger], + ) + + # Choose a CL strategy + strategy = EWC( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=128, + train_epochs=3, + eval_mb_size=128, + device=device, + evaluator=eval_plugin, + ewc_lambda=0.4, + ) + + # train and test loop + for train_task in train_stream: + strategy.train(train_task, num_workers=4, verbose_optimizer=True) + strategy.eval(test_stream, num_workers=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--cuda", + type=int, + default=0, + help="Select zero-indexed cuda device. -1 to use CPU.", + ) + args = parser.parse_args() + main(args) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index b7b2793bf..86c5f5854 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -83,183 +83,6 @@ def test_get_model(self): self.assertIsInstance(model, pytorchcv.models.resnet.ResNet) -class DynamicOptimizersTests(unittest.TestCase): - if "USE_GPU" in os.environ: - use_gpu = os.environ["USE_GPU"].lower() in ["true"] - else: - use_gpu = False - - print("Test on GPU:", use_gpu) - - if use_gpu: - device = "cuda" - else: - device = "cpu" - - def setUp(self): - common_setups() - - def _iterate_optimizers(self, model, *optimizers): - for opt_class in optimizers: - if opt_class == "SGDmom": - yield torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) - if opt_class == "SGD": - yield torch.optim.SGD(model.parameters(), lr=0.1) - if opt_class == "Adam": - yield torch.optim.Adam(model.parameters(), lr=0.001) - if opt_class == "AdamW": - yield torch.optim.AdamW(model.parameters(), lr=0.001) - - def _is_param_in_optimizer(self, param, optimizer): - for group in optimizer.param_groups: - for curr_p in group["params"]: - if hash(curr_p) == hash(param): - return True - return False - - def load_benchmark(self, use_task_labels=False): - """ - Returns a NC benchmark from a fake dataset of 10 classes, 5 experiences, - 2 classes per experience. - - :param fast_test: if True loads fake data, MNIST otherwise. - """ - return get_fast_benchmark(use_task_labels=use_task_labels) - - def init_scenario(self, multi_task=False): - model = self.get_model(multi_task=multi_task) - criterion = CrossEntropyLoss() - benchmark = self.load_benchmark(use_task_labels=multi_task) - return model, criterion, benchmark - - def test_optimizer_update(self): - model = SimpleMLP() - optimizer = SGD(model.parameters(), lr=1e-3) - strategy = Naive(model, optimizer) - - # check add_param_group - p = torch.nn.Parameter(torch.zeros(10, 10)) - add_new_params_to_optimizer(optimizer, p) - assert self._is_param_in_optimizer(p, strategy.optimizer) - - # check new_param is in optimizer - # check old_param is NOT in optimizer - p_new = torch.nn.Parameter(torch.zeros(10, 10)) - optimized = update_optimizer(optimizer, {"new_param": p_new}, {"old_param": p}) - self.assertTrue("new_param" in optimized) - self.assertFalse("old_param" in optimized) - self.assertTrue(self._is_param_in_optimizer(p_new, strategy.optimizer)) - self.assertFalse(self._is_param_in_optimizer(p, strategy.optimizer)) - - def test_optimizers(self): - # SIT scenario - model, criterion, benchmark = self.init_scenario(multi_task=True) - for optimizer in self._iterate_optimizers( - model, "SGDmom", "Adam", "SGD", "AdamW" - ): - strategy = Naive( - model=model, - optimizer=optimizer, - criterion=criterion, - train_mb_size=64, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) - self._test_optimizer(strategy) - - # Needs torch 2.0 ? - def test_checkpointing(self): - model, criterion, benchmark = self.init_scenario(multi_task=True) - optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9) - strategy = Naive( - model=model, - optimizer=optimizer, - criterion=criterion, - train_mb_size=64, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) - experience_0 = benchmark.train_stream[0] - strategy.train(experience_0) - old_state = copy.deepcopy(strategy.optimizer.state) - save_checkpoint(strategy, "./checkpoint.pt") - - del strategy - - model, criterion, benchmark = self.init_scenario(multi_task=True) - optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9) - strategy = Naive( - model=model, - optimizer=optimizer, - criterion=criterion, - train_mb_size=64, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) - strategy, exp_counter = maybe_load_checkpoint( - strategy, "./checkpoint.pt", strategy.device - ) - - # Check that the state has been well serialized - self.assertEqual(len(strategy.optimizer.state), len(old_state)) - for (key_new, value_new_dict), (key_old, value_old_dict) in zip( - strategy.optimizer.state.items(), old_state.items() - ): - self.assertTrue(torch.equal(key_new, key_old)) - - value_new = value_new_dict["momentum_buffer"] - value_old = value_old_dict["momentum_buffer"] - - # Empty state - if len(value_new) == 0 or len(value_old) == 0: - self.assertTrue(len(value_new) == len(value_old)) - else: - self.assertTrue(torch.equal(value_new, value_old)) - - experience_1 = benchmark.train_stream[1] - strategy.train(experience_1) - os.remove("./checkpoint.pt") - - def test_mh_classifier(self): - model, criterion, benchmark = self.init_scenario(multi_task=True) - optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9) - strategy = Naive( - model=model, - optimizer=optimizer, - criterion=criterion, - train_mb_size=64, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) - strategy.train(benchmark.train_stream) - - def _test_optimizer(self, strategy): - # Add a parameter - module = torch.nn.Linear(10, 10) - param1 = list(module.parameters())[0] - strategy.make_optimizer() - self.assertFalse(self._is_param_in_optimizer(param1, strategy.optimizer)) - strategy.model.add_module("new_module", module) - strategy.make_optimizer() - self.assertTrue(self._is_param_in_optimizer(param1, strategy.optimizer)) - # Remove a parameter - del strategy.model.new_module - - strategy.make_optimizer() - self.assertFalse(self._is_param_in_optimizer(param1, strategy.optimizer)) - - def get_model(self, multi_task=False): - if multi_task: - model = MTSimpleMLP(input_size=6, hidden_size=10) - else: - model = SimpleMLP(input_size=6, hidden_size=10) - return model - - class DynamicModelsTests(unittest.TestCase): def setUp(self): common_setups() diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 000000000..1c0fd7b63 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,473 @@ +#!/usr/bin/env python3 +import copy +import os +import sys +import tempfile +import unittest + +import numpy as np +import pytorchcv.models.pyramidnet_cifar +import torch +import torch.nn.functional as F +from tests.benchmarks.utils.test_avalanche_classification_dataset import get_mbatch +from tests.unit_tests_utils import common_setups, get_fast_benchmark, load_benchmark +from torch.nn import CrossEntropyLoss +from torch.optim import SGD +from torch.utils.data import DataLoader + +from avalanche.checkpointing import maybe_load_checkpoint, save_checkpoint +from avalanche.logging import TextLogger +from avalanche.models import ( + IncrementalClassifier, + MTSimpleMLP, + MultiHeadClassifier, + SimpleMLP, +) +from avalanche.models.cosine_layer import CosineLinear, SplitCosineLinear +from avalanche.models.dynamic_optimizers import ( + add_new_params_to_optimizer, + update_optimizer, +) +from avalanche.models.pytorchcv_wrapper import densenet, get_model, pyramidnet, resnet +from avalanche.models.utils import avalanche_model_adaptation +from avalanche.training.supervised import Naive + + +class TorchWrapper(torch.nn.Module): + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, *args): + return self.backbone(*args) + + +class DynamicOptimizersTests(unittest.TestCase): + if "USE_GPU" in os.environ: + use_gpu = os.environ["USE_GPU"].lower() in ["true"] + else: + use_gpu = False + + print("Test on GPU:", use_gpu) + + if use_gpu: + device = "cuda" + else: + device = "cpu" + + def setUp(self): + common_setups() + + def _iterate_optimizers(self, model, *optimizers): + for opt_class in optimizers: + if opt_class == "SGDmom": + yield torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + if opt_class == "SGD": + yield torch.optim.SGD(model.parameters(), lr=0.1) + if opt_class == "Adam": + yield torch.optim.Adam(model.parameters(), lr=0.001) + if opt_class == "AdamW": + yield torch.optim.AdamW(model.parameters(), lr=0.001) + + def _is_param_in_optimizer(self, param, optimizer): + for group in optimizer.param_groups: + for curr_p in group["params"]: + if hash(curr_p) == hash(param): + return True + return False + + def _is_param_in_optimizer_group(self, param, optimizer): + for group_idx, group in enumerate(optimizer.param_groups): + for curr_p in group["params"]: + if hash(curr_p) == hash(param): + return group_idx + return None + + def load_benchmark(self, use_task_labels=False): + return get_fast_benchmark(use_task_labels=use_task_labels) + + def init_scenario(self, multi_task=False): + model = self.get_model(multi_task=multi_task) + criterion = CrossEntropyLoss() + benchmark = self.load_benchmark(use_task_labels=multi_task) + return model, criterion, benchmark + + def test_optimizer_update(self): + model = SimpleMLP() + optimizer = SGD(model.parameters(), lr=1e-3) + strategy = Naive(model=model, optimizer=optimizer) + + # check add new parameter + p = torch.nn.Parameter(torch.zeros(10, 10)) + optimizer.param_groups[0]["params"].append(p) + assert self._is_param_in_optimizer(p, strategy.optimizer) + + # check new_param is in optimizer + # check old_param is NOT in optimizer + p_new = torch.nn.Parameter(torch.zeros(10, 10)) + + # Here we cannot know what parameter group but there is only one so it should work + new_parameters = {"new_param": p_new} + new_parameters.update(dict(model.named_parameters())) + optimized = update_optimizer( + optimizer, new_parameters, {"old_param": p}, remove_params=True + ) + self.assertTrue("new_param" in optimized) + self.assertFalse("old_param" in optimized) + self.assertTrue(self._is_param_in_optimizer(p_new, strategy.optimizer)) + self.assertFalse(self._is_param_in_optimizer(p, strategy.optimizer)) + + def test_optimizers(self): + """ + Run a series of tests on various pytorch optimizers + """ + + # SIT scenario + model, criterion, benchmark = self.init_scenario(multi_task=True) + for optimizer in self._iterate_optimizers( + model, "SGDmom", "Adam", "SGD", "AdamW" + ): + strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) + self._test_optimizer(strategy) + self._test_optimizer_state(strategy) + + def test_optimizer_groups_clf_til(self): + """ + Tests the automatic assignation of new + MultiHead parameters to the optimizer + """ + model, criterion, benchmark = self.init_scenario(multi_task=True) + + g1 = [] + g2 = [] + for n, p in model.named_parameters(): + if "classifier" in n: + g1.append(p) + else: + g2.append(p) + + optimizer = SGD([{"params": g1, "lr": 0.1}, {"params": g2, "lr": 0.05}]) + + strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) + + for experience in benchmark.train_stream: + strategy.train(experience) + + for n, p in model.named_parameters(): + assert self._is_param_in_optimizer(p, strategy.optimizer) + if "classifier" in n: + self.assertEqual( + self._is_param_in_optimizer_group(p, strategy.optimizer), 0 + ) + else: + self.assertEqual( + self._is_param_in_optimizer_group(p, strategy.optimizer), 1 + ) + + def test_optimizer_groups_clf_cil(self): + """ + Tests the automatic assignation of new + IncrementalClassifier parameters to the optimizer + """ + model, criterion, benchmark = self.init_scenario(multi_task=False) + + g1 = [] + g2 = [] + for n, p in model.named_parameters(): + if "classifier" in n: + g1.append(p) + else: + g2.append(p) + + optimizer = SGD([{"params": g1, "lr": 0.1}, {"params": g2, "lr": 0.05}]) + + strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) + + for experience in benchmark.train_stream: + strategy.train(experience) + + for n, p in model.named_parameters(): + assert self._is_param_in_optimizer(p, strategy.optimizer) + if "classifier" in n: + self.assertEqual( + self._is_param_in_optimizer_group(p, strategy.optimizer), 0 + ) + else: + self.assertEqual( + self._is_param_in_optimizer_group(p, strategy.optimizer), 1 + ) + + def test_optimizer_groups_manual_addition(self): + """ + Tests the manual addition of a new parameter group + mixed with existing MultiHeadClassifier + """ + model, criterion, benchmark = self.init_scenario(multi_task=True) + + g1 = [] + g2 = [] + for n, p in model.named_parameters(): + if "classifier" in n: + g1.append(p) + else: + g2.append(p) + + optimizer = SGD([{"params": g1, "lr": 0.1}, {"params": g2, "lr": 0.05}]) + + strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) + + experience_0 = benchmark.train_stream[0] + experience_1 = benchmark.train_stream[1] + + strategy.train(experience_0) + + # Add some new parameter and assign it manually to param group + model.new_module1 = torch.nn.Linear(10, 10) + model.new_module2 = torch.nn.Linear(10, 10) + strategy.optimizer.param_groups[1]["params"] += list( + model.new_module1.parameters() + ) + strategy.optimizer.add_param_group( + {"params": list(model.new_module2.parameters()), "lr": 0.001} + ) + + # Also add one but to a new param group + + strategy.train(experience_1) + + for n, p in model.named_parameters(): + assert self._is_param_in_optimizer(p, strategy.optimizer) + if "classifier" in n: + self.assertEqual( + self._is_param_in_optimizer_group(p, strategy.optimizer), 0 + ) + elif "new_module2" in n: + self.assertEqual( + self._is_param_in_optimizer_group(p, strategy.optimizer), 2 + ) + else: + self.assertEqual( + self._is_param_in_optimizer_group(p, strategy.optimizer), 1 + ) + + def test_optimizer_groups_rename(self): + """ + Tests the correct reassignation to + existing parameter groups after + parameter renaming + """ + model, criterion, benchmark = self.init_scenario(multi_task=False) + + g1 = [] + g2 = [] + for n, p in model.named_parameters(): + if "classifier" in n: + g1.append(p) + else: + g2.append(p) + + optimizer = SGD([{"params": g1, "lr": 0.1}, {"params": g2, "lr": 0.05}]) + + strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) + + strategy.make_optimizer() + + # Check parameter groups + for n, p in model.named_parameters(): + assert self._is_param_in_optimizer(p, strategy.optimizer) + if "classifier" in n: + self.assertEqual( + self._is_param_in_optimizer_group(p, strategy.optimizer), 0 + ) + else: + self.assertEqual( + self._is_param_in_optimizer_group(p, strategy.optimizer), 1 + ) + + # Rename parameters + strategy.model = TorchWrapper(strategy.model) + + strategy.make_optimizer() + + # Check parameter groups are still the same + for n, p in model.named_parameters(): + assert self._is_param_in_optimizer(p, strategy.optimizer) + if "classifier" in n: + self.assertEqual( + self._is_param_in_optimizer_group(p, strategy.optimizer), 0 + ) + else: + self.assertEqual( + self._is_param_in_optimizer_group(p, strategy.optimizer), 1 + ) + + # Needs torch 2.0 ? + def test_checkpointing(self): + model, criterion, benchmark = self.init_scenario(multi_task=True) + optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9) + strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) + experience_0 = benchmark.train_stream[0] + strategy.train(experience_0) + old_state = copy.deepcopy(strategy.optimizer.state) + save_checkpoint(strategy, "./checkpoint.pt") + + del strategy + + model, criterion, benchmark = self.init_scenario(multi_task=True) + optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9) + strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) + strategy, exp_counter = maybe_load_checkpoint( + strategy, "./checkpoint.pt", strategy.device + ) + + # Check that the state has been well serialized + self.assertEqual(len(strategy.optimizer.state), len(old_state)) + for (key_new, value_new_dict), (key_old, value_old_dict) in zip( + strategy.optimizer.state.items(), old_state.items() + ): + self.assertTrue(torch.equal(key_new, key_old)) + + value_new = value_new_dict["momentum_buffer"] + value_old = value_old_dict["momentum_buffer"] + + # Empty state + if len(value_new) == 0 or len(value_old) == 0: + self.assertTrue(len(value_new) == len(value_old)) + else: + self.assertTrue(torch.equal(value_new, value_old)) + + experience_1 = benchmark.train_stream[1] + strategy.train(experience_1) + os.remove("./checkpoint.pt") + + def _test_optimizer(self, strategy): + # Add a parameter + module = torch.nn.Linear(10, 10) + param1 = list(module.parameters())[0] + strategy.make_optimizer() + self.assertFalse(self._is_param_in_optimizer(param1, strategy.optimizer)) + strategy.model.add_module("new_module", module) + strategy.make_optimizer() + self.assertTrue(self._is_param_in_optimizer(param1, strategy.optimizer)) + # Remove a parameter + del strategy.model.new_module + + strategy.make_optimizer(remove_params=False) + self.assertTrue(self._is_param_in_optimizer(param1, strategy.optimizer)) + + strategy.make_optimizer(remove_params=True) + self.assertFalse(self._is_param_in_optimizer(param1, strategy.optimizer)) + + def _test_optimizer_state(self, strategy): + # Add Two modules + module1 = torch.nn.Linear(10, 10) + module2 = torch.nn.Linear(10, 10) + param1 = list(module1.parameters())[0] + param2 = list(module2.parameters())[0] + strategy.model.add_module("new_module1", module1) + strategy.model.add_module("new_module2", module2) + + strategy.make_optimizer(remove_params=True) + + self.assertTrue(self._is_param_in_optimizer(param1, strategy.optimizer)) + self.assertTrue(self._is_param_in_optimizer(param2, strategy.optimizer)) + + # Make an operation + self._optimizer_op(strategy.optimizer, module1.weight + module2.weight) + + if len(strategy.optimizer.state) > 0: + assert param1 in strategy.optimizer.state + assert param2 in strategy.optimizer.state + + # Remove one module + del strategy.model.new_module1 + + strategy.make_optimizer(remove_params=True) + + # Make an operation + self._optimizer_op(strategy.optimizer, module1.weight + module2.weight) + + if len(strategy.optimizer.state) > 0: + assert param1 not in strategy.optimizer.state + assert param2 in strategy.optimizer.state + + # Change one module size + strategy.model.new_module2 = torch.nn.Linear(10, 5) + strategy.make_optimizer(remove_params=True) + + # Make an operation + self._optimizer_op(strategy.optimizer, module1.weight + module2.weight) + + if len(strategy.optimizer.state) > 0: + assert param1 not in strategy.optimizer.state + assert param2 not in strategy.optimizer.state + + def _optimizer_op(self, optimizer, param): + optimizer.zero_grad() + loss = torch.mean(param) + loss.backward() + optimizer.step() + + def get_model(self, multi_task=False): + if multi_task: + model = MTSimpleMLP(input_size=6, hidden_size=10) + else: + model = SimpleMLP(input_size=6, hidden_size=10) + model.classifier = IncrementalClassifier(10, 1) + return model