Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(training,rollout)!: Rollout Schedulers #46

Draft
wants to merge 14 commits into
base: develop
Choose a base branch
from
2 changes: 2 additions & 0 deletions training/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Keep it human-readable, your future self will thank you!
- Add supporting arrrays (numpy) to checkpoint
- Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171)
- Improved validation metrics, allow 'all' to be scaled [#202](https://github.com/ecmwf/anemoi-training/pull/202)
- Rollout Schedulers [#206](https://github.com/ecmwf/anemoi-training/pull/206)


### Changed

Expand Down
15 changes: 11 additions & 4 deletions training/src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,19 @@ scale_validation_metrics:


# length of the "rollout" window (see Keisler's paper)
# Dataloader rollout counter can only be updated at the end of each epoch
# So updates during an epoch will only be reflected at the end of said epoch.
rollout:
start: 1
_target_: anemoi.training.schedulers.rollout.stepped.EpochStepped
minimum: 1
maximum: 12
# increase rollout every n epochs
epoch_increment: 0
# maximum rollout to use
max: 1
every_n_epochs: 1
# Control the incrementing of the rollout window
increment:
step:
0: 0
200000: 1 # After 200k steps, increment by 1 every 1 epoch
Copy link
Contributor

@anaprietonem anaprietonem Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am probably just being slow but how does this interact with the limit batches? What would be the difference between doing the above, and the 'old configuration' with a limit batches of 200000?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The limit_batches ends the training, this will continue on, and then begin updating the rollout.


# Set max_epochs or max_steps. Training stops at the first limit reached.
max_epochs: null
Expand Down
26 changes: 19 additions & 7 deletions training/src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torch_geometric.data import HeteroData

from anemoi.training.data.grid_indices import BaseGridIndices
from anemoi.training.schedulers.rollout import RolloutScheduler


class AnemoiDatasetsDataModule(pl.LightningDataModule):
Expand All @@ -52,11 +53,8 @@ def __init__(self, config: DictConfig, graph_data: HeteroData) -> None:
self.graph_data = graph_data

# Set the maximum rollout to be expected
self.rollout = (
self.config.training.rollout.max
if self.config.training.rollout.epoch_increment > 0
else self.config.training.rollout.start
)
rollout_scheduler: RolloutScheduler = instantiate(self.config.training.rollout)
self.starting_rollout = rollout_scheduler.current_maximum

# Set the training end date if not specified
if self.config.dataloader.training.end is None:
Expand Down Expand Up @@ -129,7 +127,7 @@ def ds_train(self) -> NativeGridDataset:

@cached_property
def ds_valid(self) -> NativeGridDataset:
r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))
r = max(self.starting_rollout, self.config.dataloader.get("validation_rollout", 1))

if not self.config.dataloader.training.end < self.config.dataloader.validation.start:
LOGGER.warning(
Expand Down Expand Up @@ -160,6 +158,20 @@ def ds_test(self) -> NativeGridDataset:
label="test",
)

def update_rollout(self, rollout: int) -> None:
"""
Update the rollout values in the datamodule.

Parameters
----------
rollout : int
Rollout value
"""
for ds in [self.ds_train, self.ds_test]:
ds.update_rollout(rollout)

self.ds_valid.update_rollout(max(rollout, self.config.dataloader.get("validation_rollout", 1)))

def _get_dataset(
self,
data_reader: Callable,
Expand All @@ -168,7 +180,7 @@ def _get_dataset(
label: str = "generic",
) -> NativeGridDataset:

r = max(rollout, self.rollout)
r = max(rollout, self.starting_rollout)

# Compute effective batch size
effective_bs = (
Expand Down
16 changes: 12 additions & 4 deletions training/src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ def valid_date_indices(self) -> np.ndarray:
"""
return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement)

def update_rollout(self, rollout: int) -> None:
"""Update the rollout window."""
if self.rollout == rollout:
return

self.rollout = rollout
LOGGER.debug("Updating rollout of %s dataset to %d", self.label, self.rollout)

if hasattr(self, "valid_date_indices"):
del self.valid_date_indices

def set_comm_group_info(
self,
global_rank: int,
Expand Down Expand Up @@ -228,10 +239,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None:
sanity_rnd = self.rng.random(1)

LOGGER.debug(
(
"Worker %d (%s, pid %d, glob. rank %d, model comm group %d, "
"group_rank %d, base_seed %d), sanity rnd %f"
),
("Worker %d (%s, pid %d, glob. rank %d, model comm group %d, group_rank %d, base_seed %d), sanity rnd %f"),
worker_id,
self.label,
os.getpid(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor
from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging
from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback
from anemoi.training.diagnostics.callbacks.rollout import UpdateRollout
from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder

if TYPE_CHECKING:
Expand Down Expand Up @@ -198,10 +199,12 @@ def get_callbacks(config: DictConfig) -> list[Callback]:

# Parent UUID callback
# Check variable order callback
# UpdateRollout
trainer_callbacks.extend(
(
ParentUUIDCallback(config),
CheckVariableOrder(),
UpdateRollout(),
),
)

Expand Down
1 change: 0 additions & 1 deletion training/src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,6 @@ def _plot(
data, output_tensor = self.process(pl_module, outputs, batch)

for rollout_step in range(pl_module.rollout):

# Build dictionary of inidicies and parameters to be plotted
diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic

Expand Down
70 changes: 70 additions & 0 deletions training/src/anemoi/training/diagnostics/callbacks/rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

import logging

import pytorch_lightning as pl

LOGGER = logging.getLogger(__name__)


class UpdateRollout(pl.callbacks.Callback):
"""Update Rollout values in datamodule."""

def __init__(self) -> None:
super().__init__()

def _update_rollout(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
epoch: int | None = None,
step: int | None = None,
) -> None:
rollsched = pl_module.rollout
with rollsched.at(epoch=epoch, step=step):
rollout = rollsched.current_maximum

LOGGER.debug("Propagating rollout value %s to datamodule", rollout)
trainer.datamodule.update_rollout(rollout=rollout)

def on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: dict) -> None:
"""
Update the rollout values in the datamodule when loading a checkpoint.

Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
pl_module : pl.LightningModule
Model
checkpoint : dict
Checkpoint dictionary
"""
self._update_rollout(trainer, pl_module, epoch=checkpoint["epoch"], step=checkpoint["global_step"])

def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *_) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone sets the limit_batches for validation to 0, to skip validation this hook wouldn't be triggered?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll need to take a look.

"""
Update the rollout values in the datamodule every validation epoch.

Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
pl_module : pl.LightningModule
Model
"""
if trainer.sanity_checking:
return

# Offset of 1 needed as the epoch counter does not increment
# until after the epoch ends.
self._update_rollout(trainer, pl_module, epoch=trainer.current_epoch + 1)
6 changes: 3 additions & 3 deletions training/src/anemoi/training/diagnostics/mlflow/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,12 @@ def _log_collector(self) -> None:
log_capture_time_counter = 0

def _store_buffered_logs(self) -> None:
_buffer_size = self._io_buffer.tell()
if not _buffer_size:
buffer_size = self._io_buffer.tell()
if not buffer_size:
return
self._io_buffer.seek(0)
# read and reset the buffer
data = self._io_buffer.read(_buffer_size)
data = self._io_buffer.read(buffer_size)
self._io_buffer.seek(0)
# handle the buffered data and store
# split lines and keep \n at the end of each line
Expand Down
Loading
Loading