Skip to content

Commit 5deb0bb

Browse files
author
Seppo Enarvi
committed
Fixed checkpoint loading with WeightAveraging
1 parent 822231f commit 5deb0bb

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

Diff for: src/lightning/pytorch/callbacks/weight_averaging.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,10 @@ def on_load_checkpoint(
304304
average_model_state = {"module." + name: value for name, value in checkpoint["state_dict"].items()}
305305
average_model_state |= checkpoint["averaging_state"]
306306
self._average_model.load_state_dict(average_model_state)
307-
checkpoint["state_dict"] = checkpoint["current_model_state"]
307+
# The current model state has already been loaded from "state_dict" (which contains the average model
308+
# weights) at this point, so overwriting "state_dict" in the checkpoint dictionary makes no difference. We
309+
# have to reload the model state from "current_model_state".
310+
pl_module.load_state_dict(checkpoint["current_model_state"])
308311
else:
309312
rank_zero_warn(
310313
"The checkpoint was not created with WeightAveraging. Both the current and the average model will be "

Diff for: tests/tests_pytorch/callbacks/test_weight_averaging.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_ema_resume(tmp_path, crash_on_epoch):
235235
model2 = _train_and_resume(model2, dataset, tmp_path)
236236

237237
for param1, param2 in zip(model1.parameters(), model2.parameters()):
238-
assert torch.allclose(param1, param2, atol=0.001)
238+
assert torch.allclose(param1, param2)
239239

240240

241241
@RunIf(skip_windows=True)

0 commit comments

Comments
 (0)