Closed
Description
Right now pytorch checkpoint loading just uses torch.load(${file})
.
There are some cases, such as where the checkpoint was saved with GPU tensors that we probably need to add map_location=torch.device("cpu")
. Similarly there was a checkpoint where a .detach()
call was needed before the .numpy()
call.
I think we just need to add things like this to make sure we can always get to CPU tensors and then numpy arrays, unless there are cases where we would want things to live on the GPU?
Metadata
Metadata
Assignees
Labels
No labels