Skip to content

Commit 3049df0

Browse files
author
Karthik Vadla
authored
Refactor (#194)
1 parent 02c5081 commit 3049df0

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

benchmarks/common/base_benchmark_util.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,28 @@ def validate_args(self, args):
174174
format(model_source_dir))
175175
self.check_for_link("model source directory", model_source_dir)
176176

177+
# check checkpoint location
178+
checkpoint_dir = args.checkpoint
179+
if checkpoint_dir is not None:
180+
if not os.path.exists(checkpoint_dir):
181+
raise IOError("The checkpoint location {} does not exist.".
182+
format(checkpoint_dir))
183+
elif not os.path.isdir(checkpoint_dir):
184+
raise IOError("The checkpoint location {} is not a directory.".
185+
format(checkpoint_dir))
186+
self.check_for_link("checkpoint directory", checkpoint_dir)
187+
188+
# check if input graph file exists
189+
input_graph = args.input_graph
190+
if input_graph is not None:
191+
if not os.path.exists(input_graph):
192+
raise IOError("The input graph {} does not exist.".
193+
format(input_graph))
194+
if not os.path.isfile(input_graph):
195+
raise IOError("The input graph {} must be a file.".
196+
format(input_graph))
197+
self.check_for_link("input graph", input_graph)
198+
177199
# check model_name exists
178200
if not args.model_name:
179201
raise ValueError("The model name is not valid")

benchmarks/launch_benchmark.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -78,28 +78,6 @@ def validate_args(self, args):
7878
raise ValueError("The specified framework is not supported: {}".
7979
format(args.framework))
8080

81-
# check checkpoint location
82-
checkpoint_dir = args.checkpoint
83-
if checkpoint_dir is not None:
84-
if not os.path.exists(checkpoint_dir):
85-
raise IOError("The checkpoint location {} does not exist.".
86-
format(checkpoint_dir))
87-
elif not os.path.isdir(checkpoint_dir):
88-
raise IOError("The checkpoint location {} is not a directory.".
89-
format(checkpoint_dir))
90-
self.check_for_link("checkpoint directory", checkpoint_dir)
91-
92-
# check if input graph file exists
93-
input_graph = args.input_graph
94-
if input_graph is not None:
95-
if not os.path.exists(input_graph):
96-
raise IOError("The input graph {} does not exist.".
97-
format(input_graph))
98-
if not os.path.isfile(input_graph):
99-
raise IOError("The input graph {} must be a file.".
100-
format(input_graph))
101-
self.check_for_link("input graph", input_graph)
102-
10381
# if neither benchmark_only or accuracy_only are specified, then enable
10482
# benchmark_only as the default
10583
if not args.benchmark_only and not args.accuracy_only:

0 commit comments

Comments
 (0)