Skip to content

Commit f973e72

Browse files
authored
add cpu support via map_location in test.py
1 parent 01e4689 commit f973e72

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

test.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818

1919
def main(config, out_file):
2020
logger = config.get_logger("test")
21-
21+
22+
# define cpu or gpu if possible
23+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24+
2225
# text_encoder
2326
text_encoder = CTCCharTextEncoder.get_simple_alphabet()
2427

@@ -30,14 +33,13 @@ def main(config, out_file):
3033
logger.info(model)
3134

3235
logger.info("Loading checkpoint: {} ...".format(config.resume))
33-
checkpoint = torch.load(config.resume)
36+
checkpoint = torch.load(config.resume, map_location=device)
3437
state_dict = checkpoint["state_dict"]
3538
if config["n_gpu"] > 1:
3639
model = torch.nn.DataParallel(model)
3740
model.load_state_dict(state_dict)
3841

3942
# prepare model for testing
40-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4143
model = model.to(device)
4244
model.eval()
4345

0 commit comments

Comments
 (0)