diff --git a/training/CHANGELOG.md b/training/CHANGELOG.md index 19160c59..f28e228e 100644 --- a/training/CHANGELOG.md +++ b/training/CHANGELOG.md @@ -27,11 +27,15 @@ Keep it human-readable, your future self will thank you! - Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting. - TRANSFER LEARNING: enabled new functionality. You can now load checkpoints from different models and different training runs. +- Introduce (optional) variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting. +- TRANSFER LEARNING: enabled new functionality. You can now load checkpoints from different models and different training runs. - Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`. Used for experiment reproducibility across different computing configurations. - Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) - Added default configuration files for stretched grid and limited area model experiments [173](https://github.com/ecmwf/anemoi-training/pull/173) - Added new metrics for stretched grid models to track losses inside/outside the regional domain [#199](https://github.com/ecmwf/anemoi-training/pull/199) +- Model Freezing ❄️: enabled new functionality. You can now Freeze parts of your model by specifying a list of submodules to freeze with the new config parameter: submodules_to_freeze. +- Introduce (optional) variable to configure: submodules_to_freeze -> List[str], list of submodules to freeze. - Add supporting arrrays (numpy) to checkpoint - Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171) - Improved validation metrics, allow 'all' to be scaled [#202](https://github.com/ecmwf/anemoi-training/pull/202) diff --git a/training/docs/user-guide/training.rst b/training/docs/user-guide/training.rst index e90b1583..a365c831 100644 --- a/training/docs/user-guide/training.rst +++ b/training/docs/user-guide/training.rst @@ -280,3 +280,65 @@ finished training. However if the user wants to restart the model from a specific point they can do this by setting ``config.hardware.files.warm_start`` to be the checkpoint they want to restart from.. + +******************* + Transfer Learning +******************* + +Transfer learning allows the model to reuse knowledge from a previously +trained checkpoint. This is particularly useful when the new task is +related to the old one, enabling faster convergence and often improving +model performance. + +To enable transfer learning, set the config.training.transfer_learning +flag to True in the configuration file. + +.. code:: yaml + + training: + # start the training from a checkpoint of a previous run + fork_run_id: ... + load_weights_only: True + transfer_learning: True + +When this flag is active and a checkpoint path is specified in +config.hardware.files.warm_start or self.last_checkpoint, the system +loads the pre-trained weights using the `transfer_learning_loading` +function. This approach ensures only compatible weights are loaded and +mismatched layers are handled appropriately. + +For example, transfer learning might be used to adapt a weather +forecasting model trained on one geographic region to another region +with similar characteristics. + +**************** + Model Freezing +**************** + +Model freezing is a technique where specific parts (submodules) of a +model are excluded from training. This is useful when certain parts of +the model have been sufficiently trained or should remain unchanged for +the current task. + +To specify which submodules to freeze, use the +config.training.submodules_to_freeze field in the configuration. List +the names of submodules to be frozen. During model initialization, these +submodules will have their parameters frozen, ensuring they are not +updated during training. + +For example with the following configuration, the processor will be +frozen and only the encoder and decoder will be trained: + +.. code:: yaml + + training: + # start the training from a checkpoint of a previous run + fork_run_id: ... + load_weights_only: True + + submodules_to_freeze: + - processor + +Freezing can be particularly beneficial in scenarios such as fine-tuning +when only specific components (e.g., the encoder, the decoder) need to +adapt to a new task while keeping others (e.g., the processor) fixed. diff --git a/training/src/anemoi/training/config/training/default.yaml b/training/src/anemoi/training/config/training/default.yaml index 6c915eb5..c604c081 100644 --- a/training/src/anemoi/training/config/training/default.yaml +++ b/training/src/anemoi/training/config/training/default.yaml @@ -140,3 +140,5 @@ node_loss_weights: _target_: anemoi.training.losses.nodeweights.GraphNodeAttribute target_nodes: ${graph.data} node_attribute: area_weight + +submodules_to_freeze: [] diff --git a/training/src/anemoi/training/train/train.py b/training/src/anemoi/training/train/train.py index d786c13a..dcc11474 100644 --- a/training/src/anemoi/training/train/train.py +++ b/training/src/anemoi/training/train/train.py @@ -32,6 +32,7 @@ from anemoi.training.diagnostics.logger import get_wandb_logger from anemoi.training.distributed.strategy import DDPGroupStrategy from anemoi.training.train.forecaster import GraphForecaster +from anemoi.training.utils.checkpoint import freeze_submodule_by_name from anemoi.training.utils.checkpoint import transfer_learning_loading from anemoi.training.utils.jsonify import map_config_to_primitives from anemoi.training.utils.seeding import get_base_seed @@ -155,17 +156,24 @@ def model(self) -> GraphForecaster: model = GraphForecaster(**kwargs) + # Load the model weights if self.load_weights_only: - # Sanify the checkpoint for transfer learning - if self.config.training.transfer_learning: - LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint) - return transfer_learning_loading(model, self.last_checkpoint) + if hasattr(self.config.training, "transfer_learning"): + # Sanify the checkpoint for transfer learning + if self.config.training.transfer_learning: + LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint) + model = transfer_learning_loading(model, self.last_checkpoint) + else: + LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) + model = GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) + + if hasattr(self.config.training, "submodules_to_freeze"): + # Freeze the chosen model weights + LOGGER.info("The following submodules will NOT be trained: %s", self.config.training.submodules_to_freeze) + for submodule_name in self.config.training.submodules_to_freeze: + freeze_submodule_by_name(model, submodule_name) + LOGGER.info("%s frozen successfully.", submodule_name.upper()) - LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) - - return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) - - LOGGER.info("Model initialised from scratch.") return model @rank_zero_only diff --git a/training/src/anemoi/training/utils/checkpoint.py b/training/src/anemoi/training/utils/checkpoint.py index 28152d11..54206024 100644 --- a/training/src/anemoi/training/utils/checkpoint.py +++ b/training/src/anemoi/training/utils/checkpoint.py @@ -91,3 +91,24 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> # Load the filtered st-ate_dict into the model model.load_state_dict(state_dict, strict=False) return model + + +def freeze_submodule_by_name(module: nn.Module, target_name: str) -> None: + """ + Recursively freezes the parameters of a submodule with the specified name. + + Parameters + ---------- + module : torch.nn.Module + Pytorch model + target_name : str + The name of the submodule to freeze. + """ + for name, child in module.named_children(): + # If this is the target submodule, freeze its parameters + if name == target_name: + for param in child.parameters(): + param.requires_grad = False + else: + # Recursively search within children + freeze_submodule_by_name(child, target_name)