Skip to content

Commit 5229381

Browse files
ziyeqinghancopybara-github
authored andcommitted
add object detector DataLoader in Model Maker.
PiperOrigin-RevId: 348882520
1 parent 22f86fd commit 5229381

File tree

5 files changed

+407
-1
lines changed

5 files changed

+407
-1
lines changed
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Dataloader for object detection."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import glob
21+
import hashlib
22+
import json
23+
import os
24+
import tempfile
25+
26+
from lxml import etree
27+
import tensorflow as tf
28+
from tensorflow_examples.lite.model_maker.core.data_util import dataloader
29+
import yaml
30+
31+
HAS_OBJECT_DETECTION = True
32+
try:
33+
# pylint: disable=g-import-not-at-top
34+
from efficientdet import dataloader as det_dataloader
35+
from efficientdet.dataset import create_pascal_tfrecord
36+
from efficientdet.dataset import tfrecord_util
37+
from efficientdet.keras import label_util
38+
# pylint: enable=g-import-not-at-top
39+
except ImportError:
40+
HAS_OBJECT_DETECTION = False
41+
42+
43+
def _get_cache_prefix(image_dir, annotations_dir, annotations_list):
44+
"""Get the prefix for cached files."""
45+
46+
def _get_dir_basename(dirname):
47+
return os.path.basename(os.path.abspath(dirname))
48+
49+
hasher = hashlib.md5()
50+
hasher.update(_get_dir_basename(image_dir).encode('utf-8'))
51+
hasher.update(_get_dir_basename(annotations_dir).encode('utf-8'))
52+
if annotations_list:
53+
hasher.update(' '.join(sorted(annotations_list)).encode('utf-8'))
54+
return hasher.hexdigest()
55+
56+
57+
def _get_object_detector_cache_filenames(cache_dir, image_dir, annotations_dir,
58+
annotations_list, num_shards):
59+
"""Gets cache filenames for obejct detector."""
60+
if cache_dir is None:
61+
cache_dir = tempfile.mkdtemp()
62+
print('Create the cache directory: %s.', cache_dir)
63+
cache_prefix = _get_cache_prefix(image_dir, annotations_dir, annotations_list)
64+
cache_prefix = os.path.join(cache_dir, cache_prefix)
65+
66+
tfrecord_files = [
67+
cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards)
68+
for i in range(num_shards)
69+
]
70+
annotations_json_file = cache_prefix + '_annotations.json'
71+
meta_data_file = cache_prefix + '_meta_data.yaml'
72+
73+
all_cached_files = tfrecord_files + [annotations_json_file, meta_data_file]
74+
is_cached = all(os.path.exists(path) for path in all_cached_files)
75+
return is_cached, cache_prefix, tfrecord_files, annotations_json_file, meta_data_file
76+
77+
78+
def _get_label_map(label_map):
79+
"""Gets the label map dict."""
80+
if isinstance(label_map, list):
81+
label_map_dict = {}
82+
for i, label in enumerate(label_map):
83+
# 0 is resevered for background.
84+
label_map_dict[i + 1] = label
85+
label_map = label_map_dict
86+
label_map = label_util.get_label_map(label_map)
87+
88+
if 0 in label_map and label_map[0] != 'background':
89+
raise ValueError('0 must be resevered for background.')
90+
label_map.pop(0, None)
91+
name_set = set()
92+
for idx, name in label_map.items():
93+
if not isinstance(idx, int):
94+
raise ValueError('The key (label id) in label_map must be integer.')
95+
if not isinstance(name, str):
96+
raise ValueError('The value (label name) in label_map must be string.')
97+
if name in name_set:
98+
raise ValueError('The value: %s (label name) can\'t be duplicated.' %
99+
name)
100+
name_set.add(name)
101+
return label_map
102+
103+
104+
class DataLoader(dataloader.DataLoader):
105+
"""DataLoader for object detector."""
106+
107+
def __init__(self,
108+
tfrecord_file_patten,
109+
size,
110+
label_map,
111+
annotations_json_file=None):
112+
"""Initialize DataLoader for object detector.
113+
114+
Args:
115+
tfrecord_file_patten: Glob for tfrecord files. e.g. "/tmp/coco*.tfrecord".
116+
size: The size of the dataset.
117+
label_map: Variable shows mapping label integers ids to string label
118+
names. 0 is the reserved key for `background` and doesn't need to be
119+
included in label_map. Label names can't be duplicated. Supported
120+
formats are:
121+
1. Dict, map label integers ids to string label names, such as {1:
122+
'person', 2: 'notperson'}. 2. List, a list of label names such as
123+
['person', 'notperson'] which is
124+
the same as setting label_map={1: 'person', 2: 'notperson'}.
125+
3. String, name for certain dataset. Accepted values are: 'coco', 'voc'
126+
and 'waymo'. 4. String, yaml filename that stores label_map.
127+
annotations_json_file: JSON with COCO data format containing golden
128+
bounding boxes. Used for validation. If None, use the ground truth from
129+
the dataloader. Refer to
130+
https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5
131+
for the description of COCO data format.
132+
"""
133+
if not HAS_OBJECT_DETECTION:
134+
raise NotImplementedError("Haven't support object detection yet.")
135+
super(DataLoader, self).__init__(dataset=None, size=size)
136+
self.tfrecord_file_patten = tfrecord_file_patten
137+
self.label_map = _get_label_map(label_map)
138+
self.annotations_json_file = annotations_json_file
139+
140+
@classmethod
141+
def from_pascal_voc(cls,
142+
images_dir,
143+
annotations_dir,
144+
label_map,
145+
annotations_list=None,
146+
ignore_difficult_instances=False,
147+
num_shards=100,
148+
max_num_images=None,
149+
cache_dir=None):
150+
"""Loads from dataset with PASCAL VOC format.
151+
152+
Refer to
153+
https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5#:~:text=Pascal%20VOC%20is%20an%20XML,for%20training%2C%20testing%20and%20validation
154+
for the description of PASCAL VOC data format.
155+
156+
LabelImg Tool (https://github.com/tzutalin/labelImg) can annotate the image
157+
and save annotations as XML files in PASCAL VOC data format.
158+
159+
Annotations are in the folder: ${annotations_dir}.
160+
Raw images are in the foloder: ${images_dir}.
161+
162+
Args:
163+
images_dir: Path to directory that store raw images.
164+
annotations_dir: Path to the annotations directory.
165+
label_map: Variable shows mapping label integers ids to string label
166+
names. 0 is the reserved key for `background`. Label names can't be
167+
duplicated. Supported format: 1. Dict, map label integers ids to string
168+
label names, e.g.
169+
{1: 'person', 2: 'notperson'}. 2. List, a list of label names. e.g.
170+
['person', 'notperson'] which is
171+
the same as setting label_map={1: 'person', 2: 'notperson'}.
172+
3. String, name for certain dataset. Accepted values are: 'coco', 'voc'
173+
and 'waymo'. 4. String, yaml filename that stores label_map.
174+
annotations_list: list of annotation filenames (strings) to be loaded. For
175+
instance, if there're 3 annotation files [0.xml, 1.xml, 2.xml] in
176+
`annotations_dir`, setting annotations_list=['0', '1'] makes this method
177+
only load [0.xml, 1.xml].
178+
ignore_difficult_instances: Whether to ignore difficult instances.
179+
`difficult` can be set inside `object` item in the annotation xml file.
180+
num_shards: Number of shards for output file.
181+
max_num_images: Max number of imags to process.
182+
cache_dir: The cache directory to save TFRecord and json file. When
183+
cache_dir is not set, a temporary folder will be created and will not be
184+
removed automatically after training which makes it can be used later.
185+
186+
Returns:
187+
ObjectDetectorDataLoader object.
188+
"""
189+
label_map = _get_label_map(label_map)
190+
is_cached, cache_prefix, tfrecord_files, ann_json_file, meta_data_file = \
191+
_get_object_detector_cache_filenames(cache_dir, images_dir,
192+
annotations_dir, annotations_list,
193+
num_shards)
194+
# If not cached, write data into tfrecord_file_paths and
195+
# annotations_json_file_path.
196+
# If `num_shards` differs, it's still not cached.
197+
if not is_cached:
198+
cls._write_pascal_tfrecord(images_dir, annotations_dir, label_map,
199+
annotations_list, ignore_difficult_instances,
200+
num_shards, max_num_images, tfrecord_files,
201+
ann_json_file, meta_data_file)
202+
203+
tfrecord_file_patten = cache_prefix + '-*-of-%05d.tfrecord' % num_shards
204+
if not glob.glob(tfrecord_file_patten):
205+
raise ValueError('TFRecord files are empty.')
206+
207+
with tf.io.gfile.GFile(meta_data_file, 'r') as f:
208+
meta_data = yaml.load(f, Loader=yaml.FullLoader)
209+
return DataLoader(tfrecord_file_patten, meta_data['size'],
210+
meta_data['label_map'], ann_json_file)
211+
212+
@classmethod
213+
def _write_pascal_tfrecord(cls, images_dir, annotations_dir, label_map_dict,
214+
annotations_list, ignore_difficult_instances,
215+
num_shards, max_num_images, tfrecord_files,
216+
annotations_json_file, meta_data_file):
217+
"""Write TFRecord and json file for PASCAL VOC data."""
218+
label_name2id_dict = {'background': 0}
219+
for idx, name in label_map_dict.items():
220+
label_name2id_dict[name] = idx
221+
writers = [tf.io.TFRecordWriter(path) for path in tfrecord_files]
222+
ann_json_dict = {'images': [], 'annotations': [], 'categories': []}
223+
# Gets the paths to annotations.
224+
if annotations_list:
225+
ann_path_list = [
226+
os.path.join(annotations_dir, annotation + '.xml')
227+
for annotation in annotations_list
228+
]
229+
else:
230+
ann_path_list = list(tf.io.gfile.glob(annotations_dir + r'/*.xml'))
231+
232+
for idx, ann_path in enumerate(ann_path_list):
233+
if max_num_images and idx >= max_num_images:
234+
break
235+
if idx % 100 == 0:
236+
tf.compat.v1.logging.info('On image %d of %d', idx, len(ann_path_list))
237+
with tf.io.gfile.GFile(ann_path, 'r') as fid:
238+
xml_str = fid.read()
239+
xml = etree.fromstring(xml_str)
240+
data = tfrecord_util.recursive_parse_xml_to_dict(xml)['annotation']
241+
tf_example = create_pascal_tfrecord.dict_to_tf_example(
242+
data,
243+
images_dir,
244+
label_name2id_dict,
245+
ignore_difficult_instances,
246+
ann_json_dict=ann_json_dict)
247+
writers[idx % num_shards].write(tf_example.SerializeToString())
248+
249+
meta_data = {'size': idx + 1, 'label_map': label_map_dict}
250+
with tf.io.gfile.GFile(meta_data_file, 'w') as f:
251+
yaml.dump(meta_data, f)
252+
253+
for writer in writers:
254+
writer.close()
255+
256+
with tf.io.gfile.GFile(annotations_json_file, 'w') as f:
257+
json.dump(ann_json_dict, f)
258+
259+
def gen_dataset(self,
260+
model_spec,
261+
batch_size=None,
262+
is_training=True,
263+
use_fake_data=False):
264+
"""Generate a batched tf.data.Dataset for training/evaluation.
265+
266+
Args:
267+
model_spec: Specification for the model.
268+
batch_size: A integer, the returned dataset will be batched by this size.
269+
is_training: A boolean, when True, the returned dataset will be optionally
270+
shuffled and repeated as an endless dataset.
271+
use_fake_data: Use fake input.
272+
273+
Returns:
274+
A TF dataset ready to be consumed by Keras model.
275+
"""
276+
reader = det_dataloader.InputReader(
277+
self.tfrecord_file_patten,
278+
is_training=is_training,
279+
use_fake_data=use_fake_data,
280+
max_instances_per_image=model_spec.config.max_instances_per_image,
281+
debug=model_spec.config.debug)
282+
self._dataset = reader(model_spec.config.as_dict(), batch_size=batch_size)
283+
return self._dataset
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import os
20+
21+
import numpy as np
22+
import PIL.Image
23+
import tensorflow as tf
24+
from tensorflow_examples.lite.model_maker.core import test_util
25+
26+
try:
27+
# pylint: disable=g-import-not-at-top
28+
from tensorflow_examples.lite.model_maker.core.data_util import object_detector_dataloader
29+
from efficientdet import hparams_config
30+
from efficientdet import utils
31+
# pylint: enable=g-import-not-at-top
32+
except ImportError:
33+
pass
34+
35+
36+
class MockDetectorModelSpec(object):
37+
38+
def __init__(self, model_name):
39+
self.model_name = model_name
40+
config = hparams_config.get_detection_config(model_name)
41+
config.image_size = utils.parse_image_size(config.image_size)
42+
config.update({'debug': False})
43+
self.config = config
44+
45+
46+
class ObjectDectectorDataLoaderTest(tf.test.TestCase):
47+
48+
def _create_pascal_voc(self):
49+
# Saves the image into images_dir.
50+
image_file_name = '2012_12.jpg'
51+
image_data = np.random.rand(256, 256, 3)
52+
images_dir = os.path.join(self.get_temp_dir(), 'images')
53+
os.mkdir(images_dir)
54+
save_path = os.path.join(images_dir, image_file_name)
55+
image = PIL.Image.fromarray(image_data, 'RGB')
56+
image.save(save_path)
57+
58+
# Gets the annonation path.
59+
annotations_path = test_util.get_test_data_path('2012_12.xml')
60+
annotations_dir = os.path.dirname(annotations_path)
61+
62+
label_map = {
63+
1: 'person',
64+
2: 'notperson',
65+
}
66+
return images_dir, annotations_dir, label_map
67+
68+
def test_from_pascal_voc(self):
69+
if not object_detector_dataloader.HAS_OBJECT_DETECTION:
70+
return
71+
72+
images_dir, annotations_dir, label_map = self._create_pascal_voc()
73+
model_spec = MockDetectorModelSpec('efficientdet-lite0')
74+
75+
data = object_detector_dataloader.DataLoader.from_pascal_voc(
76+
images_dir, annotations_dir, label_map)
77+
78+
self.assertIsInstance(data, object_detector_dataloader.DataLoader)
79+
self.assertLen(data, 1)
80+
self.assertEqual(data.label_map, label_map)
81+
82+
ds = data.gen_dataset(model_spec, batch_size=1, is_training=False)
83+
for i, (images, labels) in enumerate(ds):
84+
self.assertEqual(i, 0)
85+
images_shape = tf.shape(images).numpy()
86+
expected_shape = np.array([1, *model_spec.config.image_size, 3])
87+
self.assertTrue((images_shape == expected_shape).all())
88+
self.assertLen(labels, 15)
89+
90+
ds1 = data.gen_dataset(model_spec, batch_size=1, is_training=True)
91+
self.assertEqual(ds1.cardinality(), tf.data.INFINITE_CARDINALITY)
92+
for images, labels in ds1.take(10):
93+
images_shape = tf.shape(images).numpy()
94+
expected_shape = np.array([1, *model_spec.config.image_size, 3])
95+
self.assertTrue((images_shape == expected_shape).all())
96+
self.assertLen(labels, 15)
97+
98+
99+
if __name__ == '__main__':
100+
tf.test.main()

0 commit comments

Comments
 (0)