Skip to content

Pytorch Checkpoint reading #230

Closed
@blester125

Description

@blester125

Right now pytorch checkpoint loading just uses torch.load(${file}).

There are some cases, such as where the checkpoint was saved with GPU tensors that we probably need to add map_location=torch.device("cpu"). Similarly there was a checkpoint where a .detach() call was needed before the .numpy() call.

I think we just need to add things like this to make sure we can always get to CPU tensors and then numpy arrays, unless there are cases where we would want things to live on the GPU?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions