-
Notifications
You must be signed in to change notification settings - Fork 15
Feature: Temporal interpolation #168
base: develop
Are you sure you want to change the base?
Conversation
train_module = importlib.import_module(getattr(self.config.training, "train_module", "anemoi.training.train.forecaster")) | ||
train_func = getattr(train_module, getattr(self.config.training, "train_function", "GraphForecaster")) | ||
#NOTE: instantiate would be preferable, but I run into issues with "config" being the first kwarg of instantiate itself. | ||
if self.load_weights_only: | ||
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) | ||
return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs) | ||
return GraphForecaster(**kwargs) | ||
return train_func.load_from_checkpoint(self.last_checkpoint, **kwargs) | ||
return train_func(**kwargs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the instantiate would be preferable. If we were to delay the instantiatation of the model
within the Forecaster, it may be possible to mimic a hydra instantiate call.
The delay will be neccessary to support loading weights only
model = instantiate({'_target_':self.config.get('forecaster'), **kwargs)
if self.load_weights_only:
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
return train_func.load_from_checkpoint(self.last_checkpoint, **kwargs)
return model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, when adding recursive = False as an argument as well, that works to instantiate the model. However, after an epoch is complete I get "TypeError: Object of type DictConfig is not JSON serializable" during saving of metadata for the checkpoint. That should be fixable though.
As for loading weights only, it seems https://github.com/ecmwf/anemoi-training/tree/feature/ckpo_loading_skip_mismatched moves this to train.py, so the model can be instantiated beforehand without problem. I will wait until this reaches develop and pull it to this branch, then add the instantiation.
class GraphInterpolator(GraphForecaster): | ||
"""Graph neural network interpolator for PyTorch Lightning.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this work on the Interpolator
. It's a good example that the GraphForecaster
class needs some work and to be broken into a proper class structure.
What are your thoughts on which components are reusable and then in counter, which parts are typical to override?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a mix of both, as well as some components that are needed only for the forecaster and some only for the interpolator.
Reusable
- All of the init function, except for rollout and multistep.
- All of the instantiable objects: loss, metrics, the model, etc.
- The scheduler and optimizers, which should maybe become an instantiated object anyway.
- The training/validation_step functions
- calculate_val_metrics: by reusing the rollout_step label as interp_step instead.
Overwritten
- _step and forward
Only for the forecaster/interpolator
- advance_input and rollout_step
- target forcings (although these could also be useful for the forecaster)
To avoid inheriting unused components with the Interpolator, we could consider using a framework class containing only the common components between the forecaster and interpolator, then have both inherit this class. However, that might be a bit too much when there are only two options thus far.
In fact, the forecaster can be seen as a special case of the interpolator, since the boundary can be specified as the multistep input, and the target can be any time, including the future. If I implement rollout functionality to the interpolator and make the target forcings optional, I think it should be able to do anything the forecaster can.
In my opinion, it would be the best approach to merge the two this way. It also enables the option to train a combined forecaster/interpolator, instead of having two separate models.
Do you agree with merging the two, or should I make a base framework class for both to inherit, or just keep them as is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I would lean towards making a base framework class. There are other use cases coming down the pipeline that would need this.
Although I am intrigued by the idea of have a class that can do both together.
Adds temporal interpolation functionality to anemoi. The idea is that a 6 or 12 hour forecaster might yield better predictions going days out than a 1 hour forecaster, as it has to make fewer auto-regressive steps. To produce the hourly predictions still, we can use the information available from the forecaster, e.g. hours 12 and 18 as input to predict hours 13-17. These predictions are made individually, assisted by some information about the target time as input.
This is a work in progress, parts of the implementation can be found on the corresponding branch of anemoi-models.
Implemented
To do
Questions
Although a simple interpolation setup like using hours 0 and 6 to predict hours 1-5 yields a regular range from 0 to 6, irregular ranges would enable more complex setups for both the forecaster and interpolator.