Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Checkpoint reading #230

Closed
blester125 opened this issue Feb 19, 2024 · 0 comments · Fixed by #234
Closed

Pytorch Checkpoint reading #230

blester125 opened this issue Feb 19, 2024 · 0 comments · Fixed by #234

Comments

@blester125
Copy link
Collaborator

Right now pytorch checkpoint loading just uses torch.load(${file}).

There are some cases, such as where the checkpoint was saved with GPU tensors that we probably need to add map_location=torch.device("cpu"). Similarly there was a checkpoint where a .detach() call was needed before the .numpy() call.

I think we just need to add things like this to make sure we can always get to CPU tensors and then numpy arrays, unless there are cases where we would want things to live on the GPU?

@blester125 blester125 linked a pull request Mar 26, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant