Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Feature: Temporal interpolation #168

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
from functools import cached_property
from typing import Callable
import numpy as np

import pytorch_lightning as pl
from anemoi.datasets.data import open_dataset
Expand Down Expand Up @@ -68,13 +69,6 @@ def __init__(self, config: DictConfig) -> None:
self.model_comm_group_rank,
)

# Set the maximum rollout to be expected
self.rollout = (
self.config.training.rollout.max
if self.config.training.rollout.epoch_increment > 0
else self.config.training.rollout.start
)

# Set the training end date if not specified
if self.config.dataloader.training.end is None:
LOGGER.info(
Expand Down Expand Up @@ -102,6 +96,40 @@ def metadata(self) -> dict:
@cached_property
def data_indices(self) -> IndexCollection:
return IndexCollection(self.config, self.ds_train.name_to_index)

@cached_property
def relative_date_indices(self) -> list:
"""Determine a list of relative time indices to load for each batch"""
if hasattr(self.config.training, "explicit_times"):
return sorted(set(self.config.training.explicit_times.input + self.config.training.explicit_times.target))

else: #uses the old default of multistep, timeincrement and rollout.
# Use the maximum rollout to be expected
rollout = (
self.config.training.rollout.max
if self.config.training.rollout.epoch_increment > 0
else self.config.training.rollout.start
)#NOTE: --> for gradual rollout, max rollout dates is always fetched. But this was always the case in datamodule.py

multi_step = self.config.training.multistep_input
return [self.timeincrement * mstep for mstep in range(multi_step + rollout)]

def add_model_run_ids(self, data_reader):
"""Determine the model run id of each time index of the data and add to a data_reader object
NOTE/TODO: This is only relevant when training on non-analysis and should be replaced with
a property of the dataset stored in data_reader.
Until then, assumes regular interval of changed model runs
"""
if not hasattr(self.config.dataloader, "model_run_info"):
data_reader.model_run_ids = None
return data_reader

mr_start = np.datetime64(self.config.dataloader.model_run_info.start)
mr_len = self.config.dataloader.model_run_info.length # model run length in number of date indices
assert max(self.relative_date_indices) <= mr_len, f"Requested data length {max(self.relative_date_indices)} longer than model run length {mr_len}"

data_reader.model_run_ids = (data_reader.dates - mr_start)//np.timedelta64(mr_len*frequency_to_seconds(self.config.data.frequency), 's')
return data_reader

@cached_property
def timeincrement(self) -> int:
Expand Down Expand Up @@ -140,7 +168,8 @@ def ds_train(self) -> NativeGridDataset:

@cached_property
def ds_valid(self) -> NativeGridDataset:
r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))
#r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))
#NOTE: temporary left unimplemented until I figure out how to best do this with the new time_indices object

