Skip to content

Commit

Permalink
generate.py: allow loading checkpoint across devices
Browse files Browse the repository at this point in the history
  • Loading branch information
zhoupingjay committed Jan 30, 2024
1 parent 8634cd7 commit cd45dc4
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import time
import streamlit as st
from torch.serialization import MAP_LOCATION

from model import SanGuoGPTModel
from sanguo_data import encoder, decoder, load_token_map
Expand All @@ -20,14 +21,15 @@
args = parser.parse_args()

print(f"Loading model from {args.resume_from}")
checkpoint = torch.load(args.resume_from)
model_args = checkpoint['model_args']
device = (
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)

checkpoint = torch.load(args.resume_from, map_location=device)
model_args = checkpoint['model_args']

model = SanGuoGPTModel(vocab_size=model_args['vocab_size'],
d_model=model_args['d_model'],
n_layer=model_args['n_layer'],
Expand All @@ -36,10 +38,11 @@
n_head=model_args['n_head'],
device=device
)
model.load_state_dict(checkpoint['model'])
model = model.to(device)

if args.compile:
model = torch.compile(model) # requires PyTorch 2.0
model.load_state_dict(checkpoint['model'])

c2i, i2c = load_token_map(c2i_file=args.c2i, i2c_file=args.i2c)
print(f"Loading token map file from {args.c2i} and {args.i2c}")
Expand Down

0 comments on commit cd45dc4

Please sign in to comment.