Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updatable objects #1633

Merged
merged 19 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions avalanche/_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,12 @@ def decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
warnings.simplefilter("always", DeprecationWarning)
warnings.simplefilter("once", DeprecationWarning)
warnings.warn(
msg.format(name=func.__name__, version=version, reason=reason),
category=DeprecationWarning,
stacklevel=2,
)
warnings.simplefilter("default", DeprecationWarning)
return func(*args, **kwargs)

return wrapper
Expand Down
14 changes: 11 additions & 3 deletions avalanche/benchmarks/scenarios/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
shuffle: bool = True,
drop_last: bool = False,
access_task_boundaries: bool = False,
seed: int = None,
) -> None:
"""Returns a lazy stream generated by splitting an experience into
smaller ones.
Expand All @@ -181,7 +182,8 @@ def __init__(
:param experience_size: The experience size (number of instances).
:param shuffle: If True, instances will be shuffled before splitting.
:param drop_last: If True, the last mini-experience will be dropped if
not of size `experience_size`
not of size `experience_size`.
:param seed: random seed for shuffling the data if `shuffle == True`.
:return: The list of datasets that will be used to create the
mini-experiences.
"""
Expand All @@ -190,10 +192,12 @@ def __init__(
self.shuffle = shuffle
self.drop_last = drop_last
self.access_task_boundaries = access_task_boundaries
self.seed = seed

# we need to fix the seed because repeated calls to the generator
# must return the same order every time.
self.seed = random.randint(0, 2**32 - 1)
if seed is None:
self.seed = random.randint(0, 2**32 - 1)

def __iter__(self) -> Generator[OnlineCLExperience, None, None]:
exp_dataset = self.experience.dataset
Expand Down Expand Up @@ -250,13 +254,15 @@ def _default_online_split(
access_task_boundaries: bool,
exp: DatasetExperience,
size: int,
seed: int,
):
return FixedSizeExperienceSplitter(
experience=exp,
experience_size=size,
shuffle=shuffle,
drop_last=drop_last,
access_task_boundaries=access_task_boundaries,
seed=seed,
)


Expand All @@ -272,6 +278,7 @@ def split_online_stream(
]
] = None,
access_task_boundaries: bool = False,
seed: int = None,
) -> CLStream[DatasetExperience[TCLDataset]]:
"""Split a stream of large batches to create an online stream of small
mini-batches.
Expand Down Expand Up @@ -300,6 +307,7 @@ def split_online_stream(
A good starting to understand the mechanism is to look at the
implementation of the standard splitting function
:func:`fixed_size_experience_split`.
:param seed: random seed used for shuffling by the default splitter.
:return: A lazy online stream with experiences of size `experience_size`.
"""

