From 7e238e45a5376f4966b994272d93d0747add0c28 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Wed, 9 Oct 2024 15:25:03 +0200 Subject: [PATCH 01/43] Introduced resume flag and checkpoint loading for transfer learning, removed metadata saving in checkpoints due to corruption error on big models, fixed logging to work in the transfer leanring setting --- .../diagnostics/callbacks/__init__.py | 4 +-- .../training/diagnostics/mlflow/logger.py | 4 ++- src/anemoi/training/train/train.py | 25 ++++++++++++++----- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index f2195b5f..f47db036 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -997,7 +997,7 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s torch.save(model, inference_checkpoint_filepath) - save_metadata(inference_checkpoint_filepath, metadata) + # save_metadata(inference_checkpoint_filepath, metadata) model.config = save_config model.metadata = tmp_metadata @@ -1016,7 +1016,7 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s from weakref import proxy # save metadata for the training checkpoint in the same format as inference - save_metadata(lightning_checkpoint_filepath, metadata) + # save_metadata(lightning_checkpoint_filepath, metadata) # notify loggers for logger in trainer.loggers: diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 7854c172..90978e76 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -70,7 +70,7 @@ def get_mlflow_run_params(config: OmegaConf, tracking_uri: str) -> tuple[str | N if len(sys.argv) > 1: # add the arguments to the command tag tags["command"] = tags["command"] + " " + " ".join(sys.argv[1:]) - if config.training.run_id or config.training.fork_run_id: + if (config.training.run_id or config.training.fork_run_id) and config.training.resume: "Either run_id or fork_run_id must be provided to resume a run." import mlflow @@ -85,11 +85,13 @@ def get_mlflow_run_params(config: OmegaConf, tracking_uri: str) -> tuple[str | N run_name = mlflow_client.get_run(parent_run_id).info.run_name tags["mlflow.parentRunId"] = parent_run_id tags["resumedRun"] = "True" # tags can't take boolean values + elif config.training.run_id and not config.diagnostics.log.mlflow.on_resume_create_child: run_id = config.training.run_id run_name = mlflow_client.get_run(run_id).info.run_name mlflow_client.update_run(run_id=run_id, status="RUNNING") tags["resumedRun"] = "True" + else: parent_run_id = config.training.fork_run_id tags["forkedRun"] = "True" diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index f48b9467..00209a94 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -62,7 +62,7 @@ def __init__(self, config: DictConfig) -> None: self.config = config # Default to not warm-starting from a checkpoint - self.start_from_checkpoint = bool(self.config.training.run_id) or bool(self.config.training.fork_run_id) + self.start_from_checkpoint = (bool(self.config.training.run_id) or bool(self.config.training.fork_run_id)) and self.config.training.resume self.load_weights_only = config.training.load_weights_only self.parent_uuid = None @@ -141,7 +141,8 @@ def model(self) -> GraphForecaster: } if self.load_weights_only: LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) - return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs) + return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) + return GraphForecaster(**kwargs) @rank_zero_only @@ -187,12 +188,19 @@ def last_checkpoint(self) -> str | None: """Path to the last checkpoint.""" if not self.start_from_checkpoint: return None - + checkpoint = Path( self.config.hardware.paths.checkpoints.parent, - self.config.training.fork_run_id or self.run_id, - self.config.hardware.files.warm_start or "last.ckpt", + self.config.training.fork_run_id, # or self.run_id, + self.config.hardware.files.warm_start or "transfer.ckpt", ) + # Transfer learning or continue training + if not Path(checkpoint).exists(): + checkpoint = Path( + self.config.hardware.paths.checkpoints.parent, + self.config.training.fork_run_id, # or self.run_id, + self.config.hardware.files.warm_start or "last.ckpt", + ) # Check if the last checkpoint exists if Path(checkpoint).exists(): @@ -313,6 +321,9 @@ def strategy(self) -> DDPGroupStrategy: def train(self) -> None: """Training entry point.""" + + print('Setting up trainer..') + trainer = pl.Trainer( accelerator=self.accelerator, callbacks=self.callbacks, @@ -328,7 +339,7 @@ def train(self) -> None: # run a fixed no of batches per epoch (helpful when debugging) limit_train_batches=self.config.dataloader.limit_batches.training, limit_val_batches=self.config.dataloader.limit_batches.validation, - num_sanity_val_steps=4, + num_sanity_val_steps=0, accumulate_grad_batches=self.config.training.accum_grad_batches, gradient_clip_val=self.config.training.gradient_clip.val, gradient_clip_algorithm=self.config.training.gradient_clip.algorithm, @@ -338,6 +349,8 @@ def train(self) -> None: enable_progress_bar=self.config.diagnostics.enable_progress_bar, ) + print('Starting training..') + trainer.fit( self.model, datamodule=self.datamodule, From 08671d734f596af743660af19a5436d23df92e6c Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 10 Oct 2024 16:33:21 +0200 Subject: [PATCH 02/43] Added len of dataset computed dynamically --- src/anemoi/training/data/datamodule.py | 10 ++++++++++ src/anemoi/training/data/dataset.py | 7 +++++++ 2 files changed, 17 insertions(+) diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 1e119892..8d86dd58 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -95,6 +95,7 @@ def __init__(self, config: DictConfig) -> None: self.config.dataloader.validation.start - 1, ) self.config.dataloader.training.end = self.config.dataloader.validation.start - 1 + def _check_resolution(self, resolution: str) -> None: assert ( @@ -162,6 +163,13 @@ def _get_dataset( label: str = "generic", ) -> NativeGridDataset: r = max(rollout, self.rollout) + + # Compute effective batch size + effective_bs = self.config.dataloader.batch_size[label] *\ + self.config.hardware.num_gpus_per_node *\ + self.config.hardware.num_nodes //\ + self.config.hardware.num_gpus_per_model + data = NativeGridDataset( data_reader=data_reader, rollout=r, @@ -172,7 +180,9 @@ def _get_dataset( model_comm_num_groups=self.model_comm_num_groups, shuffle=shuffle, label=label, + effective_bs=effective_bs ) + self._check_resolution(data.resolution) return data diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index e2aa12bd..2e2bacc1 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -11,6 +11,7 @@ import random from functools import cached_property from typing import Callable +from omegaconf import DictConfig import numpy as np import torch @@ -38,6 +39,7 @@ def __init__( model_comm_num_groups: int = 1, shuffle: bool = True, label: str = "generic", + effective_bs: int = 1 ) -> None: """Initialize (part of) the dataset state. @@ -64,6 +66,7 @@ def __init__( """ self.label = label + self.effective_bs = effective_bs self.data = data_reader @@ -244,6 +247,10 @@ def __repr__(self) -> str: Multistep: {self.multi_step} Timeincrement: {self.timeincrement} """ + + def __len__(self) -> int: + """Estimate the total number of samples based on valid indices.""" + return len(self.valid_date_indices) // self.effective_bs def worker_init_func(worker_id: int) -> None: From e2bd86804aec2501e627240613904cd791bb2464 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 22 Oct 2024 11:50:55 +0200 Subject: [PATCH 03/43] debugging validation --- src/anemoi/training/data/datamodule.py | 5 +++-- src/anemoi/training/train/forecaster.py | 9 ++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 8d86dd58..11c707b8 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -162,10 +162,11 @@ def _get_dataset( rollout: int = 1, label: str = "generic", ) -> NativeGridDataset: + r = max(rollout, self.rollout) - # Compute effective batch size - effective_bs = self.config.dataloader.batch_size[label] *\ + # Compute effective batch size + effective_bs = self.config.dataloader.batch_size['training'] *\ self.config.hardware.num_gpus_per_node *\ self.config.hardware.num_nodes //\ self.config.hardware.num_gpus_per_model diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index ff1acfd7..3d14a570 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -75,6 +75,8 @@ def __init__( config=DotDict(map_config_to_primitives(OmegaConf.to_container(config, resolve=True))), ) + self.model = torch.compile(self.model) + self.data_indices = data_indices self.save_hyperparameters() @@ -321,8 +323,11 @@ def on_train_epoch_end(self) -> None: self.rollout = min(self.rollout, self.rollout_max) def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: + print('I am doing validation!!!') with torch.no_grad(): val_loss, metrics, y_preds = self._step(batch, batch_idx, validation_mode=True) + print('Done step..') + print('Logging..') self.log( "val_wmse", val_loss, @@ -333,7 +338,8 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: batch_size=batch.shape[0], sync_dist=True, ) - for mname, mvalue in metrics.items(): + for i, (mname, mvalue) in enumerate(metrics.items()): + print(i) self.log( "val_" + mname, mvalue, @@ -344,6 +350,7 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: batch_size=batch.shape[0], sync_dist=True, ) + print('Done') return val_loss, y_preds def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict]]: From 544dddc2363e04cc81ffcd262ab036b9458d8461 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 24 Oct 2024 09:59:36 +0200 Subject: [PATCH 04/43] Small changes --- src/anemoi/training/train/forecaster.py | 21 ++++++++++----------- src/anemoi/training/train/train.py | 3 ++- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 3d14a570..57765f0c 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -97,12 +97,12 @@ def __init__( self.loss.register_full_backward_hook(grad_scaler, prepend=False) self.multi_step = config.training.multistep_input - self.lr = ( - config.hardware.num_nodes - * config.hardware.num_gpus_per_node - * config.training.lr.rate - / config.hardware.num_gpus_per_model - ) + print('config.hardware.num_nodes', config.hardware.num_nodes, type(config.hardware.num_nodes)) + print('config.hardware.num_gpus_per_node', config.hardware.num_gpus_per_node, type(config.hardware.num_gpus_per_node)) + print('config.training.lr.rate', config.training.lr.rate, type(config.training.lr.rate)) + print('config.hardware.num_gpus_per_model', config.hardware.num_gpus_per_model, type(config.hardware.num_gpus_per_model)) + self.lr = config.hardware.num_nodes * config.hardware.num_gpus_per_node * config.training.lr.rate / config.hardware.num_gpus_per_model + self.lr_iterations = config.training.lr.iterations self.lr_min = config.training.lr.min self.rollout = config.training.rollout.start @@ -323,11 +323,10 @@ def on_train_epoch_end(self) -> None: self.rollout = min(self.rollout, self.rollout_max) def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: - print('I am doing validation!!!') + with torch.no_grad(): val_loss, metrics, y_preds = self._step(batch, batch_idx, validation_mode=True) - print('Done step..') - print('Logging..') + self.log( "val_wmse", val_loss, @@ -338,8 +337,8 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: batch_size=batch.shape[0], sync_dist=True, ) + for i, (mname, mvalue) in enumerate(metrics.items()): - print(i) self.log( "val_" + mname, mvalue, @@ -350,7 +349,7 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: batch_size=batch.shape[0], sync_dist=True, ) - print('Done') + return val_loss, y_preds def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict]]: diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 00209a94..62bb9d88 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -287,11 +287,12 @@ def _log_information(self) -> None: LOGGER.debug("Total number of auxiliary variables: %d", len(self.config.data.forcing)) # Log learning rate multiplier when running single-node, multi-GPU and/or multi-node - total_number_of_model_instances = ( + total_number_of_model_instances = int( self.config.hardware.num_nodes * self.config.hardware.num_gpus_per_node / self.config.hardware.num_gpus_per_model ) + LOGGER.debug( "Total GPU count / model group size: %d - NB: the learning rate will be scaled by this factor!", total_number_of_model_instances, From a85619d818c9c5a059035b7ad5a1c45685983f56 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Fri, 25 Oct 2024 11:21:10 +0200 Subject: [PATCH 05/43] Removed prints --- src/anemoi/training/data/datamodule.py | 1 - src/anemoi/training/train/forecaster.py | 4 ---- 2 files changed, 5 deletions(-) diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 85edbff7..8340c5a4 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -186,7 +186,6 @@ def _get_dataset( label=label, effective_bs=effective_bs ) - self._check_resolution(data.resolution) return data diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 57765f0c..a45c2bab 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -97,10 +97,6 @@ def __init__( self.loss.register_full_backward_hook(grad_scaler, prepend=False) self.multi_step = config.training.multistep_input - print('config.hardware.num_nodes', config.hardware.num_nodes, type(config.hardware.num_nodes)) - print('config.hardware.num_gpus_per_node', config.hardware.num_gpus_per_node, type(config.hardware.num_gpus_per_node)) - print('config.training.lr.rate', config.training.lr.rate, type(config.training.lr.rate)) - print('config.hardware.num_gpus_per_model', config.hardware.num_gpus_per_model, type(config.hardware.num_gpus_per_model)) self.lr = config.hardware.num_nodes * config.hardware.num_gpus_per_node * config.training.lr.rate / config.hardware.num_gpus_per_model self.lr_iterations = config.training.lr.iterations From c8ce0b04176355c2dfb6b72be12feec007de6d00 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Mon, 18 Nov 2024 13:12:13 +0100 Subject: [PATCH 06/43] Not working --- src/anemoi/training/data/datamodule.py | 2 +- .../training/diagnostics/callbacks/__init__.py | 10 ++++++++++ src/anemoi/training/train/forecaster.py | 17 ++++++++++++++--- src/anemoi/training/train/train.py | 10 +++++++--- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 8ee96221..293ef483 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -189,7 +189,7 @@ def _get_dataset( label=label, effective_bs=effective_bs ) - self._check_resolution(data.resolution) + # self._check_resolution(data.resolution) return data def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader: diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 6ef942bd..13ed4a87 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -683,12 +683,22 @@ def _plot( local_rank = pl_module.local_rank batch = pl_module.model.pre_processors(batch, in_place=False) + nan_locations = pl_module.model.pre_processors.processors['imputer'].nan_locations + nan_locations = nan_locations[ + self.sample_idx, + pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, + ..., + pl_module.data_indices.internal_data.output.full, + ] + self.post_processors.processors['imputer'].set_nan_locations(nan_locations) + input_tensor = batch[ self.sample_idx, pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, ..., pl_module.data_indices.internal_data.output.full, ].cpu() + data = self.post_processors(input_tensor) output_tensor = self.post_processors( diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index f4599572..936a3bed 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -78,8 +78,6 @@ def __init__( config=DotDict(map_config_to_primitives(OmegaConf.to_container(config, resolve=True))), ) - self.model = torch.compile(self.model) - self.data_indices = data_indices self.save_hyperparameters() @@ -248,7 +246,8 @@ def _step( y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.internal_data.output.full] # y includes the auxiliary variables, so we must leave those out when computing the loss - loss += checkpoint(self.loss, y_pred, y, use_reentrant=False) + tmp_loss = checkpoint(self.loss, y_pred, y, use_reentrant=False) + loss += tmp_loss x = self.advance_input(x, y_pred, batch, rollout_step) @@ -275,9 +274,21 @@ def calculate_val_metrics( ) -> tuple[dict, list]: metrics = {} y_preds = [] + + # Added to impute nans + nan_locations = torch.isnan(y[..., self.data_indices.internal_data.output.full]) + self.model.post_processors.processors['imputer'].set_nan_locations(nan_locations) + + print("y in val") + print(y.shape, nan_locations.shape) + y_postprocessed = self.model.post_processors(y, in_place=False) y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) for mkey, indices in self.metric_ranges_validation.items(): + print(indices, y_pred_postprocessed.shape, y_postprocessed.shape) + for idx in indices: + print("trying: ", idx, y_pred_postprocessed[..., idx].shape) + print(y_postprocessed[..., indices].shape) metrics[f"{mkey}_{rollout_step + 1}"] = self.metrics( y_pred_postprocessed[..., indices], y_postprocessed[..., indices], diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 5badab6a..d3167be8 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -42,7 +42,6 @@ LOGGER = logging.getLogger(__name__) - class AnemoiTrainer: """Utility class for training the model.""" @@ -83,6 +82,9 @@ def datamodule(self) -> AnemoiDatasetsDataModule: """DataModule instance and DataSets.""" datamodule = AnemoiDatasetsDataModule(self.config) self.config.data.num_features = len(datamodule.ds_train.data.variables) + LOGGER.info( + f"Data has {len(datamodule.ds_train.data.variables)} variables: {datamodule.ds_train.data.variables}" + ) return datamodule @cached_property @@ -198,8 +200,9 @@ def last_checkpoint(self) -> str | None: checkpoint = Path( self.config.hardware.paths.checkpoints.parent, fork_id or self.lineage_run, - self.config.hardware.files.warm_start or "last.ckpt", + self.config.hardware.files.warm_start or "transfer.ckpt" or "last.ckpt", ) + # Check if the last checkpoint exists if Path(checkpoint).exists(): LOGGER.info("Resuming training from last checkpoint: %s", checkpoint) @@ -386,7 +389,8 @@ def train(self) -> None: @hydra.main(version_base=None, config_path="../config", config_name="config") def main(config: DictConfig) -> None: - AnemoiTrainer(config).train() + trainer = AnemoiTrainer(config) + trainer.train() if __name__ == "__main__": From 135eac5c7036bcecb43afc1cffeb760ac922049a Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Mon, 18 Nov 2024 17:04:28 +0100 Subject: [PATCH 07/43] small changes --- .../config/model/graphtransformer.yaml | 38 +++++++++---------- src/anemoi/training/data/datamodule.py | 2 +- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/anemoi/training/config/model/graphtransformer.yaml b/src/anemoi/training/config/model/graphtransformer.yaml index 5c2e819a..71ffc9d0 100644 --- a/src/anemoi/training/config/model/graphtransformer.yaml +++ b/src/anemoi/training/config/model/graphtransformer.yaml @@ -52,25 +52,25 @@ attributes: node_loss_weight: area_weight -# Bounding configuration -bounding: #These are applied in order +# # Bounding configuration +# bounding: #These are applied in order - # Bound tp (total precipitation) with a Relu bounding layer - # ensuring a range of [0, infinity) to avoid negative precipitation values. - - _target_: anemoi.models.layers.bounding.ReluBounding #[0, infinity) - variables: - - tp +# # Bound tp (total precipitation) with a Relu bounding layer +# # ensuring a range of [0, infinity) to avoid negative precipitation values. +# - _target_: anemoi.models.layers.bounding.ReluBounding #[0, infinity) +# variables: +# - tp - # [OPTIONAL] Bound cp (convective precipitation) as a fraction of tp. - # This guarantees that cp is physically consistent with tp by restricting cp - # to a fraction of tp [0 to 1]. Uncomment the lines below to apply. - # NOTE: If this bounding strategy is used, the normalization of cp must be - # changed to "std" normalization, and the "cp" statistics should be remapped - # to those of tp to ensure consistency. +# # [OPTIONAL] Bound cp (convective precipitation) as a fraction of tp. +# # This guarantees that cp is physically consistent with tp by restricting cp +# # to a fraction of tp [0 to 1]. Uncomment the lines below to apply. +# # NOTE: If this bounding strategy is used, the normalization of cp must be +# # changed to "std" normalization, and the "cp" statistics should be remapped +# # to those of tp to ensure consistency. - # - _target_: anemoi.models.layers.bounding.FractionBounding # fraction of tp - # variables: - # - cp - # min_val: 0 - # max_val: 1 - # total_var: tp +# # - _target_: anemoi.models.layers.bounding.FractionBounding # fraction of tp +# # variables: +# # - cp +# # min_val: 0 +# # max_val: 1 +# # total_var: tp diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 303266fc..a5443c40 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -188,7 +188,7 @@ def _get_dataset( shuffle=shuffle, label=label, ) - self._check_resolution(data.resolution) + # self._check_resolution(data.resolution) return data def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader: From db2a14fa18eabad213363c79424fd721520655af Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 26 Nov 2024 11:56:49 +0100 Subject: [PATCH 08/43] Imputer changes --- src/anemoi/training/train/forecaster.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 277c1602..c0684b80 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -437,31 +437,19 @@ def calculate_val_metrics( validation metrics and predictions """ metrics = {} -<<<<<<< HEAD y_preds = [] # Added to impute nans nan_locations = torch.isnan(y[..., self.data_indices.internal_data.output.full]) self.model.post_processors.processors['imputer'].set_nan_locations(nan_locations) - print("y in val") - print(y.shape, nan_locations.shape) - y_postprocessed = self.model.post_processors(y, in_place=False) y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) for mkey, indices in self.metric_ranges_validation.items(): - print(indices, y_pred_postprocessed.shape, y_postprocessed.shape) - for idx in indices: - print("trying: ", idx, y_pred_postprocessed[..., idx].shape) - print(y_postprocessed[..., indices].shape) metrics[f"{mkey}_{rollout_step + 1}"] = self.metrics( y_pred_postprocessed[..., indices], y_postprocessed[..., indices], ) -======= - y_postprocessed = self.model.post_processors(y, in_place=False) - y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) ->>>>>>> develop for metric in self.metrics: metric_name = getattr(metric, "name", metric.__class__.__name__.lower()) From 57f9026b1d7744a0ce96556d00680ef16f15b165 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 26 Nov 2024 15:01:20 +0100 Subject: [PATCH 09/43] Added sanification of checkpoint, effective batch size, git pre commit --- .../training/config/training/default.yaml | 1 + src/anemoi/training/data/datamodule.py | 21 +++++----- src/anemoi/training/data/dataset.py | 8 ++-- src/anemoi/training/train/forecaster.py | 21 ++++++---- src/anemoi/training/train/train.py | 41 +++++++++++++------ src/anemoi/training/utils/checkpoint.py | 28 +++++++++++++ 6 files changed, 84 insertions(+), 36 deletions(-) diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index b471034e..9c962dcf 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -2,6 +2,7 @@ run_id: null fork_run_id: null load_weights_only: null # only load model weights, do not restore optimiser states etc. +transfer_learning: null # activate to perform transfer learning # run in deterministic mode ; slows down deterministic: False diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index be66d8ec..7a8a0b6d 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -82,7 +82,6 @@ def __init__(self, config: DictConfig) -> None: self.config.dataloader.validation.start - 1, ) self.config.dataloader.training.end = self.config.dataloader.validation.start - 1 - if not self.config.dataloader.get("pin_memory", True): LOGGER.info("Data loader memory pinning disabled.") @@ -177,15 +176,17 @@ def _get_dataset( rollout: int = 1, label: str = "generic", ) -> NativeGridDataset: - + r = max(rollout, self.rollout) - # Compute effective batch size - effective_bs = self.config.dataloader.batch_size['training'] *\ - self.config.hardware.num_gpus_per_node *\ - self.config.hardware.num_nodes //\ - self.config.hardware.num_gpus_per_model - + # Compute effective batch size + effective_bs = ( + self.config.dataloader.batch_size["training"] + * self.config.hardware.num_gpus_per_node + * self.config.hardware.num_nodes + // self.config.hardware.num_gpus_per_model + ) + data = NativeGridDataset( data_reader=data_reader, rollout=r, @@ -196,9 +197,9 @@ def _get_dataset( model_comm_num_groups=self.model_comm_num_groups, shuffle=shuffle, label=label, - effective_bs=effective_bs + effective_bs=effective_bs, ) - # self._check_resolution(data.resolution) + self._check_resolution(data.resolution) return data def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader: diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index cdf1e97b..a8075994 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -14,7 +14,6 @@ import random from functools import cached_property from typing import Callable -from omegaconf import DictConfig import numpy as np import torch @@ -42,7 +41,7 @@ def __init__( model_comm_num_groups: int = 1, shuffle: bool = True, label: str = "generic", - effective_bs: int = 1 + effective_bs: int = 1, ) -> None: """Initialize (part of) the dataset state. @@ -66,7 +65,8 @@ def __init__( Shuffle batches, by default True label : str, optional label for the dataset, by default "generic" - + effective_bs : int, default 1 + effective batch size useful to compute the lenght of the dataset """ self.label = label self.effective_bs = effective_bs @@ -250,7 +250,7 @@ def __repr__(self) -> str: Multistep: {self.multi_step} Timeincrement: {self.timeincrement} """ - + def __len__(self) -> int: """Estimate the total number of samples based on valid indices.""" return len(self.valid_date_indices) // self.effective_bs diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index c0684b80..8df49bc2 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -120,8 +120,13 @@ def __init__( self.loss.register_full_backward_hook(grad_scaler, prepend=False) self.multi_step = config.training.multistep_input - self.lr = config.hardware.num_nodes * config.hardware.num_gpus_per_node * config.training.lr.rate / config.hardware.num_gpus_per_model - + self.lr = ( + config.hardware.num_nodes + * config.hardware.num_gpus_per_node + * config.training.lr.rate + / config.hardware.num_gpus_per_model + ) + self.lr_iterations = config.training.lr.iterations self.lr_min = config.training.lr.min self.rollout = config.training.rollout.start @@ -376,8 +381,7 @@ def rollout_step( y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.internal_data.output.full] # y includes the auxiliary variables, so we must leave those out when computing the loss - tmp_loss = checkpoint(self.loss, y_pred, y, use_reentrant=False) - loss += tmp_loss + loss = checkpoint(self.loss, y_pred, y, use_reentrant=False) if training_mode else None x = self.advance_input(x, y_pred, batch, rollout_step) @@ -437,11 +441,10 @@ def calculate_val_metrics( validation metrics and predictions """ metrics = {} - y_preds = [] - + # Added to impute nans nan_locations = torch.isnan(y[..., self.data_indices.internal_data.output.full]) - self.model.post_processors.processors['imputer'].set_nan_locations(nan_locations) + self.model.post_processors.processors["imputer"].set_nan_locations(nan_locations) y_postprocessed = self.model.post_processors(y, in_place=False) y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) @@ -514,7 +517,7 @@ def on_train_epoch_end(self) -> None: self.rollout = min(self.rollout, self.rollout_max) def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: - + with torch.no_grad(): val_loss, metrics, y_preds = self._step(batch, batch_idx, validation_mode=True) @@ -529,7 +532,7 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: sync_dist=True, ) - for i, (mname, mvalue) in enumerate(metrics.items()): + for mname, mvalue in metrics.items(): self.log( "val_" + mname, mvalue, diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 8c3cbe34..a01ab9b6 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -34,6 +34,7 @@ from anemoi.training.diagnostics.logger import get_wandb_logger from anemoi.training.distributed.strategy import DDPGroupStrategy from anemoi.training.train.forecaster import GraphForecaster +from anemoi.training.utils.checkpoint import sanify_checkpoint from anemoi.training.utils.jsonify import map_config_to_primitives from anemoi.training.utils.seeding import get_base_seed @@ -42,6 +43,7 @@ LOGGER = logging.getLogger(__name__) + class AnemoiTrainer: """Utility class for training the model.""" @@ -61,8 +63,13 @@ def __init__(self, config: DictConfig) -> None: OmegaConf.resolve(config) self.config = config - # Default to not warm-starting from a checkpoint - self.start_from_checkpoint = (bool(self.config.training.run_id) or bool(self.config.training.fork_run_id)) and self.config.training.resume + # Set Transfer Learning based on the other if not provided + if self.config.training.transfer_learning is None: + self.config.training.transfer_learning = ( + bool(self.config.training.run_id) or bool(self.config.training.fork_run_id) + ) and self.load_weights_only + + self.start_from_checkpoint = bool(self.config.training.run_id) or bool(self.config.training.fork_run_id) self.load_weights_only = config.training.load_weights_only self.parent_uuid = None @@ -82,9 +89,7 @@ def datamodule(self) -> AnemoiDatasetsDataModule: """DataModule instance and DataSets.""" datamodule = AnemoiDatasetsDataModule(self.config) self.config.data.num_features = len(datamodule.ds_train.data.variables) - LOGGER.info( - f"Data has {len(datamodule.ds_train.data.variables)} variables: {datamodule.ds_train.data.variables}" - ) + LOGGER.info("Data has ", len(datamodule.ds_train.data.variables), " variables: ", datamodule.ds_train.data.variables) return datamodule @cached_property @@ -146,11 +151,22 @@ def model(self) -> GraphForecaster: "metadata": self.metadata, "statistics": self.datamodule.statistics, } + + model = GraphForecaster(**kwargs) + if self.load_weights_only: + # Sanify the checkpoint for transfer learning + if self.config.training.transfer_learning: + save_path = Path( + self.config.hardware.paths.checkpoints.parent, + (self.fork_run_server2server or self.config.training.fork_run_id) or self.lineage_run, + ) + self.last_checkpoint = sanify_checkpoint(model, self.last_checkpoint, save_path) + LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) - return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) + return model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) - return GraphForecaster(**kwargs) + return model @rank_zero_only def _get_mlflow_run_id(self) -> str: @@ -200,9 +216,9 @@ def last_checkpoint(self) -> str | None: checkpoint = Path( self.config.hardware.paths.checkpoints.parent, fork_id or self.lineage_run, - self.config.hardware.files.warm_start or "transfer.ckpt" or "last.ckpt", + self.config.hardware.files.warm_start or "last.ckpt", ) - + # Check if the last checkpoint exists if Path(checkpoint).exists(): LOGGER.info("Resuming training from last checkpoint: %s", checkpoint) @@ -297,7 +313,7 @@ def _log_information(self) -> None: total_number_of_model_instances = int( self.config.hardware.num_nodes * self.config.hardware.num_gpus_per_node - / self.config.hardware.num_gpus_per_model + / self.config.hardware.num_gpus_per_model, ) LOGGER.debug( @@ -355,8 +371,7 @@ def strategy(self) -> DDPGroupStrategy: def train(self) -> None: """Training entry point.""" - - print('Setting up trainer..') + LOGGER.debug("Setting up trainer..") trainer = pl.Trainer( accelerator=self.accelerator, @@ -384,7 +399,7 @@ def train(self) -> None: enable_progress_bar=self.config.diagnostics.enable_progress_bar, ) - print('Starting training..') + LOGGER.debug("Starting training..") trainer.fit( self.model, diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index ddb5a1c8..92759f6d 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -10,6 +10,7 @@ from __future__ import annotations +import logging from pathlib import Path import torch @@ -17,6 +18,8 @@ from anemoi.training.train.forecaster import GraphForecaster +LOGGER = logging.getLogger(__name__) + def load_and_prepare_model(lightning_checkpoint_path: str) -> tuple[torch.nn.Module, dict]: """Load the lightning checkpoint and extract the pytorch model and its metadata. @@ -65,3 +68,28 @@ def save_inference_checkpoint(model: torch.nn.Module, metadata: dict, save_path: torch.save(model, inference_filepath) save_metadata(inference_filepath, metadata) return inference_filepath + + +def sanify_checkpoint(model: torch.nn.Module, ckpt_path: Path | str, save_path: Path | str) -> Path: + + # Load the checkpoint + checkpoint = torch.load(ckpt_path, map_location=model.device) + + # Filter out layers with size mismatch + state_dict = checkpoint["state_dict"] + + model_state_dict = model.state_dict() + + for key in state_dict.copy(): + if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape: + LOGGER.debug("Skipping loading parameter: {}, checkpoint shape: {}, model shape: {}".format( + key, state_dict[key].shape, model_state_dict[key].shape + ) + ) + del state_dict[key] # Remove the mismatched key + + new_ckpt_path = Path(save_path, "transfer.ckpt") + LOGGER.info("Saved modified checkpoint at", new_ckpt_path) + torch.save(checkpoint, new_ckpt_path) + + return new_ckpt_path From 039c16f965a8db771ecf98196e776a4c985bc164 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 26 Nov 2024 15:03:08 +0100 Subject: [PATCH 10/43] gpc --- src/anemoi/training/train/train.py | 7 ++++++- src/anemoi/training/utils/checkpoint.py | 7 ++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index a01ab9b6..51a520b2 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -89,7 +89,12 @@ def datamodule(self) -> AnemoiDatasetsDataModule: """DataModule instance and DataSets.""" datamodule = AnemoiDatasetsDataModule(self.config) self.config.data.num_features = len(datamodule.ds_train.data.variables) - LOGGER.info("Data has ", len(datamodule.ds_train.data.variables), " variables: ", datamodule.ds_train.data.variables) + LOGGER.info( + "Data has ", + len(datamodule.ds_train.data.variables), + " variables: ", + datamodule.ds_train.data.variables, + ) return datamodule @cached_property diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index 92759f6d..ba117df8 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -82,9 +82,10 @@ def sanify_checkpoint(model: torch.nn.Module, ckpt_path: Path | str, save_path: for key in state_dict.copy(): if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape: - LOGGER.debug("Skipping loading parameter: {}, checkpoint shape: {}, model shape: {}".format( - key, state_dict[key].shape, model_state_dict[key].shape - ) + LOGGER.debug( + "Skipping loading parameter: ", key, + ", checkpoint shape: ", state_dict[key].shape, + ", model shape: ", model_state_dict[key].shape, ) del state_dict[key] # Remove the mismatched key From 463c6a9f8476efe0a5d2d3d67e7183a5eaee5b81 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 26 Nov 2024 15:08:51 +0100 Subject: [PATCH 11/43] gpc --- src/anemoi/training/utils/checkpoint.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index ba117df8..e5f4d7e5 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -83,9 +83,12 @@ def sanify_checkpoint(model: torch.nn.Module, ckpt_path: Path | str, save_path: for key in state_dict.copy(): if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape: LOGGER.debug( - "Skipping loading parameter: ", key, - ", checkpoint shape: ", state_dict[key].shape, - ", model shape: ", model_state_dict[key].shape, + "Skipping loading parameter: ", + key, + ", checkpoint shape: ", + state_dict[key].shape, + ", model shape: ", + model_state_dict[key].shape, ) del state_dict[key] # Remove the mismatched key From c6d751905fc81cf6307e92c75e2513be26198573 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Wed, 27 Nov 2024 16:53:42 +0100 Subject: [PATCH 12/43] New implementation: do not store modified checkpoint, load it directly after changing it --- .../config/model/graphtransformer.yaml | 38 +++++++++---------- src/anemoi/training/train/forecaster.py | 14 +++---- src/anemoi/training/train/train.py | 18 ++++----- src/anemoi/training/utils/checkpoint.py | 10 ++--- 4 files changed, 39 insertions(+), 41 deletions(-) diff --git a/src/anemoi/training/config/model/graphtransformer.yaml b/src/anemoi/training/config/model/graphtransformer.yaml index 71ffc9d0..5c2e819a 100644 --- a/src/anemoi/training/config/model/graphtransformer.yaml +++ b/src/anemoi/training/config/model/graphtransformer.yaml @@ -52,25 +52,25 @@ attributes: node_loss_weight: area_weight -# # Bounding configuration -# bounding: #These are applied in order +# Bounding configuration +bounding: #These are applied in order -# # Bound tp (total precipitation) with a Relu bounding layer -# # ensuring a range of [0, infinity) to avoid negative precipitation values. -# - _target_: anemoi.models.layers.bounding.ReluBounding #[0, infinity) -# variables: -# - tp + # Bound tp (total precipitation) with a Relu bounding layer + # ensuring a range of [0, infinity) to avoid negative precipitation values. + - _target_: anemoi.models.layers.bounding.ReluBounding #[0, infinity) + variables: + - tp -# # [OPTIONAL] Bound cp (convective precipitation) as a fraction of tp. -# # This guarantees that cp is physically consistent with tp by restricting cp -# # to a fraction of tp [0 to 1]. Uncomment the lines below to apply. -# # NOTE: If this bounding strategy is used, the normalization of cp must be -# # changed to "std" normalization, and the "cp" statistics should be remapped -# # to those of tp to ensure consistency. + # [OPTIONAL] Bound cp (convective precipitation) as a fraction of tp. + # This guarantees that cp is physically consistent with tp by restricting cp + # to a fraction of tp [0 to 1]. Uncomment the lines below to apply. + # NOTE: If this bounding strategy is used, the normalization of cp must be + # changed to "std" normalization, and the "cp" statistics should be remapped + # to those of tp to ensure consistency. -# # - _target_: anemoi.models.layers.bounding.FractionBounding # fraction of tp -# # variables: -# # - cp -# # min_val: 0 -# # max_val: 1 -# # total_var: tp + # - _target_: anemoi.models.layers.bounding.FractionBounding # fraction of tp + # variables: + # - cp + # min_val: 0 + # max_val: 1 + # total_var: tp diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 38023db5..4ec5a963 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -529,18 +529,15 @@ def calculate_val_metrics( validation metrics and predictions """ metrics = {} - # Added to impute nans + print(self.data_indices.internal_data.output.full) + print(y.shape) + print(y[..., self.data_indices.internal_data.output.full].shape) nan_locations = torch.isnan(y[..., self.data_indices.internal_data.output.full]) + print(nan_locations) self.model.post_processors.processors["imputer"].set_nan_locations(nan_locations) - y_postprocessed = self.model.post_processors(y, in_place=False) y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) - for mkey, indices in self.metric_ranges_validation.items(): - metrics[f"{mkey}_{rollout_step + 1}"] = self.metrics( - y_pred_postprocessed[..., indices], - y_postprocessed[..., indices], - ) for metric in self.metrics: metric_name = getattr(metric, "name", metric.__class__.__name__.lower()) @@ -554,6 +551,9 @@ def calculate_val_metrics( continue for mkey, indices in self.val_metric_ranges.items(): + print(indices) + print(y_pred_postprocessed) + indices.to(y_postprocessed.device) metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( y_pred_postprocessed[..., indices], y_postprocessed[..., indices], diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 2964411d..b72a2389 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -33,7 +33,7 @@ from anemoi.training.diagnostics.logger import get_wandb_logger from anemoi.training.distributed.strategy import DDPGroupStrategy from anemoi.training.train.forecaster import GraphForecaster -from anemoi.training.utils.checkpoint import sanify_checkpoint +from anemoi.training.utils.checkpoint import transfer_learning_loading from anemoi.training.utils.jsonify import map_config_to_primitives from anemoi.training.utils.seeding import get_base_seed @@ -89,10 +89,8 @@ def datamodule(self) -> AnemoiDatasetsDataModule: datamodule = AnemoiDatasetsDataModule(self.config) self.config.data.num_features = len(datamodule.ds_train.data.variables) LOGGER.info( - "Data has ", - len(datamodule.ds_train.data.variables), - " variables: ", - datamodule.ds_train.data.variables, + "Data has {} variables: {}".format(len(datamodule.ds_train.data.variables), + datamodule.ds_train.data.variables), ) return datamodule @@ -165,9 +163,12 @@ def model(self) -> GraphForecaster: self.config.hardware.paths.checkpoints.parent, (self.fork_run_server2server or self.config.training.fork_run_id) or self.lineage_run, ) - self.last_checkpoint = sanify_checkpoint(model, self.last_checkpoint, save_path) + LOGGER.info("Learning weights with Transfer Learning from %s", self.last_checkpoint) + + return transfer_learning_loading(model, self.last_checkpoint) LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) + return model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) return model @@ -314,7 +315,7 @@ def _log_information(self) -> None: LOGGER.debug("Total number of auxiliary variables: %d", len(self.config.data.forcing)) # Log learning rate multiplier when running single-node, multi-GPU and/or multi-node - total_number_of_model_instances = int( + total_number_of_model_instances = ( self.config.hardware.num_nodes * self.config.hardware.num_gpus_per_node / self.config.hardware.num_gpus_per_model, @@ -420,8 +421,7 @@ def train(self) -> None: @hydra.main(version_base=None, config_path="../config", config_name="config") def main(config: DictConfig) -> None: - trainer = AnemoiTrainer(config) - trainer.train() + AnemoiTrainer(config).train() if __name__ == "__main__": diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index e5f4d7e5..935cae4f 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -14,6 +14,7 @@ from pathlib import Path import torch +import torch.nn as nn from anemoi.utils.checkpoints import save_metadata from anemoi.training.train.forecaster import GraphForecaster @@ -70,7 +71,7 @@ def save_inference_checkpoint(model: torch.nn.Module, metadata: dict, save_path: return inference_filepath -def sanify_checkpoint(model: torch.nn.Module, ckpt_path: Path | str, save_path: Path | str) -> Path: +def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> nn.Module: # Load the checkpoint checkpoint = torch.load(ckpt_path, map_location=model.device) @@ -92,8 +93,5 @@ def sanify_checkpoint(model: torch.nn.Module, ckpt_path: Path | str, save_path: ) del state_dict[key] # Remove the mismatched key - new_ckpt_path = Path(save_path, "transfer.ckpt") - LOGGER.info("Saved modified checkpoint at", new_ckpt_path) - torch.save(checkpoint, new_ckpt_path) - - return new_ckpt_path + # Load the filtered state_dict into the model + return model.load_state_dict(state_dict, strict=False) From bca035523d1138f8562ec5ee01b662f55c9cfa36 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 28 Nov 2024 15:39:50 +0100 Subject: [PATCH 13/43] Added logging --- src/anemoi/training/train/forecaster.py | 15 ++++----------- src/anemoi/training/train/train.py | 24 ++++++++++-------------- 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 4ec5a963..92685854 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -530,11 +530,7 @@ def calculate_val_metrics( """ metrics = {} # Added to impute nans - print(self.data_indices.internal_data.output.full) - print(y.shape) - print(y[..., self.data_indices.internal_data.output.full].shape) - nan_locations = torch.isnan(y[..., self.data_indices.internal_data.output.full]) - print(nan_locations) + nan_locations = torch.isnan(y) self.model.post_processors.processors["imputer"].set_nan_locations(nan_locations) y_postprocessed = self.model.post_processors(y, in_place=False) y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) @@ -551,13 +547,10 @@ def calculate_val_metrics( continue for mkey, indices in self.val_metric_ranges.items(): - print(indices) - print(y_pred_postprocessed) - indices.to(y_postprocessed.device) metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( - y_pred_postprocessed[..., indices], - y_postprocessed[..., indices], - scalar_indices=[..., indices] if -1 in metric.scalar else None, + y_pred_postprocessed[...,], + y_postprocessed[...], + scalar_indices=[...] if -1 in metric.scalar else None, ) return metrics diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index b72a2389..4a8d8c1e 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -41,6 +41,7 @@ from torch_geometric.data import HeteroData LOGGER = logging.getLogger(__name__) +LOGGER.setLevel(logging.DEBUG) # Change DEBUG to INFO, WARNING, etc., as needed class AnemoiTrainer: @@ -66,10 +67,10 @@ def __init__(self, config: DictConfig) -> None: if self.config.training.transfer_learning is None: self.config.training.transfer_learning = ( bool(self.config.training.run_id) or bool(self.config.training.fork_run_id) - ) and self.load_weights_only + ) and self.config.training.load_weights_only self.start_from_checkpoint = bool(self.config.training.run_id) or bool(self.config.training.fork_run_id) - self.load_weights_only = config.training.load_weights_only + self.load_weights_only = self.config.training.load_weights_only self.parent_uuid = None self.config.training.run_id = self.run_id @@ -138,6 +139,8 @@ def graph_data(self) -> HeteroData: from anemoi.graphs.create import GraphCreator + LOGGER.info("Generating graph data from scratch..") + return GraphCreator(config=self.config.graph).create( save_path=graph_filename, overwrite=self.config.graph.overwrite, @@ -159,19 +162,16 @@ def model(self) -> GraphForecaster: if self.load_weights_only: # Sanify the checkpoint for transfer learning if self.config.training.transfer_learning: - save_path = Path( - self.config.hardware.paths.checkpoints.parent, - (self.fork_run_server2server or self.config.training.fork_run_id) or self.lineage_run, - ) LOGGER.info("Learning weights with Transfer Learning from %s", self.last_checkpoint) - return transfer_learning_loading(model, self.last_checkpoint) LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) return model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) - return model + else: + LOGGER.info("Model initialised from scratch.") + return model @rank_zero_only def _get_mlflow_run_id(self) -> str: @@ -315,17 +315,13 @@ def _log_information(self) -> None: LOGGER.debug("Total number of auxiliary variables: %d", len(self.config.data.forcing)) # Log learning rate multiplier when running single-node, multi-GPU and/or multi-node - total_number_of_model_instances = ( - self.config.hardware.num_nodes - * self.config.hardware.num_gpus_per_node - / self.config.hardware.num_gpus_per_model, - ) + total_number_of_model_instances = self.config.hardware.num_nodes * self.config.hardware.num_gpus_per_node / self.config.hardware.num_gpus_per_model LOGGER.debug( "Total GPU count / model group size: %d - NB: the learning rate will be scaled by this factor!", total_number_of_model_instances, ) - LOGGER.debug("Effective learning rate: %.3e", total_number_of_model_instances * self.config.training.lr.rate) + LOGGER.debug("Effective learning rate: %.3e", int(total_number_of_model_instances) * self.config.training.lr.rate) LOGGER.debug("Rollout window length: %d", self.config.training.rollout.start) if self.config.training.max_epochs is not None and self.config.training.max_steps not in (None, -1): From 7894cc032c9f169f478a46ed8b3f30cc00566f4b Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Fri, 29 Nov 2024 17:23:20 +0100 Subject: [PATCH 14/43] Transfer learning working: implemented checkpoint cleaning with large models --- src/anemoi/training/utils/checkpoint.py | 36 ++++++++++++++++--------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index 935cae4f..980cdffb 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -20,7 +20,7 @@ from anemoi.training.train.forecaster import GraphForecaster LOGGER = logging.getLogger(__name__) - +LOGGER.setLevel("DEBUG") def load_and_prepare_model(lightning_checkpoint_path: str) -> tuple[torch.nn.Module, dict]: """Load the lightning checkpoint and extract the pytorch model and its metadata. @@ -74,7 +74,25 @@ def save_inference_checkpoint(model: torch.nn.Module, metadata: dict, save_path: def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> nn.Module: # Load the checkpoint - checkpoint = torch.load(ckpt_path, map_location=model.device) + try: + checkpoint = torch.load(ckpt_path, map_location=model.device) + + # TODO: this is a patch for issue #57 + except RuntimeError: + LOGGER.debug("Need to remove metadata from the checkpoint file due to issue #57..") + import subprocess + file_to_delete = "archive/anemoi-metadata/ai-models.json" + # Construct and execute the command + command = ["zip", "-d", ckpt_path, file_to_delete] + result = subprocess.run(command, capture_output=True, text=True) + + # Check the result + if result.returncode == 0: + LOGGER.debug("File successfully removed from the zip archive.") + else: + LOGGER.debug("Error occurred: {}".format(result.stderr)) + + checkpoint = torch.load(ckpt_path, map_location=model.device) # Filter out layers with size mismatch state_dict = checkpoint["state_dict"] @@ -83,15 +101,9 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> for key in state_dict.copy(): if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape: - LOGGER.debug( - "Skipping loading parameter: ", - key, - ", checkpoint shape: ", - state_dict[key].shape, - ", model shape: ", - model_state_dict[key].shape, - ) + LOGGER.debug("Skipping loading parameter: {}, checkpoint shape: {}, model shape: {}".format(str(key), str(state_dict[key].shape), str(model_state_dict[key].shape))) del state_dict[key] # Remove the mismatched key - # Load the filtered state_dict into the model - return model.load_state_dict(state_dict, strict=False) + # Load the filtered st-ate_dict into the model + model.load_state_dict(state_dict, strict=False) + return model \ No newline at end of file From eff45396646bccd4e7a259215539337865d43d8a Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 3 Dec 2024 09:39:36 +0100 Subject: [PATCH 15/43] Reverted some changes concerning imputer issues --- src/anemoi/training/train/forecaster.py | 15 ++++++++++----- src/anemoi/training/train/train.py | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 92685854..35bace07 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -530,8 +530,8 @@ def calculate_val_metrics( """ metrics = {} # Added to impute nans - nan_locations = torch.isnan(y) - self.model.post_processors.processors["imputer"].set_nan_locations(nan_locations) + # nan_locations = torch.isnan(y) + # self.model.post_processors.processors["imputer"].set_nan_locations(nan_locations) y_postprocessed = self.model.post_processors(y, in_place=False) y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) @@ -547,10 +547,15 @@ def calculate_val_metrics( continue for mkey, indices in self.val_metric_ranges.items(): + # metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( + # y_pred_postprocessed[...,], + # y_postprocessed[...], + # scalar_indices=[...] if -1 in metric.scalar else None, + # ) metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( - y_pred_postprocessed[...,], - y_postprocessed[...], - scalar_indices=[...] if -1 in metric.scalar else None, + y_pred_postprocessed[..., indices], + y_postprocessed[..., indices], + scalar_indices=[..., indices] if -1 in metric.scalar else None, ) return metrics diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index b8f897ef..ec3f8a09 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -162,7 +162,7 @@ def model(self) -> GraphForecaster: if self.load_weights_only: # Sanify the checkpoint for transfer learning if self.config.training.transfer_learning: - LOGGER.info("Learning weights with Transfer Learning from %s", self.last_checkpoint) + LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint) return transfer_learning_loading(model, self.last_checkpoint) LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) From c1f854f111da7537573dfa320aad159422e7c63f Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 3 Dec 2024 09:39:51 +0100 Subject: [PATCH 16/43] Reverted some changes concerning imputer issues --- src/anemoi/training/train/forecaster.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 35bace07..a58fe327 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -529,9 +529,6 @@ def calculate_val_metrics( validation metrics and predictions """ metrics = {} - # Added to impute nans - # nan_locations = torch.isnan(y) - # self.model.post_processors.processors["imputer"].set_nan_locations(nan_locations) y_postprocessed = self.model.post_processors(y, in_place=False) y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) @@ -547,11 +544,6 @@ def calculate_val_metrics( continue for mkey, indices in self.val_metric_ranges.items(): - # metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( - # y_pred_postprocessed[...,], - # y_postprocessed[...], - # scalar_indices=[...] if -1 in metric.scalar else None, - # ) metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( y_pred_postprocessed[..., indices], y_postprocessed[..., indices], From 338387d934ee8bca0f5dcc4ba9484346b9000aea Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 3 Dec 2024 09:57:57 +0100 Subject: [PATCH 17/43] Cleaned code for final review --- src/anemoi/training/train/train.py | 22 +++++----- src/anemoi/training/utils/checkpoint.py | 53 ++++++++++++++++++------- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index ec3f8a09..157dd7cd 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -90,10 +90,8 @@ def datamodule(self) -> AnemoiDatasetsDataModule: """DataModule instance and DataSets.""" datamodule = AnemoiDatasetsDataModule(self.config) self.config.data.num_features = len(datamodule.ds_train.data.variables) - LOGGER.info( - "Data has {} variables: {}".format(len(datamodule.ds_train.data.variables), - datamodule.ds_train.data.variables), - ) + LOGGER.info("Number of data variables: ", len(datamodule.ds_train.data.variables)) + LOGGER.debug("Variables: ", datamodule.ds_train.data.variables) return datamodule @cached_property @@ -169,9 +167,8 @@ def model(self) -> GraphForecaster: return model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) - else: - LOGGER.info("Model initialised from scratch.") - return model + LOGGER.info("Model initialised from scratch.") + return model @rank_zero_only def _get_mlflow_run_id(self) -> str: @@ -315,13 +312,20 @@ def _log_information(self) -> None: LOGGER.debug("Total number of auxiliary variables: %d", len(self.config.data.forcing)) # Log learning rate multiplier when running single-node, multi-GPU and/or multi-node - total_number_of_model_instances = self.config.hardware.num_nodes * self.config.hardware.num_gpus_per_node / self.config.hardware.num_gpus_per_model + total_number_of_model_instances = ( + self.config.hardware.num_nodes + * self.config.hardware.num_gpus_per_node + / self.config.hardware.num_gpus_per_model + ) LOGGER.debug( "Total GPU count / model group size: %d - NB: the learning rate will be scaled by this factor!", total_number_of_model_instances, ) - LOGGER.debug("Effective learning rate: %.3e", int(total_number_of_model_instances) * self.config.training.lr.rate) + LOGGER.debug( + "Effective learning rate: %.3e", + int(total_number_of_model_instances) * self.config.training.lr.rate, + ) LOGGER.debug("Rollout window length: %d", self.config.training.rollout.start) if self.config.training.max_epochs is not None and self.config.training.max_steps not in (None, -1): diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index 980cdffb..9ed015b0 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -7,10 +7,10 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. - from __future__ import annotations import logging +import zipfile from pathlib import Path import torch @@ -20,7 +20,36 @@ from anemoi.training.train.forecaster import GraphForecaster LOGGER = logging.getLogger(__name__) -LOGGER.setLevel("DEBUG") + + +def remove_file_from_zip(zip_path: Path | str, file_to_remove: Path | str) -> None: + try: + temp_zip_path = f"{zip_path}.temp" + file_removed = False + + # Open the existing ZIP file and create a new one + with zipfile.ZipFile(zip_path, "r") as src_zip, zipfile.ZipFile(temp_zip_path, "w") as dest_zip: + for item in src_zip.infolist(): + if item.filename != file_to_remove: + dest_zip.writestr(item, src_zip.read(item.filename)) + else: + file_removed = True + + # Replace the old ZIP file with the new one + Path.replace(temp_zip_path, zip_path) + + # Check the result + if file_removed: + LOGGER.debug("File successfully removed from the zip archive.") + else: + LOGGER.debug("File not found in the zip archive.") + + except FileNotFoundError: + LOGGER.exception("Error occurred while modifying the zip archive.") + # Clean up the temporary file in case of an error + if Path.exists(temp_zip_path): + Path.unlink(temp_zip_path) + def load_and_prepare_model(lightning_checkpoint_path: str) -> tuple[torch.nn.Module, dict]: """Load the lightning checkpoint and extract the pytorch model and its metadata. @@ -77,20 +106,13 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> try: checkpoint = torch.load(ckpt_path, map_location=model.device) - # TODO: this is a patch for issue #57 + # TODO @icedoom888: this is a patch for issue #57 except RuntimeError: LOGGER.debug("Need to remove metadata from the checkpoint file due to issue #57..") - import subprocess + file_to_delete = "archive/anemoi-metadata/ai-models.json" # Construct and execute the command - command = ["zip", "-d", ckpt_path, file_to_delete] - result = subprocess.run(command, capture_output=True, text=True) - - # Check the result - if result.returncode == 0: - LOGGER.debug("File successfully removed from the zip archive.") - else: - LOGGER.debug("Error occurred: {}".format(result.stderr)) + remove_file_from_zip(ckpt_path, file_to_delete) checkpoint = torch.load(ckpt_path, map_location=model.device) @@ -101,9 +123,12 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> for key in state_dict.copy(): if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape: - LOGGER.debug("Skipping loading parameter: {}, checkpoint shape: {}, model shape: {}".format(str(key), str(state_dict[key].shape), str(model_state_dict[key].shape))) + LOGGER.debug("Skipping loading parameter: ", key) + LOGGER.debug("Checkpoint shape: ", state_dict[key].shape) + LOGGER.debug("Model shape: ", model_state_dict[key].shape) + del state_dict[key] # Remove the mismatched key # Load the filtered st-ate_dict into the model model.load_state_dict(state_dict, strict=False) - return model \ No newline at end of file + return model From f739bf466ff9ec319639c2da38d43fd0879265b9 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 3 Dec 2024 10:08:45 +0100 Subject: [PATCH 18/43] Changed changelog and assigned TODO correctly --- CHANGELOG.md | 9 +++++++++ src/anemoi/training/utils/checkpoint.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a949c1d5..4298d655 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,15 @@ Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.1...HEAD) +### Fixed +- Patched issue [#57] to load checkpoints of large models. + +### Added +- Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting. +- TRANSFER LEARNING: enabled new functionality. You can now load checkpoints from different models and different training runs. +- Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`. + Used for experiment reproducibility across different computing configurations. + ## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28 diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index 9ed015b0..656dc221 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -106,7 +106,7 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> try: checkpoint = torch.load(ckpt_path, map_location=model.device) - # TODO @icedoom888: this is a patch for issue #57 + # TODO @anaprietonem: this is a patch for issue #57 except RuntimeError: LOGGER.debug("Need to remove metadata from the checkpoint file due to issue #57..") From 7fd9a92a82a575b24a0bb07b74c4104dc79d8352 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 3 Dec 2024 10:11:07 +0100 Subject: [PATCH 19/43] Changed changelog and assigned TODO correctly --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4298d655..55997f24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,12 +10,12 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.1...HEAD) ### Fixed -- Patched issue [#57] to load checkpoints of large models. +- Patched issue [#57] to load checkpoints of large models. ### Added - Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting. - TRANSFER LEARNING: enabled new functionality. You can now load checkpoints from different models and different training runs. -- Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`. +- Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`. Used for experiment reproducibility across different computing configurations. From 1ac34d8a19176249c9def6c40f7271639b686c78 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 3 Dec 2024 13:13:44 +0100 Subject: [PATCH 20/43] Addressed review: copy checkpoint before removing metadata file --- src/anemoi/training/train/train.py | 4 +-- src/anemoi/training/utils/checkpoint.py | 39 +++++++++++++++---------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 157dd7cd..70542433 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -90,8 +90,8 @@ def datamodule(self) -> AnemoiDatasetsDataModule: """DataModule instance and DataSets.""" datamodule = AnemoiDatasetsDataModule(self.config) self.config.data.num_features = len(datamodule.ds_train.data.variables) - LOGGER.info("Number of data variables: ", len(datamodule.ds_train.data.variables)) - LOGGER.debug("Variables: ", datamodule.ds_train.data.variables) + LOGGER.info("Number of data variables: " + str(len(datamodule.ds_train.data.variables))) + LOGGER.debug("Variables: " + str(datamodule.ds_train.data.variables)) return datamodule @cached_property diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index 656dc221..046c3fba 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -22,33 +22,43 @@ LOGGER = logging.getLogger(__name__) -def remove_file_from_zip(zip_path: Path | str, file_to_remove: Path | str) -> None: +def create_new_zip_path(zip_path: Path | str) -> Path: + # Convert the path to a Path object + zip_path = Path(zip_path) + + # Add '_patched' before the file extension + new_zip_path = zip_path.stem + "_patched" + zip_path.suffix + + # Create the new path within the same directory + return zip_path.with_name(new_zip_path) + +def remove_file_from_zip( + zip_path: Path | str, + file_to_remove: Path | str, +) -> Path | str: + + new_zip_path = create_new_zip_path(zip_path) try: - temp_zip_path = f"{zip_path}.temp" file_removed = False # Open the existing ZIP file and create a new one - with zipfile.ZipFile(zip_path, "r") as src_zip, zipfile.ZipFile(temp_zip_path, "w") as dest_zip: + with zipfile.ZipFile(zip_path, "r") as src_zip, zipfile.ZipFile(new_zip_path, "w") as dest_zip: for item in src_zip.infolist(): if item.filename != file_to_remove: dest_zip.writestr(item, src_zip.read(item.filename)) else: file_removed = True - # Replace the old ZIP file with the new one - Path.replace(temp_zip_path, zip_path) - # Check the result if file_removed: - LOGGER.debug("File successfully removed from the zip archive.") + LOGGER.debug(f"File successfully removed from the zip archive and saved as {new_zip_path}.") else: - LOGGER.debug("File not found in the zip archive.") + LOGGER.debug(f"File not found in the zip archive. The new zip file is identical to the original.") except FileNotFoundError: LOGGER.exception("Error occurred while modifying the zip archive.") - # Clean up the temporary file in case of an error - if Path.exists(temp_zip_path): - Path.unlink(temp_zip_path) + + return new_zip_path def load_and_prepare_model(lightning_checkpoint_path: str) -> tuple[torch.nn.Module, dict]: @@ -111,10 +121,9 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> LOGGER.debug("Need to remove metadata from the checkpoint file due to issue #57..") file_to_delete = "archive/anemoi-metadata/ai-models.json" - # Construct and execute the command - remove_file_from_zip(ckpt_path, file_to_delete) - - checkpoint = torch.load(ckpt_path, map_location=model.device) + # Creates copy of checkpoint and removed the metadata from the copy + new_ckpt_path = remove_file_from_zip(ckpt_path, file_to_delete) + checkpoint = torch.load(new_ckpt_path, map_location=model.device) # Filter out layers with size mismatch state_dict = checkpoint["state_dict"] From 0d4fa5119a92d17be4bd5044a6dde845e25f489b Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 3 Dec 2024 14:29:11 +0100 Subject: [PATCH 21/43] gpc passed --- src/anemoi/training/train/train.py | 4 ++-- src/anemoi/training/utils/checkpoint.py | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 70542433..c197bba2 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -90,8 +90,8 @@ def datamodule(self) -> AnemoiDatasetsDataModule: """DataModule instance and DataSets.""" datamodule = AnemoiDatasetsDataModule(self.config) self.config.data.num_features = len(datamodule.ds_train.data.variables) - LOGGER.info("Number of data variables: " + str(len(datamodule.ds_train.data.variables))) - LOGGER.debug("Variables: " + str(datamodule.ds_train.data.variables)) + LOGGER.info("Number of data variables: %s", str(len(datamodule.ds_train.data.variables))) + LOGGER.debug("Variables: %s", str(datamodule.ds_train.data.variables)) return datamodule @cached_property diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index 046c3fba..065f8896 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -32,11 +32,12 @@ def create_new_zip_path(zip_path: Path | str) -> Path: # Create the new path within the same directory return zip_path.with_name(new_zip_path) + def remove_file_from_zip( - zip_path: Path | str, - file_to_remove: Path | str, + zip_path: Path | str, + file_to_remove: Path | str, ) -> Path | str: - + new_zip_path = create_new_zip_path(zip_path) try: file_removed = False @@ -51,9 +52,9 @@ def remove_file_from_zip( # Check the result if file_removed: - LOGGER.debug(f"File successfully removed from the zip archive and saved as {new_zip_path}.") + LOGGER.debug("File successfully removed from the zip archive and saved as %s.", new_zip_path) else: - LOGGER.debug(f"File not found in the zip archive. The new zip file is identical to the original.") + LOGGER.debug("File not found in the zip archive. The new zip file is identical to the original.") except FileNotFoundError: LOGGER.exception("Error occurred while modifying the zip archive.") From 32658929e0af524d51721200dc680e2d01205efd Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Wed, 4 Dec 2024 17:10:20 +0100 Subject: [PATCH 22/43] Removed logger in debugging mode --- .../config/model/graphtransformer.yaml | 38 +++++++++---------- src/anemoi/training/train/train.py | 1 - 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/anemoi/training/config/model/graphtransformer.yaml b/src/anemoi/training/config/model/graphtransformer.yaml index 9c48967b..2c008ff7 100644 --- a/src/anemoi/training/config/model/graphtransformer.yaml +++ b/src/anemoi/training/config/model/graphtransformer.yaml @@ -50,25 +50,25 @@ attributes: - edge_dirs nodes: [] -# Bounding configuration -bounding: #These are applied in order +# # Bounding configuration +# bounding: #These are applied in order - # Bound tp (total precipitation) with a Relu bounding layer - # ensuring a range of [0, infinity) to avoid negative precipitation values. - - _target_: anemoi.models.layers.bounding.ReluBounding #[0, infinity) - variables: - - tp +# # Bound tp (total precipitation) with a Relu bounding layer +# # ensuring a range of [0, infinity) to avoid negative precipitation values. +# - _target_: anemoi.models.layers.bounding.ReluBounding #[0, infinity) +# variables: +# - tp - # [OPTIONAL] Bound cp (convective precipitation) as a fraction of tp. - # This guarantees that cp is physically consistent with tp by restricting cp - # to a fraction of tp [0 to 1]. Uncomment the lines below to apply. - # NOTE: If this bounding strategy is used, the normalization of cp must be - # changed to "std" normalization, and the "cp" statistics should be remapped - # to those of tp to ensure consistency. +# # [OPTIONAL] Bound cp (convective precipitation) as a fraction of tp. +# # This guarantees that cp is physically consistent with tp by restricting cp +# # to a fraction of tp [0 to 1]. Uncomment the lines below to apply. +# # NOTE: If this bounding strategy is used, the normalization of cp must be +# # changed to "std" normalization, and the "cp" statistics should be remapped +# # to those of tp to ensure consistency. - # - _target_: anemoi.models.layers.bounding.FractionBounding # fraction of tp - # variables: - # - cp - # min_val: 0 - # max_val: 1 - # total_var: tp +# # - _target_: anemoi.models.layers.bounding.FractionBounding # fraction of tp +# # variables: +# # - cp +# # min_val: 0 +# # max_val: 1 +# # total_var: tp diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index c197bba2..5cb74b46 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -42,7 +42,6 @@ from torch_geometric.data import HeteroData LOGGER = logging.getLogger(__name__) -LOGGER.setLevel(logging.DEBUG) # Change DEBUG to INFO, WARNING, etc., as needed class AnemoiTrainer: From c325a9ef0ddb067010fae77ec68d69e4e53eef4c Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 5 Dec 2024 11:29:49 +0100 Subject: [PATCH 23/43] removed dataset lenght due to checkpointing issues --- src/anemoi/training/data/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 153ac7cf..b550fe5d 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -293,9 +293,9 @@ def __repr__(self) -> str: Timeincrement: {self.timeincrement} """ - def __len__(self) -> int: - """Estimate the total number of samples based on valid indices.""" - return len(self.valid_date_indices) // self.effective_bs + # def __len__(self) -> int: + # """Estimate the total number of samples based on valid indices.""" + # return len(self.valid_date_indices) // self.effective_bs def worker_init_func(worker_id: int) -> None: From 4709d46e22615c6d752367c13dac43a165d4d8ec Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 5 Dec 2024 11:30:35 +0100 Subject: [PATCH 24/43] Reintroduced correct config on graphtansformer --- .../config/model/graphtransformer.yaml | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/anemoi/training/config/model/graphtransformer.yaml b/src/anemoi/training/config/model/graphtransformer.yaml index 2c008ff7..9c48967b 100644 --- a/src/anemoi/training/config/model/graphtransformer.yaml +++ b/src/anemoi/training/config/model/graphtransformer.yaml @@ -50,25 +50,25 @@ attributes: - edge_dirs nodes: [] -# # Bounding configuration -# bounding: #These are applied in order +# Bounding configuration +bounding: #These are applied in order -# # Bound tp (total precipitation) with a Relu bounding layer -# # ensuring a range of [0, infinity) to avoid negative precipitation values. -# - _target_: anemoi.models.layers.bounding.ReluBounding #[0, infinity) -# variables: -# - tp + # Bound tp (total precipitation) with a Relu bounding layer + # ensuring a range of [0, infinity) to avoid negative precipitation values. + - _target_: anemoi.models.layers.bounding.ReluBounding #[0, infinity) + variables: + - tp -# # [OPTIONAL] Bound cp (convective precipitation) as a fraction of tp. -# # This guarantees that cp is physically consistent with tp by restricting cp -# # to a fraction of tp [0 to 1]. Uncomment the lines below to apply. -# # NOTE: If this bounding strategy is used, the normalization of cp must be -# # changed to "std" normalization, and the "cp" statistics should be remapped -# # to those of tp to ensure consistency. + # [OPTIONAL] Bound cp (convective precipitation) as a fraction of tp. + # This guarantees that cp is physically consistent with tp by restricting cp + # to a fraction of tp [0 to 1]. Uncomment the lines below to apply. + # NOTE: If this bounding strategy is used, the normalization of cp must be + # changed to "std" normalization, and the "cp" statistics should be remapped + # to those of tp to ensure consistency. -# # - _target_: anemoi.models.layers.bounding.FractionBounding # fraction of tp -# # variables: -# # - cp -# # min_val: 0 -# # max_val: 1 -# # total_var: tp + # - _target_: anemoi.models.layers.bounding.FractionBounding # fraction of tp + # variables: + # - cp + # min_val: 0 + # max_val: 1 + # total_var: tp From b0023f934d48b4c9d4203c935fbaad955effa4c2 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 5 Dec 2024 11:32:57 +0100 Subject: [PATCH 25/43] gpc passed --- src/anemoi/training/data/dataset.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index b550fe5d..69aa154c 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -293,10 +293,6 @@ def __repr__(self) -> str: Timeincrement: {self.timeincrement} """ - # def __len__(self) -> int: - # """Estimate the total number of samples based on valid indices.""" - # return len(self.valid_date_indices) // self.effective_bs - def worker_init_func(worker_id: int) -> None: """Configures each dataset worker process. From 6a8ac97e0fd22de52c91077b08ad6fa989616dd3 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 5 Dec 2024 16:16:40 +0100 Subject: [PATCH 26/43] Removed patched for issue #57, code expects patched checkpoint already --- src/anemoi/training/utils/checkpoint.py | 57 +++++-------------------- 1 file changed, 11 insertions(+), 46 deletions(-) diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index 065f8896..f1fa70f7 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -10,7 +10,6 @@ from __future__ import annotations import logging -import zipfile from pathlib import Path import torch @@ -22,46 +21,17 @@ LOGGER = logging.getLogger(__name__) -def create_new_zip_path(zip_path: Path | str) -> Path: +def create_new_zip_path(zip_path: Path | str, patch_str: str = "patched") -> Path: # Convert the path to a Path object zip_path = Path(zip_path) # Add '_patched' before the file extension - new_zip_path = zip_path.stem + "_patched" + zip_path.suffix + new_zip_path = zip_path.stem + "_" + patch_str + zip_path.suffix # Create the new path within the same directory return zip_path.with_name(new_zip_path) -def remove_file_from_zip( - zip_path: Path | str, - file_to_remove: Path | str, -) -> Path | str: - - new_zip_path = create_new_zip_path(zip_path) - try: - file_removed = False - - # Open the existing ZIP file and create a new one - with zipfile.ZipFile(zip_path, "r") as src_zip, zipfile.ZipFile(new_zip_path, "w") as dest_zip: - for item in src_zip.infolist(): - if item.filename != file_to_remove: - dest_zip.writestr(item, src_zip.read(item.filename)) - else: - file_removed = True - - # Check the result - if file_removed: - LOGGER.debug("File successfully removed from the zip archive and saved as %s.", new_zip_path) - else: - LOGGER.debug("File not found in the zip archive. The new zip file is identical to the original.") - - except FileNotFoundError: - LOGGER.exception("Error occurred while modifying the zip archive.") - - return new_zip_path - - def load_and_prepare_model(lightning_checkpoint_path: str) -> tuple[torch.nn.Module, dict]: """Load the lightning checkpoint and extract the pytorch model and its metadata. @@ -113,18 +83,13 @@ def save_inference_checkpoint(model: torch.nn.Module, metadata: dict, save_path: def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> nn.Module: - # Load the checkpoint - try: - checkpoint = torch.load(ckpt_path, map_location=model.device) + # Related to issue #57 + patched_ckpt_file = create_new_zip_path(ckpt_path, patch_str="patched") + if Path(patched_ckpt_file).exists(): + ckpt_path = patched_ckpt_file - # TODO @anaprietonem: this is a patch for issue #57 - except RuntimeError: - LOGGER.debug("Need to remove metadata from the checkpoint file due to issue #57..") - - file_to_delete = "archive/anemoi-metadata/ai-models.json" - # Creates copy of checkpoint and removed the metadata from the copy - new_ckpt_path = remove_file_from_zip(ckpt_path, file_to_delete) - checkpoint = torch.load(new_ckpt_path, map_location=model.device) + # Load the checkpoint + checkpoint = torch.load(ckpt_path, map_location=model.device) # Filter out layers with size mismatch state_dict = checkpoint["state_dict"] @@ -133,9 +98,9 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> for key in state_dict.copy(): if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape: - LOGGER.debug("Skipping loading parameter: ", key) - LOGGER.debug("Checkpoint shape: ", state_dict[key].shape) - LOGGER.debug("Model shape: ", model_state_dict[key].shape) + LOGGER.debug("Skipping loading parameter: %s", key) + LOGGER.debug("Checkpoint shape: %s", str(state_dict[key].shape)) + LOGGER.debug("Model shape: %s", str(model_state_dict[key].shape)) del state_dict[key] # Remove the mismatched key From 355cca1cff89703fd9ef93dd7529213bfcc2feea Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 5 Dec 2024 17:12:05 +0100 Subject: [PATCH 27/43] Removed new path name for patched checkpoint (ignoring fully issue #57) + removed fix for missing config --- src/anemoi/training/train/train.py | 6 ------ src/anemoi/training/utils/checkpoint.py | 16 ---------------- 2 files changed, 22 deletions(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 5cb74b46..6812dd99 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -63,12 +63,6 @@ def __init__(self, config: DictConfig) -> None: OmegaConf.resolve(config) self.config = config - # Set Transfer Learning based on the other if not provided - if self.config.training.transfer_learning is None: - self.config.training.transfer_learning = ( - bool(self.config.training.run_id) or bool(self.config.training.fork_run_id) - ) and self.config.training.load_weights_only - self.start_from_checkpoint = bool(self.config.training.run_id) or bool(self.config.training.fork_run_id) self.load_weights_only = self.config.training.load_weights_only self.parent_uuid = None diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index f1fa70f7..a87ac146 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -21,17 +21,6 @@ LOGGER = logging.getLogger(__name__) -def create_new_zip_path(zip_path: Path | str, patch_str: str = "patched") -> Path: - # Convert the path to a Path object - zip_path = Path(zip_path) - - # Add '_patched' before the file extension - new_zip_path = zip_path.stem + "_" + patch_str + zip_path.suffix - - # Create the new path within the same directory - return zip_path.with_name(new_zip_path) - - def load_and_prepare_model(lightning_checkpoint_path: str) -> tuple[torch.nn.Module, dict]: """Load the lightning checkpoint and extract the pytorch model and its metadata. @@ -83,11 +72,6 @@ def save_inference_checkpoint(model: torch.nn.Module, metadata: dict, save_path: def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> nn.Module: - # Related to issue #57 - patched_ckpt_file = create_new_zip_path(ckpt_path, patch_str="patched") - if Path(patched_ckpt_file).exists(): - ckpt_path = patched_ckpt_file - # Load the checkpoint checkpoint = torch.load(ckpt_path, map_location=model.device) From b875ea0dceae4fbbefa7bc7c237b5033fc5df26b Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 5 Dec 2024 17:13:19 +0100 Subject: [PATCH 28/43] Adapted changelog --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 31649f41..e9c0997a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,6 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.1...HEAD) ### Fixed -- Patched issue [#57] to load checkpoints of large models. - Not update NaN-weight-mask for loss function when using remapper and no imputer [#178](https://github.com/ecmwf/anemoi-training/pull/178) - Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180) From b9b611b9ee8a09d4f752dc2216af54dbe588edaf Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 5 Dec 2024 18:49:51 +0100 Subject: [PATCH 29/43] Added Freezing functionality --- .../training/config/training/default.yaml | 2 ++ src/anemoi/training/train/train.py | 10 ++++++++++ src/anemoi/training/utils/checkpoint.py | 18 ++++++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index c397ff75..f9636b98 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -124,3 +124,5 @@ node_loss_weights: _target_: anemoi.training.losses.nodeweights.GraphNodeAttribute target_nodes: ${graph.data} node_attribute: area_weight + +submodules_to_freeze: [] \ No newline at end of file diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 6812dd99..3280877c 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -35,6 +35,7 @@ from anemoi.training.distributed.strategy import DDPGroupStrategy from anemoi.training.train.forecaster import GraphForecaster from anemoi.training.utils.checkpoint import transfer_learning_loading +from anemoi.training.utils.checkpoint import freeze_submodule_by_name from anemoi.training.utils.jsonify import map_config_to_primitives from anemoi.training.utils.seeding import get_base_seed @@ -150,6 +151,7 @@ def model(self) -> GraphForecaster: model = GraphForecaster(**kwargs) + # Load the model weights if self.load_weights_only: # Sanify the checkpoint for transfer learning if self.config.training.transfer_learning: @@ -161,6 +163,14 @@ def model(self) -> GraphForecaster: return model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) LOGGER.info("Model initialised from scratch.") + + # Freeze the chosen model weights + if self.config.training.submodules_2_freeze: + LOGGER.info("The following submodules will NOT be trained: %s", self.config.training.submodules_2_freeze) + for submodule_name in self.config.training.submodules_to_freeze: + freeze_submodule_by_name(model, submodule_name) + LOGGER.info("%s Frozen successfully.", submodule_name) + return model @rank_zero_only diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index a87ac146..ddb74285 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -91,3 +91,21 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> # Load the filtered st-ate_dict into the model model.load_state_dict(state_dict, strict=False) return model + + +def freeze_submodule_by_name(module: nn.Module, target_name:str) -> None: + """ + Recursively freezes the parameters of a submodule with the specified name. + + Args: + module (nn.Module): The parent module to search in. + target_name (str): The name of the submodule to freeze. + """ + for name, child in module.named_children(): + # If this is the target submodule, freeze its parameters + if name == target_name: + for param in child.parameters(): + param.requires_grad = False + else: + # Recursively search within children + freeze_submodule_by_name(child, target_name) \ No newline at end of file From 0f0dff0a13f8f3b8c0f95c2f2b153b04cd682f21 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 5 Dec 2024 18:50:24 +0100 Subject: [PATCH 30/43] Added Freezing functionality --- src/anemoi/training/config/training/default.yaml | 2 +- src/anemoi/training/train/train.py | 2 +- src/anemoi/training/utils/checkpoint.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index f9636b98..b62ebce2 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -125,4 +125,4 @@ node_loss_weights: target_nodes: ${graph.data} node_attribute: area_weight -submodules_to_freeze: [] \ No newline at end of file +submodules_to_freeze: [] diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 3280877c..0d24be3c 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -34,8 +34,8 @@ from anemoi.training.diagnostics.logger import get_wandb_logger from anemoi.training.distributed.strategy import DDPGroupStrategy from anemoi.training.train.forecaster import GraphForecaster -from anemoi.training.utils.checkpoint import transfer_learning_loading from anemoi.training.utils.checkpoint import freeze_submodule_by_name +from anemoi.training.utils.checkpoint import transfer_learning_loading from anemoi.training.utils.jsonify import map_config_to_primitives from anemoi.training.utils.seeding import get_base_seed diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index ddb74285..b2cfd6b0 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -93,7 +93,7 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> return model -def freeze_submodule_by_name(module: nn.Module, target_name:str) -> None: +def freeze_submodule_by_name(module: nn.Module, target_name: str) -> None: """ Recursively freezes the parameters of a submodule with the specified name. @@ -108,4 +108,4 @@ def freeze_submodule_by_name(module: nn.Module, target_name:str) -> None: param.requires_grad = False else: # Recursively search within children - freeze_submodule_by_name(child, target_name) \ No newline at end of file + freeze_submodule_by_name(child, target_name) From 03c4adb7319be3e2036afdc4b5030b1e65eb47c0 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Fri, 6 Dec 2024 09:47:44 +0100 Subject: [PATCH 31/43] =?UTF-8?q?Tested=20=E2=9C=85=20waiting=20for=20tran?= =?UTF-8?q?sfer=20learning=20merge=20to=20happen?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/anemoi/training/train/train.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 0d24be3c..c1bb894f 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -156,20 +156,19 @@ def model(self) -> GraphForecaster: # Sanify the checkpoint for transfer learning if self.config.training.transfer_learning: LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint) - return transfer_learning_loading(model, self.last_checkpoint) + model = transfer_learning_loading(model, self.last_checkpoint) + else: + LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) + model = model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) - LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) - - return model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) - - LOGGER.info("Model initialised from scratch.") + else: + LOGGER.info("Model initialised from scratch.") # Freeze the chosen model weights - if self.config.training.submodules_2_freeze: - LOGGER.info("The following submodules will NOT be trained: %s", self.config.training.submodules_2_freeze) - for submodule_name in self.config.training.submodules_to_freeze: - freeze_submodule_by_name(model, submodule_name) - LOGGER.info("%s Frozen successfully.", submodule_name) + LOGGER.info("The following submodules will NOT be trained: %s", self.config.training.submodules_to_freeze) + for submodule_name in self.config.training.submodules_to_freeze: + freeze_submodule_by_name(model, submodule_name) + LOGGER.info("%s Frozen successfully.", submodule_name) return model From 7d51c75a4d1388e97c11e76de32b3c2861fda1a2 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Fri, 6 Dec 2024 10:13:32 +0100 Subject: [PATCH 32/43] Switched logging to info from debug --- src/anemoi/training/utils/checkpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index a87ac146..a78ef524 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -82,9 +82,9 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> for key in state_dict.copy(): if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape: - LOGGER.debug("Skipping loading parameter: %s", key) - LOGGER.debug("Checkpoint shape: %s", str(state_dict[key].shape)) - LOGGER.debug("Model shape: %s", str(model_state_dict[key].shape)) + LOGGER.info("Skipping loading parameter: %s", key) + LOGGER.info("Checkpoint shape: %s", str(state_dict[key].shape)) + LOGGER.info("Model shape: %s", str(model_state_dict[key].shape)) del state_dict[key] # Remove the mismatched key From 8c7d54c18460196dd747527e2db2ccdc84415e85 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Fri, 6 Dec 2024 11:16:31 +0100 Subject: [PATCH 33/43] GPC passed --- src/anemoi/training/train/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 37e34c61..42e8c5f6 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -34,8 +34,10 @@ from anemoi.training.diagnostics.logger import get_wandb_logger from anemoi.training.distributed.strategy import DDPGroupStrategy from anemoi.training.train.forecaster import GraphForecaster + <<<<<<< HEAD from anemoi.training.utils.checkpoint import freeze_submodule_by_name + ======= >>>>>>> develop from anemoi.training.utils.checkpoint import transfer_learning_loading From 4bce6f1a61e5afb843ec51669a057df6af8b5812 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Fri, 6 Dec 2024 11:23:05 +0100 Subject: [PATCH 34/43] Changelog updated --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9c0997a..ce46249f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,11 +14,14 @@ Keep it human-readable, your future self will thank you! - Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180) ### Added + +- Transfer Learning: enabled new functionality. You can now load checkpoints from different models and different training runs. - Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting. -- TRANSFER LEARNING: enabled new functionality. You can now load checkpoints from different models and different training runs. - Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`. Used for experiment reproducibility across different computing configurations. - Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) +- Model Freezing ❄️: enabled new functionality. You can now Freeze parts of your model by specifying a list of submodules to freeze with the new config parameter: submodules_to_freeze. +- Introduce new variable to configure: submodules_to_freeze -> List[str], list of submodules to freeze. ### Changed From bd320962e42446825d16e46db676e4f278f4a543 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Fri, 6 Dec 2024 11:24:34 +0100 Subject: [PATCH 35/43] Completed Merge and code check --- src/anemoi/training/train/train.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 42e8c5f6..c1bb894f 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -34,12 +34,7 @@ from anemoi.training.diagnostics.logger import get_wandb_logger from anemoi.training.distributed.strategy import DDPGroupStrategy from anemoi.training.train.forecaster import GraphForecaster - -<<<<<<< HEAD from anemoi.training.utils.checkpoint import freeze_submodule_by_name - -======= ->>>>>>> develop from anemoi.training.utils.checkpoint import transfer_learning_loading from anemoi.training.utils.jsonify import map_config_to_primitives from anemoi.training.utils.seeding import get_base_seed From 6aac5482ed455e6a87404cba9e3ef342996b7808 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Wed, 11 Dec 2024 14:53:22 +0100 Subject: [PATCH 36/43] gpc --- src/anemoi/training/train/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 3a2e430f..6d18e2bd 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -164,7 +164,7 @@ def model(self) -> GraphForecaster: else: LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) model = model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) - + else: LOGGER.info("Model initialised from scratch.") From 8478689c12d13d1b0d0973801238106335593db5 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 17 Dec 2024 12:12:14 +0100 Subject: [PATCH 37/43] Changed docstring and pytorch lightnening freeze --- src/anemoi/training/utils/checkpoint.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index 21b4fdae..97b3946f 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -97,15 +97,17 @@ def freeze_submodule_by_name(module: nn.Module, target_name: str) -> None: """ Recursively freezes the parameters of a submodule with the specified name. - Args: - module (nn.Module): The parent module to search in. - target_name (str): The name of the submodule to freeze. + Parameters + ---------- + model : torch.nn.Module + Pytorch model + target_name : str + The name of the submodule to freeze. """ for name, child in module.named_children(): # If this is the target submodule, freeze its parameters if name == target_name: - for param in child.parameters(): - param.requires_grad = False + child.freeze() else: # Recursively search within children freeze_submodule_by_name(child, target_name) From 2eb214021f44aef951fc5a647f5b4fabd5dd1db7 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 17 Dec 2024 14:37:46 +0100 Subject: [PATCH 38/43] Addressed review --- src/anemoi/training/train/train.py | 3 --- src/anemoi/training/utils/checkpoint.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index ca9fd81f..d3b92eb1 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -166,9 +166,6 @@ def model(self) -> GraphForecaster: LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) model = model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) - else: - LOGGER.info("Model initialised from scratch.") - # Freeze the chosen model weights LOGGER.info("The following submodules will NOT be trained: %s", self.config.training.submodules_to_freeze) for submodule_name in self.config.training.submodules_to_freeze: diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index 97b3946f..76f19679 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -107,7 +107,7 @@ def freeze_submodule_by_name(module: nn.Module, target_name: str) -> None: for name, child in module.named_children(): # If this is the target submodule, freeze its parameters if name == target_name: - child.freeze() + module.freeze(child) else: # Recursively search within children freeze_submodule_by_name(child, target_name) From 742a7a8c3eaf4a3968e2f49da3f3e311eafd6884 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 17 Dec 2024 15:48:46 +0100 Subject: [PATCH 39/43] Changes for review --- src/anemoi/training/train/train.py | 20 +++++++++++--------- src/anemoi/training/utils/checkpoint.py | 5 +++-- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index d3b92eb1..f6e7f895 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -158,19 +158,21 @@ def model(self) -> GraphForecaster: # Load the model weights if self.load_weights_only: - # Sanify the checkpoint for transfer learning - if self.config.training.transfer_learning: - LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint) - model = transfer_learning_loading(model, self.last_checkpoint) + if hasattr(self.config.training, "transfer_learning"): + # Sanify the checkpoint for transfer learning + if self.config.training.transfer_learning: + LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint) + model = transfer_learning_loading(model, self.last_checkpoint) else: LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) model = model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) - # Freeze the chosen model weights - LOGGER.info("The following submodules will NOT be trained: %s", self.config.training.submodules_to_freeze) - for submodule_name in self.config.training.submodules_to_freeze: - freeze_submodule_by_name(model, submodule_name) - LOGGER.info("%s Frozen successfully.", submodule_name) + if hasattr(self.config.training, "submodules_to_freeze"): + # Freeze the chosen model weights + LOGGER.info("The following submodules will NOT be trained: %s", self.config.training.submodules_to_freeze) + for submodule_name in self.config.training.submodules_to_freeze: + freeze_submodule_by_name(model, submodule_name) + LOGGER.info("%s frozen successfully.", submodule_name.upper()) return model diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index 76f19679..e123e355 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -99,7 +99,7 @@ def freeze_submodule_by_name(module: nn.Module, target_name: str) -> None: Parameters ---------- - model : torch.nn.Module + module : torch.nn.Module Pytorch model target_name : str The name of the submodule to freeze. @@ -107,7 +107,8 @@ def freeze_submodule_by_name(module: nn.Module, target_name: str) -> None: for name, child in module.named_children(): # If this is the target submodule, freeze its parameters if name == target_name: - module.freeze(child) + for param in child.parameters(): + param.requires_grad = False else: # Recursively search within children freeze_submodule_by_name(child, target_name) From 0b8a40739435e4da6fa80d8e696600ac059ee3c5 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Tue, 17 Dec 2024 15:51:07 +0100 Subject: [PATCH 40/43] Refactor CHANGELOG --- CHANGELOG.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 922734fc..ec70c50a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,15 +17,14 @@ Keep it human-readable, your future self will thank you! - Identify stretched grid models based on graph rather than configuration file [#204](https://github.com/ecmwf/anemoi-training/pull/204) ### Added - -- Transfer Learning: enabled new functionality. You can now load checkpoints from different models and different training runs. -- Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting. +- Introduce (optional) variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting. +- TRANSFER LEARNING: enabled new functionality. You can now load checkpoints from different models and different training runs. - Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`. Used for experiment reproducibility across different computing configurations. - Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) -- Model Freezing ❄️: enabled new functionality. You can now Freeze parts of your model by specifying a list of submodules to freeze with the new config parameter: submodules_to_freeze. -- Introduce new variable to configure: submodules_to_freeze -> List[str], list of submodules to freeze. - Added new metrics for stretched grid models to track losses inside/outside the regional domain [#199](https://github.com/ecmwf/anemoi-training/pull/199) +- Model Freezing ❄️: enabled new functionality. You can now Freeze parts of your model by specifying a list of submodules to freeze with the new config parameter: submodules_to_freeze. +- Introduce (optional) variable to configure: submodules_to_freeze -> List[str], list of submodules to freeze. ### Changed From 8797fb3cc092378b1907e47cd9a8dc4b003c3a5f Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Wed, 18 Dec 2024 19:28:02 +0100 Subject: [PATCH 41/43] Rebased on develop --- src/anemoi/training/train/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index f6e7f895..d92a9cc3 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -165,7 +165,7 @@ def model(self) -> GraphForecaster: model = transfer_learning_loading(model, self.last_checkpoint) else: LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) - model = model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) + model = GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) if hasattr(self.config.training, "submodules_to_freeze"): # Freeze the chosen model weights From 7705a7efec330737b4bdb11a4801541e2e451aaf Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Wed, 18 Dec 2024 19:40:00 +0100 Subject: [PATCH 42/43] Added documentation --- docs/user-guide/training.rst | 62 ++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/docs/user-guide/training.rst b/docs/user-guide/training.rst index e90b1583..6eda7f01 100644 --- a/docs/user-guide/training.rst +++ b/docs/user-guide/training.rst @@ -280,3 +280,65 @@ finished training. However if the user wants to restart the model from a specific point they can do this by setting ``config.hardware.files.warm_start`` to be the checkpoint they want to restart from.. + +******************* + Transfer Learning +******************* + +Transfer learning allows the model to reuse knowledge from a previously +trained checkpoint. This is particularly useful when the new task is +related to the old one, enabling faster convergence and often improving +model performance. + +To enable transfer learning, set the config.training.transfer_learning +flag to True in the configuration file. + +.. code:: yaml + + training: + # start the training from a checkpoint of a previous run + fork_run_id: '51a97d40a49e48d284494a3b5d87ef2b' + load_weights_only: True + transfer_learning: True + +When this flag is active and a checkpoint path is specified in +config.hardware.files.warm_start or self.last_checkpoint, the system +loads the pre-trained weights using the `transfer_learning_loading` +function. This approach ensures only compatible weights are loaded and +mismatched layers are handled appropriately. + +For example, transfer learning might be used to adapt a weather +forecasting model trained on one geographic region to another region +with similar characteristics. + +**************** + Model Freezing +**************** + +Model freezing is a technique where specific parts (submodules) of a +model are excluded from training. This is useful when certain parts of +the model have been sufficiently trained or should remain unchanged for +the current task. + +To specify which submodules to freeze, use the +config.training.submodules_to_freeze field in the configuration. List +the names of submodules to be frozen. During model initialization, these +submodules will have their parameters frozen, ensuring they are not +updated during training. + +For example with the following configuration, the processor will be +frozen and only the encoder and decoder will be trained: + +.. code:: yaml + + training: + # start the training from a checkpoint of a previous run + fork_run_id: '51a97d40a49e48d284494a3b5d87ef2b' + load_weights_only: True + + submodules_to_freeze: + - processor + +Freezing can be particularly beneficial in scenarios such as fine-tuning +when only specific components (e.g., the encoder, the decoder) need to +adapt to a new task while keeping others (e.g., the processor) fixed. From 463bec447aaa6daaf38e3b9c6dc041e168ee7ea1 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Wed, 18 Dec 2024 19:40:36 +0100 Subject: [PATCH 43/43] Added documentation --- docs/user-guide/training.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/user-guide/training.rst b/docs/user-guide/training.rst index 6eda7f01..a365c831 100644 --- a/docs/user-guide/training.rst +++ b/docs/user-guide/training.rst @@ -297,7 +297,7 @@ flag to True in the configuration file. training: # start the training from a checkpoint of a previous run - fork_run_id: '51a97d40a49e48d284494a3b5d87ef2b' + fork_run_id: ... load_weights_only: True transfer_learning: True @@ -333,7 +333,7 @@ frozen and only the encoder and decoder will be trained: training: # start the training from a checkpoint of a previous run - fork_run_id: '51a97d40a49e48d284494a3b5d87ef2b' + fork_run_id: ... load_weights_only: True submodules_to_freeze: