diff --git a/cosmos_predict2/checkpointer.py b/cosmos_predict2/checkpointer.py index fa416996..cc9b628a 100644 --- a/cosmos_predict2/checkpointer.py +++ b/cosmos_predict2/checkpointer.py @@ -237,6 +237,7 @@ def load( strict=False if model.config.train_architecture == "lora" else True, ), ) + del state_dicts_to_load_for_dit_reg # Load EMA weights. if model.pipe.config.ema.enabled: set_model_state_dict( @@ -248,6 +249,7 @@ def load( strict=False if model.config.train_architecture == "lora" else True, ), ) + del state_dicts_to_load_for_dit_ema # Restore the attention operators. model.pipe.apply_cp()