Skip to content

Commit 08aacd7

Browse files
committed
Added train and export scripts
These scripts are provided from TensorFlow Object Detection API. The train script is used for training a pretrained model with new data. So, I grabbed ssd mobilenet and trained this pretrained model with traffic light images. So, it already was good at general object detection, such as identifying those are traffic lights. I further trained it to identify those are traffic lights that are red, green, yellow. Then export inference script, exports the model checkpoint that was saved from training as a frozen graph, so we can use it for later. Exporting a frozen graph is very important for tensorflow version compability. The ssd mobilenet was trained with tensorflow 1.12, then I trained it with tensorflow 1.4.0. I used GPU. The important part, if I want to use it with tensorflow 1.3.0, then I need to export the trained model with tensorflow 1.4.0 since they use the same CUDA and cuDNN versions.
1 parent f1515ab commit 08aacd7

File tree

2 files changed

+282
-0
lines changed

2 files changed

+282
-0
lines changed

export_inference_graph.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
r"""Tool to export an object detection model for inference.
17+
18+
Prepares an object detection tensorflow graph for inference using model
19+
configuration and an optional trained checkpoint. Outputs inference
20+
graph, associated checkpoint files, a frozen inference graph and a
21+
SavedModel (https://tensorflow.github.io/serving/serving_basic.html).
22+
23+
The inference graph contains one of three input nodes depending on the user
24+
specified option.
25+
* `image_tensor`: Accepts a uint8 4-D tensor of shape [None, None, None, 3]
26+
* `encoded_image_string_tensor`: Accepts a 1-D string tensor of shape [None]
27+
containing encoded PNG or JPEG images. Image resolutions are expected to be
28+
the same if more than 1 image is provided.
29+
* `tf_example`: Accepts a 1-D string tensor of shape [None] containing
30+
serialized TFExample protos. Image resolutions are expected to be the same
31+
if more than 1 image is provided.
32+
33+
and the following output nodes returned by the model.postprocess(..):
34+
* `num_detections`: Outputs float32 tensors of the form [batch]
35+
that specifies the number of valid boxes per image in the batch.
36+
* `detection_boxes`: Outputs float32 tensors of the form
37+
[batch, num_boxes, 4] containing detected boxes.
38+
* `detection_scores`: Outputs float32 tensors of the form
39+
[batch, num_boxes] containing class scores for the detections.
40+
* `detection_classes`: Outputs float32 tensors of the form
41+
[batch, num_boxes] containing classes for the detections.
42+
* `detection_masks`: Outputs float32 tensors of the form
43+
[batch, num_boxes, mask_height, mask_width] containing predicted instance
44+
masks for each box if its present in the dictionary of postprocessed
45+
tensors returned by the model.
46+
47+
Notes:
48+
* This tool uses `use_moving_averages` from eval_config to decide which
49+
weights to freeze.
50+
51+
Example Usage:
52+
--------------
53+
python export_inference_graph \
54+
--input_type image_tensor \
55+
--pipeline_config_path path/to/ssd_inception_v2.config \
56+
--trained_checkpoint_prefix path/to/model.ckpt \
57+
--output_directory path/to/exported_model_directory
58+
59+
The expected output would be in the directory
60+
path/to/exported_model_directory (which is created if it does not exist)
61+
with contents:
62+
- graph.pbtxt
63+
- model.ckpt.data-00000-of-00001
64+
- model.ckpt.info
65+
- model.ckpt.meta
66+
- frozen_inference_graph.pb
67+
+ saved_model (a directory)
68+
"""
69+
import tensorflow as tf
70+
from google.protobuf import text_format
71+
from object_detection import exporter
72+
from object_detection.protos import pipeline_pb2
73+
74+
slim = tf.contrib.slim
75+
flags = tf.app.flags
76+
77+
flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be '
78+
'one of [`image_tensor`, `encoded_image_string_tensor`, '
79+
'`tf_example`]')
80+
flags.DEFINE_string('input_shape', None,
81+
'If input_type is `image_tensor`, this can explicitly set '
82+
'the shape of this input tensor to a fixed size. The '
83+
'dimensions are to be provided as a comma-separated list '
84+
'of integers. A value of -1 can be used for unknown '
85+
'dimensions. If not specified, for an `image_tensor, the '
86+
'default shape will be partially specified as '
87+
'`[None, None, None, 3]`.')
88+
flags.DEFINE_string('pipeline_config_path', None,
89+
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
90+
'file.')
91+
flags.DEFINE_string('trained_checkpoint_prefix', None,
92+
'Path to trained checkpoint, typically of the form '
93+
'path/to/model.ckpt')
94+
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
95+
96+
tf.app.flags.mark_flag_as_required('pipeline_config_path')
97+
tf.app.flags.mark_flag_as_required('trained_checkpoint_prefix')
98+
tf.app.flags.mark_flag_as_required('output_directory')
99+
FLAGS = flags.FLAGS
100+
101+
102+
def main(_):
103+
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
104+
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
105+
text_format.Merge(f.read(), pipeline_config)
106+
if FLAGS.input_shape:
107+
input_shape = [
108+
int(dim) if dim != '-1' else None
109+
for dim in FLAGS.input_shape.split(',')
110+
]
111+
else:
112+
input_shape = None
113+
exporter.export_inference_graph(FLAGS.input_type, pipeline_config,
114+
FLAGS.trained_checkpoint_prefix,
115+
FLAGS.output_directory, input_shape)
116+
117+
118+
if __name__ == '__main__':
119+
tf.app.run()

