Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Feature: Temporal interpolation #168

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,6 @@ def __init__(self, config: DictConfig) -> None:
self.model_comm_group_rank,
)

# Set the maximum rollout to be expected
self.rollout = (
self.config.training.rollout.max
if self.config.training.rollout.epoch_increment > 0
else self.config.training.rollout.start
)

# Set the training end date if not specified
if self.config.dataloader.training.end is None:
LOGGER.info(
Expand Down Expand Up @@ -102,6 +95,23 @@ def metadata(self) -> dict:
@cached_property
def data_indices(self) -> IndexCollection:
return IndexCollection(self.config, self.ds_train.name_to_index)

@cached_property
def relative_date_indices(self) -> list:
"""Determine a list of relative time indices to load for each batch"""
if hasattr(self.config.training, "explicit_times"):
return sorted(set(self.config.training.explicit_times.input + self.config.training.explicit_times.target))

else: #uses the old default of multistep, timeincrement and rollout.
# Use the maximum rollout to be expected
rollout = (
self.config.training.rollout.max
if self.config.training.rollout.epoch_increment > 0
else self.config.training.rollout.start
)#NOTE: --> for gradual rollout, max rollout dates is always fetched. But this was always the case in datamodule.py

multi_step = self.config.training.multistep_input
return [self.timeincrement * mstep for mstep in range(multi_step + rollout)]

@cached_property
def timeincrement(self) -> int:
Expand Down Expand Up @@ -140,7 +150,8 @@ def ds_train(self) -> NativeGridDataset:

@cached_property
def ds_valid(self) -> NativeGridDataset:
r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))
#r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))
#NOTE: temporary left unimplemented until I figure out how to best do this with the new time_indices object

