Skip to content

Commit

Permalink
Merge commit '463bec447aaa6daaf38e3b9c6dc041e168ee7ea1' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
icedoom888 committed Dec 27, 2024
2 parents f8e5b39 + 463bec4 commit 781ea4e
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 9 deletions.
4 changes: 4 additions & 0 deletions training/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ Keep it human-readable, your future self will thank you!
- Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting.
-
<b> TRANSFER LEARNING</b>: enabled new functionality. You can now load checkpoints from different models and different training runs.
- Introduce (optional) variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting.
- <b> TRANSFER LEARNING</b>: enabled new functionality. You can now load checkpoints from different models and different training runs.
- Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`.
Used for experiment reproducibility across different computing configurations.
- Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120)
- Added default configuration files for stretched grid and limited area model experiments [173](https://github.com/ecmwf/anemoi-training/pull/173)
- Added new metrics for stretched grid models to track losses inside/outside the regional domain [#199](https://github.com/ecmwf/anemoi-training/pull/199)
- <b> Model Freezing ❄️</b>: enabled new functionality. You can now Freeze parts of your model by specifying a list of submodules to freeze with the new config parameter: submodules_to_freeze.
- Introduce (optional) variable to configure: submodules_to_freeze -> List[str], list of submodules to freeze.
- Add supporting arrrays (numpy) to checkpoint
- Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171)
- Improved validation metrics, allow 'all' to be scaled [#202](https://github.com/ecmwf/anemoi-training/pull/202)
Expand Down
62 changes: 62 additions & 0 deletions training/docs/user-guide/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,65 @@ finished training. However if the user wants to restart the model from a
specific point they can do this by setting
``config.hardware.files.warm_start`` to be the checkpoint they want to
restart from..

*******************
Transfer Learning
*******************

Transfer learning allows the model to reuse knowledge from a previously
trained checkpoint. This is particularly useful when the new task is
related to the old one, enabling faster convergence and often improving
model performance.

To enable transfer learning, set the config.training.transfer_learning
flag to True in the configuration file.

.. code:: yaml
training:
# start the training from a checkpoint of a previous run
fork_run_id: ...
load_weights_only: True
transfer_learning: True
When this flag is active and a checkpoint path is specified in
config.hardware.files.warm_start or self.last_checkpoint, the system
loads the pre-trained weights using the `transfer_learning_loading`
function. This approach ensures only compatible weights are loaded and
mismatched layers are handled appropriately.

For example, transfer learning might be used to adapt a weather
forecasting model trained on one geographic region to another region
with similar characteristics.

****************
Model Freezing
****************

Model freezing is a technique where specific parts (submodules) of a
model are excluded from training. This is useful when certain parts of
the model have been sufficiently trained or should remain unchanged for
the current task.

To specify which submodules to freeze, use the
config.training.submodules_to_freeze field in the configuration. List
the names of submodules to be frozen. During model initialization, these
submodules will have their parameters frozen, ensuring they are not
updated during training.

For example with the following configuration, the processor will be
frozen and only the encoder and decoder will be trained:

.. code:: yaml
training:
# start the training from a checkpoint of a previous run
fork_run_id: ...
load_weights_only: True
submodules_to_freeze:
- processor
Freezing can be particularly beneficial in scenarios such as fine-tuning
when only specific components (e.g., the encoder, the decoder) need to
adapt to a new task while keeping others (e.g., the processor) fixed.
2 changes: 2 additions & 0 deletions training/src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,5 @@ node_loss_weights:
_target_: anemoi.training.losses.nodeweights.GraphNodeAttribute
target_nodes: ${graph.data}
node_attribute: area_weight

submodules_to_freeze: []
26 changes: 17 additions & 9 deletions training/src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from anemoi.training.diagnostics.logger import get_wandb_logger
from anemoi.training.distributed.strategy import DDPGroupStrategy
from anemoi.training.train.forecaster import GraphForecaster
from anemoi.training.utils.checkpoint import freeze_submodule_by_name
from anemoi.training.utils.checkpoint import transfer_learning_loading
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.seeding import get_base_seed
Expand Down Expand Up @@ -155,17 +156,24 @@ def model(self) -> GraphForecaster:

model = GraphForecaster(**kwargs)

# Load the model weights
if self.load_weights_only:
# Sanify the checkpoint for transfer learning
if self.config.training.transfer_learning:
LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint)
return transfer_learning_loading(model, self.last_checkpoint)
if hasattr(self.config.training, "transfer_learning"):
# Sanify the checkpoint for transfer learning
if self.config.training.transfer_learning:
LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint)
model = transfer_learning_loading(model, self.last_checkpoint)
else:
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
model = GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False)

if hasattr(self.config.training, "submodules_to_freeze"):
# Freeze the chosen model weights
LOGGER.info("The following submodules will NOT be trained: %s", self.config.training.submodules_to_freeze)
for submodule_name in self.config.training.submodules_to_freeze:
freeze_submodule_by_name(model, submodule_name)
LOGGER.info("%s frozen successfully.", submodule_name.upper())

LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)

return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False)

LOGGER.info("Model initialised from scratch.")
return model

@rank_zero_only
Expand Down
21 changes: 21 additions & 0 deletions training/src/anemoi/training/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,24 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) ->
# Load the filtered st-ate_dict into the model
model.load_state_dict(state_dict, strict=False)
return model


def freeze_submodule_by_name(module: nn.Module, target_name: str) -> None:
"""
Recursively freezes the parameters of a submodule with the specified name.
Parameters
----------
module : torch.nn.Module
Pytorch model
target_name : str
The name of the submodule to freeze.
"""
for name, child in module.named_children():
# If this is the target submodule, freeze its parameters
if name == target_name:
for param in child.parameters():
param.requires_grad = False
else:
# Recursively search within children
freeze_submodule_by_name(child, target_name)

0 comments on commit 781ea4e

Please sign in to comment.