Expand All @@ -308,7 +316,7 @@ def split_online_stream(
# However, MyPy does not understand what a partial is -_-
def default_online_split_wrapper(e, e_sz):
return _default_online_split(
shuffle, drop_last, access_task_boundaries, e, e_sz
shuffle, drop_last, access_task_boundaries, e, e_sz, seed=seed
)

split_strategy = default_online_split_wrapper
Expand Down
1 change: 1 addition & 0 deletions avalanche/benchmarks/scenarios/validation_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def random_validation_split_strategy_wrapper(data):

# don't drop classes-timeline for compatibility with old API
e0 = next(iter(train_stream))

if hasattr(e0, "dataset") and hasattr(e0.dataset, "targets"):
train_stream = with_classes_timeline(train_stream)
valid_stream = with_classes_timeline(valid_stream)
Expand Down
142 changes: 141 additions & 1 deletion avalanche/core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,153 @@
# TODO: doc
from abc import ABC
from typing import Any, TypeVar, Generic
from typing import Any, TypeVar, Generic, Protocol, runtime_checkable
from typing import TYPE_CHECKING

from avalanche.benchmarks import CLExperience

if TYPE_CHECKING:
from avalanche.training.templates.base import BaseTemplate

Template = TypeVar("Template", bound="BaseTemplate")


class Agent:
"""Avalanche Continual Learning Agent.

The agent stores the state needed by continual learning training methods,
such as optimizers, models, regularization losses.
You can add any objects as attributes dynamically:

.. code-block::

agent = Agent()
agent.replay = ReservoirSamplingBuffer(max_size=200)
agent.loss = MaskedCrossEntropy()
agent.reg_loss = LearningWithoutForgetting(alpha=1, temperature=2)
agent.model = my_model
agent.opt = SGD(agent.model.parameters(), lr=0.001)
agent.scheduler = ExponentialLR(agent.opt, gamma=0.999)

Many CL objects will need to perform some operation before or
after training on each experience. This is supported via the `Adaptable`
Protocol, which requires the `pre_adapt` and `post_adapt` methods.
To call the pre/post adaptation you can implement your training loop
like in the following example:

.. code-block::

def train(agent, exp):
agent.pre_adapt(exp)
# do training here
agent.post_update(exp)
AntonioCarta marked this conversation as resolved.
Show resolved Hide resolved

Objects that implement the `Adaptable` Protocol will be called by the Agent.

You can also add additional functionality to the adaptation phases with
hooks. For example:

.. code-block::
agent.add_pre_hooks(lambda a, e: update_optimizer(a.opt, new_params={}, optimized_params=dict(a.model.named_parameters())))
# we update the lr scheduler after each experience (not every epoch!)
agent.add_post_hooks(lambda a, e: a.scheduler.step())


"""

def __init__(self, verbose=False):
"""Init.

:param verbose: If True, print every time an adaptable object or hook
is called during the adaptation. Useful for debugging.
"""
self._updatable_objects = []
self.verbose = verbose
self._pre_hooks = []
self._post_hooks = []

def __setattr__(self, name, value):
super().__setattr__(name, value)
if hasattr(value, "pre_adapt") or hasattr(value, "post_adapt"):
AntonioCarta marked this conversation as resolved.
Show resolved Hide resolved
self._updatable_objects.append(value)
if self.verbose:
print("Added updatable object ", value)

def pre_adapt(self, exp):
"""Pre-adaptation.

Remember to call this before training on a new experience.

:param exp: current experience
"""
for uo in self._updatable_objects:
if hasattr(uo, "pre_adapt"):
uo.pre_adapt(self, exp)
if self.verbose:
print("pre_adapt ", uo)
for foo in self._pre_hooks:
if self.verbose:
print("pre_adapt hook ", foo)
foo(self, exp)

def post_adapt(self, exp):
"""Post-adaptation.

Remember to call this after training on a new experience.

:param exp: current experience
"""
for uo in self._updatable_objects:
if hasattr(uo, "post_adapt"):
uo.post_adapt(self, exp)
if self.verbose:
print("post_adapt ", uo)
for foo in self._post_hooks:
if self.verbose:
print("post_adapt hook ", foo)
foo(self, exp)

def add_pre_hooks(self, foo):
"""Add a pre-adaptation hooks

Hooks take two arguments: `<agent, experience>`.

:param foo: the hook function
"""
self._pre_hooks.append(foo)

def add_post_hooks(self, foo):
"""Add a post-adaptation hooks

Hooks take two arguments: `<agent, experience>`.

:param foo: the hook function
"""
self._post_hooks.append(foo)


class Adaptable(Protocol):
"""Adaptable objects Protocol.

These class documents the Adaptable objects API but it is not necessary
for an object to inherit from it since the `Agent` will search for the methods
dynamically.

Adaptable objects are objects that require to run their `pre_adapt` and
`post_adapt` methods before (and after, respectively) training on each
experience.

Adaptable objects can implement only the method that they need since the
`Agent` will look for the methods dynamically and call it only if it is
implemented.
"""

def pre_adapt(self, agent: Agent, exp: CLExperience):
pass

def post_adapt(self, agent: Agent, exp: CLExperience):
pass


class BasePlugin(Generic[Template], ABC):
"""ABC for BaseTemplate plugins.

Expand Down
Loading
Loading