File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -84,7 +84,8 @@ def load_model_state_dict(self, model: PreTrainedModel):
8484 full_state_dict = load_safetensors_state_dict (
8585 self .model_path (self .previous_epoch ), "cuda:0"
8686 )
87- model .load_state_dict (full_state_dict )
87+ # Note: `strict=False` because we don't load the verifier weights
88+ model .load_state_dict (full_state_dict , strict = False )
8889
8990 def load_optimizer_state_dict (
9091 self ,
@@ -110,10 +111,13 @@ def load_model_state_dict(self, model: PreTrainedModel):
110111 full_state_dict = load_safetensors_state_dict (
111112 self .model_path (self .previous_epoch ), "cpu"
112113 )
114+ # Note: `strict=False` because we don't load the verifier weights
113115 set_model_state_dict (
114116 model ,
115117 full_state_dict , # type: ignore[arg-type]
116- options = StateDictOptions (full_state_dict = True , broadcast_from_rank0 = True ),
118+ options = StateDictOptions (
119+ full_state_dict = True , broadcast_from_rank0 = True , strict = False
120+ ),
117121 )
118122 dist .barrier ()
119123
You can’t perform that action at this time.
0 commit comments