-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathbuild_tf_records.py
120 lines (94 loc) · 3.44 KB
/
build_tf_records.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@author: ZuoXiang
@contact: [email protected]
@file: build_tf_records.py
@time: 2019/4/18 17:31
@desc:
"""
import re
import sys
import random
import threading
import numpy as np
import tensorflow as tf
from queue import Queue
from datetime import datetime
from hparams import hparams as hp
from data_utils import ImageReader, process_image_files_batch
slim = tf.contrib.slim
def create_tfrecord(dataset, dataset_name, output_directory, num_shards,
num_threads, shuffle=True, store_image=True):
"""Create TFRecords
:param dataset: list, a list of an image json.
:param dataset_name:
:param output_directory:
:param num_shards:
:param num_threads:
:param shuffle:
:param store_image:
:return:
"""
# Images in TFRecords set must be shuffled properly
if shuffle:
random.shuffle(dataset)
# Break all images into batches with a [ranges[i][0], ranges[i][1]].
spacing = np.linspace(0, len(dataset), num_shards+1).astype(int)
ranges = []
threads = []
for i in range(len(spacing) - 1):
ranges.append([spacing[i], spacing[i+1]])
# Launch a thread for each batch.
print('Lanching %d threads for spacings: %s' % (num_threads, ranges))
sys.stdout.flush()
# Create a mechanism for monitoring when all threads finished.
coord = tf.train.Coordinator()
# Create a generic TensorFlow-based utility for converting all image reader.
image_reader = ImageReader()
# A Queue to hold the image examples that fail to process.
error_queue = Queue()
for thread_index in range(len(ranges)):
args = (image_reader, thread_index, ranges, dataset_name, output_directory,
dataset, num_shards, store_image, error_queue)
t = threading.Thread(target=process_image_files_batch, args=args)
t.start()
threads.append(t)
# Wait for all the threads to terminate
coord.join(threads)
print('%s: Finished writing all %d images in data set.' %
(datetime.now(), len(dataset)))
# Collect the error messages.
errors = []
while not error_queue.empty():
errors.append(error_queue.get())
print('%d examples failed.' % (len(errors),))
return errors
def create_tf_records(input_file, save_file):
with open(input_file, 'r') as f1:
data = f1.readlines()
iter = 0
with tf.python_io.TFRecordWriter(save_file) as writer:
for line in data:
tmp_data = line.strip().split('\t')
if len(tmp_data) != 3:
raise AssertionError('Data split error! Please check data!')
filenames = hp.image_path + tmp_data[0]
with open(filenames, 'rb') as f2:
encode_image = f2.read()
category = int(tmp_data[1])
attribute = re.split('\s+', tmp_data[2].strip())
attribute = [int(i) for i in attribute]
if len(attribute) != 1000:
raise AssertionError("Attribute vector's shape not equal 1000! Please check data!")
try:
tf_example = _image_example(encode_image, category, attribute)
iter += 1
except Exception as e:
raise e
if iter % 500 == 0:
print('Processed image num: {}'.format(iter))
writer.write(tf_example.SerializeToString())
print('Done!')
if __name__ == '__main__':
create_tf_records(r'', r'')