From 973349f0afe21a6cf17f02b3bd8821426068b38f Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 18 Dec 2024 15:29:45 +0000 Subject: [PATCH 01/11] Rollout Schedulers --- .../training/diagnostics/mlflow/logger.py | 6 +- .../training/schedulers/rollout/__init__.py | 167 ++++++++ .../training/schedulers/rollout/indexed.py | 172 +++++++++ .../training/schedulers/rollout/randomise.py | 364 ++++++++++++++++++ .../training/schedulers/rollout/stepped.py | 155 ++++++++ src/anemoi/training/train/forecaster.py | 27 +- src/anemoi/training/train/train.py | 2 +- tests/schedulers/__init__.py | 8 + tests/schedulers/rollout/__init__.py | 8 + 9 files changed, 891 insertions(+), 18 deletions(-) create mode 100644 src/anemoi/training/schedulers/rollout/__init__.py create mode 100644 src/anemoi/training/schedulers/rollout/indexed.py create mode 100644 src/anemoi/training/schedulers/rollout/randomise.py create mode 100644 src/anemoi/training/schedulers/rollout/stepped.py create mode 100644 tests/schedulers/__init__.py create mode 100644 tests/schedulers/rollout/__init__.py diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 03a4b6de..0f6deeb8 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -196,12 +196,12 @@ def _log_collector(self) -> None: log_capture_time_counter = 0 def _store_buffered_logs(self) -> None: - _buffer_size = self._io_buffer.tell() - if not _buffer_size: + buffer_size = self._io_buffer.tell() + if not buffer_size: return self._io_buffer.seek(0) # read and reset the buffer - data = self._io_buffer.read(_buffer_size) + data = self._io_buffer.read(buffer_size) self._io_buffer.seek(0) # handle the buffered data and store # split lines and keep \n at the end of each line diff --git a/src/anemoi/training/schedulers/rollout/__init__.py b/src/anemoi/training/schedulers/rollout/__init__.py new file mode 100644 index 00000000..9da65b2e --- /dev/null +++ b/src/anemoi/training/schedulers/rollout/__init__.py @@ -0,0 +1,167 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod +from typing import Literal + + +class RolloutScheduler(ABC): + """ + `RolloutScheduler` is an abstract base class for rollout schedulers. + + A rollout scheduler is an object that manages the rollout of a training loop. + + ```python + RollSched = RolloutScheduler() + + for epoch in range(20): + for step in range(100): + y = model(x, rollout = RollSched.rollout) + + RollSched.step() + RollSched.step_epoch() + ``` + """ + + _epoch: int = 0 + _step: int = 0 + + @property + @abstractmethod + def rollout(self) -> int: + """Get the current rollout value.""" + error_msg = "`rollout` property not implemented by parent class." + raise NotImplementedError(error_msg) + + @property + @abstractmethod + def maximum_rollout(self) -> int: + """Get maximum rollout possible.""" + error_msg = "`maximum_rollout` property not implemented by parent class." + raise NotImplementedError(error_msg) + + @property + def current_maximum(self) -> int: + """Get the current maximum rollout value.""" + return self.rollout + + def __int__(self) -> int: + return int(self.rollout) + + def rollout_at(self, step: int | None = None, epoch: int | None = None) -> int: + """ + Get the rollout at a specific step and epoch. + + Parameters + ---------- + step : int, optional + Step value to override with, by default None + epoch : int, optional + Epoch value to override with, by default None + + Returns + ------- + int + Rollout value at the specified step and epoch. + """ + step_ = self._step + epoch_ = self._epoch + + self._step = step if step is not None else step_ + self._epoch = epoch if epoch is not None else epoch_ + + rollout = self.rollout + + self._step = step_ + self._epoch = epoch_ + + return rollout + + def step(self, count: int = 1, /) -> None: + """Step the scheduler by a count.""" + self._step += count + + def step_epoch(self, count: int = 1, /) -> None: + """Step the scheduler by a count of epochs.""" + self._epoch += count + + def count(self, every_n: int, step_type: Literal["step", "epoch"]) -> int: + """ + Get the count of steps or epochs. + + Parameters + ---------- + every_n : int + Every n steps or epochs. + step_type : _type_, optional + Which to count, by default Literal['step', 'epoch'] + + Returns + ------- + int + Count of steps or epochs. + + Raises + ------ + ValueError + If the step_type is not 'step' or 'epoch'. + """ + if step_type == "epoch": + return (self._epoch - 1) // every_n + if step_type == "step": + return self._step // every_n + + error_msg = "Invalid `step_type`. Must be 'epoch' or 'step'." + raise ValueError(error_msg) + + @abstractmethod + def description(self) -> str: + """Description of the rollout scheduler.""" + error_msg = "`description` method not implemented by parent class." + raise NotImplementedError(error_msg) + + +class Static(RolloutScheduler): + """`Static` is a rollout scheduler that always returns the same rollout value.""" + + def __init__(self, rollout_value: int): + """ + `Static` is a rollout scheduler that always returns the same rollout value. + + Parameters + ---------- + rollout_value : int + Rollout value to return. + + Example + ------- + ```python + from anemoi.training.schedulers.rollout import Static + RollSched = Static(rollout_value = 5) + RollSched.rollout_at(epoch = 1) + # 5 + RollSched.rollout_at(epoch = 5) + # 5 + ``` + """ + self._rollout_value = rollout_value + + @property + def rollout(self) -> int: + return self._rollout_value + + @property + def maximum_rollout(self) -> int: + return self._rollout_value + + def description(self) -> str: + return f"Static rollout value of {self._rollout_value}." diff --git a/src/anemoi/training/schedulers/rollout/indexed.py b/src/anemoi/training/schedulers/rollout/indexed.py new file mode 100644 index 00000000..307d76f8 --- /dev/null +++ b/src/anemoi/training/schedulers/rollout/indexed.py @@ -0,0 +1,172 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from typing import Any +from typing import Literal + +from anemoi.training.schedulers.rollout import RolloutScheduler + + +def get_closest_key(dictionary: dict[int, Any], key: int) -> int: + """ + Get the closest int key in a dictionary to a given key. + + Where the closest key is the one with the smallest absolute difference + and the key is less than or equal to the given key. + + Parameters + ---------- + dictionary : dict[int, Any] + Dictionary to search. + key : int + Key to search for. + + Returns + ------- + int + Closest key in the dictionary. + """ + return min(dictionary.keys(), key=lambda x: abs(x - key) if x <= key else float("inf")) + + +class PositionalIndexed(RolloutScheduler): + """ + `PositionalIndexed` retrieves the rollout value from a list of rollouts based on the current epoch or step. + + Once the list is exhausted, the rollout will remain at the last value. + """ + + def __init__( + self, + rollouts: list[int], + num_times_per_element: int = 1, + step_type: Literal["step", "epoch"] = "epoch", + ): + """ + `PositionalIndexed` retrieves the rollout value from a list of rollouts based on the current epoch or step. + + Once the list is exhausted, the rollout will remain at the last value. + + Parameters + ---------- + rollouts : list[int] + List of rollout values. + num_times_per_element: int, optional + Number of times to remain at a element, by default 1 + step_type : Literal['step', 'epoch'], optional + Type of step, either 'epoch' or 'step'. + by default 'epoch'. + + Example + ------- + ```python + from anemoi.training.schedulers.rollout.indexed import PositionalIndexed + + RollSched = PositionalIndexed(rollouts = [1, 2, 3, 4], num_times_per_element = 2, step_type = 'epoch') + RollSched.at_epoch(1) + # 1 + RollSched.at_epoch(2) + # 1 + RollSched.at_epoch(3) + # 2 + ``` + """ + super().__init__() + self._rollouts = rollouts + self._num_times_per_element = num_times_per_element + self._step_type = step_type + + @property + def rollout(self) -> int: + count = self.count(self._num_times_per_element, self._step_type) + return self._rollouts[min(len(self._rollouts), count)] + + @property + def maximum_rollout(self) -> int: + return max(self._rollouts) + + +class EpochPositionalIndexed(PositionalIndexed): + """Epoch based PositionalIndexed.""" + + def __init__(self, rollouts: list[int]): + super().__init__(rollouts, step_type="epoch") + + +class StepPositionalIndexed(PositionalIndexed): + """Step based PositionalIndexed.""" + + def __init__(self, rollouts: list[int]): + super().__init__(rollouts, step_type="step") + + +class Lookup(RolloutScheduler): + """ + `Lookup` retrieves the rollout value from a dictionary of rollouts based on the current epoch or step. + + It will return the closest key that is less than or equal to the current epoch or step. + """ + + def __init__(self, rollouts: dict[int, int], step_type: Literal["step", "epoch"] = "epoch"): + """ + `Lookup` retrieves the rollout value from a dictionary of rollouts based on the current epoch or step. + + It will return the closest key that is less than or equal to the current epoch or step. + + Parameters + ---------- + rollouts : dict[int, int] + Dictionary of rollouts. + step_type : Literal['step', 'epoch'], optional + Type of step, either 'epoch' or 'step'. + by default 'epoch' + + Example + ------- + ```python + from anemoi.training.schedulers.rollout.indexed import Lookup + + RollSched = Lookup(rollouts = {0: 1, 5: 2, 10: 3}, step_type = 'epoch') + RollSched.at_epoch(1) + # 1 + RollSched.at_epoch(5) + # 2 + ``` + """ + super().__init__() + self._rollouts = rollouts + self._step_type = step_type + + @property + def rollout(self) -> int: + if self._step_type == "epoch": + return self._rollouts.get(get_closest_key(self._rollouts, self._epoch), 1) + if self._step_type == "step": + return self._rollouts.get(get_closest_key(self._rollouts, self._step), 1) + + error_msg = "Invalid step_type. Must be 'epoch' or 'step'." + raise ValueError(error_msg) + + @property + def maximum_rollout(self) -> int: + return max(self._rollouts.values()) + + +class EpochLookup(Lookup): + """Epoch based Lookup.""" + + def __init__(self, rollouts: dict[int, int]): + super().__init__(rollouts, step_type="epoch") + + +class StepLookup(Lookup): + """Step based Lookup.""" + + def __init__(self, rollouts: dict[int, int]): + super().__init__(rollouts, step_type="step") diff --git a/src/anemoi/training/schedulers/rollout/randomise.py b/src/anemoi/training/schedulers/rollout/randomise.py new file mode 100644 index 00000000..ec0eef71 --- /dev/null +++ b/src/anemoi/training/schedulers/rollout/randomise.py @@ -0,0 +1,364 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# ruff: noqa: S608 + +from __future__ import annotations + +from typing import Literal + +import numpy as np +import pytorch_lightning as pl + +from anemoi.training.schedulers.rollout import RolloutScheduler +from anemoi.training.schedulers.rollout.indexed import get_closest_key +from anemoi.training.utils.seeding import get_base_seed + + +class BaseRandom(RolloutScheduler): + """BaseRandom Scheduler.""" + + def __init__(self): + """ + Initialise the base random rollout scheduler. + + Set the seed with the environment variable `ANEMOI_BASE_SEED` if it exists, + """ + super().__init__() + + try: + seed = get_base_seed() + except AssertionError: + seed = 42 + + rnd_seed = pl.seed_everything(seed, workers=True) + self.rng = np.random.default_rng(rnd_seed) + + def broadcast(self, value: int) -> None: + """ + Broadcast the rollout value to all processes. + + Parameters + ---------- + value : int + Value to broadcast. + """ + # TODO(Harrison Cook): Need to broadcast the rollout to all processes + + def _randomly_pick(self, rollouts: list[int]) -> int: + """ + Randomly pick from a list of rollouts. + + Parameters + ---------- + rollouts : list[int] + s to choose from. + + Returns + ------- + int + Randomly selected rollout. + """ + rollout = self.rng.choice(rollouts) + self.broadcast(rollout) + return rollout + + +class RandomList(BaseRandom): + """`RandomList` is a rollout scheduler that randomly selects a rollout from a list of values.""" + + def __init__(self, rollouts: list[int]): + """ + RandomList is a rollout scheduler that randomly selects a rollout from a list of values. + + Parameters + ---------- + rollouts : list[int] + List of rollouts to choose from. + + Example + ------- + ```python + from anemoi.training.schedulers.rollout import RandomList + + RollSched = RandomList(rollouts = [1, 2, 3, 4, 5]) + RollSched.rollout_at(epoch = 1) + # any value in the list + RollSched.rollout_at(epoch = 2) + # any value in the list + ``` + """ + super().__init__() + self._rollouts = rollouts + + @property + def rollout(self) -> int: + return self._randomly_pick(self._rollouts) + + @property + def maximum_rollout(self) -> int: + return max(self._rollouts) + + @property + def current_maximum(self) -> int: + return self.maximum_rollout + + def description(self) -> str: + return f"Randomly select a rollout from {self._rollouts}" + + +class RandomRange(RandomList): + """`RandomRange` is a rollout scheduler that randomly selects a rollout from a range of values.""" + + def __init__(self, minimum: int = 1, maximum: int = 1, step: int = 1): + """ + RandomRange is a rollout scheduler that randomly selects a rollout from a range of values. + + Parameters + ---------- + minimum : int, optional + Minimum rollout to choose from, by default 1 + maximum : int, optional + Maximum rollout to choose from, by default 1 + step : int, optional + Step size for the range, by default 1 + + Example + ------- + ```python + from anemoi.training.schedulers.rollout import RandomRange + + RollSched = RandomRange(minimum = 1, maximum = 5) + RollSched.rollout_at(epoch = 1) + # any value between 1 and 5 + RollSched.rollout_at(epoch = 2) + # any value between 1 and 5 + ``` + """ + super().__init__(list(range(minimum, maximum + 1, step))) + + def description(self) -> str: + return ( + "Randomly select a rollout from the " + f"{range(min(self._rollouts), max(self._rollouts) + 1, np.diff(self._rollouts)[0])}" + ) + + +class IncreasingRandom(BaseRandom): + """IncreasingRandom is a rollout scheduler that randomly selects a rollout from an increasing range of values.""" + + def __init__( + self, + minimum: int = 1, + maximum: int = 1, + range_step: int = 1, + every_n: int = 1, + increment: int | dict[int, int] = 1, + step_type: Literal["step", "epoch"] = "epoch", + ): + """ + `IncreasingRandom` is a rollout scheduler that randomly selects a rollout from an increasing range of values. + + Parameters + ---------- + minimum : int, optional + Minimum rollout to choose from, by default 1 + maximum : int, optional + Maximum rollout to choose from, can be -1 for no maximum, + by default 1. + range_step : int, optional + Step size for the range, by default 1 + every_n : int, optional + Number of steps or epochs to step the rollout value. + If `every_n` is 0, the rollout will stay at `minimum`. + increment : int | dict[int, int], optional + Value to increment the rollout by `every_n_epochs`, by default 1 + step_type : Literal['step', 'epoch'], optional + Type of step, either 'epoch' or 'batch'. + by default 'epoch'. + + Example + ------- + ```python + from anemoi.training.schedulers.rollout import IncreasingRandom + + RollSched = IncreasingRandom(minimum = 1, maximum = 10, step = 1, every_n_epochs = 1) + RollSched.rollout_at(epoch = 1) + # any value between 1 and 1 + RollSched.rollout_at(epoch = 2) + # any value between 1 and 2 + ``` + """ + super().__init__() + + if maximum <= -1: + maximum = float("inf") + + self._minimum = minimum + self._maximum = maximum + self._range_step = range_step + self._every_n = every_n + self._increment = increment + self._step_type = step_type + + @property + def rollout(self) -> int: + if self._every_n == 0: + return self._minimum + + count_of_n = self.count(self._every_n, self._step_type) + + if isinstance(self._increment, int): + maximum_value = self._minimum + self._increment * count_of_n + else: + sum_of_increments = [ + self._increment.get(get_closest_key(self._increment, i + 1)) for i in range(count_of_n) + ] + maximum_value = self._minimum + sum(sum_of_increments) + + rollouts = range(self._minimum, maximum_value + 1, self._range_step) + + return self._randomly_pick(rollouts) + + @property + def maximum_rollout(self) -> int: + return self._maximum + + @property + def current_maximum(self) -> int: + return self._minimum + ((self._epoch // self._every_n_epochs) * self._step) + + def description(self) -> str: + return ( + f"Randomly select a rollout from the increasing range " + f"{range(self._minimum, self._maximum, self._step)}" + f"with the upper bound increasing by {self._step} every {self._every_n} {self._step_type}" + ) + + +class EpochIncreasingRandom(IncreasingRandom): + """ + `EpochIncreasingRandom` is a rollout scheduler that randomly selects a rollout from an increasing range of values. + + The maximum is incremented every n epochs. + """ + + def __init__( + self, + minimum: int = 1, + maximum: int = 1, + range_step: int = 1, + every_n_epochs: int = 1, + increment: int | dict[int, int] = 1, + ): + """ + EpochIncreasingRandom is a rollout scheduler that randomly selects a rollout from an increasing range of values. + + The maximum is incremented every n epochs. + + Parameters + ---------- + minimum : int, optional + Minimum rollout to choose from, by default 1 + maximum : int, optional + Maximum rollout to choose from, can be -1 for no maximum, + by default 1. + range_step : int, optional + Step size for the range, by default 1 + every_n_epochs : int, optional + Number of epochs to step the rollout value. + If `every_n_epochs` is 0, the rollout will stay at `minimum`. + increment : int | dict[int, int], optional + Value to increment the rollout by `every_n_epochs`, by default 1 + + Example + ------- + ```python + from anemoi.training.schedulers.rollout import EpochIncreasingRandom + + RollSched = EpochIncreasingRandom(minimum = 1, maximum = 10, range_step = 1, every_n_epochs = 1, increment = 1) + RollSched.rollout_at(epoch = 1) + # any value between 1 and 1 + RollSched.rollout_at(epoch = 2) + # any value between 1 and 2 + + RollSched = EpochIncreasingRandom( + minimum = 1, maximum = 10, range_step = 1, + every_n_epochs = 1, increment = {0: 0, 10: 1} + ) + RollSched.rollout_at(epoch = 1) + # any value between 1 and 1 + RollSched.rollout_at(epoch = 9) + # any value between 1 and 1 + RollSched.rollout_at(epoch = 10) + # any value between 1 and 2, and then increments of 1 + ``` + """ + super().__init__(minimum, maximum, range_step, every_n_epochs, increment, step_type="epoch") + + +class StepIncreasingRandom(IncreasingRandom): + """ + `StepIncreasingRandom` is a rollout scheduler that randomly selects a rollout from an increasing range of values. + + The maximum is incremented every n steps. + """ + + def __init__( + self, + minimum: int = 1, + maximum: int = 1, + range_step: int = 1, + every_n_steps: int = 1, + increment: int | dict[int, int] = 1, + ): + """ + StepIncreasingRandom` is a rollout scheduler that randomly selects a rollout from an increasing range of values. + + The maximum is incremented every n steps. + + Parameters + ---------- + minimum : int, optional + Minimum rollout to choose from, by default 1 + maximum : int, optional + Maximum rollout to choose from, can be -1 for no maximum, + by default 1. + range_step : int, optional + Step size for the range, by default 1 + every_n_steps : int, optional + Number of steps to step the rollout value. + If `every_n_steps` is 0, the rollout will stay at `minimum`. + increment : int | dict[int, int], optional + Value to increment the rollout by `every_n_epochs`, by default 1 + + Example + ------- + ```python + from anemoi.training.schedulers.rollout import StepIncreasingRandom + + RollSched = StepIncreasingRandom(minimum = 1, maximum = 10, range_step = 1, every_n_steps = 1, increment = 1) + RollSched.rollout_at(step = 1) + # any value between 1 and 1 + RollSched.rollout_at(step = 2) + # any value between 1 and 2 + + RollSched = StepIncreasingRandom( + minimum = 1, maximum = 10, range_step = 1, + every_n_steps = 1, increment = {0: 0, 10: 1} + ) + RollSched.rollout_at(step = 1) + # any value between 1 and 1 + RollSched.rollout_at(step = 9) + # any value between 1 and 1 + RollSched.rollout_at(step = 10) + # any value between 1 and 2, and then increments of 1 + ``` + """ + super().__init__(minimum, maximum, range_step, every_n_steps, increment, step_type="step") diff --git a/src/anemoi/training/schedulers/rollout/stepped.py b/src/anemoi/training/schedulers/rollout/stepped.py new file mode 100644 index 00000000..f07425e1 --- /dev/null +++ b/src/anemoi/training/schedulers/rollout/stepped.py @@ -0,0 +1,155 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +from __future__ import annotations + +from typing import Literal + +from anemoi.training.schedulers.rollout import RolloutScheduler +from anemoi.training.schedulers.rollout.indexed import get_closest_key + + +class Stepped(RolloutScheduler): + """`Stepped` is a base rollout scheduler that steps the rollout value at the end of each n steps or epochs.""" + + def __init__( + self, + minimum: int, + maximum: int, + every_n: int, + increment: int | dict[int, int], + step_type: Literal["step", "epoch"] = "epoch", + ): + """ + `SteppedRollout` is a base rollout scheduler that steps the rollout value at the end of each n steps or epochs. + + Parameters + ---------- + minimum : int + Minimum rollout value. + maximum : int + Maximum rollout value. + Can be -1 for no maximum. + every_n : int + Number of steps or epochs to step the rollout value. + If `every_n` is 0, the rollout will stay at `minimum`. + increment : int | dict[int, int], optional + Value to increment the rollout by. + Can be an int or dictionary, where the keys represent the value of `step_type` + and the values represent the increment. + Will round down to the closest key. + i.e. {0: 1, 10: 2} will increment by 1 until 10, then by 2. + by default 1. + step_type : Literal['step', 'epoch'], optional + Type of step, either 'epoch' or 'step'. + by default 'epoch'. + + Example + ------- + ```python + from anemoi.training.schedulers.rollout.stepped import Stepped + + RollSched = Stepped(minimum = 1, maximum = 10, every_n = 5, increment = 1) + RollSched.rollout_at(epoch = 2) + # 1 + RollSched.rollout_at(epoch = 5) + # 2 + + RollSched = Stepped(minimum = 1, maximum = 10, every_n = 5, increment = 2) + RollSched.rollout_at(epoch = 2) + # 1 + RollSched.rollout_at(epoch = 5) + # 3 + + RollSched = Stepped(minimum = 1, maximum = 10, every_n = 1, increment = {0: 0, 10: 1}) + RollSched.rollout_at(epoch = 2) + # 1 + RollSched.rollout_at(epoch = 9) + # 1 + RollSched.rollout_at(epoch = 10) + # 2, and then increments of 1 + ``` + """ + super().__init__() + + if maximum <= -1: + maximum = float("inf") + + self._minimum = minimum + self._maximum = maximum + self._every_n = every_n + self._increment = increment + self._step_type = step_type + + @property + def rollout(self) -> int: + if self._every_n == 0: + return self._minimum + + count_of_n = self.count(self._every_n, self._step_type) + + if isinstance(self._increment, int): + return min(self._maximum, self._minimum + self._increment * count_of_n) + + sum_of_increments = [ + self._increment.get(get_closest_key(self._increment, i + 1 if self._step_type == "epoch" else i)) + for i in range(count_of_n) + ] + return min(self._maximum, self._minimum + sum(sum_of_increments)) + + @property + def maximum_rollout(self) -> int: + return self._maximum + + def description(self) -> str: + return ( + "Stepped rollout scheduler stepping between" + f"{self._minimum} and {self._maximum} by {self._increment} for {self._every_n} {self._step_type}s." + ) + + +class EpochStepped(Stepped): + """`EpochStepped` is a rollout scheduler that steps the rollout value at the end of each n epochs.""" + + def __init__(self, minimum: int, maximum: int, every_n_epochs: int = 1, increment: int = 1): + """ + `EpochStepped` is a rollout scheduler that steps the rollout value at the end of each n epochs. + + Parameters + ---------- + minimum : int + The minimum value for the scheduler. + maximum : int + The maximum value for the scheduler. + every_n_epochs : int, optional + The number of epochs after which the value is incremented, by default 1. + increment : int, optional + The amount by which the value is incremented, by default 1. + """ + super().__init__(minimum, maximum, every_n_epochs, increment, step_type="epoch") + + +class StepStepped(Stepped): + """`StepStepped` is a rollout scheduler that steps the rollout value at the end of each n steps.""" + + def __init__(self, minimum: int, maximum: int, every_n_steps: int = 1000, increment: int = 1): + """ + `StepStepped` is a rollout scheduler that steps the rollout value at the end of each n steps. + + Parameters + ---------- + minimum : int + The minimum value for the scheduler. + maximum : int + The maximum value for the scheduler. + every_n_steps : int, optional + The number of steps after which the value is incremented, by default 1000. + increment : int, optional + The amount by which the value is incremented, by default 1. + """ + super().__init__(minimum, maximum, every_n_steps, increment, step_type="step") diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 0059d90a..4351fb35 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -12,6 +12,7 @@ from collections import defaultdict from collections.abc import Generator from collections.abc import Mapping +from typing import TYPE_CHECKING from typing import Optional from typing import Union @@ -38,6 +39,9 @@ LOGGER = logging.getLogger(__name__) +if TYPE_CHECKING: + from anemoi.training.training.schedulers.rollout import RolloutScheduler + class GraphForecaster(pl.LightningModule): """Graph neural network forecaster for PyTorch Lightning.""" @@ -146,18 +150,15 @@ def __init__( self.warmup_t = getattr(config.training.lr, "warmup_t", 1000) self.lr_iterations = config.training.lr.iterations self.lr_min = config.training.lr.min - self.rollout = config.training.rollout.start - self.rollout_epoch_increment = config.training.rollout.epoch_increment - self.rollout_max = config.training.rollout.max + + self.rollout: RolloutScheduler = instantiate(config.training.rollout) self.use_zero_optimizer = config.training.zero_optimizer self.model_comm_group = None self.reader_groups = None - LOGGER.debug("Rollout window length: %d", self.rollout) - LOGGER.debug("Rollout increase every : %d epochs", self.rollout_epoch_increment) - LOGGER.debug("Rollout max : %d", self.rollout_max) + LOGGER.debug("Rollout config: %d", self.rollout.description()) LOGGER.debug("Multistep: %d", self.multi_step) # lazy init model and reader group info, will be set by the DDPGroupStrategy: @@ -451,7 +452,7 @@ def rollout_step( ) assert batch.shape[1] >= rollout + self.multi_step, msg - for rollout_step in range(rollout or self.rollout): + for rollout_step in range(rollout or int(self.rollout)): # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) y_pred = self(x) @@ -485,7 +486,7 @@ def _step( for loss_next, metrics_next, y_preds_next in self.rollout_step( batch, - rollout=self.rollout, + rollout=int(self.rollout), training_mode=True, validation_mode=validation_mode, ): @@ -493,7 +494,8 @@ def _step( metrics.update(metrics_next) y_preds.extend(y_preds_next) - loss *= 1.0 / self.rollout + loss *= 1.0 / int(self.rollout) + self.rollout.step() return loss, metrics, y_preds def allgather_batch(self, batch: torch.Tensor) -> torch.Tensor: @@ -619,7 +621,7 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: ) self.log( "rollout", - float(self.rollout), + int(self.rollout), on_step=True, logger=self.logger_enabled, rank_zero_only=True, @@ -642,10 +644,7 @@ def lr_scheduler_step(self, scheduler: CosineLRScheduler, metric: None = None) - scheduler.step(epoch=self.trainer.global_step) def on_train_epoch_end(self) -> None: - if self.rollout_epoch_increment > 0 and self.current_epoch % self.rollout_epoch_increment == 0: - self.rollout += 1 - LOGGER.debug("Rollout window length: %d", self.rollout) - self.rollout = min(self.rollout, self.rollout_max) + self.rollout.step_epoch() def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: """ diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 694fb2da..c638bc11 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -328,7 +328,7 @@ def _log_information(self) -> None: "Effective learning rate: %.3e", int(total_number_of_model_instances) * self.config.training.lr.rate, ) - LOGGER.debug("Rollout window length: %d", self.config.training.rollout.start) + LOGGER.debug("Rollout config: %d", self.config.training.rollout) if self.config.training.max_epochs is not None and self.config.training.max_steps not in (None, -1): LOGGER.info( diff --git a/tests/schedulers/__init__.py b/tests/schedulers/__init__.py new file mode 100644 index 00000000..c167afa2 --- /dev/null +++ b/tests/schedulers/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/tests/schedulers/rollout/__init__.py b/tests/schedulers/rollout/__init__.py new file mode 100644 index 00000000..c167afa2 --- /dev/null +++ b/tests/schedulers/rollout/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. From fcf1f1fe7347d450d21994cc9fc8896a4e11f0a7 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 18 Dec 2024 16:53:48 +0000 Subject: [PATCH 02/11] Incrementer - Allow for complex incrementing setup --- .../training/config/training/default.yaml | 13 +- .../training/schedulers/rollout/__init__.py | 26 ++-- .../training/schedulers/rollout/indexed.py | 8 +- .../training/schedulers/rollout/randomise.py | 36 +++--- .../training/schedulers/rollout/stepped.py | 117 +++++++++++++++--- 5 files changed, 150 insertions(+), 50 deletions(-) diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index 6c915eb5..27ff7e22 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -86,11 +86,16 @@ scale_validation_metrics: # length of the "rollout" window (see Keisler's paper) rollout: - start: 1 + _target_: anemoi.training.schedulers.stepped.EpochStepped + minimum: 1 + maximum: 12 # increase rollout every n epochs - epoch_increment: 0 - # maximum rollout to use - max: 1 + every_n_epochs: 1 + # increment + increment: + step: + 0: 0 + 200000: 1 # Set max_epochs or max_steps. Training stops at the first limit reached. max_epochs: null diff --git a/src/anemoi/training/schedulers/rollout/__init__.py b/src/anemoi/training/schedulers/rollout/__init__.py index 9da65b2e..e4e9ec06 100644 --- a/src/anemoi/training/schedulers/rollout/__init__.py +++ b/src/anemoi/training/schedulers/rollout/__init__.py @@ -94,16 +94,16 @@ def step_epoch(self, count: int = 1, /) -> None: """Step the scheduler by a count of epochs.""" self._epoch += count - def count(self, every_n: int, step_type: Literal["step", "epoch"]) -> int: + def count(self, n_epochs: int | None = None, n_steps: int | None = None) -> int: """ Get the count of steps or epochs. Parameters ---------- - every_n : int - Every n steps or epochs. - step_type : _type_, optional - Which to count, by default Literal['step', 'epoch'] + n_epochs : int | None, optional + Number of epochs to count, by default None + n_steps : int | None, optional + Number of steps to count, by default None Returns ------- @@ -113,15 +113,17 @@ def count(self, every_n: int, step_type: Literal["step", "epoch"]) -> int: Raises ------ ValueError - If the step_type is not 'step' or 'epoch'. + If both `n_epochs` and `n_steps` are given, or if neither are given. """ - if step_type == "epoch": - return (self._epoch - 1) // every_n - if step_type == "step": - return self._step // every_n + if n_epochs is not None and n_steps is not None or n_epochs is None and n_steps is None: + error_msg = "Only one of `n_epochs` or `n_steps` can be given." + raise ValueError(error_msg) + + if n_epochs is not None: + return self._epoch // n_epochs + if n_steps is not None: + return self._step // n_steps - error_msg = "Invalid `step_type`. Must be 'epoch' or 'step'." - raise ValueError(error_msg) @abstractmethod def description(self) -> str: diff --git a/src/anemoi/training/schedulers/rollout/indexed.py b/src/anemoi/training/schedulers/rollout/indexed.py index 307d76f8..782aa27a 100644 --- a/src/anemoi/training/schedulers/rollout/indexed.py +++ b/src/anemoi/training/schedulers/rollout/indexed.py @@ -84,7 +84,13 @@ def __init__( @property def rollout(self) -> int: - count = self.count(self._num_times_per_element, self._step_type) + if self._step_type == "epoch": + count = self.count(n_epochs=self._num_times_per_element) + elif self._step_type == "step": + count = self.count(n_steps=self._num_times_per_element) + else: + error_msg = "Invalid step_type. Must be 'epoch' or 'step'." + raise ValueError(error_msg) return self._rollouts[min(len(self._rollouts), count)] @property diff --git a/src/anemoi/training/schedulers/rollout/randomise.py b/src/anemoi/training/schedulers/rollout/randomise.py index ec0eef71..efa0efaa 100644 --- a/src/anemoi/training/schedulers/rollout/randomise.py +++ b/src/anemoi/training/schedulers/rollout/randomise.py @@ -16,9 +16,9 @@ import numpy as np import pytorch_lightning as pl -from anemoi.training.schedulers.rollout import RolloutScheduler -from anemoi.training.schedulers.rollout.indexed import get_closest_key from anemoi.training.utils.seeding import get_base_seed +from anemoi.training.schedulers.rollout import RolloutScheduler +from anemoi.training.schedulers.rollout.stepped import BaseIncrementingRolloutScheduler, VALID_INCREMENT_TYPE, VALID_STEP_TYPES class BaseRandom(RolloutScheduler): @@ -150,7 +150,7 @@ def description(self) -> str: ) -class IncreasingRandom(BaseRandom): +class IncreasingRandom(BaseIncrementingRolloutScheduler, BaseRandom): """IncreasingRandom is a rollout scheduler that randomly selects a rollout from an increasing range of values.""" def __init__( @@ -159,8 +159,9 @@ def __init__( maximum: int = 1, range_step: int = 1, every_n: int = 1, - increment: int | dict[int, int] = 1, - step_type: Literal["step", "epoch"] = "epoch", + increment: VALID_INCREMENT_TYPE = 1, + *, + step_type: VALID_STEP_TYPES = "epoch", ): """ `IncreasingRandom` is a rollout scheduler that randomly selects a rollout from an increasing range of values. @@ -177,7 +178,7 @@ def __init__( every_n : int, optional Number of steps or epochs to step the rollout value. If `every_n` is 0, the rollout will stay at `minimum`. - increment : int | dict[int, int], optional + increment : int | dict[int, int] | dict[Literal['step', 'epoch'], dict[int, int]], optional Value to increment the rollout by `every_n_epochs`, by default 1 step_type : Literal['step', 'epoch'], optional Type of step, either 'epoch' or 'batch'. @@ -195,7 +196,7 @@ def __init__( # any value between 1 and 2 ``` """ - super().__init__() + super().__init__(every_n = every_n, increment = increment, step_type = step_type) if maximum <= -1: maximum = float("inf") @@ -203,26 +204,23 @@ def __init__( self._minimum = minimum self._maximum = maximum self._range_step = range_step - self._every_n = every_n - self._increment = increment - self._step_type = step_type @property def rollout(self) -> int: if self._every_n == 0: return self._minimum - count_of_n = self.count(self._every_n, self._step_type) + # count_of_n = self.count(self._every_n, self._step_type) - if isinstance(self._increment, int): - maximum_value = self._minimum + self._increment * count_of_n - else: - sum_of_increments = [ - self._increment.get(get_closest_key(self._increment, i + 1)) for i in range(count_of_n) - ] - maximum_value = self._minimum + sum(sum_of_increments) + # if isinstance(self._increment, int): + # maximum_value = self._minimum + self._increment * count_of_n + # else: + # sum_of_increments = [ + # self._increment.get(get_closest_key(self._increment, i + 1)) for i in range(count_of_n) + # ] + # maximum_value = self._minimum + sum(sum_of_increments) - rollouts = range(self._minimum, maximum_value + 1, self._range_step) + rollouts = range(self._minimum, self._minimum + self.total_increment, self._range_step) return self._randomly_pick(rollouts) diff --git a/src/anemoi/training/schedulers/rollout/stepped.py b/src/anemoi/training/schedulers/rollout/stepped.py index f07425e1..3a417395 100644 --- a/src/anemoi/training/schedulers/rollout/stepped.py +++ b/src/anemoi/training/schedulers/rollout/stepped.py @@ -14,7 +14,79 @@ from anemoi.training.schedulers.rollout.indexed import get_closest_key -class Stepped(RolloutScheduler): +VALID_STEP_TYPE = ["step", "epoch"] +VALID_STEP_TYPES = Literal["step", "epoch"] + +VALID_INCREMENT_TYPE = int | dict[int, int] | dict[VALID_STEP_TYPES, dict[int, int]] + +class BaseIncrementingRolloutScheduler(RolloutScheduler): + """Base class for schedulers that have an incrementing value.""" + _increment_value = 0 + + def __init__(self, every_n: int, step_type: VALID_STEP_TYPES, increment: VALID_INCREMENT_TYPE = 1): + super().__init__() + + if step_type not in VALID_STEP_TYPE: + error_msg = "Step type must be either 'step' or 'epoch'." + raise ValueError(error_msg) + + if isinstance(increment, dict): + if not len(increment) == 1: + error_msg = ( + "Increment dictionary cannot be empty, nor can it contain more then one entry." + "\nIt should either be a dictionary of ints or contain a single key of 'step' or 'epoch'." + ) + raise ValueError(error_msg) + + self._every_n = every_n + self._step_type = step_type + self._increment = increment + + + @property + def total_increment(self) -> int: + return self._increment_value + + def _get_current_increment(self): + if isinstance(self._increment, int): + return self._increment + + if isinstance(list(self._increment.keys())[0], int): + current_value = self._step if self._step_type == 'step' else self._epoch + return get_closest_key(self._increment, current_value) + + elif isinstance(list(self._increment.keys())[0], str): + step_type = list(self._increment.keys())[0] + if step_type not in ['step', 'epoch']: + error_msg = "Increment dictionary keys must be either 'step' or 'epoch'." + raise ValueError(error_msg) + + current_value = self._step if step_type == 'step' else self._epoch + increment_dict = self._increment[step_type] + return increment_dict.get(get_closest_key(increment_dict, current_value), 0) + else: + error_msg = "Increment dictionary keys must be either int or str." + raise ValueError(error_msg) + + + def step(self, count = 1): + super().step(count) + if self._every_n == 0: + return + + if self._step_type == 'step' and self._step % self._every_n == 0: + self._increment_value += self._get_current_increment() + + + def step_epoch(self, count = 1): + super().step_epoch(count) + if self._every_n == 0: + return + + if self._step_type == 'epoch' and self._epoch % self._every_n == 0: + self._increment_value += self._get_current_increment() + +class Stepped(BaseIncrementingRolloutScheduler): """`Stepped` is a base rollout scheduler that steps the rollout value at the end of each n steps or epochs.""" def __init__( @@ -22,8 +94,9 @@ def __init__( minimum: int, maximum: int, every_n: int, - increment: int | dict[int, int], - step_type: Literal["step", "epoch"] = "epoch", + increment: VALID_INCREMENT_TYPE = 1, + *, + step_type: VALID_STEP_TYPES = "epoch", ): """ `SteppedRollout` is a base rollout scheduler that steps the rollout value at the end of each n steps or epochs. @@ -38,7 +111,7 @@ def __init__( every_n : int Number of steps or epochs to step the rollout value. If `every_n` is 0, the rollout will stay at `minimum`. - increment : int | dict[int, int], optional + increment : int | dict[int, int] | dict[Literal['step', 'epoch'], dict[int, int]], optional Value to increment the rollout by. Can be an int or dictionary, where the keys represent the value of `step_type` and the values represent the increment. @@ -73,21 +146,27 @@ def __init__( # 1 RollSched.rollout_at(epoch = 10) # 2, and then increments of 1 + + RollSched = Stepped(minimum = 1, maximum = 10, every_n = 1, step_type = 'epoch', increment = {'step':{0: 0, 1000: 1}}) + RollSched.rollout_at(epoch = 2) + # 1 + RollSched.rollout_at(epoch = 2, step = 1000) + # 2 + ``` """ - super().__init__() + super().__init__(every_n=every_n, step_type=step_type, increment=increment) if maximum <= -1: maximum = float("inf") self._minimum = minimum self._maximum = maximum - self._every_n = every_n - self._increment = increment - self._step_type = step_type @property def rollout(self) -> int: + return min(self._maximum, self._minimum + self.total_increment) + if self._every_n == 0: return self._minimum @@ -116,7 +195,7 @@ def description(self) -> str: class EpochStepped(Stepped): """`EpochStepped` is a rollout scheduler that steps the rollout value at the end of each n epochs.""" - def __init__(self, minimum: int, maximum: int, every_n_epochs: int = 1, increment: int = 1): + def __init__(self, minimum: int, maximum: int, every_n_epochs: int = 1, increment: VALID_INCREMENT_TYPE = 1): """ `EpochStepped` is a rollout scheduler that steps the rollout value at the end of each n epochs. @@ -128,8 +207,13 @@ def __init__(self, minimum: int, maximum: int, every_n_epochs: int = 1, incremen The maximum value for the scheduler. every_n_epochs : int, optional The number of epochs after which the value is incremented, by default 1. - increment : int, optional - The amount by which the value is incremented, by default 1. + increment : int | dict[int, int] | dict[Literal['step', 'epoch'], dict[int, int]], optional + Value to increment the rollout by. + Can be an int or dictionary, where the keys represent the value of `step_type` + and the values represent the increment. + Will round down to the closest key. + i.e. {0: 1, 10: 2} will increment by 1 until 10, then by 2. + by default 1. """ super().__init__(minimum, maximum, every_n_epochs, increment, step_type="epoch") @@ -137,7 +221,7 @@ def __init__(self, minimum: int, maximum: int, every_n_epochs: int = 1, incremen class StepStepped(Stepped): """`StepStepped` is a rollout scheduler that steps the rollout value at the end of each n steps.""" - def __init__(self, minimum: int, maximum: int, every_n_steps: int = 1000, increment: int = 1): + def __init__(self, minimum: int, maximum: int, every_n_steps: int = 1000, increment: VALID_INCREMENT_TYPE = 1): """ `StepStepped` is a rollout scheduler that steps the rollout value at the end of each n steps. @@ -149,7 +233,12 @@ def __init__(self, minimum: int, maximum: int, every_n_steps: int = 1000, increm The maximum value for the scheduler. every_n_steps : int, optional The number of steps after which the value is incremented, by default 1000. - increment : int, optional - The amount by which the value is incremented, by default 1. + increment : int | dict[int, int] | dict[Literal['step', 'epoch'], dict[int, int]], optional + Value to increment the rollout by. + Can be an int or dictionary, where the keys represent the value of `step_type` + and the values represent the increment. + Will round down to the closest key. + i.e. {0: 1, 10: 2} will increment by 1 until 10, then by 2. + by default 1. """ super().__init__(minimum, maximum, every_n_steps, increment, step_type="step") From a712c4864422830e69fcffc0732d0b0ab57ccbfc Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 18 Dec 2024 17:39:20 +0000 Subject: [PATCH 03/11] Improve incrementor - Calculation based not step based --- .../training/config/training/default.yaml | 4 +- .../training/schedulers/rollout/randomise.py | 24 +-- .../training/schedulers/rollout/stepped.py | 154 ++++++++++-------- 3 files changed, 94 insertions(+), 88 deletions(-) diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index 27ff7e22..7eb79077 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -91,11 +91,11 @@ rollout: maximum: 12 # increase rollout every n epochs every_n_epochs: 1 - # increment + # Control the incrementing of the rollout window increment: step: 0: 0 - 200000: 1 + 200000: 1 # After 200k steps, increment by 1 every 1 epoch # Set max_epochs or max_steps. Training stops at the first limit reached. max_epochs: null diff --git a/src/anemoi/training/schedulers/rollout/randomise.py b/src/anemoi/training/schedulers/rollout/randomise.py index efa0efaa..901b0696 100644 --- a/src/anemoi/training/schedulers/rollout/randomise.py +++ b/src/anemoi/training/schedulers/rollout/randomise.py @@ -11,14 +11,14 @@ from __future__ import annotations -from typing import Literal - import numpy as np import pytorch_lightning as pl -from anemoi.training.utils.seeding import get_base_seed from anemoi.training.schedulers.rollout import RolloutScheduler -from anemoi.training.schedulers.rollout.stepped import BaseIncrementingRolloutScheduler, VALID_INCREMENT_TYPE, VALID_STEP_TYPES +from anemoi.training.schedulers.rollout.stepped import VALID_INCREMENT_TYPE +from anemoi.training.schedulers.rollout.stepped import VALID_STEP_TYPES +from anemoi.training.schedulers.rollout.stepped import IncrementMixin +from anemoi.training.utils.seeding import get_base_seed class BaseRandom(RolloutScheduler): @@ -150,7 +150,7 @@ def description(self) -> str: ) -class IncreasingRandom(BaseIncrementingRolloutScheduler, BaseRandom): +class IncreasingRandom(IncrementMixin, BaseRandom): """IncreasingRandom is a rollout scheduler that randomly selects a rollout from an increasing range of values.""" def __init__( @@ -196,7 +196,7 @@ def __init__( # any value between 1 and 2 ``` """ - super().__init__(every_n = every_n, increment = increment, step_type = step_type) + super().__init__(every_n=every_n, increment=increment, step_type=step_type) if maximum <= -1: maximum = float("inf") @@ -210,17 +210,7 @@ def rollout(self) -> int: if self._every_n == 0: return self._minimum - # count_of_n = self.count(self._every_n, self._step_type) - - # if isinstance(self._increment, int): - # maximum_value = self._minimum + self._increment * count_of_n - # else: - # sum_of_increments = [ - # self._increment.get(get_closest_key(self._increment, i + 1)) for i in range(count_of_n) - # ] - # maximum_value = self._minimum + sum(sum_of_increments) - - rollouts = range(self._minimum, self._minimum + self.total_increment, self._range_step) + rollouts = range(self._minimum, self._minimum + self.increment(self._step, self._epoch), self._range_step) return self._randomly_pick(rollouts) diff --git a/src/anemoi/training/schedulers/rollout/stepped.py b/src/anemoi/training/schedulers/rollout/stepped.py index 3a417395..6192fa30 100644 --- a/src/anemoi/training/schedulers/rollout/stepped.py +++ b/src/anemoi/training/schedulers/rollout/stepped.py @@ -13,15 +13,14 @@ from anemoi.training.schedulers.rollout import RolloutScheduler from anemoi.training.schedulers.rollout.indexed import get_closest_key - VALID_STEP_TYPE = ["step", "epoch"] VALID_STEP_TYPES = Literal["step", "epoch"] VALID_INCREMENT_TYPE = int | dict[int, int] | dict[VALID_STEP_TYPES, dict[int, int]] -class BaseIncrementingRolloutScheduler(RolloutScheduler): - """Base class for schedulers that have an incrementing value.""" - _increment_value = 0 + +class IncrementMixin: + """Mixin class for schedulers that have an incrementing value based on the steps and epochs.""" def __init__(self, every_n: int, step_type: VALID_STEP_TYPES, increment: VALID_INCREMENT_TYPE = 1): super().__init__() @@ -30,63 +29,91 @@ def __init__(self, every_n: int, step_type: VALID_STEP_TYPES, increment: VALID_I error_msg = "Step type must be either 'step' or 'epoch'." raise ValueError(error_msg) - if isinstance(increment, dict): - if not len(increment) == 1: - error_msg = ( - "Increment dictionary cannot be empty, nor can it contain more then one entry." - "\nIt should either be a dictionary of ints or contain a single key of 'step' or 'epoch'." - ) - raise ValueError(error_msg) + if isinstance(increment, dict) and len(increment) == 0: + error_msg = ( + "Increment dictionary cannot be empty." + "\nIt should either be a dictionary of ints or contain a single key of 'step' or 'epoch'." + ) + raise ValueError(error_msg) self._every_n = every_n self._step_type = step_type self._increment = increment - - @property - def total_increment(self) -> int: - return self._increment_value + def increment(self, step: int, epoch: int) -> int: + """ + Get the increment value for a particular step or epoch. + + Relies on the number of steps per epochs to calculate the increment + when the step_type of the increment is different from the stepper step_type. + - def _get_current_increment(self): + Parameters + ---------- + step : int + Step number. + epoch : int + Epoch number. + + Returns + ------- + int + Increment value. + + Raises + ------ + ValueError + If cannot parse the `increment` value given at init. + """ if isinstance(self._increment, int): return self._increment - if isinstance(list(self._increment.keys())[0], int): - current_value = self._step if self._step_type == 'step' else self._epoch - return get_closest_key(self._increment, current_value) - - elif isinstance(list(self._increment.keys())[0], str): - step_type = list(self._increment.keys())[0] - if step_type not in ['step', 'epoch']: + count = (step // self._every_n if self._step_type == "step" else epoch // self._every_n) + 1 + + if isinstance(next(iter(self._increment.keys())), int): + return sum( + (self._increment.get(get_closest_key(self._increment, i * self._every_n), 0) for i in range(count)), + ) + + if isinstance(next(iter(self._increment.keys())), str): + increment_step_type = next(iter(self._increment.keys())) + if increment_step_type not in ["step", "epoch"]: error_msg = "Increment dictionary keys must be either 'step' or 'epoch'." raise ValueError(error_msg) - - current_value = self._step if step_type == 'step' else self._epoch - increment_dict = self._increment[step_type] - return increment_dict.get(get_closest_key(increment_dict, current_value), 0) - else: - error_msg = "Increment dictionary keys must be either int or str." - raise ValueError(error_msg) - - - def step(self, count = 1): - super().step(count) - if self._every_n == 0: - return - - if self._step_type == 'step' and self._step % self._every_n == 0: - self._increment_value += self._get_current_increment() - - - def step_epoch(self, count = 1): - super().step_epoch(count) - if self._every_n == 0: - return - - if self._step_type == 'epoch' and self._epoch % self._every_n == 0: - self._increment_value += self._get_current_increment() - -class Stepped(BaseIncrementingRolloutScheduler): + + increment_dict = self._increment[increment_step_type] + + if increment_step_type == self._step_type: + return sum( + (increment_dict.get(get_closest_key(increment_dict, i * self._every_n), 0) for i in range(count)), + ) + + if epoch == 0 or step == 0: + return 0 + + num_steps_per_epoch = step / epoch + if increment_step_type == "step" and self._step_type == "epoch": + return sum( + increment_dict.get( + get_closest_key(increment_dict, (i * self._every_n) * num_steps_per_epoch), + 0, + ) + for i in range(count) + ) + if increment_step_type == "epoch" and self._step_type == "step": + return sum( + increment_dict.get( + get_closest_key(increment_dict, (i * self._every_n) // num_steps_per_epoch), + 0, + ) + for i in range(count) + ) + + error_msg = "Increment dictionary keys must be either int or a single str." + raise TypeError(error_msg) + + +class Stepped(RolloutScheduler, IncrementMixin): """`Stepped` is a base rollout scheduler that steps the rollout value at the end of each n steps or epochs.""" def __init__( @@ -147,8 +174,11 @@ def __init__( RollSched.rollout_at(epoch = 10) # 2, and then increments of 1 - RollSched = Stepped(minimum = 1, maximum = 10, every_n = 1, step_type = 'epoch', increment = {'step':{0: 0, 1000: 1}}) - RollSched.rollout_at(epoch = 2) + RollSched = Stepped( + minimum = 1, maximum = 10, every_n = 1, + step_type = 'epoch', increment = {'step':{0: 0, 1000: 1}} + ) + RollSched.rollout_at(epoch = 1, step = 500 ) # 1 RollSched.rollout_at(epoch = 2, step = 1000) # 2 @@ -165,21 +195,7 @@ def __init__( @property def rollout(self) -> int: - return min(self._maximum, self._minimum + self.total_increment) - - if self._every_n == 0: - return self._minimum - - count_of_n = self.count(self._every_n, self._step_type) - - if isinstance(self._increment, int): - return min(self._maximum, self._minimum + self._increment * count_of_n) - - sum_of_increments = [ - self._increment.get(get_closest_key(self._increment, i + 1 if self._step_type == "epoch" else i)) - for i in range(count_of_n) - ] - return min(self._maximum, self._minimum + sum(sum_of_increments)) + return min(self._maximum, self._minimum + self.increment(self._step, self._epoch)) @property def maximum_rollout(self) -> int: @@ -187,8 +203,8 @@ def maximum_rollout(self) -> int: def description(self) -> str: return ( - "Stepped rollout scheduler stepping between" - f"{self._minimum} and {self._maximum} by {self._increment} for {self._every_n} {self._step_type}s." + "Stepped rollout scheduler stepping between " + f"{self._minimum} and {self._maximum} by {self._increment} for every {self._every_n} {self._step_type}/s." ) From 72e0bf9e1c32349bd64d95fdcbbeb40f8b621fb9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:41:36 +0000 Subject: [PATCH 04/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/training/schedulers/rollout/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anemoi/training/schedulers/rollout/__init__.py b/src/anemoi/training/schedulers/rollout/__init__.py index e4e9ec06..20bce2e3 100644 --- a/src/anemoi/training/schedulers/rollout/__init__.py +++ b/src/anemoi/training/schedulers/rollout/__init__.py @@ -115,7 +115,7 @@ def count(self, n_epochs: int | None = None, n_steps: int | None = None) -> int: ValueError If both `n_epochs` and `n_steps` are given, or if neither are given. """ - if n_epochs is not None and n_steps is not None or n_epochs is None and n_steps is None: + if (n_epochs is not None and n_steps is not None) or (n_epochs is None and n_steps is None): error_msg = "Only one of `n_epochs` or `n_steps` can be given." raise ValueError(error_msg) @@ -124,7 +124,6 @@ def count(self, n_epochs: int | None = None, n_steps: int | None = None) -> int: if n_steps is not None: return self._step // n_steps - @abstractmethod def description(self) -> str: """Description of the rollout scheduler.""" From c199c0e1306c3eb31690c7194a88016019adc36a Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 18 Dec 2024 17:44:13 +0000 Subject: [PATCH 05/11] Precommit fixes --- src/anemoi/training/schedulers/rollout/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/anemoi/training/schedulers/rollout/__init__.py b/src/anemoi/training/schedulers/rollout/__init__.py index 20bce2e3..464ed71f 100644 --- a/src/anemoi/training/schedulers/rollout/__init__.py +++ b/src/anemoi/training/schedulers/rollout/__init__.py @@ -11,7 +11,6 @@ from abc import ABC from abc import abstractmethod -from typing import Literal class RolloutScheduler(ABC): @@ -121,8 +120,7 @@ def count(self, n_epochs: int | None = None, n_steps: int | None = None) -> int: if n_epochs is not None: return self._epoch // n_epochs - if n_steps is not None: - return self._step // n_steps + return self._step // n_steps @abstractmethod def description(self) -> str: From 69a5d9a03f8c7560986b9e269f2bfd21a9d51c75 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 18 Dec 2024 17:45:44 +0000 Subject: [PATCH 06/11] Add changelog entry --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 286fd915..e35e40b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ Keep it human-readable, your future self will thank you! - Add supporting arrrays (numpy) to checkpoint - Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171) - Improved validation metrics, allow 'all' to be scaled [#202](https://github.com/ecmwf/anemoi-training/pull/202) +- Rollout Schedulers [#206](https://github.com/ecmwf/anemoi-training/pull/206) + ### Changed From d5a0ff9c20b0560da9bb540a1e63a4e9015dcc79 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Thu, 19 Dec 2024 14:12:29 +0000 Subject: [PATCH 07/11] Seed random every time and remove -1 for inf --- .../training/schedulers/rollout/randomise.py | 14 ++++++++------ src/anemoi/training/schedulers/rollout/stepped.py | 6 ++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/anemoi/training/schedulers/rollout/randomise.py b/src/anemoi/training/schedulers/rollout/randomise.py index 901b0696..0b0b31b4 100644 --- a/src/anemoi/training/schedulers/rollout/randomise.py +++ b/src/anemoi/training/schedulers/rollout/randomise.py @@ -37,8 +37,11 @@ def __init__(self): except AssertionError: seed = 42 - rnd_seed = pl.seed_everything(seed, workers=True) - self.rng = np.random.default_rng(rnd_seed) + self._rnd_seed = pl.seed_everything(seed, workers=True) + + @property + def rng(self): + return np.random.default_rng(hash((self._rnd_seed, self._epoch, self._step))) def broadcast(self, value: int) -> None: """ @@ -171,7 +174,7 @@ def __init__( minimum : int, optional Minimum rollout to choose from, by default 1 maximum : int, optional - Maximum rollout to choose from, can be -1 for no maximum, + Maximum rollout to choose from, by default 1. range_step : int, optional Step size for the range, by default 1 @@ -198,9 +201,6 @@ def __init__( """ super().__init__(every_n=every_n, increment=increment, step_type=step_type) - if maximum <= -1: - maximum = float("inf") - self._minimum = minimum self._maximum = maximum self._range_step = range_step @@ -216,6 +216,8 @@ def rollout(self) -> int: @property def maximum_rollout(self) -> int: + if self._every_n == 0: + return self._minimum return self._maximum @property diff --git a/src/anemoi/training/schedulers/rollout/stepped.py b/src/anemoi/training/schedulers/rollout/stepped.py index 6192fa30..debcfdc4 100644 --- a/src/anemoi/training/schedulers/rollout/stepped.py +++ b/src/anemoi/training/schedulers/rollout/stepped.py @@ -134,7 +134,6 @@ def __init__( Minimum rollout value. maximum : int Maximum rollout value. - Can be -1 for no maximum. every_n : int Number of steps or epochs to step the rollout value. If `every_n` is 0, the rollout will stay at `minimum`. @@ -187,9 +186,6 @@ def __init__( """ super().__init__(every_n=every_n, step_type=step_type, increment=increment) - if maximum <= -1: - maximum = float("inf") - self._minimum = minimum self._maximum = maximum @@ -199,6 +195,8 @@ def rollout(self) -> int: @property def maximum_rollout(self) -> int: + if self._every_n == 0: + return self._minimum return self._maximum def description(self) -> str: From 249144e5c4ea3222e2d5eddec80025ce3acd5a5d Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 20 Dec 2024 10:39:41 +0000 Subject: [PATCH 08/11] MIGRATION COMMIT --- .../training/config/training/default.yaml | 4 +- src/anemoi/training/data/datamodule.py | 26 +++-- src/anemoi/training/data/dataset.py | 15 ++- .../diagnostics/callbacks/__init__.py | 3 + .../training/diagnostics/callbacks/rollout.py | 95 +++++++++++++++++++ .../training/schedulers/rollout/__init__.py | 94 +++++++++++++++--- .../training/schedulers/rollout/indexed.py | 4 + .../training/schedulers/rollout/randomise.py | 13 ++- .../training/schedulers/rollout/stepped.py | 19 ++-- src/anemoi/training/train/forecaster.py | 32 ++++++- src/anemoi/training/train/train.py | 20 ++++ tests/diagnostics/test_callbacks.py | 2 +- tests/schedulers/rollout/test_indexed.py | 11 +++ tests/schedulers/rollout/test_random.py | 8 ++ tests/schedulers/rollout/test_rollout.py | 8 ++ tests/schedulers/rollout/test_stepped.py | 8 ++ 16 files changed, 325 insertions(+), 37 deletions(-) create mode 100644 src/anemoi/training/diagnostics/callbacks/rollout.py create mode 100644 tests/schedulers/rollout/test_indexed.py create mode 100644 tests/schedulers/rollout/test_random.py create mode 100644 tests/schedulers/rollout/test_rollout.py create mode 100644 tests/schedulers/rollout/test_stepped.py diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index 7eb79077..1d107473 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -85,8 +85,10 @@ scale_validation_metrics: # length of the "rollout" window (see Keisler's paper) +# Dataloader rollout counter can only be updated at the end of each epoch +# So updates during an epoch will only be reflected at the end of said epoch. rollout: - _target_: anemoi.training.schedulers.stepped.EpochStepped + _target_: anemoi.training.schedulers.rollout.stepped.EpochStepped minimum: 1 maximum: 12 # increase rollout every n epochs diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index e0502acd..8f89a026 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -32,6 +32,7 @@ from torch_geometric.data import HeteroData from anemoi.training.data.grid_indices import BaseGridIndices + from anemoi.training.schedulers.rollout import RolloutScheduler class AnemoiDatasetsDataModule(pl.LightningDataModule): @@ -52,11 +53,8 @@ def __init__(self, config: DictConfig, graph_data: HeteroData) -> None: self.graph_data = graph_data # Set the maximum rollout to be expected - self.rollout = ( - self.config.training.rollout.max - if self.config.training.rollout.epoch_increment > 0 - else self.config.training.rollout.start - ) + rollout_scheduler: RolloutScheduler = instantiate(self.config.training.rollout) + self.starting_rollout = rollout_scheduler.current_maximum # Set the training end date if not specified if self.config.dataloader.training.end is None: @@ -129,7 +127,7 @@ def ds_train(self) -> NativeGridDataset: @cached_property def ds_valid(self) -> NativeGridDataset: - r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1)) + r = max(self.starting_rollout, self.config.dataloader.get("validation_rollout", 1)) if not self.config.dataloader.training.end < self.config.dataloader.validation.start: LOGGER.warning( @@ -160,6 +158,20 @@ def ds_test(self) -> NativeGridDataset: label="test", ) + def update_rollout(self, rollout: int) -> None: + """ + Update the rollout values in the datamodule. + + Parameters + ---------- + rollout : int + Rollout value + """ + for ds in [self.ds_train, self.ds_test]: + ds.update_rollout(rollout) + + self.ds_valid.update_rollout(max(rollout, self.config.dataloader.get("validation_rollout", 1))) + def _get_dataset( self, data_reader: Callable, @@ -168,7 +180,7 @@ def _get_dataset( label: str = "generic", ) -> NativeGridDataset: - r = max(rollout, self.rollout) + r = max(rollout, self.starting_rollout) # Compute effective batch size effective_bs = ( diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 431f0227..e004bddb 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -13,7 +13,7 @@ import os import random from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self from typing import Callable import numpy as np @@ -138,6 +138,17 @@ def valid_date_indices(self) -> np.ndarray: """ return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement) + def update_rollout(self, rollout: int) -> None: + """Update the rollout window.""" + if self.rollout == rollout: + return + + self.rollout = rollout + LOGGER.warning(f"Updating rollout of {self.label} dataset to {self.rollout}") + + if hasattr(self, "valid_date_indices"): + del self.valid_date_indices + def set_comm_group_info( self, global_rank: int, @@ -273,7 +284,7 @@ def __iter__(self) -> torch.Tensor: self.model_comm_group_rank, shuffled_chunk_indices[:10], ) - + LOGGER.warning(f"Rollout in dataset: {self.label} is {self.rollout}") for i in shuffled_chunk_indices: start = i - (self.multi_step - 1) * self.timeincrement end = i + (self.rollout + 1) * self.timeincrement diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 65a19ce1..16fbcd2a 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -24,6 +24,7 @@ from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder +from anemoi.training.diagnostics.callbacks.rollout import UpdateRollout if TYPE_CHECKING: from pytorch_lightning.callbacks import Callback @@ -198,10 +199,12 @@ def get_callbacks(config: DictConfig) -> list[Callback]: # Parent UUID callback # Check variable order callback + # UpdateRollout trainer_callbacks.extend( ( ParentUUIDCallback(config), CheckVariableOrder(), + UpdateRollout(), ), ) diff --git a/src/anemoi/training/diagnostics/callbacks/rollout.py b/src/anemoi/training/diagnostics/callbacks/rollout.py new file mode 100644 index 00000000..b294bc04 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/rollout.py @@ -0,0 +1,95 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import pytorch_lightning as pl + +LOGGER = logging.getLogger(__name__) + + +class UpdateRollout(pl.callbacks.Callback): + """Update Rollout values in datamodule.""" + + def __init__(self) -> None: + super().__init__() + + def _update_rollout(self, trainer, pl_module, epoch: int | None = None, step: int | None = None) -> None: + rollsched = pl_module.rollout + with rollsched.at(epoch=epoch, step=step): + rollout = rollsched.current_maximum + + trainer.datamodule.update_rollout(rollout) + + def on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: dict) -> None: + """ + Update the rollout values in the datamodule when loading a checkpoint. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + pl_module : pl.LightningModule + Model + checkpoint : dict + Checkpoint dictionary + """ + LOGGER.warning('Updating rollout values from checkpoint.') + self._update_rollout(trainer, pl_module, epoch = checkpoint['epoch'], step = checkpoint['global_step']) + + # def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + # """ + # Update the rollout values in the datamodule when starting fitting. + + # Parameters + # ---------- + # trainer : pl.Trainer + # Pytorch Lightning trainer + # pl_module : pl.LightningModule + # Model + # """ + # LOGGER.warning('Updating rollout values when fit starts.') + # self._update_rollout(trainer, pl_module) + + # def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: + # """ + # Update the rollout values in the datamodule when setting up the trainer. + + # Parameters + # ---------- + # trainer : pl.Trainer + # Pytorch Lightning trainer + # pl_module : pl.LightningModule + # Model + # stage : str + # Stage of the training + # """ + # LOGGER.warning('Updating rollout values from setup.') + # self._update_rollout(trainer, pl_module) + + + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *a) -> None: + """ + Update the rollout values in the datamodule every validation epoch. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + pl_module : pl.LightningModule + Model + """ + if trainer.sanity_checking: + return + + LOGGER.warning('Updating rollout values from validation epoch end.') + + # Offset of 1 needed as the epoch counter does not increment + # until after the epoch ends. + self._update_rollout(trainer, pl_module, epoch = trainer.current_epoch + 1) diff --git a/src/anemoi/training/schedulers/rollout/__init__.py b/src/anemoi/training/schedulers/rollout/__init__.py index 464ed71f..c6373851 100644 --- a/src/anemoi/training/schedulers/rollout/__init__.py +++ b/src/anemoi/training/schedulers/rollout/__init__.py @@ -29,6 +29,12 @@ class RolloutScheduler(ABC): RollSched.step() RollSched.step_epoch() ``` + + The rollout value must be calculatable given the epoch and the step, + accessible within subclasses by the `_epoch` and `_step` attributes. + + Override the `rollout` property to implement the rollout calculation, + and the `maximum_rollout` property to provide the maximum rollout possible. """ _epoch: int = 0 @@ -50,12 +56,54 @@ def maximum_rollout(self) -> int: @property def current_maximum(self) -> int: - """Get the current maximum rollout value.""" + """Get the current maximum rollout value. + + Allows for dataloader to only get the data neccessary. + Most cases this is just the current rollout. + """ return self.rollout def __int__(self) -> int: + """Get rollout value as int""" + return int(self.rollout) + + def __index__(self) -> int: + """Get rollout value as index""" return int(self.rollout) + + def at(self, step: int | None = None, epoch: int | None = None) -> FrozenStateRecord: + """ + Temporarily hold the scheduler at a specific step and epoch. + + Parameters + ---------- + step : int, optional + Step value to override with, by default None + epoch : int, optional + Epoch value to override with, by default None + + Returns + ------- + FrozenStateRecord + Record of the prior state. + """ + prior_step = self._step + prior_epoch = self._epoch + class FrozenStateRecord: + """Freeze the state of the RolloutScheduler. Any changes will be reverted on exit.""" + + def __enter__(self): + pass + + def __exit__(context_self, *a): # noqa: N805 + self._step = prior_step + self._epoch = prior_epoch + + self._step = step if step is not None else prior_step + self._epoch = epoch if epoch is not None else prior_epoch + return FrozenStateRecord() + def rollout_at(self, step: int | None = None, epoch: int | None = None) -> int: """ Get the rollout at a specific step and epoch. @@ -72,18 +120,8 @@ def rollout_at(self, step: int | None = None, epoch: int | None = None) -> int: int Rollout value at the specified step and epoch. """ - step_ = self._step - epoch_ = self._epoch - - self._step = step if step is not None else step_ - self._epoch = epoch if epoch is not None else epoch_ - - rollout = self.rollout - - self._step = step_ - self._epoch = epoch_ - - return rollout + with self.at(step, epoch): + return self.rollout def step(self, count: int = 1, /) -> None: """Step the scheduler by a count.""" @@ -93,7 +131,21 @@ def step_epoch(self, count: int = 1, /) -> None: """Step the scheduler by a count of epochs.""" self._epoch += count - def count(self, n_epochs: int | None = None, n_steps: int | None = None) -> int: + def sync(self, step: int = None, epoch: int = None): + """ + Sync state of the Rollout Scheduler + + Parameters + ---------- + step : int, optional + Override for step, by default None + epoch : int, optional + Override for epoch, by default None + """ + self._step = step if step is not None else self._step + self._epoch = epoch if epoch is not None else self._epoch + + def count(self, n_steps: int | None = None, n_epochs: int | None = None) -> int: """ Get the count of steps or epochs. @@ -127,6 +179,20 @@ def description(self) -> str: """Description of the rollout scheduler.""" error_msg = "`description` method not implemented by parent class." raise NotImplementedError(error_msg) + + # Mathematical operations + def __add__(self, other: int) -> int: + return self.rollout + other + def __radd__(self, other: int) -> int: + return other + self.rollout + def __sub__(self, other: int) -> int: + return self.rollout - other + def __rsub__(self, other: int) -> int: + return other - self.rollout + def __mul__(self, other: int) -> int: + return self.rollout * other + def __rmul__(self, other: int) -> int: + return other * self.rollout class Static(RolloutScheduler): diff --git a/src/anemoi/training/schedulers/rollout/indexed.py b/src/anemoi/training/schedulers/rollout/indexed.py index 782aa27a..08eb43e2 100644 --- a/src/anemoi/training/schedulers/rollout/indexed.py +++ b/src/anemoi/training/schedulers/rollout/indexed.py @@ -10,6 +10,8 @@ from typing import Any from typing import Literal +import warnings + from anemoi.training.schedulers.rollout import RolloutScheduler @@ -109,6 +111,7 @@ class StepPositionalIndexed(PositionalIndexed): """Step based PositionalIndexed.""" def __init__(self, rollouts: list[int]): + warnings.warn(f"Pytorch Lightning datamodules can only be refreshed at the end of an epoch, adjusting the rollout during an epoch will likely fail.", UserWarning) super().__init__(rollouts, step_type="step") @@ -175,4 +178,5 @@ class StepLookup(Lookup): """Step based Lookup.""" def __init__(self, rollouts: dict[int, int]): + warnings.warn(f"Pytorch Lightning datamodules can only be refreshed at the end of an epoch, adjusting the rollout during an epoch will likely fail.", UserWarning) super().__init__(rollouts, step_type="step") diff --git a/src/anemoi/training/schedulers/rollout/randomise.py b/src/anemoi/training/schedulers/rollout/randomise.py index 0b0b31b4..9aff3b7e 100644 --- a/src/anemoi/training/schedulers/rollout/randomise.py +++ b/src/anemoi/training/schedulers/rollout/randomise.py @@ -11,6 +11,8 @@ from __future__ import annotations +import warnings + import numpy as np import pytorch_lightning as pl @@ -175,6 +177,7 @@ def __init__( Minimum rollout to choose from, by default 1 maximum : int, optional Maximum rollout to choose from, + Can be -1 for no maximum, by default 1. range_step : int, optional Step size for the range, by default 1 @@ -202,6 +205,10 @@ def __init__( super().__init__(every_n=every_n, increment=increment, step_type=step_type) self._minimum = minimum + + if maximum == -1: + maximum = float("inf") + self._maximum = maximum self._range_step = range_step @@ -210,7 +217,7 @@ def rollout(self) -> int: if self._every_n == 0: return self._minimum - rollouts = range(self._minimum, self._minimum + self.increment(self._step, self._epoch), self._range_step) + rollouts = range(self._minimum, self.current_maximum, self._range_step) return self._randomly_pick(rollouts) @@ -222,7 +229,7 @@ def maximum_rollout(self) -> int: @property def current_maximum(self) -> int: - return self._minimum + ((self._epoch // self._every_n_epochs) * self._step) + return min(self._maximum, self._minimum + self.increment(self._step, self._epoch)) def description(self) -> str: return ( @@ -351,4 +358,6 @@ def __init__( # any value between 1 and 2, and then increments of 1 ``` """ + warnings.warn(f"Pytorch Lightning datamodules can only be refreshed at the end of an epoch, adjusting the rollout during an epoch will likely fail.", UserWarning) + super().__init__(minimum, maximum, range_step, every_n_steps, increment, step_type="step") diff --git a/src/anemoi/training/schedulers/rollout/stepped.py b/src/anemoi/training/schedulers/rollout/stepped.py index debcfdc4..de182105 100644 --- a/src/anemoi/training/schedulers/rollout/stepped.py +++ b/src/anemoi/training/schedulers/rollout/stepped.py @@ -9,6 +9,7 @@ from __future__ import annotations from typing import Literal +import warnings from anemoi.training.schedulers.rollout import RolloutScheduler from anemoi.training.schedulers.rollout.indexed import get_closest_key @@ -65,14 +66,15 @@ def increment(self, step: int, epoch: int) -> int: ValueError If cannot parse the `increment` value given at init. """ - if isinstance(self._increment, int): - return self._increment - count = (step // self._every_n if self._step_type == "step" else epoch // self._every_n) + 1 + count = (step // self._every_n if self._step_type == "step" else epoch // self._every_n) + + if isinstance(self._increment, int): + return self._increment * count if isinstance(next(iter(self._increment.keys())), int): return sum( - (self._increment.get(get_closest_key(self._increment, i * self._every_n), 0) for i in range(count)), + (self._increment.get(get_closest_key(self._increment, i * self._every_n), 0) for i in range(count + 1)), ) if isinstance(next(iter(self._increment.keys())), str): @@ -85,7 +87,7 @@ def increment(self, step: int, epoch: int) -> int: if increment_step_type == self._step_type: return sum( - (increment_dict.get(get_closest_key(increment_dict, i * self._every_n), 0) for i in range(count)), + (increment_dict.get(get_closest_key(increment_dict, i * self._every_n), 0) for i in range(count + 1)), ) if epoch == 0 or step == 0: @@ -106,7 +108,7 @@ def increment(self, step: int, epoch: int) -> int: get_closest_key(increment_dict, (i * self._every_n) // num_steps_per_epoch), 0, ) - for i in range(count) + for i in range(count + 1) ) error_msg = "Increment dictionary keys must be either int or a single str." @@ -134,6 +136,7 @@ def __init__( Minimum rollout value. maximum : int Maximum rollout value. + Can be -1 to indicate no maximum. every_n : int Number of steps or epochs to step the rollout value. If `every_n` is 0, the rollout will stay at `minimum`. @@ -186,6 +189,9 @@ def __init__( """ super().__init__(every_n=every_n, step_type=step_type, increment=increment) + if maximum == -1: + maximum = float("inf") + self._minimum = minimum self._maximum = maximum @@ -255,4 +261,5 @@ def __init__(self, minimum: int, maximum: int, every_n_steps: int = 1000, increm i.e. {0: 1, 10: 2} will increment by 1 until 10, then by 2. by default 1. """ + warnings.warn(f"Pytorch Lightning datamodules can only be refreshed at the end of an epoch, adjusting the rollout during an epoch will likely fail.", UserWarning) super().__init__(minimum, maximum, every_n_steps, increment, step_type="step") diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 4351fb35..f7658a4c 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -40,7 +40,7 @@ LOGGER = logging.getLogger(__name__) if TYPE_CHECKING: - from anemoi.training.training.schedulers.rollout import RolloutScheduler + from anemoi.training.schedulers.rollout import RolloutScheduler class GraphForecaster(pl.LightningModule): @@ -484,6 +484,7 @@ def _step( metrics = {} y_preds = [] + # print('Rollout', int(self.rollout)) for loss_next, metrics_next, y_preds_next in self.rollout_step( batch, rollout=int(self.rollout), @@ -495,7 +496,6 @@ def _step( y_preds.extend(y_preds_next) loss *= 1.0 / int(self.rollout) - self.rollout.step() return loss, metrics, y_preds def allgather_batch(self, batch: torch.Tensor) -> torch.Tensor: @@ -607,6 +607,28 @@ def calculate_val_metrics( return metrics + def on_train_start(self): + # Sync the rollout at the start of training + print("Rollout at start of training", int(self.rollout), self.rollout._epoch, self.rollout._step) + self.rollout.sync(step = self.global_step, epoch = self.current_epoch) + + def on_load_checkpoint(self, checkpoint: dict): + # Sync the rollout at the start of training + print("Rollout at on_load_checkpoint", int(self.rollout), self.rollout._epoch, self.rollout._step) + self.rollout.sync(step = checkpoint["global_step"], epoch = checkpoint["epoch"]) + + def on_train_epoch_start(self): + self.rollout.sync(step = self.global_step, epoch = self.current_epoch) + LOGGER.warning(f"Rollout at start of training, {int(self.rollout)}, {self.rollout._epoch}, {self.rollout._step}") + + def on_validation_epoch_start(self): + LOGGER.warning(f"Rollout at start of validation, {int(self.rollout)}, {self.rollout._epoch}, {self.rollout._step}") + + def on_validation_epoch_end(self) -> None: + # if not self.trainer.sanity_checking: + # self.rollout_epoch_step() + LOGGER.warning(f"Rollout at end of validation, {int(self.rollout)}, {self.rollout._epoch}, {self.rollout._step}") + def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: train_loss, _, _ = self._step(batch, batch_idx) self.log( @@ -627,6 +649,7 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: rank_zero_only=True, sync_dist=False, ) + self.rollout.step() return train_loss def lr_scheduler_step(self, scheduler: CosineLRScheduler, metric: None = None) -> None: @@ -643,8 +666,6 @@ def lr_scheduler_step(self, scheduler: CosineLRScheduler, metric: None = None) - del metric scheduler.step(epoch=self.trainer.global_step) - def on_train_epoch_end(self) -> None: - self.rollout.step_epoch() def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: """ @@ -661,6 +682,7 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: ------- None """ + with torch.no_grad(): val_loss, metrics, y_preds = self._step(batch, batch_idx, validation_mode=True) @@ -689,6 +711,8 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: return val_loss, y_preds + + def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict]]: if self.use_zero_optimizer: optimizer = ZeroRedundancyOptimizer( diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index c638bc11..e4f06a98 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -376,6 +376,25 @@ def strategy(self) -> DDPGroupStrategy: self.config.dataloader.get("read_group_size", self.config.hardware.num_gpus_per_model), static_graph=not self.config.training.accum_grad_batches > 1, ) + + @cached_property + def _need_to_reload_dataloaders(self) -> bool: + """Determines if the dataloaders need to be reloaded. + + If the model's rollout scheduler is already at it's maximum, + the dataloaders do not need to be reloaded. + + Returns + ------- + bool + True if the dataloaders need to be reloaded, False otherwise. + """ + rollsched = self.model.rollout + + if rollsched.current_maximum == rollsched.maximum_rollout: + return False + LOGGER.info("Dataloaders will be reloaded every epoch to support dynamic rollout.") + return True def train(self) -> None: """Training entry point.""" @@ -405,6 +424,7 @@ def train(self) -> None: use_distributed_sampler=False, profiler=self.profiler, enable_progress_bar=self.config.diagnostics.enable_progress_bar, + reload_dataloaders_every_n_epochs=self._need_to_reload_dataloaders, ) LOGGER.debug("Starting training..") diff --git a/tests/diagnostics/test_callbacks.py b/tests/diagnostics/test_callbacks.py index 58ea6440..efea0b79 100644 --- a/tests/diagnostics/test_callbacks.py +++ b/tests/diagnostics/test_callbacks.py @@ -14,7 +14,7 @@ from anemoi.training.diagnostics.callbacks import get_callbacks -NUM_FIXED_CALLBACKS = 2 # ParentUUIDCallback, CheckVariableOrder +NUM_FIXED_CALLBACKS = 3 # ParentUUIDCallback, CheckVariableOrder, RolloutUpdate default_config = """ diagnostics: diff --git a/tests/schedulers/rollout/test_indexed.py b/tests/schedulers/rollout/test_indexed.py new file mode 100644 index 00000000..dfc43718 --- /dev/null +++ b/tests/schedulers/rollout/test_indexed.py @@ -0,0 +1,11 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from anemoi.training.schedulers.rollout.indexed import PositionalIndexed, Lookup \ No newline at end of file diff --git a/tests/schedulers/rollout/test_random.py b/tests/schedulers/rollout/test_random.py new file mode 100644 index 00000000..c167afa2 --- /dev/null +++ b/tests/schedulers/rollout/test_random.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/tests/schedulers/rollout/test_rollout.py b/tests/schedulers/rollout/test_rollout.py new file mode 100644 index 00000000..c167afa2 --- /dev/null +++ b/tests/schedulers/rollout/test_rollout.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/tests/schedulers/rollout/test_stepped.py b/tests/schedulers/rollout/test_stepped.py new file mode 100644 index 00000000..c167afa2 --- /dev/null +++ b/tests/schedulers/rollout/test_stepped.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. From 433362a61cf67afe5ccf95da80a76ede4a81bf56 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 20 Dec 2024 15:11:16 +0000 Subject: [PATCH 09/11] Update warnings --- training/src/anemoi/training/data/dataset.py | 4 +- .../training/diagnostics/callbacks/rollout.py | 37 +------------------ .../training/schedulers/rollout/__init__.py | 2 +- 3 files changed, 5 insertions(+), 38 deletions(-) diff --git a/training/src/anemoi/training/data/dataset.py b/training/src/anemoi/training/data/dataset.py index e004bddb..315829d5 100644 --- a/training/src/anemoi/training/data/dataset.py +++ b/training/src/anemoi/training/data/dataset.py @@ -144,7 +144,7 @@ def update_rollout(self, rollout: int) -> None: return self.rollout = rollout - LOGGER.warning(f"Updating rollout of {self.label} dataset to {self.rollout}") + LOGGER.debug(f"Updating rollout of {self.label} dataset to {self.rollout}") if hasattr(self, "valid_date_indices"): del self.valid_date_indices @@ -284,7 +284,7 @@ def __iter__(self) -> torch.Tensor: self.model_comm_group_rank, shuffled_chunk_indices[:10], ) - LOGGER.warning(f"Rollout in dataset: {self.label} is {self.rollout}") + for i in shuffled_chunk_indices: start = i - (self.multi_step - 1) * self.timeincrement end = i + (self.rollout + 1) * self.timeincrement diff --git a/training/src/anemoi/training/diagnostics/callbacks/rollout.py b/training/src/anemoi/training/diagnostics/callbacks/rollout.py index b294bc04..28fb1dd2 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/rollout.py +++ b/training/src/anemoi/training/diagnostics/callbacks/rollout.py @@ -25,7 +25,8 @@ def _update_rollout(self, trainer, pl_module, epoch: int | None = None, step: in with rollsched.at(epoch=epoch, step=step): rollout = rollsched.current_maximum - trainer.datamodule.update_rollout(rollout) + LOGGER.debug("Propagating rollout value %s to datamodule", rollout) + trainer.datamodule.update_rollout(rollout = rollout) def on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: dict) -> None: """ @@ -40,40 +41,8 @@ def on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint : dict Checkpoint dictionary """ - LOGGER.warning('Updating rollout values from checkpoint.') self._update_rollout(trainer, pl_module, epoch = checkpoint['epoch'], step = checkpoint['global_step']) - # def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - # """ - # Update the rollout values in the datamodule when starting fitting. - - # Parameters - # ---------- - # trainer : pl.Trainer - # Pytorch Lightning trainer - # pl_module : pl.LightningModule - # Model - # """ - # LOGGER.warning('Updating rollout values when fit starts.') - # self._update_rollout(trainer, pl_module) - - # def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: - # """ - # Update the rollout values in the datamodule when setting up the trainer. - - # Parameters - # ---------- - # trainer : pl.Trainer - # Pytorch Lightning trainer - # pl_module : pl.LightningModule - # Model - # stage : str - # Stage of the training - # """ - # LOGGER.warning('Updating rollout values from setup.') - # self._update_rollout(trainer, pl_module) - - def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *a) -> None: """ Update the rollout values in the datamodule every validation epoch. @@ -88,8 +57,6 @@ def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningMo if trainer.sanity_checking: return - LOGGER.warning('Updating rollout values from validation epoch end.') - # Offset of 1 needed as the epoch counter does not increment # until after the epoch ends. self._update_rollout(trainer, pl_module, epoch = trainer.current_epoch + 1) diff --git a/training/src/anemoi/training/schedulers/rollout/__init__.py b/training/src/anemoi/training/schedulers/rollout/__init__.py index c6373851..378f2d98 100644 --- a/training/src/anemoi/training/schedulers/rollout/__init__.py +++ b/training/src/anemoi/training/schedulers/rollout/__init__.py @@ -159,7 +159,7 @@ def count(self, n_steps: int | None = None, n_epochs: int | None = None) -> int: Returns ------- int - Count of steps or epochs. + Count of steps or epochs, rounded down. Raises ------ From 3ac7dcdea5dc20265e7a9b700d3aa50479f33829 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 20 Dec 2024 15:11:50 +0000 Subject: [PATCH 10/11] Add tests --- .../training/schedulers/rollout/indexed.py | 14 +- .../training/schedulers/rollout/randomise.py | 8 +- .../src/anemoi/training/train/forecaster.py | 18 +-- .../diagnostics/callbacks/test_rollout.py | 88 ++++++++++++ training/tests/schedulers/rollout/__init__.py | 3 + .../tests/schedulers/rollout/test_indexed.py | 43 +++++- .../tests/schedulers/rollout/test_random.py | 130 ++++++++++++++++++ .../tests/schedulers/rollout/test_rollout.py | 50 +++++++ .../tests/schedulers/rollout/test_stepped.py | 62 +++++++++ 9 files changed, 396 insertions(+), 20 deletions(-) create mode 100644 training/tests/diagnostics/callbacks/test_rollout.py diff --git a/training/src/anemoi/training/schedulers/rollout/indexed.py b/training/src/anemoi/training/schedulers/rollout/indexed.py index 08eb43e2..e26ac9ce 100644 --- a/training/src/anemoi/training/schedulers/rollout/indexed.py +++ b/training/src/anemoi/training/schedulers/rollout/indexed.py @@ -22,6 +22,8 @@ def get_closest_key(dictionary: dict[int, Any], key: int) -> int: Where the closest key is the one with the smallest absolute difference and the key is less than or equal to the given key. + If no lower key is found, returns -1. + Parameters ---------- dictionary : dict[int, Any] @@ -34,7 +36,10 @@ def get_closest_key(dictionary: dict[int, Any], key: int) -> int: int Closest key in the dictionary. """ - return min(dictionary.keys(), key=lambda x: abs(x - key) if x <= key else float("inf")) + lowest_key = min(dictionary.keys(), key=lambda x: abs(x - key) if x <= key else float("inf")) + if key < lowest_key: + return -1 + return lowest_key class PositionalIndexed(RolloutScheduler): @@ -99,6 +104,8 @@ def rollout(self) -> int: def maximum_rollout(self) -> int: return max(self._rollouts) + def description(self): + return f"PositionalIndexed with rollouts {self._rollouts} and num_times_per_{self._step_type} {self._num_times_per_element} ." class EpochPositionalIndexed(PositionalIndexed): """Epoch based PositionalIndexed.""" @@ -128,6 +135,8 @@ def __init__(self, rollouts: dict[int, int], step_type: Literal["step", "epoch"] It will return the closest key that is less than or equal to the current epoch or step. + If there is no key lower then the index, defaults to 1. + Parameters ---------- rollouts : dict[int, int] @@ -165,6 +174,9 @@ def rollout(self) -> int: @property def maximum_rollout(self) -> int: return max(self._rollouts.values()) + + def description(self): + return f"Lookup with rollouts {self._rollouts} based on {self._step_type}." class EpochLookup(Lookup): diff --git a/training/src/anemoi/training/schedulers/rollout/randomise.py b/training/src/anemoi/training/schedulers/rollout/randomise.py index 9aff3b7e..2ae0fc0b 100644 --- a/training/src/anemoi/training/schedulers/rollout/randomise.py +++ b/training/src/anemoi/training/schedulers/rollout/randomise.py @@ -43,7 +43,7 @@ def __init__(self): @property def rng(self): - return np.random.default_rng(hash((self._rnd_seed, self._epoch, self._step))) + return np.random.default_rng(abs(hash((self._rnd_seed, self._epoch, self._step)))) def broadcast(self, value: int) -> None: """ @@ -130,7 +130,7 @@ def __init__(self, minimum: int = 1, maximum: int = 1, step: int = 1): minimum : int, optional Minimum rollout to choose from, by default 1 maximum : int, optional - Maximum rollout to choose from, by default 1 + Maximum rollout to choose from, inclusive., by default 1 step : int, optional Step size for the range, by default 1 @@ -146,7 +146,7 @@ def __init__(self, minimum: int = 1, maximum: int = 1, step: int = 1): # any value between 1 and 5 ``` """ - super().__init__(list(range(minimum, maximum + 1, step))) + super().__init__(range(minimum, maximum + 1, step)) def description(self) -> str: return ( @@ -217,7 +217,7 @@ def rollout(self) -> int: if self._every_n == 0: return self._minimum - rollouts = range(self._minimum, self.current_maximum, self._range_step) + rollouts = range(self._minimum, self.current_maximum + 1, self._range_step) return self._randomly_pick(rollouts) diff --git a/training/src/anemoi/training/train/forecaster.py b/training/src/anemoi/training/train/forecaster.py index 88f2bb61..0c69f470 100644 --- a/training/src/anemoi/training/train/forecaster.py +++ b/training/src/anemoi/training/train/forecaster.py @@ -609,25 +609,18 @@ def calculate_val_metrics( def on_train_start(self): # Sync the rollout at the start of training - print("Rollout at start of training", int(self.rollout), self.rollout._epoch, self.rollout._step) self.rollout.sync(step = self.global_step, epoch = self.current_epoch) def on_load_checkpoint(self, checkpoint: dict): - # Sync the rollout at the start of training - print("Rollout at on_load_checkpoint", int(self.rollout), self.rollout._epoch, self.rollout._step) + # Sync the rollout on the load of a checkpoint self.rollout.sync(step = checkpoint["global_step"], epoch = checkpoint["epoch"]) def on_train_epoch_start(self): + # Sync the rollout at the start of each epoch + # Cannot use stepping due to inconsistent behaviour with Pytorch Lightning self.rollout.sync(step = self.global_step, epoch = self.current_epoch) - LOGGER.warning(f"Rollout at start of training, {int(self.rollout)}, {self.rollout._epoch}, {self.rollout._step}") - - def on_validation_epoch_start(self): - LOGGER.warning(f"Rollout at start of validation, {int(self.rollout)}, {self.rollout._epoch}, {self.rollout._step}") + LOGGER.debug(f"Rollout at start of training epoch {self.current_epoch}: {int(self.rollout)}.") - def on_validation_epoch_end(self) -> None: - # if not self.trainer.sanity_checking: - # self.rollout_epoch_step() - LOGGER.warning(f"Rollout at end of validation, {int(self.rollout)}, {self.rollout._epoch}, {self.rollout._step}") def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: train_loss, _, _ = self._step(batch, batch_idx) @@ -666,7 +659,6 @@ def lr_scheduler_step(self, scheduler: CosineLRScheduler, metric: None = None) - del metric scheduler.step(epoch=self.trainer.global_step) - def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: """ Calculate the loss over a validation batch using the training loss function. @@ -711,8 +703,6 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: return val_loss, y_preds - - def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict]]: if self.use_zero_optimizer: optimizer = ZeroRedundancyOptimizer( diff --git a/training/tests/diagnostics/callbacks/test_rollout.py b/training/tests/diagnostics/callbacks/test_rollout.py new file mode 100644 index 00000000..cc09d5fe --- /dev/null +++ b/training/tests/diagnostics/callbacks/test_rollout.py @@ -0,0 +1,88 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from typing import Any + +import pytest + +from anemoi.training.diagnostics.callbacks.rollout import UpdateRollout +from anemoi.training.train.train import AnemoiTrainer +from anemoi.training.train.forecaster import GraphForecaster + +from anemoi.training.schedulers.rollout import RolloutScheduler + +class DebugScheduler(RolloutScheduler): + @property + def rollout(self): + return self._epoch + + @property + def maximum_rollout(self): + return self._epoch + + def description(self): + return "DebugScheduler" + +@pytest.fixture +def fake_trainer(mocker: Any) -> AnemoiTrainer: + trainer = mocker.Mock(spec=AnemoiTrainer) + + trainer.datamodule.update_rollout = mocker.patch('anemoi.training.data.datamodule.AnemoiDatasetsDataModule.update_rollout') + return trainer + +@pytest.fixture +def fake_forecaster(mocker: Any) -> GraphForecaster: + model = mocker.Mock(spec=GraphForecaster) + + model.rollout = DebugScheduler() + return model + +@pytest.fixture +def checkpoint(mocker: Any) -> dict[str, int]: + return {"epoch": 10, "global_step":100} + + +@pytest.fixture +def callback() -> UpdateRollout: + callback = UpdateRollout() + assert callback is not None + assert hasattr(callback, "on_load_checkpoint") + assert hasattr(callback, "on_validation_epoch_end") + + return callback + + +def test_on_load_checkpoint( + fake_trainer: AnemoiTrainer, + fake_forecaster: GraphForecaster, + callback: UpdateRollout, + checkpoint: dict, +) -> None: + callback.on_load_checkpoint(fake_trainer, fake_forecaster, checkpoint) + spy = fake_trainer.datamodule.update_rollout + + spy.assert_called_once_with(rollout = checkpoint["epoch"]) + + +def test_on_validation_epoch_sanity(fake_trainer: AnemoiTrainer, fake_forecaster: GraphForecaster, callback: UpdateRollout) -> None: + fake_trainer.current_epoch = 10 + fake_trainer.sanity_checking = True + spy = fake_trainer.datamodule.update_rollout + + callback.on_validation_epoch_end(fake_trainer, fake_forecaster, None) + spy.assert_not_called() + +def test_on_validation_epoch(fake_trainer: AnemoiTrainer, fake_forecaster: GraphForecaster, callback: UpdateRollout) -> None: + fake_trainer.current_epoch = 10 + spy = fake_trainer.datamodule.update_rollout + fake_trainer.sanity_checking = False + + callback.on_validation_epoch_end(fake_trainer, fake_forecaster, None) + + spy.assert_called_once_with(rollout = 11) #Offset 1 \ No newline at end of file diff --git a/training/tests/schedulers/rollout/__init__.py b/training/tests/schedulers/rollout/__init__.py index c167afa2..1ef0eb8d 100644 --- a/training/tests/schedulers/rollout/__init__.py +++ b/training/tests/schedulers/rollout/__init__.py @@ -6,3 +6,6 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + + + diff --git a/training/tests/schedulers/rollout/test_indexed.py b/training/tests/schedulers/rollout/test_indexed.py index dfc43718..669aefc8 100644 --- a/training/tests/schedulers/rollout/test_indexed.py +++ b/training/tests/schedulers/rollout/test_indexed.py @@ -8,4 +8,45 @@ # nor does it submit to any jurisdiction. -from anemoi.training.schedulers.rollout.indexed import PositionalIndexed, Lookup \ No newline at end of file +import pytest + +from anemoi.training.schedulers.rollout.indexed import PositionalIndexed, Lookup + +@pytest.mark.parametrize( + "rollouts, num_times_per_element, test_epoch, expected", + [ + ([1, 2, 3], 1, 0, 1), + ([1, 2, 3], 1, 1, 2), + ([1, 2, 3], 2, 0, 1), + ([1, 2, 3], 2, 2, 2), + ([4, 5, 6], 1, 0, 4), + ([4, 5, 6], 1, 0, 4), + ([4, 5, 6], 2, 0, 4), + ([4, 5, 6], 2, 2, 5), + ] +) +def test_positional(rollouts: list[int], num_times_per_element: int, test_epoch: int, expected: int): + sched = PositionalIndexed(rollouts, num_times_per_element) + assert sched.rollout == rollouts[0] + assert sched.maximum_rollout == max(rollouts) + + assert sched.rollout_at(epoch = test_epoch) == expected + +@pytest.mark.parametrize( + "rollouts, test_epoch, expected", + [ + ({0:1, 1:2, 2:3}, 0, 1), + ({0:1, 1:2, 2:3}, 1, 2), + ({1:1, 5:2}, 1, 1), + ({1:1, 5:2}, 4, 1), + ({1:1, 5:2}, 5, 2), + ({1:1, 5:2}, 10, 2), + ({5:2}, 1, 1), + + ] +) +def test_lookup(rollouts: dict[int, int], test_epoch: int, expected: int): + sched = Lookup(rollouts) + assert sched.maximum_rollout == max(rollouts.values()) + + assert sched.rollout_at(epoch = test_epoch) == expected diff --git a/training/tests/schedulers/rollout/test_random.py b/training/tests/schedulers/rollout/test_random.py index c167afa2..7d5ff6e7 100644 --- a/training/tests/schedulers/rollout/test_random.py +++ b/training/tests/schedulers/rollout/test_random.py @@ -6,3 +6,133 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + +from typing import Any + +import pytest +from unittest.mock import patch + +from anemoi.training.schedulers.rollout.randomise import RandomList, RandomRange, IncreasingRandom, BaseRandom + + +def test_determism(): + sched = RandomList([1, 2, 3]) + sched_1 = RandomList([1, 2, 3]) + + sched.rollout # Force a retrieval to try and break the determinism + + for i in range(100): + sched.sync(epoch = i) + sched_1.sync(epoch = i) + + assert sched.rollout == sched_1.rollout + +@pytest.mark.parametrize( + "rollouts", + [ + [1, 2, 3], + [1, 2, 3, 4], + [1, 2, 3, 4, 5], + [16, 2, 3, 4, 5], + ] +) +@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps = RandomList([0])._randomly_pick) +def test_random_list(pick_mock: Any, rollouts: list[int]): + sched = RandomList(rollouts) + assert sched.rollout in rollouts + assert sched.maximum_rollout == max(rollouts) + + pick_mock.assert_called_once_with(rollouts) + +@pytest.mark.parametrize( + "minimum, maximum, step", + [ + (1, 10, 1), + (1, 10, 2), + (1, 10, 3), + ] +) +@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps = RandomList([0])._randomly_pick) +def test_random_range(pick_mock: Any, minimum: int, maximum: int, step: int): + sched = RandomRange(minimum, maximum, step) + assert sched.rollout in range(minimum, maximum + 1, step) + assert sched.maximum_rollout == max(range(minimum, maximum + 1, step)) + + pick_mock.assert_called_once_with(range(minimum, maximum + 1, step)) + + +@pytest.mark.parametrize( + "minimum, maximum, step, every_n, epoch_test, expected_max", + [ + (1, 10, 1, 1, 0, 1), + (1, 10, 1, 1, 1, 2), + (1, 10, 1, 1, 2, 3), + (1, 10, 1, 1, 10, 10), + (1, 10, 1, 1, 100, 10), + (1, 10, 1, 2, 2, 2), + (1, 10, 1, 2, 4, 3), + (1, 10, 2, 2, 4, 3), + + ] +) +@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps = RandomList([0])._randomly_pick) +def test_increasing_random_increment(pick_mock: Any, minimum: int, maximum: int, step: int, every_n: int, epoch_test: int, expected_max: int): + sched = IncreasingRandom(minimum, maximum, step, every_n, 1) + + sched.sync(epoch = epoch_test) + + assert sched.current_maximum == expected_max + assert sched.rollout in list(range(minimum, expected_max + 1, step)) + assert sched.maximum_rollout == maximum + + pick_mock.assert_called_once_with(range(minimum, expected_max + 1, step)) + + +@pytest.mark.parametrize( + "minimum, maximum, step, every_n, increment, epoch_test, expected_max", + [ + (1, 10, 1, 1, {0:0, 2:1, 4:2,}, 0, 1), + (1, 10, 1, 1, {0:0, 2:1, 4:2,}, 1, 1), + (1, 10, 1, 1, {0:0, 2:1, 4:2,}, 2, 2), + (1, 10, 1, 1, {0:0, 2:1, 4:2,}, 3, 3), + (1, 10, 1, 1, {0:0, 2:1, 4:2,}, 4, 5), + (1, 10, 1, 1, {0:0, 2:1, 4:2,}, 5, 7), + (1, 10, 1, 1, {0:0, 2:1, 3:0, 4:2,}, 4, 4), + (1, 10, 1, 1, {0:0, 2:1, 3:0, 4:2,}, 5, 6), + (1, 10, 1, 1, {0:0, 2:1, 3:0, 4:2,}, 1000, 10), + (1, 10, 2, 1, {0:0, 2:1, 3:0, 4:2,}, 1000, 10), + ] +) +@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps = RandomList([0])._randomly_pick) +def test_increasing_random_complex_increment(pick_mock: Any, minimum: int, maximum: int, step: int, every_n: int, increment: dict[int, int], epoch_test: int, expected_max: int): + + sched = IncreasingRandom(minimum, maximum, step, every_n, increment=increment) + + sched.sync(epoch = epoch_test) + assert sched.rollout in list(range(minimum, expected_max + 1, step)) + assert sched.current_maximum == expected_max + pick_mock.assert_called_with(range(minimum, expected_max + 1, step)) + + +@pytest.mark.parametrize( + "minimum, maximum, step, every_n, increment, epoch_test, expected_max", + [ + (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 0, 1), + (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 1, 1), + (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 2, 2), + (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 3, 2), + (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 4, 4), + (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 5, 4), + (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 6, 6), + ] +) +@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps = RandomList([0])._randomly_pick) +def test_increasing_random_complex_increment_every_not_1(pick_mock: Any, minimum: int, maximum: int, step: int, every_n: int, increment: dict[int, int], epoch_test: int, expected_max: int): + + sched = IncreasingRandom(minimum, maximum, step, every_n, increment=increment) + + sched.sync(epoch = epoch_test) + assert sched.rollout in list(range(minimum, expected_max + 1, step)) + assert sched.current_maximum == expected_max + pick_mock.assert_called_with(range(minimum, expected_max + 1, step)) + diff --git a/training/tests/schedulers/rollout/test_rollout.py b/training/tests/schedulers/rollout/test_rollout.py index c167afa2..a50ff4c0 100644 --- a/training/tests/schedulers/rollout/test_rollout.py +++ b/training/tests/schedulers/rollout/test_rollout.py @@ -6,3 +6,53 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + +from anemoi.training.schedulers.rollout import RolloutScheduler, Static + +class DebugScheduler(RolloutScheduler): + @property + def rollout(self): + return self._epoch + + @property + def maximum_rollout(self): + return self._epoch + + def description(self): + return "DebugScheduler" + +def test_static(): + sched = Static(1) + assert sched.rollout == 1 + assert sched.maximum_rollout == 1 + assert sched.current_maximum == 1 + + +def test_at(): + sched = DebugScheduler() + + with sched.at(epoch = 1): + assert sched.rollout == 1 + assert sched.maximum_rollout == 1 + assert sched.current_maximum == 1 + + assert sched.rollout == 0 + +def test_sync(): + sched = DebugScheduler() + sched.sync(epoch = 10) + assert sched.rollout == 10 + + +def test_count(): + sched = DebugScheduler() + sched.sync(epoch = 10) + assert sched.count(n_epochs=5) == 2 + assert sched.count(n_epochs=3) == 3 + +def test_int_conversion(): + sched = DebugScheduler() + sched.sync(epoch = 10) + assert int(sched) == 10 + + diff --git a/training/tests/schedulers/rollout/test_stepped.py b/training/tests/schedulers/rollout/test_stepped.py index c167afa2..477e9814 100644 --- a/training/tests/schedulers/rollout/test_stepped.py +++ b/training/tests/schedulers/rollout/test_stepped.py @@ -6,3 +6,65 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + +import pytest + +from anemoi.training.schedulers.rollout.stepped import Stepped + +@pytest.mark.parametrize( + "minimum, maximum, every_n, increment, epoch_test, expected_value", + [ + # Increment of 1 and every_n of 1 + (1, 10, 1, 1, 0, 1), + (1, 10, 1, 1, 1, 2), + (1, 10, 1, 1, 5, 6), + (1, 10, 1, 1, 6, 7), + (1, 10, 1, 1, 8, 9), + (1, 10, 1, 1, 9, 10), + (1, 10, 1, 1, 10, 10), + (1, 10, 1, 1, 11, 10), + (1, 10, 1, 1, 1000, 10), + + # Increment of 2 and every_n of 1 + (1, 10, 1, 2, 1, 3), + (1, 10, 1, 2, 2, 5), + (1, 10, 1, 2, 4, 9), + (1, 10, 1, 2, 5, 10), + + # Increment of 1 and every_n of 2 + (1, 10, 2, 1, 0, 1), + (1, 10, 2, 1, 1, 1), + (1, 10, 2, 1, 2, 2), + + ] +) +def test_stepped(minimum: int, maximum: int, every_n: int, increment: int, epoch_test: int, expected_value: int): + sched = Stepped(minimum, maximum, every_n, increment=increment) + + sched.sync(epoch = epoch_test) + assert sched.rollout == expected_value + assert sched.current_maximum == expected_value + + + +@pytest.mark.parametrize( + "minimum, maximum, every_n, increment, epoch_test, expected_value", + [ + (1, 10, 1, {0:0, 2:1, 4:2,}, 0, 1), + (1, 10, 1, {0:0, 2:1, 4:2,}, 1, 1), + (1, 10, 1, {0:0, 2:1, 4:2,}, 2, 2), + (1, 10, 1, {0:0, 2:1, 4:2,}, 3, 3), + (1, 10, 1, {0:0, 2:1, 4:2,}, 4, 5), + (1, 10, 1, {0:0, 2:1, 4:2,}, 5, 7), + (1, 10, 1, {0:0, 2:1, 3:0, 4:2,}, 4, 4), + (1, 10, 1, {0:0, 2:1, 3:0, 4:2,}, 5, 6), + (1, 10, 1, {0:0, 2:1, 3:0, 4:2,}, 1000, 10), + ] +) +def test_stepped_complex_increment(minimum: int, maximum: int, every_n: int, increment: dict[int, int], epoch_test: int, expected_value: int): + + sched = Stepped(minimum, maximum, every_n, increment=increment) + + sched.sync(epoch = epoch_test) + assert sched.rollout == expected_value + assert sched.current_maximum == expected_value From 71a9e08235b75eed5eed2c037293684200c25294 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 20 Dec 2024 15:35:40 +0000 Subject: [PATCH 11/11] pre-commit --- .../src/anemoi/training/data/datamodule.py | 2 +- training/src/anemoi/training/data/dataset.py | 9 +- .../diagnostics/callbacks/__init__.py | 2 +- .../training/diagnostics/callbacks/plot.py | 1 - .../training/diagnostics/callbacks/rollout.py | 18 ++- .../training/schedulers/rollout/__init__.py | 31 ++-- .../training/schedulers/rollout/indexed.py | 27 +++- .../training/schedulers/rollout/randomise.py | 9 +- .../training/schedulers/rollout/stepped.py | 16 +- .../src/anemoi/training/train/forecaster.py | 20 +-- training/src/anemoi/training/train/train.py | 11 +- .../diagnostics/callbacks/test_rollout.py | 46 ++++-- training/tests/schedulers/rollout/__init__.py | 3 - .../tests/schedulers/rollout/test_indexed.py | 46 +++--- .../tests/schedulers/rollout/test_random.py | 142 ++++++++++-------- .../tests/schedulers/rollout/test_rollout.py | 39 ++--- .../tests/schedulers/rollout/test_stepped.py | 78 +++++++--- 17 files changed, 294 insertions(+), 206 deletions(-) diff --git a/training/src/anemoi/training/data/datamodule.py b/training/src/anemoi/training/data/datamodule.py index 54be3dd7..7fee5f6d 100644 --- a/training/src/anemoi/training/data/datamodule.py +++ b/training/src/anemoi/training/data/datamodule.py @@ -166,7 +166,7 @@ def update_rollout(self, rollout: int) -> None: ---------- rollout : int Rollout value - """ + """ for ds in [self.ds_train, self.ds_test]: ds.update_rollout(rollout) diff --git a/training/src/anemoi/training/data/dataset.py b/training/src/anemoi/training/data/dataset.py index 315829d5..c42da39c 100644 --- a/training/src/anemoi/training/data/dataset.py +++ b/training/src/anemoi/training/data/dataset.py @@ -13,7 +13,7 @@ import os import random from functools import cached_property -from typing import TYPE_CHECKING, Self +from typing import TYPE_CHECKING from typing import Callable import numpy as np @@ -144,7 +144,7 @@ def update_rollout(self, rollout: int) -> None: return self.rollout = rollout - LOGGER.debug(f"Updating rollout of {self.label} dataset to {self.rollout}") + LOGGER.debug("Updating rollout of %s dataset to %d", self.label, self.rollout) if hasattr(self, "valid_date_indices"): del self.valid_date_indices @@ -239,10 +239,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: sanity_rnd = self.rng.random(1) LOGGER.debug( - ( - "Worker %d (%s, pid %d, glob. rank %d, model comm group %d, " - "group_rank %d, base_seed %d), sanity rnd %f" - ), + ("Worker %d (%s, pid %d, glob. rank %d, model comm group %d, group_rank %d, base_seed %d), sanity rnd %f"), worker_id, self.label, os.getpid(), diff --git a/training/src/anemoi/training/diagnostics/callbacks/__init__.py b/training/src/anemoi/training/diagnostics/callbacks/__init__.py index 16fbcd2a..16e490aa 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/training/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -23,8 +23,8 @@ from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback -from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder from anemoi.training.diagnostics.callbacks.rollout import UpdateRollout +from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder if TYPE_CHECKING: from pytorch_lightning.callbacks import Callback diff --git a/training/src/anemoi/training/diagnostics/callbacks/plot.py b/training/src/anemoi/training/diagnostics/callbacks/plot.py index aebba10e..4f3168e2 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/training/src/anemoi/training/diagnostics/callbacks/plot.py @@ -1164,7 +1164,6 @@ def _plot( data, output_tensor = self.process(pl_module, outputs, batch) for rollout_step in range(pl_module.rollout): - # Build dictionary of inidicies and parameters to be plotted diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic diff --git a/training/src/anemoi/training/diagnostics/callbacks/rollout.py b/training/src/anemoi/training/diagnostics/callbacks/rollout.py index 28fb1dd2..44cf6f02 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/rollout.py +++ b/training/src/anemoi/training/diagnostics/callbacks/rollout.py @@ -7,6 +7,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from __future__ import annotations + import logging import pytorch_lightning as pl @@ -20,13 +22,19 @@ class UpdateRollout(pl.callbacks.Callback): def __init__(self) -> None: super().__init__() - def _update_rollout(self, trainer, pl_module, epoch: int | None = None, step: int | None = None) -> None: + def _update_rollout( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + epoch: int | None = None, + step: int | None = None, + ) -> None: rollsched = pl_module.rollout with rollsched.at(epoch=epoch, step=step): rollout = rollsched.current_maximum LOGGER.debug("Propagating rollout value %s to datamodule", rollout) - trainer.datamodule.update_rollout(rollout = rollout) + trainer.datamodule.update_rollout(rollout=rollout) def on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: dict) -> None: """ @@ -41,9 +49,9 @@ def on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint : dict Checkpoint dictionary """ - self._update_rollout(trainer, pl_module, epoch = checkpoint['epoch'], step = checkpoint['global_step']) + self._update_rollout(trainer, pl_module, epoch=checkpoint["epoch"], step=checkpoint["global_step"]) - def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *a) -> None: + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *_) -> None: """ Update the rollout values in the datamodule every validation epoch. @@ -59,4 +67,4 @@ def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningMo # Offset of 1 needed as the epoch counter does not increment # until after the epoch ends. - self._update_rollout(trainer, pl_module, epoch = trainer.current_epoch + 1) + self._update_rollout(trainer, pl_module, epoch=trainer.current_epoch + 1) diff --git a/training/src/anemoi/training/schedulers/rollout/__init__.py b/training/src/anemoi/training/schedulers/rollout/__init__.py index 378f2d98..e54d2041 100644 --- a/training/src/anemoi/training/schedulers/rollout/__init__.py +++ b/training/src/anemoi/training/schedulers/rollout/__init__.py @@ -57,21 +57,21 @@ def maximum_rollout(self) -> int: @property def current_maximum(self) -> int: """Get the current maximum rollout value. - + Allows for dataloader to only get the data neccessary. Most cases this is just the current rollout. """ return self.rollout def __int__(self) -> int: - """Get rollout value as int""" + """Get rollout value as int.""" return int(self.rollout) - + def __index__(self) -> int: - """Get rollout value as index""" + """Get rollout value as index.""" return int(self.rollout) - - def at(self, step: int | None = None, epoch: int | None = None) -> FrozenStateRecord: + + def at(self, step: int | None = None, epoch: int | None = None) -> FrozenStateRecord: # noqa: F821 """ Temporarily hold the scheduler at a specific step and epoch. @@ -103,7 +103,7 @@ def __exit__(context_self, *a): # noqa: N805 self._step = step if step is not None else prior_step self._epoch = epoch if epoch is not None else prior_epoch return FrozenStateRecord() - + def rollout_at(self, step: int | None = None, epoch: int | None = None) -> int: """ Get the rollout at a specific step and epoch. @@ -131,9 +131,9 @@ def step_epoch(self, count: int = 1, /) -> None: """Step the scheduler by a count of epochs.""" self._epoch += count - def sync(self, step: int = None, epoch: int = None): + def sync(self, step: int | None = None, epoch: int | None = None) -> None: """ - Sync state of the Rollout Scheduler + Sync state of the Rollout Scheduler. Parameters ---------- @@ -151,10 +151,10 @@ def count(self, n_steps: int | None = None, n_epochs: int | None = None) -> int: Parameters ---------- - n_epochs : int | None, optional - Number of epochs to count, by default None n_steps : int | None, optional Number of steps to count, by default None + n_epochs : int | None, optional + Number of epochs to count, by default None Returns ------- @@ -179,18 +179,23 @@ def description(self) -> str: """Description of the rollout scheduler.""" error_msg = "`description` method not implemented by parent class." raise NotImplementedError(error_msg) - - # Mathematical operations + + # Mathematical operations def __add__(self, other: int) -> int: return self.rollout + other + def __radd__(self, other: int) -> int: return other + self.rollout + def __sub__(self, other: int) -> int: return self.rollout - other + def __rsub__(self, other: int) -> int: return other - self.rollout + def __mul__(self, other: int) -> int: return self.rollout * other + def __rmul__(self, other: int) -> int: return other * self.rollout diff --git a/training/src/anemoi/training/schedulers/rollout/indexed.py b/training/src/anemoi/training/schedulers/rollout/indexed.py index e26ac9ce..75132cd7 100644 --- a/training/src/anemoi/training/schedulers/rollout/indexed.py +++ b/training/src/anemoi/training/schedulers/rollout/indexed.py @@ -7,11 +7,10 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import warnings from typing import Any from typing import Literal -import warnings - from anemoi.training.schedulers.rollout import RolloutScheduler @@ -104,8 +103,12 @@ def rollout(self) -> int: def maximum_rollout(self) -> int: return max(self._rollouts) - def description(self): - return f"PositionalIndexed with rollouts {self._rollouts} and num_times_per_{self._step_type} {self._num_times_per_element} ." + def description(self) -> str: + return ( + f"PositionalIndexed with rollouts {self._rollouts} and num_times_per_{self._step_type} " + f"{self._num_times_per_element}." + ) + class EpochPositionalIndexed(PositionalIndexed): """Epoch based PositionalIndexed.""" @@ -118,7 +121,11 @@ class StepPositionalIndexed(PositionalIndexed): """Step based PositionalIndexed.""" def __init__(self, rollouts: list[int]): - warnings.warn(f"Pytorch Lightning datamodules can only be refreshed at the end of an epoch, adjusting the rollout during an epoch will likely fail.", UserWarning) + warnings.warn( + "Pytorch Lightning datamodules can only be refreshed at the end of an epoch, " + "adjusting the rollout during an epoch will likely fail.", + UserWarning, + ) super().__init__(rollouts, step_type="step") @@ -174,8 +181,8 @@ def rollout(self) -> int: @property def maximum_rollout(self) -> int: return max(self._rollouts.values()) - - def description(self): + + def description(self) -> str: return f"Lookup with rollouts {self._rollouts} based on {self._step_type}." @@ -190,5 +197,9 @@ class StepLookup(Lookup): """Step based Lookup.""" def __init__(self, rollouts: dict[int, int]): - warnings.warn(f"Pytorch Lightning datamodules can only be refreshed at the end of an epoch, adjusting the rollout during an epoch will likely fail.", UserWarning) + warnings.warn( + "Pytorch Lightning datamodules can only be refreshed at the end of an epoch, " + "adjusting the rollout during an epoch will likely fail.", + UserWarning, + ) super().__init__(rollouts, step_type="step") diff --git a/training/src/anemoi/training/schedulers/rollout/randomise.py b/training/src/anemoi/training/schedulers/rollout/randomise.py index 2ae0fc0b..e4261cde 100644 --- a/training/src/anemoi/training/schedulers/rollout/randomise.py +++ b/training/src/anemoi/training/schedulers/rollout/randomise.py @@ -42,7 +42,8 @@ def __init__(self): self._rnd_seed = pl.seed_everything(seed, workers=True) @property - def rng(self): + def rng(self) -> np.random.Generator: + """Get `np.rng` object, seeded off epoch and step.""" return np.random.default_rng(abs(hash((self._rnd_seed, self._epoch, self._step)))) def broadcast(self, value: int) -> None: @@ -358,6 +359,10 @@ def __init__( # any value between 1 and 2, and then increments of 1 ``` """ - warnings.warn(f"Pytorch Lightning datamodules can only be refreshed at the end of an epoch, adjusting the rollout during an epoch will likely fail.", UserWarning) + warnings.warn( + "Pytorch Lightning datamodules can only be refreshed at the end of an epoch, " + "adjusting the rollout during an epoch will likely fail.", + UserWarning, + ) super().__init__(minimum, maximum, range_step, every_n_steps, increment, step_type="step") diff --git a/training/src/anemoi/training/schedulers/rollout/stepped.py b/training/src/anemoi/training/schedulers/rollout/stepped.py index de182105..66b682dc 100644 --- a/training/src/anemoi/training/schedulers/rollout/stepped.py +++ b/training/src/anemoi/training/schedulers/rollout/stepped.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. from __future__ import annotations -from typing import Literal import warnings +from typing import Literal from anemoi.training.schedulers.rollout import RolloutScheduler from anemoi.training.schedulers.rollout.indexed import get_closest_key @@ -66,8 +66,7 @@ def increment(self, step: int, epoch: int) -> int: ValueError If cannot parse the `increment` value given at init. """ - - count = (step // self._every_n if self._step_type == "step" else epoch // self._every_n) + count = step // self._every_n if self._step_type == "step" else epoch // self._every_n if isinstance(self._increment, int): return self._increment * count @@ -87,7 +86,10 @@ def increment(self, step: int, epoch: int) -> int: if increment_step_type == self._step_type: return sum( - (increment_dict.get(get_closest_key(increment_dict, i * self._every_n), 0) for i in range(count + 1)), + ( + increment_dict.get(get_closest_key(increment_dict, i * self._every_n), 0) + for i in range(count + 1) + ), ) if epoch == 0 or step == 0: @@ -261,5 +263,9 @@ def __init__(self, minimum: int, maximum: int, every_n_steps: int = 1000, increm i.e. {0: 1, 10: 2} will increment by 1 until 10, then by 2. by default 1. """ - warnings.warn(f"Pytorch Lightning datamodules can only be refreshed at the end of an epoch, adjusting the rollout during an epoch will likely fail.", UserWarning) + warnings.warn( + "Pytorch Lightning datamodules can only be refreshed at the end of an epoch, " + "adjusting the rollout during an epoch will likely fail.", + UserWarning, + ) super().__init__(minimum, maximum, every_n_steps, increment, step_type="step") diff --git a/training/src/anemoi/training/train/forecaster.py b/training/src/anemoi/training/train/forecaster.py index 0c69f470..754d9b84 100644 --- a/training/src/anemoi/training/train/forecaster.py +++ b/training/src/anemoi/training/train/forecaster.py @@ -263,7 +263,6 @@ def training_weights_for_imputed_variables( @staticmethod def get_val_metric_ranges(config: DictConfig, data_indices: IndexCollection) -> tuple[dict, dict]: - metric_ranges = defaultdict(list) metric_ranges_validation = defaultdict(list) @@ -484,7 +483,6 @@ def _step( metrics = {} y_preds = [] - # print('Rollout', int(self.rollout)) for loss_next, metrics_next, y_preds_next in self.rollout_step( batch, rollout=int(self.rollout), @@ -607,20 +605,19 @@ def calculate_val_metrics( return metrics - def on_train_start(self): + def on_train_start(self) -> None: # Sync the rollout at the start of training - self.rollout.sync(step = self.global_step, epoch = self.current_epoch) + self.rollout.sync(step=self.global_step, epoch=self.current_epoch) - def on_load_checkpoint(self, checkpoint: dict): + def on_load_checkpoint(self, checkpoint: dict) -> None: # Sync the rollout on the load of a checkpoint - self.rollout.sync(step = checkpoint["global_step"], epoch = checkpoint["epoch"]) + self.rollout.sync(step=checkpoint["global_step"], epoch=checkpoint["epoch"]) - def on_train_epoch_start(self): + def on_train_epoch_start(self) -> None: # Sync the rollout at the start of each epoch - # Cannot use stepping due to inconsistent behaviour with Pytorch Lightning - self.rollout.sync(step = self.global_step, epoch = self.current_epoch) - LOGGER.debug(f"Rollout at start of training epoch {self.current_epoch}: {int(self.rollout)}.") - + # Cannot use stepping due to inconsistent behaviour with Pytorch Lightning + self.rollout.sync(step=self.global_step, epoch=self.current_epoch) + LOGGER.debug("Rollout at start of training epoch %d: %d.", self.current_epoch, int(self.rollout)) def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: train_loss, _, _ = self._step(batch, batch_idx) @@ -674,7 +671,6 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: ------- None """ - with torch.no_grad(): val_loss, metrics, y_preds = self._step(batch, batch_idx, validation_mode=True) diff --git a/training/src/anemoi/training/train/train.py b/training/src/anemoi/training/train/train.py index 394f8c20..2973feeb 100644 --- a/training/src/anemoi/training/train/train.py +++ b/training/src/anemoi/training/train/train.py @@ -376,21 +376,16 @@ def strategy(self) -> DDPGroupStrategy: self.config.dataloader.get("read_group_size", self.config.hardware.num_gpus_per_model), static_graph=not self.config.training.accum_grad_batches > 1, ) - + @cached_property def _need_to_reload_dataloaders(self) -> bool: """Determines if the dataloaders need to be reloaded. - If the model's rollout scheduler is already at it's maximum, + If the model's rollout scheduler is already at it's maximum, the dataloaders do not need to be reloaded. - - Returns - ------- - bool - True if the dataloaders need to be reloaded, False otherwise. """ rollsched = self.model.rollout - + if rollsched.current_maximum == rollsched.maximum_rollout: return False LOGGER.info("Dataloaders will be reloaded every epoch to support dynamic rollout.") diff --git a/training/tests/diagnostics/callbacks/test_rollout.py b/training/tests/diagnostics/callbacks/test_rollout.py index cc09d5fe..56de369c 100644 --- a/training/tests/diagnostics/callbacks/test_rollout.py +++ b/training/tests/diagnostics/callbacks/test_rollout.py @@ -12,30 +12,34 @@ import pytest from anemoi.training.diagnostics.callbacks.rollout import UpdateRollout -from anemoi.training.train.train import AnemoiTrainer +from anemoi.training.schedulers.rollout import RolloutScheduler from anemoi.training.train.forecaster import GraphForecaster +from anemoi.training.train.train import AnemoiTrainer -from anemoi.training.schedulers.rollout import RolloutScheduler class DebugScheduler(RolloutScheduler): @property - def rollout(self): + def rollout(self) -> int: return self._epoch - + @property - def maximum_rollout(self): + def maximum_rollout(self) -> int: return self._epoch - - def description(self): + + def description(self) -> str: return "DebugScheduler" + @pytest.fixture def fake_trainer(mocker: Any) -> AnemoiTrainer: trainer = mocker.Mock(spec=AnemoiTrainer) - trainer.datamodule.update_rollout = mocker.patch('anemoi.training.data.datamodule.AnemoiDatasetsDataModule.update_rollout') + trainer.datamodule.update_rollout = mocker.patch( + "anemoi.training.data.datamodule.AnemoiDatasetsDataModule.update_rollout", + ) return trainer + @pytest.fixture def fake_forecaster(mocker: Any) -> GraphForecaster: model = mocker.Mock(spec=GraphForecaster) @@ -43,9 +47,10 @@ def fake_forecaster(mocker: Any) -> GraphForecaster: model.rollout = DebugScheduler() return model + @pytest.fixture -def checkpoint(mocker: Any) -> dict[str, int]: - return {"epoch": 10, "global_step":100} +def checkpoint() -> dict[str, int]: + return {"epoch": 10, "global_step": 100} @pytest.fixture @@ -66,23 +71,32 @@ def test_on_load_checkpoint( ) -> None: callback.on_load_checkpoint(fake_trainer, fake_forecaster, checkpoint) spy = fake_trainer.datamodule.update_rollout - - spy.assert_called_once_with(rollout = checkpoint["epoch"]) + + spy.assert_called_once_with(rollout=checkpoint["epoch"]) -def test_on_validation_epoch_sanity(fake_trainer: AnemoiTrainer, fake_forecaster: GraphForecaster, callback: UpdateRollout) -> None: +def test_on_validation_epoch_sanity( + fake_trainer: AnemoiTrainer, + fake_forecaster: GraphForecaster, + callback: UpdateRollout, +) -> None: fake_trainer.current_epoch = 10 fake_trainer.sanity_checking = True spy = fake_trainer.datamodule.update_rollout - + callback.on_validation_epoch_end(fake_trainer, fake_forecaster, None) spy.assert_not_called() -def test_on_validation_epoch(fake_trainer: AnemoiTrainer, fake_forecaster: GraphForecaster, callback: UpdateRollout) -> None: + +def test_on_validation_epoch( + fake_trainer: AnemoiTrainer, + fake_forecaster: GraphForecaster, + callback: UpdateRollout, +) -> None: fake_trainer.current_epoch = 10 spy = fake_trainer.datamodule.update_rollout fake_trainer.sanity_checking = False callback.on_validation_epoch_end(fake_trainer, fake_forecaster, None) - spy.assert_called_once_with(rollout = 11) #Offset 1 \ No newline at end of file + spy.assert_called_once_with(rollout=11) # Offset 1 diff --git a/training/tests/schedulers/rollout/__init__.py b/training/tests/schedulers/rollout/__init__.py index 1ef0eb8d..c167afa2 100644 --- a/training/tests/schedulers/rollout/__init__.py +++ b/training/tests/schedulers/rollout/__init__.py @@ -6,6 +6,3 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. - - - diff --git a/training/tests/schedulers/rollout/test_indexed.py b/training/tests/schedulers/rollout/test_indexed.py index 669aefc8..b7ce9530 100644 --- a/training/tests/schedulers/rollout/test_indexed.py +++ b/training/tests/schedulers/rollout/test_indexed.py @@ -10,43 +10,45 @@ import pytest -from anemoi.training.schedulers.rollout.indexed import PositionalIndexed, Lookup +from anemoi.training.schedulers.rollout.indexed import Lookup +from anemoi.training.schedulers.rollout.indexed import PositionalIndexed + @pytest.mark.parametrize( - "rollouts, num_times_per_element, test_epoch, expected", + ("rollouts", "num_times_per_element", "test_epoch", "expected"), [ - ([1, 2, 3], 1, 0, 1), - ([1, 2, 3], 1, 1, 2), - ([1, 2, 3], 2, 0, 1), - ([1, 2, 3], 2, 2, 2), - ([4, 5, 6], 1, 0, 4), + ([1, 2, 3], 1, 0, 1), + ([1, 2, 3], 1, 1, 2), + ([1, 2, 3], 2, 0, 1), + ([1, 2, 3], 2, 2, 2), ([4, 5, 6], 1, 0, 4), + ([4, 5, 6], 1, 1, 5), ([4, 5, 6], 2, 0, 4), ([4, 5, 6], 2, 2, 5), - ] + ], ) -def test_positional(rollouts: list[int], num_times_per_element: int, test_epoch: int, expected: int): +def test_positional(rollouts: list[int], num_times_per_element: int, test_epoch: int, expected: int) -> None: sched = PositionalIndexed(rollouts, num_times_per_element) assert sched.rollout == rollouts[0] assert sched.maximum_rollout == max(rollouts) - assert sched.rollout_at(epoch = test_epoch) == expected + assert sched.rollout_at(epoch=test_epoch) == expected + @pytest.mark.parametrize( - "rollouts, test_epoch, expected", + ("rollouts", "test_epoch", "expected"), [ - ({0:1, 1:2, 2:3}, 0, 1), - ({0:1, 1:2, 2:3}, 1, 2), - ({1:1, 5:2}, 1, 1), - ({1:1, 5:2}, 4, 1), - ({1:1, 5:2}, 5, 2), - ({1:1, 5:2}, 10, 2), - ({5:2}, 1, 1), - - ] + ({0: 1, 1: 2, 2: 3}, 0, 1), + ({0: 1, 1: 2, 2: 3}, 1, 2), + ({1: 1, 5: 2}, 1, 1), + ({1: 1, 5: 2}, 4, 1), + ({1: 1, 5: 2}, 5, 2), + ({1: 1, 5: 2}, 10, 2), + ({5: 2}, 1, 1), + ], ) -def test_lookup(rollouts: dict[int, int], test_epoch: int, expected: int): +def test_lookup(rollouts: dict[int, int], test_epoch: int, expected: int) -> None: sched = Lookup(rollouts) assert sched.maximum_rollout == max(rollouts.values()) - assert sched.rollout_at(epoch = test_epoch) == expected + assert sched.rollout_at(epoch=test_epoch) == expected diff --git a/training/tests/schedulers/rollout/test_random.py b/training/tests/schedulers/rollout/test_random.py index 7d5ff6e7..cc6d436f 100644 --- a/training/tests/schedulers/rollout/test_random.py +++ b/training/tests/schedulers/rollout/test_random.py @@ -8,25 +8,28 @@ # nor does it submit to any jurisdiction. from typing import Any +from unittest.mock import patch import pytest -from unittest.mock import patch -from anemoi.training.schedulers.rollout.randomise import RandomList, RandomRange, IncreasingRandom, BaseRandom +from anemoi.training.schedulers.rollout.randomise import IncreasingRandom +from anemoi.training.schedulers.rollout.randomise import RandomList +from anemoi.training.schedulers.rollout.randomise import RandomRange -def test_determism(): +def test_determism() -> None: sched = RandomList([1, 2, 3]) sched_1 = RandomList([1, 2, 3]) - sched.rollout # Force a retrieval to try and break the determinism + sched.rollout # Force a retrieval to try and break the determinism for i in range(100): - sched.sync(epoch = i) - sched_1.sync(epoch = i) + sched.sync(epoch=i) + sched_1.sync(epoch=i) assert sched.rollout == sched_1.rollout + @pytest.mark.parametrize( "rollouts", [ @@ -34,26 +37,27 @@ def test_determism(): [1, 2, 3, 4], [1, 2, 3, 4, 5], [16, 2, 3, 4, 5], - ] + ], ) -@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps = RandomList([0])._randomly_pick) -def test_random_list(pick_mock: Any, rollouts: list[int]): +@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps=RandomList([0])._randomly_pick) +def test_random_list(pick_mock: Any, rollouts: list[int]) -> None: sched = RandomList(rollouts) assert sched.rollout in rollouts assert sched.maximum_rollout == max(rollouts) pick_mock.assert_called_once_with(rollouts) + @pytest.mark.parametrize( - "minimum, maximum, step", + ("minimum", "maximum", "step"), [ (1, 10, 1), (1, 10, 2), (1, 10, 3), - ] + ], ) -@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps = RandomList([0])._randomly_pick) -def test_random_range(pick_mock: Any, minimum: int, maximum: int, step: int): +@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps=RandomList([0])._randomly_pick) +def test_random_range(pick_mock: Any, minimum: int, maximum: int, step: int) -> None: sched = RandomRange(minimum, maximum, step) assert sched.rollout in range(minimum, maximum + 1, step) assert sched.maximum_rollout == max(range(minimum, maximum + 1, step)) @@ -62,7 +66,7 @@ def test_random_range(pick_mock: Any, minimum: int, maximum: int, step: int): @pytest.mark.parametrize( - "minimum, maximum, step, every_n, epoch_test, expected_max", + ("minimum", "maximum", "step", "every_n", "epoch_test", "expected_max"), [ (1, 10, 1, 1, 0, 1), (1, 10, 1, 1, 1, 2), @@ -72,67 +76,79 @@ def test_random_range(pick_mock: Any, minimum: int, maximum: int, step: int): (1, 10, 1, 2, 2, 2), (1, 10, 1, 2, 4, 3), (1, 10, 2, 2, 4, 3), - - ] + ], ) -@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps = RandomList([0])._randomly_pick) -def test_increasing_random_increment(pick_mock: Any, minimum: int, maximum: int, step: int, every_n: int, epoch_test: int, expected_max: int): +@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps=RandomList([0])._randomly_pick) +def test_increasing_random_increment( + pick_mock: Any, + minimum: int, + maximum: int, + step: int, + every_n: int, + epoch_test: int, + expected_max: int, +) -> None: sched = IncreasingRandom(minimum, maximum, step, every_n, 1) - sched.sync(epoch = epoch_test) + sched.sync(epoch=epoch_test) assert sched.current_maximum == expected_max assert sched.rollout in list(range(minimum, expected_max + 1, step)) assert sched.maximum_rollout == maximum - - pick_mock.assert_called_once_with(range(minimum, expected_max + 1, step)) - -@pytest.mark.parametrize( - "minimum, maximum, step, every_n, increment, epoch_test, expected_max", - [ - (1, 10, 1, 1, {0:0, 2:1, 4:2,}, 0, 1), - (1, 10, 1, 1, {0:0, 2:1, 4:2,}, 1, 1), - (1, 10, 1, 1, {0:0, 2:1, 4:2,}, 2, 2), - (1, 10, 1, 1, {0:0, 2:1, 4:2,}, 3, 3), - (1, 10, 1, 1, {0:0, 2:1, 4:2,}, 4, 5), - (1, 10, 1, 1, {0:0, 2:1, 4:2,}, 5, 7), - (1, 10, 1, 1, {0:0, 2:1, 3:0, 4:2,}, 4, 4), - (1, 10, 1, 1, {0:0, 2:1, 3:0, 4:2,}, 5, 6), - (1, 10, 1, 1, {0:0, 2:1, 3:0, 4:2,}, 1000, 10), - (1, 10, 2, 1, {0:0, 2:1, 3:0, 4:2,}, 1000, 10), - ] -) -@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps = RandomList([0])._randomly_pick) -def test_increasing_random_complex_increment(pick_mock: Any, minimum: int, maximum: int, step: int, every_n: int, increment: dict[int, int], epoch_test: int, expected_max: int): + pick_mock.assert_called_once_with(range(minimum, expected_max + 1, step)) - sched = IncreasingRandom(minimum, maximum, step, every_n, increment=increment) - sched.sync(epoch = epoch_test) - assert sched.rollout in list(range(minimum, expected_max + 1, step)) - assert sched.current_maximum == expected_max - pick_mock.assert_called_with(range(minimum, expected_max + 1, step)) +INCREMENT_DICT = { + 0: 0, + 2: 1, + 4: 2, +} +INCREMENT_DICT_1 = { + 0: 0, + 2: 1, + 3: 0, + 4: 2, +} + +COMPLEX_INCREMENT_TESTS_EVERY_N_1 = [ + (1, INCREMENT_DICT, 0, 1), + (1, INCREMENT_DICT, 1, 1), + (1, INCREMENT_DICT, 2, 2), + (1, INCREMENT_DICT, 3, 3), + (1, INCREMENT_DICT, 4, 5), + (1, INCREMENT_DICT, 5, 7), + (1, INCREMENT_DICT_1, 4, 4), + (1, INCREMENT_DICT_1, 5, 6), + (1, INCREMENT_DICT_1, 1000, 10), +] + +COMPLEX_INCREMENT_TESTS_EVERY_N_2 = [ + (2, INCREMENT_DICT, 0, 1), + (2, INCREMENT_DICT, 1, 1), + (2, INCREMENT_DICT, 2, 2), + (2, INCREMENT_DICT, 3, 2), + (2, INCREMENT_DICT, 4, 4), + (2, INCREMENT_DICT, 5, 4), + (2, INCREMENT_DICT, 6, 6), +] @pytest.mark.parametrize( - "minimum, maximum, step, every_n, increment, epoch_test, expected_max", - [ - (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 0, 1), - (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 1, 1), - (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 2, 2), - (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 3, 2), - (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 4, 4), - (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 5, 4), - (1, 10, 1, 2, {0:0, 2:1, 4:2,}, 6, 6), - ] + ("every_n", "increment", "epoch_test", "expected_max"), + [*COMPLEX_INCREMENT_TESTS_EVERY_N_1, *COMPLEX_INCREMENT_TESTS_EVERY_N_2], ) -@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps = RandomList([0])._randomly_pick) -def test_increasing_random_complex_increment_every_not_1(pick_mock: Any, minimum: int, maximum: int, step: int, every_n: int, increment: dict[int, int], epoch_test: int, expected_max: int): - - sched = IncreasingRandom(minimum, maximum, step, every_n, increment=increment) - - sched.sync(epoch = epoch_test) - assert sched.rollout in list(range(minimum, expected_max + 1, step)) +@patch("anemoi.training.schedulers.rollout.randomise.BaseRandom._randomly_pick", wraps=RandomList([0])._randomly_pick) +def test_increasing_random_complex_increment( + pick_mock: Any, + every_n: int, + increment: dict[int, int], + epoch_test: int, + expected_max: int, +) -> None: + sched = IncreasingRandom(1, 10, 1, every_n, increment=increment) + + sched.sync(epoch=epoch_test) + assert sched.rollout in list(range(1, expected_max + 1, 1)) assert sched.current_maximum == expected_max - pick_mock.assert_called_with(range(minimum, expected_max + 1, step)) - + pick_mock.assert_called_with(range(1, expected_max + 1, 1)) diff --git a/training/tests/schedulers/rollout/test_rollout.py b/training/tests/schedulers/rollout/test_rollout.py index a50ff4c0..739675b5 100644 --- a/training/tests/schedulers/rollout/test_rollout.py +++ b/training/tests/schedulers/rollout/test_rollout.py @@ -7,52 +7,55 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -from anemoi.training.schedulers.rollout import RolloutScheduler, Static +from anemoi.training.schedulers.rollout import RolloutScheduler +from anemoi.training.schedulers.rollout import Static + class DebugScheduler(RolloutScheduler): @property - def rollout(self): + def rollout(self) -> None: return self._epoch - + @property - def maximum_rollout(self): + def maximum_rollout(self) -> None: return self._epoch - - def description(self): + + def description(self) -> None: return "DebugScheduler" -def test_static(): + +def test_static() -> None: sched = Static(1) assert sched.rollout == 1 assert sched.maximum_rollout == 1 assert sched.current_maximum == 1 -def test_at(): +def test_at() -> None: sched = DebugScheduler() - with sched.at(epoch = 1): + with sched.at(epoch=1): assert sched.rollout == 1 assert sched.maximum_rollout == 1 assert sched.current_maximum == 1 - + assert sched.rollout == 0 -def test_sync(): + +def test_sync() -> None: sched = DebugScheduler() - sched.sync(epoch = 10) + sched.sync(epoch=10) assert sched.rollout == 10 -def test_count(): +def test_count() -> None: sched = DebugScheduler() - sched.sync(epoch = 10) + sched.sync(epoch=10) assert sched.count(n_epochs=5) == 2 assert sched.count(n_epochs=3) == 3 -def test_int_conversion(): + +def test_int_conversion() -> None: sched = DebugScheduler() - sched.sync(epoch = 10) + sched.sync(epoch=10) assert int(sched) == 10 - - diff --git a/training/tests/schedulers/rollout/test_stepped.py b/training/tests/schedulers/rollout/test_stepped.py index 477e9814..8202b918 100644 --- a/training/tests/schedulers/rollout/test_stepped.py +++ b/training/tests/schedulers/rollout/test_stepped.py @@ -11,8 +11,9 @@ from anemoi.training.schedulers.rollout.stepped import Stepped + @pytest.mark.parametrize( - "minimum, maximum, every_n, increment, epoch_test, expected_value", + ("minimum", "maximum", "every_n", "increment", "epoch_test", "expected_value"), [ # Increment of 1 and every_n of 1 (1, 10, 1, 1, 0, 1), @@ -24,47 +25,80 @@ (1, 10, 1, 1, 10, 10), (1, 10, 1, 1, 11, 10), (1, 10, 1, 1, 1000, 10), - # Increment of 2 and every_n of 1 (1, 10, 1, 2, 1, 3), (1, 10, 1, 2, 2, 5), (1, 10, 1, 2, 4, 9), (1, 10, 1, 2, 5, 10), - # Increment of 1 and every_n of 2 (1, 10, 2, 1, 0, 1), (1, 10, 2, 1, 1, 1), (1, 10, 2, 1, 2, 2), - - ] + ], ) -def test_stepped(minimum: int, maximum: int, every_n: int, increment: int, epoch_test: int, expected_value: int): +def test_stepped( + minimum: int, + maximum: int, + every_n: int, + increment: int, + epoch_test: int, + expected_value: int, +) -> None: sched = Stepped(minimum, maximum, every_n, increment=increment) - sched.sync(epoch = epoch_test) + sched.sync(epoch=epoch_test) assert sched.rollout == expected_value assert sched.current_maximum == expected_value +INCREMENT_DICT = { + 0: 0, + 2: 1, + 4: 2, +} +INCREMENT_DICT_1 = { + 0: 0, + 2: 1, + 3: 0, + 4: 2, +} + +COMPLEX_INCREMENT_TESTS_EVERY_N_1 = [ + (1, INCREMENT_DICT, 0, 1), + (1, INCREMENT_DICT, 1, 1), + (1, INCREMENT_DICT, 2, 2), + (1, INCREMENT_DICT, 3, 3), + (1, INCREMENT_DICT, 4, 5), + (1, INCREMENT_DICT, 5, 7), + (1, INCREMENT_DICT_1, 4, 4), + (1, INCREMENT_DICT_1, 5, 6), + (1, INCREMENT_DICT_1, 1000, 10), +] + +COMPLEX_INCREMENT_TESTS_EVERY_N_2 = [ + (2, INCREMENT_DICT, 0, 1), + (2, INCREMENT_DICT, 1, 1), + (2, INCREMENT_DICT, 2, 2), + (2, INCREMENT_DICT, 3, 2), + (2, INCREMENT_DICT, 4, 4), + (2, INCREMENT_DICT, 5, 4), + (2, INCREMENT_DICT, 6, 6), +] + @pytest.mark.parametrize( - "minimum, maximum, every_n, increment, epoch_test, expected_value", - [ - (1, 10, 1, {0:0, 2:1, 4:2,}, 0, 1), - (1, 10, 1, {0:0, 2:1, 4:2,}, 1, 1), - (1, 10, 1, {0:0, 2:1, 4:2,}, 2, 2), - (1, 10, 1, {0:0, 2:1, 4:2,}, 3, 3), - (1, 10, 1, {0:0, 2:1, 4:2,}, 4, 5), - (1, 10, 1, {0:0, 2:1, 4:2,}, 5, 7), - (1, 10, 1, {0:0, 2:1, 3:0, 4:2,}, 4, 4), - (1, 10, 1, {0:0, 2:1, 3:0, 4:2,}, 5, 6), - (1, 10, 1, {0:0, 2:1, 3:0, 4:2,}, 1000, 10), - ] + ("every_n", "increment", "epoch_test", "expected_max"), + [*COMPLEX_INCREMENT_TESTS_EVERY_N_1, *COMPLEX_INCREMENT_TESTS_EVERY_N_2], ) -def test_stepped_complex_increment(minimum: int, maximum: int, every_n: int, increment: dict[int, int], epoch_test: int, expected_value: int): +def test_stepped_complex_increment( + every_n: int, + increment: dict[int, int], + epoch_test: int, + expected_value: int, +) -> None: - sched = Stepped(minimum, maximum, every_n, increment=increment) + sched = Stepped(1, 10, every_n, increment=increment) - sched.sync(epoch = epoch_test) + sched.sync(epoch=epoch_test) assert sched.rollout == expected_value assert sched.current_maximum == expected_value