You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary: If you train a model with multi-GPU, then try to load it with single-GPU, the weights will not be successfully loaded. The run will print warnings but will not fail.
Background:
When a PyTorch module is wrapped in DataParallel, the original module will be stored as multi_gpu_model.module.
This means that calling state_dict on multi_gpu_model will add a "module." prefix to every key in the state_dict.
In other words, state_dicts are not compatible between single-gpu and multi-gpu models.
What this affects:
Two-phase training runs that are run from two separate commands (via load_target_train_checkpoint) will incorrectly load the model weights, reverting to the pretrained weights if one phase uses multi-GPU and the other uses single-GPU
E.g. If intermediate task training is on multi-GPU and target task training is on single-GPU, the target task run will fail to load the intermediate-trained weights, and revert to pretrained weights. The logs will be filled with "parameter missed" warnings, but will not fail
Proposed solution:
Don't call state_dict on a DataParallel model. Write a function that determines whether a model is DaraParallel, and call model.state_dict or model.module.state_dict correspondingly
Comment by pruksmhc Sunday May 17, 2020 at 16:32 GMT
Update: It's actually not quite as simple as just putting the above function to modify saving, since modifying saving will not work if we want to restart a job from multi-GPU (because the saved model will now always have model.state_dict() while the model in multi-GPU expects module.model.state_dict()
Issue by zphang
Friday May 15, 2020 at 06:16 GMT
Originally opened as nyu-mll/jiant#1087
Summary: If you train a model with multi-GPU, then try to load it with single-GPU, the weights will not be successfully loaded. The run will print warnings but will not fail.
Background:
DataParallel
, the original module will be stored asmulti_gpu_model.module
.state_dict
onmulti_gpu_model
will add a "module." prefix to every key in thestate_dict
.state_dict
s are not compatible between single-gpu and multi-gpu models.What this affects:
load_target_train_checkpoint
) will incorrectly load the model weights, reverting to the pretrained weights if one phase uses multi-GPU and the other uses single-GPUProposed solution:
state_dict
on a DataParallel model. Write a function that determines whether a model is DaraParallel, and callmodel.state_dict
ormodel.module.state_dict
correspondinglyThe text was updated successfully, but these errors were encountered: