19
19
from parameterized import parameterized
20
20
21
21
from monai .bundle import ConfigParser
22
- from monai .bundle .scripts import ckpt_export
23
22
from monai .data import load_net_with_metadata
24
23
from monai .networks import save_state
25
- from tests .utils import skip_if_windows
24
+ from tests .utils import command_line_tests , skip_if_windows
26
25
27
26
TEST_CASE_1 = ["" , "" ]
28
27
@@ -52,6 +51,8 @@ def setUp(self):
52
51
self .parser .export_config_file (config = self .def_args , filepath = self .def_args_file )
53
52
self .parser .read_config (self .config_file )
54
53
self .net = self .parser .get_parsed_content ("network_def" )
54
+ self .cmd = ["coverage" , "run" , "-m" , "monai.bundle" , "ckpt_export" , "network_def" , "--filepath" , self .ts_file ]
55
+ self .cmd += ["--meta_file" , self .meta_file , "--config_file" , f"['{ self .config_file } ','{ self .def_args_file } ']" , "--ckpt_file" ]
55
56
56
57
def tearDown (self ):
57
58
if self .device is not None :
@@ -61,46 +62,35 @@ def tearDown(self):
61
62
self .tempdir_obj .cleanup ()
62
63
63
64
@parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ])
64
- def test_ckpt_export_default (self , key_in_ckpt , use_trace ):
65
+ def test_export (self , key_in_ckpt , use_trace ):
66
+ save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = self .ckpt_file )
67
+ full_cmd = self .cmd + [self .ckpt_file , "--key_in_ckpt" , key_in_ckpt , "--args_file" , self .def_args_file ]
68
+ if use_trace == "True" :
69
+ full_cmd += ["--use_trace" , use_trace , "--input_shape" , "[1, 1, 96, 96, 96]" ]
70
+ command_line_tests (full_cmd )
71
+ self .assertTrue (os .path .exists (self .ts_file ))
72
+
73
+ _ , metadata , extra_files = load_net_with_metadata (
74
+ self .ts_file , more_extra_files = ["inference.json" , "def_args.json" ]
75
+ )
76
+ self .assertIn ("schema" , metadata )
77
+ self .assertIn ("meta_file" , json .loads (extra_files ["def_args.json" ]))
78
+ self .assertIn ("network_def" , json .loads (extra_files ["inference.json" ]))
79
+
80
+ @parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ])
81
+ def test_default_value (self , key_in_ckpt , use_trace ):
65
82
ckpt_file = os .path .join (self .tempdir_obj .name , "models/model.pt" )
66
83
ts_file = os .path .join (self .tempdir_obj .name , "models/model.ts" )
67
84
68
85
save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = ckpt_file )
69
- ckpt_export (
70
- net_id = "network_def" ,
71
- filepath = ts_file ,
72
- meta_file = self .meta_file ,
73
- config_file = self .config_file ,
74
- ckpt_file = ckpt_file ,
75
- key_in_ckpt = key_in_ckpt ,
76
- args_file = self .def_args_file ,
77
- use_trace = use_trace ,
78
- input_shape = [1 , 1 , 96 , 96 , 96 ] if use_trace == "True" else None ,
79
- )
80
- self .assertTrue (os .path .exists (ts_file ))
81
86
82
- @parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ])
83
- def test_ckpt_export (self , key_in_ckpt , use_trace ):
84
- save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = self .ckpt_file )
85
- ckpt_export (
86
- net_id = "network_def" ,
87
- filepath = self .ts_file ,
88
- meta_file = self .meta_file ,
89
- config_file = [self .config_file , self .def_args_file ],
90
- ckpt_file = self .ckpt_file ,
91
- key_in_ckpt = key_in_ckpt ,
92
- args_file = self .def_args_file ,
93
- use_trace = use_trace ,
94
- input_shape = [1 , 1 , 96 , 96 , 96 ] if use_trace == "True" else None ,
95
- )
96
- self .assertTrue (os .path .exists (self .ts_file ))
97
-
98
- _ , metadata , extra_files = load_net_with_metadata (
99
- self .ts_file , more_extra_files = ["inference.json" , "def_args.json" ]
100
- )
101
- self .assertIn ("schema" , metadata )
102
- self .assertIn ("meta_file" , json .loads (extra_files ["def_args.json" ]))
103
- self .assertIn ("network_def" , json .loads (extra_files ["inference.json" ]))
87
+ # check with default value
88
+ cmd = ["coverage" , "run" , "-m" , "monai.bundle" , "ckpt_export" , "--key_in_ckpt" , key_in_ckpt ]
89
+ cmd += ["--config_file" , self .config_file , "--bundle_root" , self .tempdir_obj .name ]
90
+ if use_trace == "True" :
91
+ cmd += ["--use_trace" , use_trace , "--input_shape" , "[1, 1, 96, 96, 96]" ]
92
+ command_line_tests (cmd )
93
+ self .assertTrue (os .path .exists (ts_file ))
104
94
105
95
106
96
if __name__ == "__main__" :
0 commit comments