18
18
19
19
import itertools
20
20
from copy import deepcopy
21
- from typing import Any , Callable , Optional , Union
21
+ from typing import Any , Optional , Union
22
22
23
23
import torch
24
- from torch import Tensor
25
24
from torch .optim .swa_utils import AveragedModel
25
+ from typing_extensions import override
26
26
27
27
import lightning .pytorch as pl
28
28
from lightning .pytorch .callbacks .callback import Callback
29
29
from lightning .pytorch .utilities .rank_zero import rank_zero_info , rank_zero_warn
30
30
from lightning .pytorch .utilities .types import STEP_OUTPUT
31
31
32
32
33
- def _return_true (x : int ) -> bool :
34
- return True
35
-
36
-
37
- def _return_false (x : int ) -> bool :
38
- return False
39
-
40
-
41
33
class WeightAveraging (Callback ):
42
34
r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
43
35
(EMA) after each training step.
44
36
45
- The user should provide either `update_on_step` or `update_on_epoch`, a function that determines when the average
46
- model should be updated. If neither function is provided, the average model will be updated after every optimizer
47
- step.
37
+ Arguments given to the constructor will be passed to the :class:`AveragedModel` constructor. There are a couple of
38
+ differences to the default values, however. By default, the average model is stored on the CPU. If ``device`` is set
39
+ to ``None``, the device will be inferred from the original model. By default, the callback will compute running
40
+ averages for both the parameters and the buffers of the model. Setting ``use_buffers`` to ``False`` will cause only
41
+ the model parameters to be averaged, leaving updating the batch normalization statistics to the user (using
42
+ ``torch.optim.swa_utils.update_bn()``).
43
+
44
+ You can provide a custom averaging function with the ``avg_fn`` or ``multi_avg_fn`` parameter. See the
45
+ :class:`AveragedModel` class for details. If no averaging function is provided, the default is to compute the
46
+ equally-weighted average of the weights (SWA).
47
+
48
+ You can customize when the average model is updated by overriding the ``should_update()`` method. The callback calls
49
+ it with either ``step_idx`` or ``epoch_idx`` and the method returns a boolean indicating whether to update after the
50
+ given step or epoch. The default is to update after every step.
48
51
49
52
During validation and after the training finishes, the current model parameters will be replaced with the averaged
50
53
values.
51
54
55
+ Example::
56
+
57
+ from lightning.pytorch.callbacks import WeightAveraging
58
+ from torch.optim.swa_utils import get_ema_avg_fn
59
+
60
+ class EMAWeightAveraging(WeightAveraging):
61
+ def __init__(self):
62
+ super().__init__(avg_fn=get_ema_avg_fn())
63
+
64
+ def should_update(self, step_idx=None, epoch_idx=None):
65
+ # Start after 100 steps.
66
+ return (step_idx is not None) and (step_idx >= 100)
67
+
68
+ trainer = Trainer(callbacks=EMAWeightAveraging(), max_epochs=10)
69
+ trainer.fit(model, dataloader)
70
+
52
71
Args:
53
72
device: If provided, the :class:`AveragedModel` will be stored on the ``device``. If ``None`` the device will be
54
73
inferred from the original model.
55
- avg_fn: The averaging function used to update the parameters. The function must take in an
56
- :class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If
57
- ``None``, an equally weighted average will be used.
58
- update_on_step: A function that takes the number of optimizer steps taken, and returns ``True`` if the average
59
- model should be updated.
60
- update_on_epoch: A function that takes the zero-based epoch number, and returns ``True`` if the average model
61
- should be updated.
74
+ use_buffers: If ``False``, the buffers of the model will not be averaged.
75
+ kwargs: Additional keyword arguments to be passed to the :class:`AveragedModel` constructor, such as ``avg_fn``
76
+ or ``multi_avg_fn``.
62
77
63
78
"""
64
79
65
80
def __init__ (
66
81
self ,
67
- device : Optional [Union [torch .device , int ]] = torch .device ("cpu" ),
68
- avg_fn : Optional [Callable [[Tensor , Tensor , Union [Tensor , int ]], Tensor ]] = None ,
69
- update_on_step : Optional [Callable [[int ], bool ]] = None ,
70
- update_on_epoch : Optional [Callable [[int ], bool ]] = None ,
71
- ):
72
- self ._device = device
73
- self ._avg_fn = avg_fn
74
-
75
- if (update_on_step is None ) and (update_on_epoch is None ):
76
- self ._update_on_step : Callable [[int ], bool ] = _return_true
77
- self ._update_on_epoch : Callable [[int ], bool ] = _return_false
82
+ device : Optional [Union [torch .device , str , int ]] = "cpu" ,
83
+ use_buffers : bool = True ,
84
+ ** kwargs : Any ,
85
+ ) -> None :
86
+ # The default value is a string so that jsonargparse knows how to serialize it.
87
+ if isinstance (device , str ):
88
+ self ._device : Optional [Union [torch .device , int ]] = torch .device (device )
78
89
else :
79
- self ._update_on_step = _return_false if update_on_step is None else update_on_step
80
- self ._update_on_epoch = _return_false if update_on_epoch is None else update_on_epoch
90
+ self ._device = device
91
+ self ._use_buffers = use_buffers
92
+ self ._kwargs = kwargs
81
93
82
94
self ._average_model : Optional [AveragedModel ] = None
83
95
84
96
# Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures
85
- # that the average model will be first updated after the first optimizer step, which takes place after N batches
86
- # when using accumulate_grad_batches=N.
97
+ # that self.should_update() will be first called after the first optimizer step, which takes place after N
98
+ # batches when using accumulate_grad_batches=N.
87
99
self ._latest_update_step = 0
88
100
# The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a
89
- # negative value means that if update_on_step(0) returns True, the first update is after the first epoch.
101
+ # negative value means that if self.should_update(epoch_idx=0) returns True, the first update is after the first
102
+ # epoch.
90
103
self ._latest_update_epoch = - 1
91
104
105
+ def should_update (self , step_idx : Optional [int ] = None , epoch_idx : Optional [int ] = None ) -> bool :
106
+ """Called after every optimizer step and after every training epoch to check whether the average model should
107
+ be updated.
108
+
109
+ One of the arguments is set to the zero-based index of the last training step or epoch. The default
110
+ implementation returns ``True`` when any ``step_idx`` is provided. The user can customize when the average model
111
+ gets updated by overriding this method.
112
+
113
+ Args:
114
+ step_idx: Index of the last optimizer step, or ``None`` when called at the epoch end.
115
+ epoch_idx: Index of the last epoch, or ``None`` when called after an optimizer step.
116
+
117
+ Returns:
118
+ ``True`` if the average model should be updated and ``False`` if not.
119
+
120
+ """
121
+ return step_idx is not None
122
+
123
+ @override
92
124
def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
93
125
"""Called when fit, validate, test, predict, or tune begins.
94
126
@@ -102,14 +134,17 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
102
134
"""
103
135
if stage == "fit" :
104
136
device = self ._device or pl_module .device
105
- self ._average_model = AveragedModel (model = pl_module , device = device , avg_fn = self ._avg_fn , use_buffers = True )
137
+ self ._average_model = AveragedModel (
138
+ model = pl_module , device = device , use_buffers = self ._use_buffers , ** self ._kwargs
139
+ )
106
140
141
+ @override
107
142
def on_train_batch_end (
108
143
self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , outputs : STEP_OUTPUT , batch : Any , batch_idx : int
109
144
) -> None :
110
145
"""Called when a training batch ends.
111
146
112
- Updates the :class:`AveragedModel` parameters, if requested by ``update_on_step ()``.
147
+ Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update ()``.
113
148
114
149
Args:
115
150
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
@@ -119,26 +154,31 @@ def on_train_batch_end(
119
154
batch_idx: Index of the training batch.
120
155
121
156
"""
122
- if self ._update_on_step (trainer .global_step ) and (trainer .global_step > self ._latest_update_step ):
157
+ # trainer.global_step is the number of optimizer steps taken so far, i.e. 1 after the first optimizer step. To
158
+ # make step_idx consistent with epoch_idx, we'll pass a zero-based index.
159
+ step_idx = trainer .global_step - 1
160
+ if (trainer .global_step > self ._latest_update_step ) and self .should_update (step_idx = step_idx ):
123
161
assert self ._average_model is not None
124
162
self ._average_model .update_parameters (pl_module )
125
163
self ._latest_update_step = trainer .global_step
126
164
165
+ @override
127
166
def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
128
167
"""Called when a training epoch ends.
129
168
130
- Updates the :class:`AveragedModel` parameters, if requested by ``update_on_epoch ()``.
169
+ Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update ()``.
131
170
132
171
Args:
133
172
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
134
173
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
135
174
136
175
"""
137
- if self . _update_on_epoch (trainer .current_epoch ) and ( trainer .current_epoch > self . _latest_update_epoch ):
176
+ if (trainer .current_epoch > self . _latest_update_epoch ) and self . should_update ( epoch_idx = trainer .current_epoch ):
138
177
assert self ._average_model is not None
139
178
self ._average_model .update_parameters (pl_module )
140
179
self ._latest_update_epoch = trainer .current_epoch
141
180
181
+ @override
142
182
def on_train_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
143
183
"""Called when training ends.
144
184
@@ -150,8 +190,10 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
150
190
151
191
"""
152
192
assert self ._average_model is not None
193
+ rank_zero_info ("Loading the average model parameters to the final model." )
153
194
self ._copy_average_to_current (pl_module )
154
195
196
+ @override
155
197
def on_validation_epoch_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
156
198
"""Called when a validation epoch begins.
157
199
@@ -166,6 +208,7 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn
166
208
rank_zero_info ("Loading the average model parameters for validation." )
167
209
self ._swap_models (pl_module )
168
210
211
+ @override
169
212
def on_validation_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
170
213
"""Called when a validation epoch ends.
171
214
@@ -180,6 +223,7 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin
180
223
rank_zero_info ("Recovering the current model parameters after validation." )
181
224
self ._swap_models (pl_module )
182
225
226
+ @override
183
227
def state_dict (self ) -> dict [str , Any ]:
184
228
"""Called when saving a checkpoint.
185
229
@@ -191,6 +235,7 @@ def state_dict(self) -> dict[str, Any]:
191
235
"""
192
236
return {"latest_update_step" : self ._latest_update_step }
193
237
238
+ @override
194
239
def load_state_dict (self , state_dict : dict [str , Any ]) -> None :
195
240
"""Called when loading a checkpoint.
196
241
@@ -202,6 +247,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
202
247
"""
203
248
self ._latest_update_step = state_dict ["latest_update_step" ]
204
249
250
+ @override
205
251
def on_save_checkpoint (
206
252
self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , checkpoint : dict [str , Any ]
207
253
) -> None :
@@ -218,18 +264,23 @@ def on_save_checkpoint(
218
264
219
265
"""
220
266
if self ._average_model is None :
221
- raise Exception ("Trying to save a checkpoint, but no average model (outside fit). Don't know what to do." )
222
-
223
- rank_zero_info ("The average model parameters will be saved to the state_dict in the checkpoint." )
224
- average_model_state = self ._average_model .state_dict ()
225
- checkpoint ["current_model_state" ] = checkpoint ["state_dict" ]
226
- checkpoint ["state_dict" ] = {
227
- name [7 :]: value for name , value in average_model_state .items () if name .startswith ("module." )
228
- }
229
- checkpoint ["averaging_state" ] = {
230
- name : value for name , value in average_model_state .items () if not name .startswith ("module." )
231
- }
232
-
267
+ rank_zero_info (
268
+ "You're using the WeightAveraging callback, but saving a checkpoint outside the 'fit' stage. The state "
269
+ "of the WeightAveraging callback won't be saved in the checkpoint. If training has finished, the "
270
+ "average model parameters will be saved to the state_dict in the checkpoint."
271
+ )
272
+ else :
273
+ rank_zero_info ("The average model parameters will be saved to the state_dict in the checkpoint." )
274
+ average_model_state = self ._average_model .state_dict ()
275
+ checkpoint ["current_model_state" ] = checkpoint ["state_dict" ]
276
+ checkpoint ["state_dict" ] = {
277
+ name [7 :]: value for name , value in average_model_state .items () if name .startswith ("module." )
278
+ }
279
+ checkpoint ["averaging_state" ] = {
280
+ name : value for name , value in average_model_state .items () if not name .startswith ("module." )
281
+ }
282
+
283
+ @override
233
284
def on_load_checkpoint (
234
285
self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , checkpoint : dict [str , Any ]
235
286
) -> None :
@@ -244,9 +295,12 @@ def on_load_checkpoint(
244
295
245
296
"""
246
297
if self ._average_model is None :
247
- raise Exception ("Trying to load a checkpoint, but no average model (outside fit). Don't know what to do." )
248
-
249
- if ("current_model_state" in checkpoint ) and ("averaging_state" in checkpoint ):
298
+ rank_zero_warn (
299
+ "You're using the WeightAveraging callback, but loading a checkpoint outside the 'fit' stage. The "
300
+ "WeightAveraging state cannot be restored. If you're using the checkpoint for prediction or testing, "
301
+ "you can ignore this warning. To disable the warning, remove the WeightAveraging callback."
302
+ )
303
+ elif ("current_model_state" in checkpoint ) and ("averaging_state" in checkpoint ):
250
304
rank_zero_info ("Found current_model_state in the checkpoint. This will be used to initialize the model." )
251
305
average_model_state = {"module." + name : value for name , value in checkpoint ["state_dict" ].items ()}
252
306
average_model_state |= checkpoint ["averaging_state" ]
0 commit comments