train.py

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
r"""Training executable for detection models.
17+
18+
This executable is used to train DetectionModels. There are two ways of
19+
configuring the training job:
20+
21+
1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file
22+
can be specified by --pipeline_config_path.
23+
24+
Example usage:
25+
./train \
26+
--logtostderr \
27+
--train_dir=path/to/train_dir \
28+
--pipeline_config_path=pipeline_config.pbtxt
29+
30+
2) Three configuration files can be provided: a model_pb2.DetectionModel
31+
configuration file to define what type of DetectionModel is being trained, an
32+
input_reader_pb2.InputReader file to specify what training data will be used and
33+
a train_pb2.TrainConfig file to configure training parameters.
34+
35+
Example usage:
36+
./train \
37+
--logtostderr \
38+
--train_dir=path/to/train_dir \
39+
--model_config_path=model_config.pbtxt \
40+
--train_config_path=train_config.pbtxt \
41+
--input_config_path=train_input_config.pbtxt
42+
"""
43+
44+
import functools
45+
import json
46+
import os
47+
import tensorflow as tf
48+
49+
from object_detection import trainer
50+
from object_detection.builders import input_reader_builder
51+
from object_detection.builders import model_builder
52+
from object_detection.utils import config_util
53+
54+
tf.logging.set_verbosity(tf.logging.INFO)
55+
56+
flags = tf.app.flags
57+
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
58+
flags.DEFINE_integer('task', 0, 'task id')
59+
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy per worker.')
60+
flags.DEFINE_boolean('clone_on_cpu', False,
61+
'Force clones to be deployed on CPU. Note that even if '
62+
'set to False (allowing ops to run on gpu), some ops may '
63+
'still be run on the CPU if they have no GPU kernel.')
64+
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker+trainer '
65+
'replicas.')
66+
flags.DEFINE_integer('ps_tasks', 0,
67+
'Number of parameter server tasks. If None, does not use '
68+
'a parameter server.')
69+
flags.DEFINE_string('train_dir', '',
70+
'Directory to save the checkpoints and training summaries.')
71+
72+
flags.DEFINE_string('pipeline_config_path', '',
73+
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
74+
'file. If provided, other configs are ignored')
75+
76+
flags.DEFINE_string('train_config_path', '',
77+
'Path to a train_pb2.TrainConfig config file.')
78+
flags.DEFINE_string('input_config_path', '',
79+
'Path to an input_reader_pb2.InputReader config file.')
80+
flags.DEFINE_string('model_config_path', '',
81+
'Path to a model_pb2.DetectionModel config file.')
82+
83+
FLAGS = flags.FLAGS
84+
85+
86+
def main(_):
87+
assert FLAGS.train_dir, '`train_dir` is missing.'
88+
if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)
89+
if FLAGS.pipeline_config_path:
90+
configs = config_util.get_configs_from_pipeline_file(
91+
FLAGS.pipeline_config_path)
92+
if FLAGS.task == 0:
93+
tf.gfile.Copy(FLAGS.pipeline_config_path,
94+
os.path.join(FLAGS.train_dir, 'pipeline.config'),
95+
overwrite=True)
96+
else:
97+
configs = config_util.get_configs_from_multiple_files(
98+
model_config_path=FLAGS.model_config_path,
99+
train_config_path=FLAGS.train_config_path,
100+
train_input_config_path=FLAGS.input_config_path)
101+
if FLAGS.task == 0:
102+
for name, config in [('model.config', FLAGS.model_config_path),
103+
('train.config', FLAGS.train_config_path),
104+
('input.config', FLAGS.input_config_path)]:
105+
tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name),
106+
overwrite=True)
107+
108+
model_config = configs['model']
109+
train_config = configs['train_config']
110+
input_config = configs['train_input_config']
111+
112+
model_fn = functools.partial(
113+
model_builder.build,
114+
model_config=model_config,
115+
is_training=True)
116+
117+
create_input_dict_fn = functools.partial(
118+
input_reader_builder.build, input_config)
119+
120+
env = json.loads(os.environ.get('TF_CONFIG', '{}'))
121+
cluster_data = env.get('cluster', None)
122+
cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
123+
task_data = env.get('task', None) or {'type': 'master', 'index': 0}
124+
task_info = type('TaskSpec', (object,), task_data)
125+
126+
# Parameters for a single worker.
127+
ps_tasks = 0
128+
worker_replicas = 1
129+
worker_job_name = 'lonely_worker'
130+
task = 0
131+
is_chief = True
132+
master = ''
133+
134+
if cluster_data and 'worker' in cluster_data:
135+
# Number of total worker replicas include "worker"s and the "master".
136+
worker_replicas = len(cluster_data['worker']) + 1
137+
if cluster_data and 'ps' in cluster_data:
138+
ps_tasks = len(cluster_data['ps'])
139+
140+
if worker_replicas > 1 and ps_tasks < 1:
141+
raise ValueError('At least 1 ps task is needed for distributed training.')
142+
143+
if worker_replicas >= 1 and ps_tasks > 0:
144+
# Set up distributed training.
145+
server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc',
146+
job_name=task_info.type,
147+
task_index=task_info.index)
148+
if task_info.type == 'ps':
149+
server.join()
150+
return
151+
152+
worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
153+
task = task_info.index
154+
is_chief = (task_info.type == 'master')
155+
master = server.target
156+
157+
trainer.train(create_input_dict_fn, model_fn, train_config, master, task,
158+
FLAGS.num_clones, worker_replicas, FLAGS.clone_on_cpu, ps_tasks,
159+
worker_job_name, is_chief, FLAGS.train_dir)
160+
161+
162+
if __name__ == '__main__':
163+
tf.app.run()

0 commit comments

Comments
 (0)