Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiple param groups #1609

Merged
merged 9 commits into from
Mar 27, 2024
316 changes: 263 additions & 53 deletions avalanche/models/dynamic_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,220 @@
"""
from collections import defaultdict

import numpy as np

def compare_keys(old_dict, new_dict):
not_in_new = list(set(old_dict.keys()) - set(new_dict.keys()))
in_both = list(set(old_dict.keys()) & set(new_dict.keys()))
not_in_old = list(set(new_dict.keys()) - set(old_dict.keys()))
return not_in_new, in_both, not_in_old
from avalanche._annotations import deprecated

colors = {
"END": "\033[0m",
0: "\033[32m",
1: "\033[33m",
2: "\033[34m",
3: "\033[35m",
4: "\033[36m",
}
colors[None] = colors["END"]


def _map_optimized_params(optimizer, parameters, old_params=None):
"""
Establishes a mapping between a list of named parameters and the parameters
that are in the optimizer, additionally,
returns the lists of:

returns:
new_parameters: Names of new parameters in the provided "parameters" argument,
that are not in the old parameters
changed_parameters: Names and indexes of parameters that have changed (grown, shrink)
not_found_in_parameters: List of indexes of optimizer parameters
that are not found in the provided parameters
"""

if old_params is None:
old_params = {}

group_mapping = defaultdict(dict)
new_parameters = []

found_indexes = []
changed_parameters = []
for group in optimizer.param_groups:
params = group["params"]
found_indexes.append(np.zeros(len(params)))

for n, p in parameters.items():
gidx = None
pidx = None

# Find param in optimizer
found = False

if n in old_params:
search_id = id(old_params[n])
else:
search_id = id(p)

for group_idx, group in enumerate(optimizer.param_groups):
params = group["params"]
for param_idx, po in enumerate(params):
if id(po) == search_id:
gidx = group_idx
pidx = param_idx
found = True
# Update found indexes
assert found_indexes[group_idx][param_idx] == 0
found_indexes[group_idx][param_idx] = 1
break
if found:
break

if not found:
new_parameters.append(n)

if search_id != id(p):
if found:
changed_parameters.append((n, gidx, pidx))

if len(optimizer.param_groups) > 1:
group_mapping[n] = gidx
else:
group_mapping[n] = 0

not_found_in_parameters = [np.where(arr == 0)[0] for arr in found_indexes]

return (
group_mapping,
changed_parameters,
new_parameters,
not_found_in_parameters,
)


def _build_tree_from_name_groups(name_groups):
root = _TreeNode("") # Root node
node_mapping = {}

# Iterate through each string in the list
for name, group in name_groups.items():
components = name.split(".")
current_node = root

# Traverse the tree and construct nodes for each component
for component in components:
if component not in current_node.children:
current_node.children[component] = _TreeNode(
component, parent=current_node
)
current_node = current_node.children[component]

# Update the groups for the leaf node
if group is not None:
current_node.groups |= set([group])
current_node.update_groups_upwards() # Inform parent about the groups

# Update leaf node mapping dict
node_mapping[name] = current_node

# This will resolve nodes without group
root.update_groups_downwards()
return root, node_mapping


def _print_group_information(node, prefix=""):
# Print the groups for the current node

if len(node.groups) == 1:
pstring = (
colors[list(node.groups)[0]]
+ f"{prefix}{node.global_name()}: {node.groups}"
+ colors["END"]
)
print(pstring)
else:
print(f"{prefix}{node.global_name()}: {node.groups}")

# Recursively print group information for children nodes
for child_name, child_node in node.children.items():
_print_group_information(child_node, prefix + " ")


class _ParameterGroupStructure:
"""
Structure used for the resolution of unknown parameter groups,
stores parameters as a tree and propagates parameter groups from leaves of
the same hierarchical level
"""

def __init__(self, name_groups, verbose=False):
# Here we rebuild the tree
self.root, self.node_mapping = _build_tree_from_name_groups(name_groups)
if verbose:
_print_group_information(self.root)

def __getitem__(self, name):
return self.node_mapping[name]


class _TreeNode:
def __init__(self, name, parent=None):
self.name = name
self.children = {}
self.groups = set() # Set of groups (represented by index) this node belongs to
self.parent = parent # Reference to the parent node
if parent:
# Inform the parent about the new child node
parent.add_child(self)

def add_child(self, child):
self.children[child.name] = child

def update_groups_upwards(self):
if self.parent:
if self.groups != {None}:
self.parent.groups |= (
self.groups
) # Update parent's groups with the child's groups
self.parent.update_groups_upwards() # Propagate the group update to the parent

def update_groups_downwards(self, new_groups=None):
# If you are a node with no groups, absorb
if len(self.groups) == 0 and new_groups is not None:
self.groups = self.groups.union(new_groups)

# Then transmit
if len(self.groups) > 0:
for key, children in self.children.items():
children.update_groups_downwards(self.groups)

def global_name(self, initial_name=None):
"""
Returns global node name
"""
if initial_name is None:
initial_name = self.name
elif self.name != "":
initial_name = ".".join([self.name, initial_name])

if self.parent:
return self.parent.global_name(initial_name)
else:
return initial_name

@property
def single_group(self):
if len(self.groups) == 0:
raise AttributeError(
f"Could not identify group for this node {self.global_name()}"
)
elif len(self.groups) > 1:
raise AttributeError(
f"No unique group found for this node {self.global_name()}"
)
else:
return list(self.groups)[0]


@deprecated(0.6, "update_optimizer with optimized_params=None is now used instead")
def reset_optimizer(optimizer, model):
"""Reset the optimizer to update the list of learnable parameters.

