Skip to content

Commit 3e0fe76

Browse files
committed
Revert "Directly tests export_ckpt function instead of using command_line_tests"
This reverts commit d4b01e6.
1 parent ba16743 commit 3e0fe76

File tree

1 file changed

+27
-37
lines changed

1 file changed

+27
-37
lines changed

tests/test_bundle_ckpt_export.py

+27-37
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
from parameterized import parameterized
2020

2121
from monai.bundle import ConfigParser
22-
from monai.bundle.scripts import ckpt_export
2322
from monai.data import load_net_with_metadata
2423
from monai.networks import save_state
25-
from tests.utils import skip_if_windows
24+
from tests.utils import command_line_tests, skip_if_windows
2625

2726
TEST_CASE_1 = ["", ""]
2827

@@ -52,6 +51,8 @@ def setUp(self):
5251
self.parser.export_config_file(config=self.def_args, filepath=self.def_args_file)
5352
self.parser.read_config(self.config_file)
5453
self.net = self.parser.get_parsed_content("network_def")
54+
self.cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", self.ts_file]
55+
self.cmd += ["--meta_file", self.meta_file, "--config_file", f"['{self.config_file}','{self.def_args_file}']", "--ckpt_file"]
5556

5657
def tearDown(self):
5758
if self.device is not None:
@@ -61,46 +62,35 @@ def tearDown(self):
6162
self.tempdir_obj.cleanup()
6263

6364
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
64-
def test_ckpt_export_default(self, key_in_ckpt, use_trace):
65+
def test_export(self, key_in_ckpt, use_trace):
66+
save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=self.ckpt_file)
67+
full_cmd = self.cmd + [self.ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", self.def_args_file]
68+
if use_trace == "True":
69+
full_cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
70+
command_line_tests(full_cmd)
71+
self.assertTrue(os.path.exists(self.ts_file))
72+
73+
_, metadata, extra_files = load_net_with_metadata(
74+
self.ts_file, more_extra_files=["inference.json", "def_args.json"]
75+
)
76+
self.assertIn("schema", metadata)
77+
self.assertIn("meta_file", json.loads(extra_files["def_args.json"]))
78+
self.assertIn("network_def", json.loads(extra_files["inference.json"]))
79+
80+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
81+
def test_default_value(self, key_in_ckpt, use_trace):
6582
ckpt_file = os.path.join(self.tempdir_obj.name, "models/model.pt")
6683
ts_file = os.path.join(self.tempdir_obj.name, "models/model.ts")
6784

6885
save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=ckpt_file)
69-
ckpt_export(
70-
net_id="network_def",
71-
filepath=ts_file,
72-
meta_file=self.meta_file,
73-
config_file=self.config_file,
74-
ckpt_file=ckpt_file,
75-
key_in_ckpt=key_in_ckpt,
76-
args_file=self.def_args_file,
77-
use_trace=use_trace,
78-
input_shape=[1, 1, 96, 96, 96] if use_trace == "True" else None,
79-
)
80-
self.assertTrue(os.path.exists(ts_file))
8186

82-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
83-
def test_ckpt_export(self, key_in_ckpt, use_trace):
84-
save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=self.ckpt_file)
85-
ckpt_export(
86-
net_id="network_def",
87-
filepath=self.ts_file,
88-
meta_file=self.meta_file,
89-
config_file=[self.config_file, self.def_args_file],
90-
ckpt_file=self.ckpt_file,
91-
key_in_ckpt=key_in_ckpt,
92-
args_file=self.def_args_file,
93-
use_trace=use_trace,
94-
input_shape=[1, 1, 96, 96, 96] if use_trace == "True" else None,
95-
)
96-
self.assertTrue(os.path.exists(self.ts_file))
97-
98-
_, metadata, extra_files = load_net_with_metadata(
99-
self.ts_file, more_extra_files=["inference.json", "def_args.json"]
100-
)
101-
self.assertIn("schema", metadata)
102-
self.assertIn("meta_file", json.loads(extra_files["def_args.json"]))
103-
self.assertIn("network_def", json.loads(extra_files["inference.json"]))
87+
# check with default value
88+
cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "--key_in_ckpt", key_in_ckpt]
89+
cmd += ["--config_file", self.config_file, "--bundle_root", self.tempdir_obj.name]
90+
if use_trace == "True":
91+
cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
92+
command_line_tests(cmd)
93+
self.assertTrue(os.path.exists(ts_file))
10494

10595

10696
if __name__ == "__main__":

0 commit comments

Comments
 (0)