Incompatibility in model checkpoints between multi-GPU/single-GPU #1087
Labels
bug
Something isn't working
high-priority
Fix this before addressing any other major issue.
jiant-v1-legacy
Relevant to versions <= v1.3.2
Summary: If you train a model with multi-GPU, then try to load it with single-GPU, the weights will not be successfully loaded. The run will print warnings but will not fail.
Background:
DataParallel
, the original module will be stored asmulti_gpu_model.module
.state_dict
onmulti_gpu_model
will add a "module." prefix to every key in thestate_dict
.state_dict
s are not compatible between single-gpu and multi-gpu models.What this affects:
load_target_train_checkpoint
) will incorrectly load the model weights, reverting to the pretrained weights if one phase uses multi-GPU and the other uses single-GPUProposed solution:
state_dict
on a DataParallel model. Write a function that determines whether a model is DaraParallel, and callmodel.state_dict
ormodel.module.state_dict
correspondinglyThe text was updated successfully, but these errors were encountered: