Skip to content

Commit 6e61fc6

Browse files
authored
Merge pull request sgrvinod#93 from ngshya/patch-1
added map_location=str(device) in torch.load()
2 parents 6b0f830 + 094f0b9 commit 6e61fc6

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

caption.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def visualize_att(image_path, seq, alphas, rev_word_map, smooth=True):
197197
args = parser.parse_args()
198198

199199
# Load model
200-
checkpoint = torch.load(args.model)
200+
checkpoint = torch.load(args.model, map_location=str(device))
201201
decoder = checkpoint['decoder']
202202
decoder = decoder.to(device)
203203
decoder.eval()

0 commit comments

Comments
 (0)