Expand Down Expand Up @@ -53,7 +259,14 @@ def reset_optimizer(optimizer, model):
return optimized_param_id


def update_optimizer(optimizer, new_params, optimized_params, reset_state=False):
def update_optimizer(
optimizer,
new_params,
optimized_params=None,
reset_state=False,
remove_params=False,
verbose=False,
):
"""Update the optimizer by adding new parameters,
removing removed parameters, and adding new parameters
to the optimizer, for instance after model has been adapted
Expand All @@ -64,72 +277,69 @@ def update_optimizer(optimizer, new_params, optimized_params, reset_state=False)

:param new_params: Dict (name, param) of new parameters
:param optimized_params: Dict (name, param) of
currently optimized parameters (returned by reset_optimizer)
:param reset_state: Wheter to reset the optimizer's state (i.e momentum).
Defaults to False.
currently optimized parameters
:param reset_state: Whether to reset the optimizer's state (i.e momentum).
Defaults to False.
:param remove_params: Whether to remove parameters that were in the optimizer
but are not found in new parameters. For safety reasons,
defaults to False.
:param verbose: If True, prints information about inferred
parameter groups for new params

:return: Dict (name, param) of optimized parameters
"""
not_in_new, in_both, not_in_old = compare_keys(optimized_params, new_params)
(
group_mapping,
changed_parameters,
new_parameters,
not_found_in_parameters,
) = _map_optimized_params(optimizer, new_params, old_params=optimized_params)

# Change reference to already existing parameters
# i.e growing IncrementalClassifier
for key in in_both:
old_p_hash = optimized_params[key]
new_p = new_params[key]
for name, group_idx, param_idx in changed_parameters:
group = optimizer.param_groups[group_idx]
old_p = optimized_params[name]
new_p = new_params[name]
# Look for old parameter id in current optimizer
found = False
for group in optimizer.param_groups:
for i, curr_p in enumerate(group["params"]):
if id(curr_p) == id(old_p_hash):
found = True
if id(curr_p) != id(new_p):
group["params"][i] = new_p
optimized_params[key] = new_p
optimizer.state[new_p] = {}
break
if not found:
raise Exception(
f"Parameter {key} expected but " "not found in the optimizer"
)
group["params"][param_idx] = new_p
if old_p in optimizer.state:
optimizer.state.pop(old_p)
optimizer.state[new_p] = {}

# Remove parameters that are not here anymore
# This should not happend in most use case
keys_to_remove = []
for key in not_in_new:
old_p_hash = optimized_params[key]
found = False
for i, group in enumerate(optimizer.param_groups):
keys_to_remove.append([])
for j, curr_p in enumerate(group["params"]):
if id(curr_p) == id(old_p_hash):
found = True
keys_to_remove[i].append((j, curr_p))
optimized_params.pop(key)
break
if not found:
raise Exception(
f"Parameter {key} expected but " "not found in the optimizer"
)

for i, idx_list in enumerate(keys_to_remove):
for j, p in sorted(idx_list, key=lambda x: x[0], reverse=True):
del optimizer.param_groups[i]["params"][j]
if p in optimizer.state:
optimizer.state.pop(p)
if remove_params:
for group_idx, idx_list in enumerate(not_found_in_parameters):
for j in sorted(idx_list, key=lambda x: x, reverse=True):
p = optimizer.param_groups[group_idx]["params"][j]
optimizer.param_groups[group_idx]["params"].pop(j)
if p in optimizer.state:
optimizer.state.pop(p)
del p

# Add newly added parameters (i.e Multitask, PNN)
# by default, add to param groups 0
for key in not_in_old:

param_structure = _ParameterGroupStructure(group_mapping, verbose=verbose)

# New parameters
for key in new_parameters:
new_p = new_params[key]
optimizer.param_groups[0]["params"].append(new_p)
group = param_structure[key].single_group
optimizer.param_groups[group]["params"].append(new_p)
optimized_params[key] = new_p
optimizer.state[new_p] = {}

if reset_state:
optimizer.state = defaultdict(dict)

return optimized_params
return new_params


@deprecated(
0.6,
"parameters have to be added manually to the optimizer in an existing or a new parameter group",
)
def add_new_params_to_optimizer(optimizer, new_params):
"""Add new parameters to the trainable parameters.

Expand Down
Loading
Loading