Skip to content

Commit 99b6638

Browse files
author
Seppo Enarvi
committed
More generic customization of the WeightAveraging callback
- The user can specify when to update the average model by overriding the should_update() method - Any keyword arguments will be passed to the AveragedModel constructor
1 parent 075bfcf commit 99b6638

File tree

3 files changed

+191
-111
lines changed

3 files changed

+191
-111
lines changed

Diff for: src/lightning/pytorch/CHANGELOG.md

+12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
## [unreleased] - YYYY-MM-DD
8+
9+
### Added
10+
11+
- WeightAveraging callback that wraps the PyTorch AveragedModel class ([#20545](https://github.com/Lightning-AI/pytorch-lightning/pull/20545))
12+
13+
### Changed
14+
15+
### Removed
16+
17+
### Fixed
18+
719
## [2.5.0] - 2024-12-19
820

921
### Added

Diff for: src/lightning/pytorch/callbacks/weight_averaging.py

+110-56
Original file line numberDiff line numberDiff line change
@@ -18,77 +18,109 @@
1818

1919
import itertools
2020
from copy import deepcopy
21-
from typing import Any, Callable, Optional, Union
21+
from typing import Any, Optional, Union
2222

2323
import torch
24-
from torch import Tensor
2524
from torch.optim.swa_utils import AveragedModel
25+
from typing_extensions import override
2626

2727
import lightning.pytorch as pl
2828
from lightning.pytorch.callbacks.callback import Callback
2929
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
3030
from lightning.pytorch.utilities.types import STEP_OUTPUT
3131

3232

33-
def _return_true(x: int) -> bool:
34-
return True
35-
36-
37-
def _return_false(x: int) -> bool:
38-
return False
39-
40-
4133
class WeightAveraging(Callback):
4234
r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
4335
(EMA) after each training step.
4436
45-
The user should provide either `update_on_step` or `update_on_epoch`, a function that determines when the average
46-
model should be updated. If neither function is provided, the average model will be updated after every optimizer
47-
step.
37+
Arguments given to the constructor will be passed to the :class:`AveragedModel` constructor. There are a couple of
38+
differences to the default values, however. By default, the average model is stored on the CPU. If ``device`` is set
39+
to ``None``, the device will be inferred from the original model. By default, the callback will compute running
40+
averages for both the parameters and the buffers of the model. Setting ``use_buffers`` to ``False`` will cause only
41+
the model parameters to be averaged, leaving updating the batch normalization statistics to the user (using
42+
``torch.optim.swa_utils.update_bn()``).
43+
44+
You can provide a custom averaging function with the ``avg_fn`` or ``multi_avg_fn`` parameter. See the
45+
:class:`AveragedModel` class for details. If no averaging function is provided, the default is to compute the
46+
equally-weighted average of the weights (SWA).
47+
48+
You can customize when the average model is updated by overriding the ``should_update()`` method. The callback calls
49+
it with either ``step_idx`` or ``epoch_idx`` and the method returns a boolean indicating whether to update after the
50+
given step or epoch. The default is to update after every step.
4851
4952
During validation and after the training finishes, the current model parameters will be replaced with the averaged
5053
values.
5154
55+
Example::
56+
57+
from lightning.pytorch.callbacks import WeightAveraging
58+
from torch.optim.swa_utils import get_ema_avg_fn
59+
60+
class EMAWeightAveraging(WeightAveraging):
61+
def __init__(self):
62+
super().__init__(avg_fn=get_ema_avg_fn())
63+
64+
def should_update(self, step_idx=None, epoch_idx=None):
65+
# Start after 100 steps.
66+
return (step_idx is not None) and (step_idx >= 100)
67+
68+
trainer = Trainer(callbacks=EMAWeightAveraging(), max_epochs=10)
69+
trainer.fit(model, dataloader)
70+
5271
Args:
5372
device: If provided, the :class:`AveragedModel` will be stored on the ``device``. If ``None`` the device will be
5473
inferred from the original model.
55-
avg_fn: The averaging function used to update the parameters. The function must take in an
56-
:class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If
57-
``None``, an equally weighted average will be used.
58-
update_on_step: A function that takes the number of optimizer steps taken, and returns ``True`` if the average
59-
model should be updated.
60-
update_on_epoch: A function that takes the zero-based epoch number, and returns ``True`` if the average model
61-
should be updated.
74+
use_buffers: If ``False``, the buffers of the model will not be averaged.
75+
kwargs: Additional keyword arguments to be passed to the :class:`AveragedModel` constructor, such as ``avg_fn``
76+
or ``multi_avg_fn``.
6277
6378
"""
6479

6580
def __init__(
6681
self,
67-
device: Optional[Union[torch.device, int]] = torch.device("cpu"),
68-
avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None,
69-
update_on_step: Optional[Callable[[int], bool]] = None,
70-
update_on_epoch: Optional[Callable[[int], bool]] = None,
71-
):
72-
self._device = device
73-
self._avg_fn = avg_fn
74-
75-
if (update_on_step is None) and (update_on_epoch is None):
76-
self._update_on_step: Callable[[int], bool] = _return_true
77-
self._update_on_epoch: Callable[[int], bool] = _return_false
82+
device: Optional[Union[torch.device, str, int]] = "cpu",
83+
use_buffers: bool = True,
84+
**kwargs: Any,
85+
) -> None:
86+
# The default value is a string so that jsonargparse knows how to serialize it.
87+
if isinstance(device, str):
88+
self._device: Optional[Union[torch.device, int]] = torch.device(device)
7889
else:
79-
self._update_on_step = _return_false if update_on_step is None else update_on_step
80-
self._update_on_epoch = _return_false if update_on_epoch is None else update_on_epoch
90+
self._device = device
91+
self._use_buffers = use_buffers
92+
self._kwargs = kwargs
8193

8294
self._average_model: Optional[AveragedModel] = None
8395

8496
# Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures
85-
# that the average model will be first updated after the first optimizer step, which takes place after N batches
86-
# when using accumulate_grad_batches=N.
97+
# that self.should_update() will be first called after the first optimizer step, which takes place after N
98+
# batches when using accumulate_grad_batches=N.
8799
self._latest_update_step = 0
88100
# The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a
89-
# negative value means that if update_on_step(0) returns True, the first update is after the first epoch.
101+
# negative value means that if self.should_update(epoch_idx=0) returns True, the first update is after the first
102+
# epoch.
90103
self._latest_update_epoch = -1
91104

105+
def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool:
106+
"""Called after every optimizer step and after every training epoch to check whether the average model should
107+
be updated.
108+
109+
One of the arguments is set to the zero-based index of the last training step or epoch. The default
110+
implementation returns ``True`` when any ``step_idx`` is provided. The user can customize when the average model
111+
gets updated by overriding this method.
112+
113+
Args:
114+
step_idx: Index of the last optimizer step, or ``None`` when called at the epoch end.
115+
epoch_idx: Index of the last epoch, or ``None`` when called after an optimizer step.
116+
117+
Returns:
118+
``True`` if the average model should be updated and ``False`` if not.
119+
120+
"""
121+
return step_idx is not None
122+
123+
@override
92124
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
93125
"""Called when fit, validate, test, predict, or tune begins.
94126
@@ -102,14 +134,17 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
102134
"""
103135
if stage == "fit":
104136
device = self._device or pl_module.device
105-
self._average_model = AveragedModel(model=pl_module, device=device, avg_fn=self._avg_fn, use_buffers=True)
137+
self._average_model = AveragedModel(
138+
model=pl_module, device=device, use_buffers=self._use_buffers, **self._kwargs
139+
)
106140

141+
@override
107142
def on_train_batch_end(
108143
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
109144
) -> None:
110145
"""Called when a training batch ends.
111146
112-
Updates the :class:`AveragedModel` parameters, if requested by ``update_on_step()``.
147+
Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update()``.
113148
114149
Args:
115150
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
@@ -119,26 +154,31 @@ def on_train_batch_end(
119154
batch_idx: Index of the training batch.
120155
121156
"""
122-
if self._update_on_step(trainer.global_step) and (trainer.global_step > self._latest_update_step):
157+
# trainer.global_step is the number of optimizer steps taken so far, i.e. 1 after the first optimizer step. To
158+
# make step_idx consistent with epoch_idx, we'll pass a zero-based index.
159+
step_idx = trainer.global_step - 1
160+
if (trainer.global_step > self._latest_update_step) and self.should_update(step_idx=step_idx):
123161
assert self._average_model is not None
124162
self._average_model.update_parameters(pl_module)
125163
self._latest_update_step = trainer.global_step
126164

165+
@override
127166
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
128167
"""Called when a training epoch ends.
129168
130-
Updates the :class:`AveragedModel` parameters, if requested by ``update_on_epoch()``.
169+
Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update()``.
131170
132171
Args:
133172
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
134173
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
135174
136175
"""
137-
if self._update_on_epoch(trainer.current_epoch) and (trainer.current_epoch > self._latest_update_epoch):
176+
if (trainer.current_epoch > self._latest_update_epoch) and self.should_update(epoch_idx=trainer.current_epoch):
138177
assert self._average_model is not None
139178
self._average_model.update_parameters(pl_module)
140179
self._latest_update_epoch = trainer.current_epoch
141180

181+
@override
142182
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
143183
"""Called when training ends.
144184
@@ -150,8 +190,10 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
150190
151191
"""
152192
assert self._average_model is not None
193+
rank_zero_info("Loading the average model parameters to the final model.")
153194
self._copy_average_to_current(pl_module)
154195

196+
@override
155197
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
156198
"""Called when a validation epoch begins.
157199
@@ -166,6 +208,7 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn
166208
rank_zero_info("Loading the average model parameters for validation.")
167209
self._swap_models(pl_module)
168210

211+
@override
169212
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
170213
"""Called when a validation epoch ends.
171214
@@ -180,6 +223,7 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin
180223
rank_zero_info("Recovering the current model parameters after validation.")
181224
self._swap_models(pl_module)
182225

226+
@override
183227
def state_dict(self) -> dict[str, Any]:
184228
"""Called when saving a checkpoint.
185229
@@ -191,6 +235,7 @@ def state_dict(self) -> dict[str, Any]:
191235
"""
192236
return {"latest_update_step": self._latest_update_step}
193237

238+
@override
194239
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
195240
"""Called when loading a checkpoint.
196241
@@ -202,6 +247,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
202247
"""
203248
self._latest_update_step = state_dict["latest_update_step"]
204249

250+
@override
205251
def on_save_checkpoint(
206252
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any]
207253
) -> None:
@@ -218,18 +264,23 @@ def on_save_checkpoint(
218264
219265
"""
220266
if self._average_model is None:
221-
raise Exception("Trying to save a checkpoint, but no average model (outside fit). Don't know what to do.")
222-
223-
rank_zero_info("The average model parameters will be saved to the state_dict in the checkpoint.")
224-
average_model_state = self._average_model.state_dict()
225-
checkpoint["current_model_state"] = checkpoint["state_dict"]
226-
checkpoint["state_dict"] = {
227-
name[7:]: value for name, value in average_model_state.items() if name.startswith("module.")
228-
}
229-
checkpoint["averaging_state"] = {
230-
name: value for name, value in average_model_state.items() if not name.startswith("module.")
231-
}
232-
267+
rank_zero_info(
268+
"You're using the WeightAveraging callback, but saving a checkpoint outside the 'fit' stage. The state "
269+
"of the WeightAveraging callback won't be saved in the checkpoint. If training has finished, the "
270+
"average model parameters will be saved to the state_dict in the checkpoint."
271+
)
272+
else:
273+
rank_zero_info("The average model parameters will be saved to the state_dict in the checkpoint.")
274+
average_model_state = self._average_model.state_dict()
275+
checkpoint["current_model_state"] = checkpoint["state_dict"]
276+
checkpoint["state_dict"] = {
277+
name[7:]: value for name, value in average_model_state.items() if name.startswith("module.")
278+
}
279+
checkpoint["averaging_state"] = {
280+
name: value for name, value in average_model_state.items() if not name.startswith("module.")
281+
}
282+
283+
@override
233284
def on_load_checkpoint(
234285
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any]
235286
) -> None:
@@ -244,9 +295,12 @@ def on_load_checkpoint(
244295
245296
"""
246297
if self._average_model is None:
247-
raise Exception("Trying to load a checkpoint, but no average model (outside fit). Don't know what to do.")
248-
249-
if ("current_model_state" in checkpoint) and ("averaging_state" in checkpoint):
298+
rank_zero_warn(
299+
"You're using the WeightAveraging callback, but loading a checkpoint outside the 'fit' stage. The "
300+
"WeightAveraging state cannot be restored. If you're using the checkpoint for prediction or testing, "
301+
"you can ignore this warning. To disable the warning, remove the WeightAveraging callback."
302+
)
303+
elif ("current_model_state" in checkpoint) and ("averaging_state" in checkpoint):
250304
rank_zero_info("Found current_model_state in the checkpoint. This will be used to initialize the model.")
251305
average_model_state = {"module." + name: value for name, value in checkpoint["state_dict"].items()}
252306
average_model_state |= checkpoint["averaging_state"]

0 commit comments

Comments
 (0)