assert self.config.dataloader.training.end < self.config.dataloader.validation.start, (
f"Training end date {self.config.dataloader.training.end} is not before"
Expand All @@ -149,7 +160,7 @@ def ds_valid(self) -> NativeGridDataset:
return self._get_dataset(
open_dataset(OmegaConf.to_container(self.config.dataloader.validation, resolve=True)),
shuffle=False,
rollout=r,
#rollout=r, #NOTE: see the above
label="validation",
)

Expand All @@ -173,15 +184,11 @@ def _get_dataset(
self,
data_reader: Callable,
shuffle: bool = True,
rollout: int = 1,
label: str = "generic",
) -> NativeGridDataset:
r = max(rollout, self.rollout)
data = NativeGridDataset(
data_reader=data_reader,
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
relative_date_indices = self.relative_date_indices,
model_comm_group_rank=self.model_comm_group_rank,
model_comm_group_id=self.model_comm_group_id,
model_comm_num_groups=self.model_comm_num_groups,
Expand Down
33 changes: 11 additions & 22 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ class NativeGridDataset(IterableDataset):
def __init__(
self,
data_reader: Callable,
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
relative_date_indices: list = [0,1,2],
model_comm_group_rank: int = 0,
model_comm_group_id: int = 0,
model_comm_num_groups: int = 1,
Expand All @@ -48,12 +46,8 @@ def __init__(
----------
data_reader : Callable
user function that opens and returns the zarr array data
rollout : int, optional
length of rollout window, by default 12
timeincrement : int, optional
time increment between samples, by default 1
multistep : int, optional
collate (t-1, ... t - multistep) into the input state vector, by default 1
relative_date_indices : list
list of time indices to load from the data relative to the current sample i in __iter__
model_comm_group_rank : int, optional
process rank in the torch.distributed group (important when running on multiple GPUs), by default 0
model_comm_group_id: int, optional
Expand All @@ -70,9 +64,6 @@ def __init__(

self.data = data_reader

self.rollout = rollout
self.timeincrement = timeincrement

# lazy init
self.n_samples_per_epoch_total: int = 0
self.n_samples_per_epoch_per_worker: int = 0
Expand All @@ -89,11 +80,12 @@ def __init__(
self.shuffle = shuffle

# Data dimensions
self.multi_step = multistep
assert self.multi_step > 0, "Multistep value must be greater than zero."
self.ensemble_dim: int = 2
self.ensemble_size = self.data.shape[self.ensemble_dim]

# relative index of dates to extract
self.relative_date_indices = relative_date_indices #np.array(date_indices, dtype = np.int32)

@cached_property
def statistics(self) -> dict:
"""Return dataset statistics."""
Expand Down Expand Up @@ -126,7 +118,7 @@ def valid_date_indices(self) -> np.ndarray:
dataset length minus rollout minus additional multistep inputs
(if time_increment is 1).
"""
return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement)
return get_usable_indices(self.data.missing, len(self.data), np.array(self.relative_date_indices, dtype=np.int32))

def per_worker_init(self, n_workers: int, worker_id: int) -> None:
"""Called by worker_init_func on each copy of dataset.
Expand Down Expand Up @@ -230,10 +222,9 @@ def __iter__(self) -> torch.Tensor:
)

for i in shuffled_chunk_indices:
start = i - (self.multi_step - 1) * self.timeincrement
end = i + (self.rollout + 1) * self.timeincrement

x = self.data[start : end : self.timeincrement]
#TODO: self.data[relative_date_indices + i] is intended here, but it seems like array indices are not supported in
#anemoi-datasets, and I couldn't get a tuple of indices that may not have a regular structure to work either
x = self.data[slice(self.relative_date_indices[0]+i, i+ self.relative_date_indices[-1]+1, 1)]
x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
self.ensemble_dim = 1

Expand All @@ -243,9 +234,7 @@ def __repr__(self) -> str:
return f"""
{super().__repr__()}
Dataset: {self.data}
Rollout: {self.rollout}
Multistep: {self.multi_step}
Timeincrement: {self.timeincrement}
Relative dates: {self.relative_date_indices}
"""


Expand Down
118 changes: 118 additions & 0 deletions src/anemoi/training/train/interpolator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging
import math
import os
from collections import defaultdict
from collections.abc import Generator
from collections.abc import Mapping
from typing import Optional
from typing import Union
from operator import itemgetter

import numpy as np
import pytorch_lightning as pl
import torch
from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.interface import AnemoiModelInterface
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from omegaconf import DictConfig
from omegaconf import OmegaConf
from timm.scheduler import CosineLRScheduler
from torch.distributed.distributed_c10d import ProcessGroup
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData

from anemoi.training.losses.utils import grad_scaler
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.masks import Boolean1DMask
from anemoi.training.utils.masks import NoOutputMask

from anemoi.training.train.forecaster import GraphForecaster

LOGGER = logging.getLogger(__name__)

class GraphInterpolator(GraphForecaster):
"""Graph neural network interpolator for PyTorch Lightning."""

Comment on lines +46 to +48
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this work on the Interpolator. It's a good example that the GraphForecaster class needs some work and to be broken into a proper class structure.
What are your thoughts on which components are reusable and then in counter, which parts are typical to override?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a mix of both, as well as some components that are needed only for the forecaster and some only for the interpolator.

Reusable

  • All of the init function, except for rollout and multistep.
  • All of the instantiable objects: loss, metrics, the model, etc.
  • The scheduler and optimizers, which should maybe become an instantiated object anyway.
  • The training/validation_step functions
  • calculate_val_metrics: by reusing the rollout_step label as interp_step instead.

Overwritten

  • _step and forward

Only for the forecaster/interpolator

  • advance_input and rollout_step
  • target forcings (although these could also be useful for the forecaster)

To avoid inheriting unused components with the Interpolator, we could consider using a framework class containing only the common components between the forecaster and interpolator, then have both inherit this class. However, that might be a bit too much when there are only two options thus far.
In fact, the forecaster can be seen as a special case of the interpolator, since the boundary can be specified as the multistep input, and the target can be any time, including the future. If I implement rollout functionality to the interpolator and make the target forcings optional, I think it should be able to do anything the forecaster can.

In my opinion, it would be the best approach to merge the two this way. It also enables the option to train a combined forecaster/interpolator, instead of having two separate models.
Do you agree with merging the two, or should I make a base framework class for both to inherit, or just keep them as is?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would lean towards making a base framework class. There are other use cases coming down the pipeline that would need this.
Although I am intrigued by the idea of have a class that can do both together.

def __init__(
self,
*,
config: DictConfig,
graph_data: HeteroData,
statistics: dict,
data_indices: IndexCollection,
metadata: dict,
) -> None:
"""Initialize graph neural network interpolator.

Parameters
----------
config : DictConfig
Job configuration
graph_data : HeteroData
Graph object
statistics : dict
Statistics of the training data
data_indices : IndexCollection
Indices of the training data,
metadata : dict
Provenance information

"""
super().__init__(config = config, graph_data = graph_data, statistics = statistics, data_indices = data_indices, metadata = metadata)
self.target_forcing_indices = itemgetter(*config.training.target_forcing.data)(data_indices.data.input.name_to_index)
self.boundary_times = config.training.explicit_times.input
self.interp_times = config.training.explicit_times.target


def _step(
self,
batch: torch.Tensor,
batch_idx: int,
validation_mode: bool = False,
) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]:

del batch_idx
loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False)
metrics = {}
y_preds = []

batch = self.model.pre_processors(batch)
x_bound = batch[:, self.boundary_times][..., self.data_indices.data.input.full] # (bs, time, ens, latlon, nvar)

tfi = self.target_forcing_indices
target_forcing = torch.empty(batch.shape[0], batch.shape[2], batch.shape[3], len(tfi)+1, device = self.device, dtype = batch.dtype)
for interp_step in self.interp_times:
#get the forcing information for the target interpolation time:
target_forcing[..., :len(tfi)] = batch[:, interp_step, :, :, tfi]
target_forcing[..., -1] = (interp_step - self.boundary_times[0])/(self.boundary_times[1] - self.boundary_times[0])
#TODO: make fraction time one of a config given set of arbitrary custom forcing functions.

y_pred = self(x_bound, target_forcing)
y = batch[:, interp_step, :, :, self.data_indices.data.output.full]

loss += checkpoint(self.loss, y_pred, y, use_reentrant=False)

metrics_next = {}
if validation_mode:
metrics_next = self.calculate_val_metrics(y_pred, y, interp_step-1) #expects rollout but can be repurposed here.
metrics.update(metrics_next)
y_preds.extend(y_pred)

loss *= 1.0 / len(self.interp_times)
return loss, metrics, y_preds

def forward(self, x: torch.Tensor, target_forcing: torch.Tensor) -> torch.Tensor:
return self.model(x, target_forcing, self.model_comm_group)
11 changes: 7 additions & 4 deletions src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING
import importlib

import hydra
import numpy as np
Expand All @@ -33,7 +34,6 @@
from anemoi.training.diagnostics.logger import get_tensorboard_logger
from anemoi.training.diagnostics.logger import get_wandb_logger
from anemoi.training.distributed.strategy import DDPGroupStrategy
from anemoi.training.train.forecaster import GraphForecaster
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.seeding import get_base_seed

Expand Down Expand Up @@ -135,7 +135,7 @@ def graph_data(self) -> HeteroData:
)

@cached_property
def model(self) -> GraphForecaster:
def model(self) -> pl.LightningModule:
"""Provide the model instance."""
kwargs = {
"config": self.config,
Expand All @@ -144,10 +144,13 @@ def model(self) -> GraphForecaster:
"metadata": self.metadata,
"statistics": self.datamodule.statistics,
}
train_module = importlib.import_module(getattr(self.config.training, "train_module", "anemoi.training.train.forecaster"))
train_func = getattr(train_module, getattr(self.config.training, "train_function", "GraphForecaster"))
#NOTE: instantiate would be preferable, but I run into issues with "config" being the first kwarg of instantiate itself.
if self.load_weights_only:
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs)
return GraphForecaster(**kwargs)
return train_func.load_from_checkpoint(self.last_checkpoint, **kwargs)
return train_func(**kwargs)

Comment on lines +147 to 154
Copy link
Member

@HCookie HCookie Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the instantiate would be preferable. If we were to delay the instantiatation of the model within the Forecaster, it may be possible to mimic a hydra instantiate call.

The delay will be neccessary to support loading weights only

model = instantiate({'_target_':self.config.get('forecaster'), **kwargs)
if self.load_weights_only:
            LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
            return train_func.load_from_checkpoint(self.last_checkpoint, **kwargs)
return model

Copy link
Contributor Author

@Magnus-SI Magnus-SI Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when adding recursive = False as an argument as well, that works to instantiate the model. However, after an epoch is complete I get "TypeError: Object of type DictConfig is not JSON serializable" during saving of metadata for the checkpoint. That should be fixable though.
As for loading weights only, it seems https://github.com/ecmwf/anemoi-training/tree/feature/ckpo_loading_skip_mismatched moves this to train.py, so the model can be instantiated beforehand without problem. I will wait until this reaches develop and pull it to this branch, then add the instantiation.

@rank_zero_only
def _get_mlflow_run_id(self) -> str:
Expand Down
22 changes: 6 additions & 16 deletions src/anemoi/training/utils/usable_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
def get_usable_indices(
missing_indices: set[int] | None,
series_length: int,
rollout: int,
multistep: int,
timeincrement: int = 1,
relative_indices: np.ndarray,
) -> np.ndarray:
"""Get the usable indices of a series whit missing indices.

Expand All @@ -28,32 +26,24 @@ def get_usable_indices(
Dataset to be used.
series_length : int
Length of the series.
rollout : int
Number of steps to roll out.
multistep : int
Number of previous indices to include as predictors.
timeincrement : int
Time increment, by default 1.
relative_indices:
Array of relative indices requested at each index i.

Returns
-------
usable_indices : np.array
Array of usable indices.
"""
prev_invalid_dates = (multistep - 1) * timeincrement
next_invalid_dates = rollout * timeincrement

usable_indices = np.arange(series_length) # set of all indices

if missing_indices is None:
missing_indices = set()

missing_indices |= {-1, series_length} # to filter initial and final indices
missing_indices |= set(range(series_length, series_length + max(relative_indices) + 1)) #filter indices larger than series length

# Missing indices
for i in missing_indices:
usable_indices = usable_indices[
(usable_indices < i - next_invalid_dates) + (usable_indices > i + prev_invalid_dates)
]
rel_missing = i - relative_indices #indices which have their relative indices match the missing.
usable_indices = usable_indices[np.all(usable_indices != rel_missing[:,np.newaxis], axis = 0)]

return usable_indices
Loading
Loading