File tree 1 file changed +5
-3
lines changed
1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change 18
18
19
19
def main (config , out_file ):
20
20
logger = config .get_logger ("test" )
21
-
21
+
22
+ # define cpu or gpu if possible
23
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
24
+
22
25
# text_encoder
23
26
text_encoder = CTCCharTextEncoder .get_simple_alphabet ()
24
27
@@ -30,14 +33,13 @@ def main(config, out_file):
30
33
logger .info (model )
31
34
32
35
logger .info ("Loading checkpoint: {} ..." .format (config .resume ))
33
- checkpoint = torch .load (config .resume )
36
+ checkpoint = torch .load (config .resume , map_location = device )
34
37
state_dict = checkpoint ["state_dict" ]
35
38
if config ["n_gpu" ] > 1 :
36
39
model = torch .nn .DataParallel (model )
37
40
model .load_state_dict (state_dict )
38
41
39
42
# prepare model for testing
40
- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
41
43
model = model .to (device )
42
44
model .eval ()
43
45
You can’t perform that action at this time.
0 commit comments