assert self.config.dataloader.training.end < self.config.dataloader.validation.start, (
f"Training end date {self.config.dataloader.training.end} is not before"
Expand All @@ -149,7 +178,7 @@ def ds_valid(self) -> NativeGridDataset:
return self._get_dataset(
open_dataset(OmegaConf.to_container(self.config.dataloader.validation, resolve=True)),
shuffle=False,
rollout=r,
#rollout=r, #NOTE: see the above
label="validation",
)

Expand All @@ -173,15 +202,12 @@ def _get_dataset(
self,
data_reader: Callable,
shuffle: bool = True,
rollout: int = 1,
label: str = "generic",
) -> NativeGridDataset:
r = max(rollout, self.rollout)
data_reader = self.add_model_run_ids(data_reader) # NOTE: Temporary
data = NativeGridDataset(
data_reader=data_reader,
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
relative_date_indices = self.relative_date_indices,
model_comm_group_rank=self.model_comm_group_rank,
model_comm_group_id=self.model_comm_group_id,
model_comm_num_groups=self.model_comm_num_groups,
Expand Down
31 changes: 9 additions & 22 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ class NativeGridDataset(IterableDataset):
def __init__(
self,
data_reader: Callable,
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
relative_date_indices: list = [0,1,2],
model_comm_group_rank: int = 0,
model_comm_group_id: int = 0,
model_comm_num_groups: int = 1,
Expand All @@ -48,12 +46,8 @@ def __init__(
----------
data_reader : Callable
user function that opens and returns the zarr array data
rollout : int, optional
length of rollout window, by default 12
timeincrement : int, optional
time increment between samples, by default 1
multistep : int, optional
collate (t-1, ... t - multistep) into the input state vector, by default 1
relative_date_indices : list
list of time indices to load from the data relative to the current sample i in __iter__
model_comm_group_rank : int, optional
process rank in the torch.distributed group (important when running on multiple GPUs), by default 0
model_comm_group_id: int, optional
Expand All @@ -70,9 +64,6 @@ def __init__(

self.data = data_reader

self.rollout = rollout
self.timeincrement = timeincrement

# lazy init
self.n_samples_per_epoch_total: int = 0
self.n_samples_per_epoch_per_worker: int = 0
Expand All @@ -89,11 +80,12 @@ def __init__(
self.shuffle = shuffle

# Data dimensions
self.multi_step = multistep
assert self.multi_step > 0, "Multistep value must be greater than zero."
self.ensemble_dim: int = 2
self.ensemble_size = self.data.shape[self.ensemble_dim]

# relative index of dates to extract
self.relative_date_indices = relative_date_indices

@cached_property
def statistics(self) -> dict:
"""Return dataset statistics."""
Expand Down Expand Up @@ -126,7 +118,7 @@ def valid_date_indices(self) -> np.ndarray:
dataset length minus rollout minus additional multistep inputs
(if time_increment is 1).
"""
return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement)
return get_usable_indices(self.data.missing, len(self.data), np.array(self.relative_date_indices, dtype=np.int64), self.data.model_run_ids)

def per_worker_init(self, n_workers: int, worker_id: int) -> None:
"""Called by worker_init_func on each copy of dataset.
Expand Down Expand Up @@ -230,10 +222,7 @@ def __iter__(self) -> torch.Tensor:
)

for i in shuffled_chunk_indices:
start = i - (self.multi_step - 1) * self.timeincrement
end = i + (self.rollout + 1) * self.timeincrement

x = self.data[start : end : self.timeincrement]
x = self.data[self.relative_date_indices + i] #NOTE: this requires an update to anemoi datasets
x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
self.ensemble_dim = 1

Expand All @@ -243,9 +232,7 @@ def __repr__(self) -> str:
return f"""
{super().__repr__()}
Dataset: {self.data}
Rollout: {self.rollout}
Multistep: {self.multi_step}
Timeincrement: {self.timeincrement}
Relative dates: {self.relative_date_indices}
"""


Expand Down
122 changes: 122 additions & 0 deletions src/anemoi/training/train/interpolator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging
import math
import os
from collections import defaultdict
from collections.abc import Generator
from collections.abc import Mapping
from typing import Optional
from typing import Union
from operator import itemgetter

import numpy as np
import pytorch_lightning as pl
import torch
from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.interface import AnemoiModelInterface
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from omegaconf import DictConfig
from omegaconf import OmegaConf
from timm.scheduler import CosineLRScheduler
from torch.distributed.distributed_c10d import ProcessGroup
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData

from anemoi.training.losses.utils import grad_scaler
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.masks import Boolean1DMask
from anemoi.training.utils.masks import NoOutputMask

from anemoi.training.train.forecaster import GraphForecaster

LOGGER = logging.getLogger(__name__)

class GraphInterpolator(GraphForecaster):
"""Graph neural network interpolator for PyTorch Lightning."""

Comment on lines +46 to +48
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this work on the Interpolator. It's a good example that the GraphForecaster class needs some work and to be broken into a proper class structure.
What are your thoughts on which components are reusable and then in counter, which parts are typical to override?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a mix of both, as well as some components that are needed only for the forecaster and some only for the interpolator.

Reusable

  • All of the init function, except for rollout and multistep.
  • All of the instantiable objects: loss, metrics, the model, etc.
  • The scheduler and optimizers, which should maybe become an instantiated object anyway.
  • The training/validation_step functions
  • calculate_val_metrics: by reusing the rollout_step label as interp_step instead.

Overwritten

  • _step and forward

Only for the forecaster/interpolator

  • advance_input and rollout_step
  • target forcings (although these could also be useful for the forecaster)

To avoid inheriting unused components with the Interpolator, we could consider using a framework class containing only the common components between the forecaster and interpolator, then have both inherit this class. However, that might be a bit too much when there are only two options thus far.
In fact, the forecaster can be seen as a special case of the interpolator, since the boundary can be specified as the multistep input, and the target can be any time, including the future. If I implement rollout functionality to the interpolator and make the target forcings optional, I think it should be able to do anything the forecaster can.

In my opinion, it would be the best approach to merge the two this way. It also enables the option to train a combined forecaster/interpolator, instead of having two separate models.
Do you agree with merging the two, or should I make a base framework class for both to inherit, or just keep them as is?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would lean towards making a base framework class. There are other use cases coming down the pipeline that would need this.
Although I am intrigued by the idea of have a class that can do both together.

def __init__(
self,
*,
config: DictConfig,
graph_data: HeteroData,
statistics: dict,
data_indices: IndexCollection,
metadata: dict,
) -> None:
"""Initialize graph neural network interpolator.

Parameters
----------
config : DictConfig
Job configuration
graph_data : HeteroData
Graph object
statistics : dict
Statistics of the training data
data_indices : IndexCollection
Indices of the training data,
metadata : dict
Provenance information

"""
super().__init__(config = config, graph_data = graph_data, statistics = statistics, data_indices = data_indices, metadata = metadata)
self.target_forcing_indices = itemgetter(*config.training.target_forcing.data)(data_indices.data.input.name_to_index)
if type(self.target_forcing_indices) == int:
self.target_forcing_indices = [self.target_forcing_indices]
self.boundary_times = config.training.explicit_times.input
self.interp_times = config.training.explicit_times.target
sorted_indices = sorted(set(self.boundary_times + self.interp_times))
self.imap = {data_index: batch_index for batch_index,data_index in enumerate(sorted_indices)}


def _step(
self,
batch: torch.Tensor,
batch_idx: int,
validation_mode: bool = False,
) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]:

del batch_idx
loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False)
metrics = {}
y_preds = []

batch = self.model.pre_processors(batch)
x_bound = batch[:, itemgetter(*self.boundary_times)(self.imap)][..., self.data_indices.data.input.full] # (bs, time, ens, latlon, nvar)

tfi = self.target_forcing_indices
target_forcing = torch.empty(batch.shape[0], batch.shape[2], batch.shape[3], len(tfi)+1, device = self.device, dtype = batch.dtype)
for interp_step in self.interp_times:
#get the forcing information for the target interpolation time:
target_forcing[..., :len(tfi)] = batch[:, self.imap[interp_step], :, :, tfi]
target_forcing[..., -1] = (interp_step - self.boundary_times[1])/(self.boundary_times[1] - self.boundary_times[0])
#TODO: make fraction time one of a config given set of arbitrary custom forcing functions.

y_pred = self(x_bound, target_forcing)
y = batch[:, self.imap[interp_step], :, :, self.data_indices.data.output.full]

loss += checkpoint(self.loss, y_pred, y, use_reentrant=False)

metrics_next = {}
if validation_mode:
metrics_next = self.calculate_val_metrics(y_pred, y, interp_step-1) #expects rollout but can be repurposed here.
metrics.update(metrics_next)
y_preds.extend(y_pred)

loss *= 1.0 / len(self.interp_times)
return loss, metrics, y_preds

def forward(self, x: torch.Tensor, target_forcing: torch.Tensor) -> torch.Tensor:
return self.model(x, target_forcing, self.model_comm_group)
11 changes: 7 additions & 4 deletions src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING
import importlib

import hydra
import numpy as np
Expand All @@ -33,7 +34,6 @@
from anemoi.training.diagnostics.logger import get_tensorboard_logger
from anemoi.training.diagnostics.logger import get_wandb_logger
from anemoi.training.distributed.strategy import DDPGroupStrategy
from anemoi.training.train.forecaster import GraphForecaster
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.seeding import get_base_seed

Expand Down Expand Up @@ -135,7 +135,7 @@ def graph_data(self) -> HeteroData:
)

@cached_property
def model(self) -> GraphForecaster:
def model(self) -> pl.LightningModule:
"""Provide the model instance."""
kwargs = {
"config": self.config,
Expand All @@ -144,10 +144,13 @@ def model(self) -> GraphForecaster:
"metadata": self.metadata,
"statistics": self.datamodule.statistics,
}
train_module = importlib.import_module(getattr(self.config.training, "train_module", "anemoi.training.train.forecaster"))
train_func = getattr(train_module, getattr(self.config.training, "train_function", "GraphForecaster"))
#NOTE: instantiate would be preferable, but I run into issues with "config" being the first kwarg of instantiate itself.
if self.load_weights_only:
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs)
return GraphForecaster(**kwargs)
return train_func.load_from_checkpoint(self.last_checkpoint, **kwargs)
return train_func(**kwargs)

Comment on lines +147 to 154
Copy link
Member

@HCookie HCookie Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the instantiate would be preferable. If we were to delay the instantiatation of the model within the Forecaster, it may be possible to mimic a hydra instantiate call.

The delay will be neccessary to support loading weights only

model = instantiate({'_target_':self.config.get('forecaster'), **kwargs)
if self.load_weights_only:
            LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
            return train_func.load_from_checkpoint(self.last_checkpoint, **kwargs)
return model

Copy link
Contributor Author

@Magnus-SI Magnus-SI Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when adding recursive = False as an argument as well, that works to instantiate the model. However, after an epoch is complete I get "TypeError: Object of type DictConfig is not JSON serializable" during saving of metadata for the checkpoint. That should be fixable though.
As for loading weights only, it seems https://github.com/ecmwf/anemoi-training/tree/feature/ckpo_loading_skip_mismatched moves this to train.py, so the model can be instantiated beforehand without problem. I will wait until this reaches develop and pull it to this branch, then add the instantiation.

@rank_zero_only
def _get_mlflow_run_id(self) -> str:
Expand Down
Loading
Loading