Skip to content

Commit 38fc772

Browse files
committed
Allow non-strict weight loading
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 077020d commit 38fc772

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/speculators/train/checkpointer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def load_model_state_dict(self, model: PreTrainedModel):
8484
full_state_dict = load_safetensors_state_dict(
8585
self.model_path(self.previous_epoch), "cuda:0"
8686
)
87-
model.load_state_dict(full_state_dict)
87+
# Note: `strict=False` because we don't load the verifier weights
88+
model.load_state_dict(full_state_dict, strict=False)
8889

8990
def load_optimizer_state_dict(
9091
self,
@@ -110,10 +111,13 @@ def load_model_state_dict(self, model: PreTrainedModel):
110111
full_state_dict = load_safetensors_state_dict(
111112
self.model_path(self.previous_epoch), "cpu"
112113
)
114+
# Note: `strict=False` because we don't load the verifier weights
113115
set_model_state_dict(
114116
model,
115117
full_state_dict, # type: ignore[arg-type]
116-
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
118+
options=StateDictOptions(
119+
full_state_dict=True, broadcast_from_rank0=True, strict=False
120+
),
117121
)
118122
dist.barrier()
119123

0 commit comments

Comments
 (0)