diff --git a/training/CHANGELOG.md b/training/CHANGELOG.md
index 19160c59..f28e228e 100644
--- a/training/CHANGELOG.md
+++ b/training/CHANGELOG.md
@@ -27,11 +27,15 @@ Keep it human-readable, your future self will thank you!
- Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting.
-
TRANSFER LEARNING: enabled new functionality. You can now load checkpoints from different models and different training runs.
+- Introduce (optional) variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting.
+- TRANSFER LEARNING: enabled new functionality. You can now load checkpoints from different models and different training runs.
- Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`.
Used for experiment reproducibility across different computing configurations.
- Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120)
- Added default configuration files for stretched grid and limited area model experiments [173](https://github.com/ecmwf/anemoi-training/pull/173)
- Added new metrics for stretched grid models to track losses inside/outside the regional domain [#199](https://github.com/ecmwf/anemoi-training/pull/199)
+- Model Freezing ❄️: enabled new functionality. You can now Freeze parts of your model by specifying a list of submodules to freeze with the new config parameter: submodules_to_freeze.
+- Introduce (optional) variable to configure: submodules_to_freeze -> List[str], list of submodules to freeze.
- Add supporting arrrays (numpy) to checkpoint
- Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171)
- Improved validation metrics, allow 'all' to be scaled [#202](https://github.com/ecmwf/anemoi-training/pull/202)
diff --git a/training/docs/user-guide/training.rst b/training/docs/user-guide/training.rst
index e90b1583..a365c831 100644
--- a/training/docs/user-guide/training.rst
+++ b/training/docs/user-guide/training.rst
@@ -280,3 +280,65 @@ finished training. However if the user wants to restart the model from a
specific point they can do this by setting
``config.hardware.files.warm_start`` to be the checkpoint they want to
restart from..
+
+*******************
+ Transfer Learning
+*******************
+
+Transfer learning allows the model to reuse knowledge from a previously
+trained checkpoint. This is particularly useful when the new task is
+related to the old one, enabling faster convergence and often improving
+model performance.
+
+To enable transfer learning, set the config.training.transfer_learning
+flag to True in the configuration file.
+
+.. code:: yaml
+
+ training:
+ # start the training from a checkpoint of a previous run
+ fork_run_id: ...
+ load_weights_only: True
+ transfer_learning: True
+
+When this flag is active and a checkpoint path is specified in
+config.hardware.files.warm_start or self.last_checkpoint, the system
+loads the pre-trained weights using the `transfer_learning_loading`
+function. This approach ensures only compatible weights are loaded and
+mismatched layers are handled appropriately.
+
+For example, transfer learning might be used to adapt a weather
+forecasting model trained on one geographic region to another region
+with similar characteristics.
+
+****************
+ Model Freezing
+****************
+
+Model freezing is a technique where specific parts (submodules) of a
+model are excluded from training. This is useful when certain parts of
+the model have been sufficiently trained or should remain unchanged for
+the current task.
+
+To specify which submodules to freeze, use the
+config.training.submodules_to_freeze field in the configuration. List
+the names of submodules to be frozen. During model initialization, these
+submodules will have their parameters frozen, ensuring they are not
+updated during training.
+
+For example with the following configuration, the processor will be
+frozen and only the encoder and decoder will be trained:
+
+.. code:: yaml
+
+ training:
+ # start the training from a checkpoint of a previous run
+ fork_run_id: ...
+ load_weights_only: True
+
+ submodules_to_freeze:
+ - processor
+
+Freezing can be particularly beneficial in scenarios such as fine-tuning
+when only specific components (e.g., the encoder, the decoder) need to
+adapt to a new task while keeping others (e.g., the processor) fixed.
diff --git a/training/src/anemoi/training/config/training/default.yaml b/training/src/anemoi/training/config/training/default.yaml
index 6c915eb5..c604c081 100644
--- a/training/src/anemoi/training/config/training/default.yaml
+++ b/training/src/anemoi/training/config/training/default.yaml
@@ -140,3 +140,5 @@ node_loss_weights:
_target_: anemoi.training.losses.nodeweights.GraphNodeAttribute
target_nodes: ${graph.data}
node_attribute: area_weight
+
+submodules_to_freeze: []
diff --git a/training/src/anemoi/training/train/train.py b/training/src/anemoi/training/train/train.py
index d786c13a..dcc11474 100644
--- a/training/src/anemoi/training/train/train.py
+++ b/training/src/anemoi/training/train/train.py
@@ -32,6 +32,7 @@
from anemoi.training.diagnostics.logger import get_wandb_logger
from anemoi.training.distributed.strategy import DDPGroupStrategy
from anemoi.training.train.forecaster import GraphForecaster
+from anemoi.training.utils.checkpoint import freeze_submodule_by_name
from anemoi.training.utils.checkpoint import transfer_learning_loading
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.seeding import get_base_seed
@@ -155,17 +156,24 @@ def model(self) -> GraphForecaster:
model = GraphForecaster(**kwargs)
+ # Load the model weights
if self.load_weights_only:
- # Sanify the checkpoint for transfer learning
- if self.config.training.transfer_learning:
- LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint)
- return transfer_learning_loading(model, self.last_checkpoint)
+ if hasattr(self.config.training, "transfer_learning"):
+ # Sanify the checkpoint for transfer learning
+ if self.config.training.transfer_learning:
+ LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint)
+ model = transfer_learning_loading(model, self.last_checkpoint)
+ else:
+ LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
+ model = GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False)
+
+ if hasattr(self.config.training, "submodules_to_freeze"):
+ # Freeze the chosen model weights
+ LOGGER.info("The following submodules will NOT be trained: %s", self.config.training.submodules_to_freeze)
+ for submodule_name in self.config.training.submodules_to_freeze:
+ freeze_submodule_by_name(model, submodule_name)
+ LOGGER.info("%s frozen successfully.", submodule_name.upper())
- LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
-
- return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False)
-
- LOGGER.info("Model initialised from scratch.")
return model
@rank_zero_only
diff --git a/training/src/anemoi/training/utils/checkpoint.py b/training/src/anemoi/training/utils/checkpoint.py
index 28152d11..54206024 100644
--- a/training/src/anemoi/training/utils/checkpoint.py
+++ b/training/src/anemoi/training/utils/checkpoint.py
@@ -91,3 +91,24 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) ->
# Load the filtered st-ate_dict into the model
model.load_state_dict(state_dict, strict=False)
return model
+
+
+def freeze_submodule_by_name(module: nn.Module, target_name: str) -> None:
+ """
+ Recursively freezes the parameters of a submodule with the specified name.
+
+ Parameters
+ ----------
+ module : torch.nn.Module
+ Pytorch model
+ target_name : str
+ The name of the submodule to freeze.
+ """
+ for name, child in module.named_children():
+ # If this is the target submodule, freeze its parameters
+ if name == target_name:
+ for param in child.parameters():
+ param.requires_grad = False
+ else:
+ # Recursively search within children
+ freeze_submodule_by_name(child, target_name)