-
Notifications
You must be signed in to change notification settings - Fork 15
Feature: Temporal interpolation #168
base: develop
Are you sure you want to change the base?
Changes from 3 commits
780630b
a97a34a
0d12c01
c7abf17
a1377b9
f4f7797
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# (C) Copyright 2024 Anemoi contributors. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
|
||
import logging | ||
import math | ||
import os | ||
from collections import defaultdict | ||
from collections.abc import Generator | ||
from collections.abc import Mapping | ||
from typing import Optional | ||
from typing import Union | ||
from operator import itemgetter | ||
|
||
import numpy as np | ||
import pytorch_lightning as pl | ||
import torch | ||
from anemoi.models.data_indices.collection import IndexCollection | ||
from anemoi.models.interface import AnemoiModelInterface | ||
from anemoi.utils.config import DotDict | ||
from hydra.utils import instantiate | ||
from omegaconf import DictConfig | ||
from omegaconf import OmegaConf | ||
from timm.scheduler import CosineLRScheduler | ||
from torch.distributed.distributed_c10d import ProcessGroup | ||
from torch.distributed.optim import ZeroRedundancyOptimizer | ||
from torch.utils.checkpoint import checkpoint | ||
from torch_geometric.data import HeteroData | ||
|
||
from anemoi.training.losses.utils import grad_scaler | ||
from anemoi.training.losses.weightedloss import BaseWeightedLoss | ||
from anemoi.training.utils.jsonify import map_config_to_primitives | ||
from anemoi.training.utils.masks import Boolean1DMask | ||
from anemoi.training.utils.masks import NoOutputMask | ||
|
||
from anemoi.training.train.forecaster import GraphForecaster | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
class GraphInterpolator(GraphForecaster): | ||
"""Graph neural network interpolator for PyTorch Lightning.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
config: DictConfig, | ||
graph_data: HeteroData, | ||
statistics: dict, | ||
data_indices: IndexCollection, | ||
metadata: dict, | ||
) -> None: | ||
"""Initialize graph neural network interpolator. | ||
|
||
Parameters | ||
---------- | ||
config : DictConfig | ||
Job configuration | ||
graph_data : HeteroData | ||
Graph object | ||
statistics : dict | ||
Statistics of the training data | ||
data_indices : IndexCollection | ||
Indices of the training data, | ||
metadata : dict | ||
Provenance information | ||
|
||
""" | ||
super().__init__(config = config, graph_data = graph_data, statistics = statistics, data_indices = data_indices, metadata = metadata) | ||
self.target_forcing_indices = itemgetter(*config.training.target_forcing.data)(data_indices.data.input.name_to_index) | ||
self.boundary_times = config.training.explicit_times.input | ||
self.interp_times = config.training.explicit_times.target | ||
|
||
|
||
def _step( | ||
self, | ||
batch: torch.Tensor, | ||
batch_idx: int, | ||
validation_mode: bool = False, | ||
) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: | ||
|
||
del batch_idx | ||
loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) | ||
metrics = {} | ||
y_preds = [] | ||
|
||
batch = self.model.pre_processors(batch) | ||
x_bound = batch[:, self.boundary_times][..., self.data_indices.data.input.full] # (bs, time, ens, latlon, nvar) | ||
|
||
tfi = self.target_forcing_indices | ||
target_forcing = torch.empty(batch.shape[0], batch.shape[2], batch.shape[3], len(tfi)+1, device = self.device, dtype = batch.dtype) | ||
for interp_step in self.interp_times: | ||
#get the forcing information for the target interpolation time: | ||
target_forcing[..., :len(tfi)] = batch[:, interp_step, :, :, tfi] | ||
target_forcing[..., -1] = (interp_step - self.boundary_times[0])/(self.boundary_times[1] - self.boundary_times[0]) | ||
#TODO: make fraction time one of a config given set of arbitrary custom forcing functions. | ||
|
||
y_pred = self(x_bound, target_forcing) | ||
y = batch[:, interp_step, :, :, self.data_indices.data.output.full] | ||
|
||
loss += checkpoint(self.loss, y_pred, y, use_reentrant=False) | ||
|
||
metrics_next = {} | ||
if validation_mode: | ||
metrics_next = self.calculate_val_metrics(y_pred, y, interp_step-1) #expects rollout but can be repurposed here. | ||
metrics.update(metrics_next) | ||
y_preds.extend(y_pred) | ||
|
||
loss *= 1.0 / len(self.interp_times) | ||
return loss, metrics, y_preds | ||
|
||
def forward(self, x: torch.Tensor, target_forcing: torch.Tensor) -> torch.Tensor: | ||
return self.model(x, target_forcing, self.model_comm_group) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
from functools import cached_property | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
import importlib | ||
|
||
import hydra | ||
import numpy as np | ||
|
@@ -33,7 +34,6 @@ | |
from anemoi.training.diagnostics.logger import get_tensorboard_logger | ||
from anemoi.training.diagnostics.logger import get_wandb_logger | ||
from anemoi.training.distributed.strategy import DDPGroupStrategy | ||
from anemoi.training.train.forecaster import GraphForecaster | ||
from anemoi.training.utils.jsonify import map_config_to_primitives | ||
from anemoi.training.utils.seeding import get_base_seed | ||
|
||
|
@@ -135,7 +135,7 @@ def graph_data(self) -> HeteroData: | |
) | ||
|
||
@cached_property | ||
def model(self) -> GraphForecaster: | ||
def model(self) -> pl.LightningModule: | ||
"""Provide the model instance.""" | ||
kwargs = { | ||
"config": self.config, | ||
|
@@ -144,10 +144,13 @@ def model(self) -> GraphForecaster: | |
"metadata": self.metadata, | ||
"statistics": self.datamodule.statistics, | ||
} | ||
train_module = importlib.import_module(getattr(self.config.training, "train_module", "anemoi.training.train.forecaster")) | ||
train_func = getattr(train_module, getattr(self.config.training, "train_function", "GraphForecaster")) | ||
#NOTE: instantiate would be preferable, but I run into issues with "config" being the first kwarg of instantiate itself. | ||
if self.load_weights_only: | ||
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) | ||
return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs) | ||
return GraphForecaster(**kwargs) | ||
return train_func.load_from_checkpoint(self.last_checkpoint, **kwargs) | ||
return train_func(**kwargs) | ||
|
||
Comment on lines
+147
to
154
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that the instantiate would be preferable. If we were to delay the instantiatation of the The delay will be neccessary to support loading weights only model = instantiate({'_target_':self.config.get('forecaster'), **kwargs)
if self.load_weights_only:
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
return train_func.load_from_checkpoint(self.last_checkpoint, **kwargs)
return model There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, when adding recursive = False as an argument as well, that works to instantiate the model. However, after an epoch is complete I get "TypeError: Object of type DictConfig is not JSON serializable" during saving of metadata for the checkpoint. That should be fixable though. |
||
@rank_zero_only | ||
def _get_mlflow_run_id(self) -> str: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this work on the
Interpolator
. It's a good example that theGraphForecaster
class needs some work and to be broken into a proper class structure.What are your thoughts on which components are reusable and then in counter, which parts are typical to override?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a mix of both, as well as some components that are needed only for the forecaster and some only for the interpolator.
Reusable
Overwritten
Only for the forecaster/interpolator
To avoid inheriting unused components with the Interpolator, we could consider using a framework class containing only the common components between the forecaster and interpolator, then have both inherit this class. However, that might be a bit too much when there are only two options thus far.
In fact, the forecaster can be seen as a special case of the interpolator, since the boundary can be specified as the multistep input, and the target can be any time, including the future. If I implement rollout functionality to the interpolator and make the target forcings optional, I think it should be able to do anything the forecaster can.
In my opinion, it would be the best approach to merge the two this way. It also enables the option to train a combined forecaster/interpolator, instead of having two separate models.
Do you agree with merging the two, or should I make a base framework class for both to inherit, or just keep them as is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I would lean towards making a base framework class. There are other use cases coming down the pipeline that would need this.
Although I am intrigued by the idea of have a class that can do both together.