-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathchunked_dataset.py
executable file
·180 lines (138 loc) · 5.21 KB
/
chunked_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# Copyright (c) 2021 Project Bee4Exp.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
""" Creates HDF5 dataset for model training, validation and test.
CL Args:
-i Path to directory with input video files.
--mask Path to directory with ground truth masks.
-o Path to HDF5 dataset.
--type Type of dataset to create (train, val, test).
"""
from skvideo.io import vread
from skvideo.utils import rgb2gray
from skimage.measure import label, regionprops
from skimage.morphology import closing, square
from skimage.util import img_as_float
import h5py
from skimage.transform import rescale
import numpy as np
import os
from util import get_parser
args = get_parser().parse_args()
TYPE = args.type
DATASET_PATH = args.output
SIZE = 256
CHUNK_SIZE = 64
NUM_OF_CHANNELS = 5
if TYPE == 'val' or TYPE == 'test':
MAX_CLASS = 3200
else:
MAX_CLASS = 13440
BGSUB_PATH = args.input
MASKS_PATH = args.mask
NUM_OF_CHUNKS = MAX_CLASS*4/CHUNK_SIZE
PROB0 = 0.03
PROB1 = 0.08
PROB2 = 0.25
PROB3 = 0.9
count0 = 0
count1 = 0
count2 = 0
count3 = 0
data_np = np.zeros((CHUNK_SIZE, SIZE, SIZE, NUM_OF_CHANNELS), dtype='float32')
label_np = np.zeros((CHUNK_SIZE, 1), dtype='u8')
mask_np = np.zeros((CHUNK_SIZE, SIZE//2, SIZE//2), dtype='float32')
hf = h5py.File(DATASET_PATH+'.h5', 'w')
data_dset = hf.create_dataset("data", (NUM_OF_CHUNKS*CHUNK_SIZE, SIZE, SIZE, NUM_OF_CHANNELS),
chunks=(CHUNK_SIZE, SIZE, SIZE, NUM_OF_CHANNELS),
maxshape=(None, SIZE, SIZE, NUM_OF_CHANNELS), compression="gzip",
dtype='f', compression_opts=4)
label_dset = hf.create_dataset("label", (NUM_OF_CHUNKS*CHUNK_SIZE, 1), chunks=(CHUNK_SIZE, 1),
maxshape=(None, 1), compression="gzip", dtype='u8', compression_opts=4)
mask_dset = hf.create_dataset("mask", (NUM_OF_CHUNKS*CHUNK_SIZE, SIZE//2, SIZE//2),
chunks=(CHUNK_SIZE, SIZE//2, SIZE//2), maxshape=(None, SIZE//2, SIZE//2),
compression="gzip", dtype='f', compression_opts=4)
def should_write(objects):
r = np.random.rand()
if objects == 0:
if (r < PROB0) and (count0 < MAX_CLASS):
return True
elif objects == 1:
if (r < PROB1) and (count1 < MAX_CLASS):
return True
elif objects == 2:
if (r < PROB2) and (count2 < MAX_CLASS):
return True
elif objects == 3:
if (r < PROB3) and (count3 < MAX_CLASS):
return True
return False
def written(objects):
global count0
global count1
global count2
global count3
if objects == 0:
count0 += 1
elif objects == 1:
count1 += 1
elif objects == 2:
count2 += 1
elif objects == 3:
count3 += 1
def count_objects(image):
image = img_as_float(image)
bw = closing(image > 0.3, square(3))
label_image = label(bw)
props = regionprops(label_image)
return len(props)
def extract_data(bgsub_path, mask_path, counter):
bgsub = rgb2gray(vread(bgsub_path))
mask = rgb2gray(vread(mask_path))
num_frames, height, width, depth = bgsub.shape
internal_counter = counter
for i in range(2, num_frames - 2):
for j in range(height//SIZE):
for k in range(width//SIZE):
mask_data = rescale(np.uint8(mask[i, j*SIZE:(j+1)*SIZE, k*SIZE:(k+1)*SIZE, 0]), 0.5,
anti_aliasing=False)
objects = count_objects(mask_data)
write = should_write(objects)
if write:
data_np[internal_counter % CHUNK_SIZE, :, :, :] = img_as_float(np.uint8(bgsub[i-2:i+3, j*SIZE:(j+1)*SIZE, k*SIZE:(k+1)*SIZE, 0].transpose((1, 2, 0))))
label_np[internal_counter % CHUNK_SIZE, :] = objects
mask_np[internal_counter % CHUNK_SIZE, :, :] = mask_data
written(objects)
if ((internal_counter+1) % CHUNK_SIZE) == 0:
data_dset[(internal_counter//CHUNK_SIZE)*CHUNK_SIZE:((internal_counter+1)//CHUNK_SIZE)*CHUNK_SIZE, :, :, :] = data_np
label_dset[(internal_counter//CHUNK_SIZE)*CHUNK_SIZE:((internal_counter+1)//CHUNK_SIZE)*CHUNK_SIZE, :] = label_np
mask_dset[(internal_counter//CHUNK_SIZE)*CHUNK_SIZE:((internal_counter+1)//CHUNK_SIZE)*CHUNK_SIZE, :, :] = mask_np
if write:
internal_counter += 1
if count0 >= MAX_CLASS and count1 >= MAX_CLASS and count2 >= MAX_CLASS and count3 >= MAX_CLASS:
return False, internal_counter
return True, internal_counter
rng = np.random.RandomState(42)
vids = np.arange(1, 1001)
vids = rng.permutation(vids)
i = 0
cnt = 0
running = True
if TYPE == 'val':
div = 250
offset = 500
elif TYPE == 'test':
div = 250
offset = 750
else:
div = 500
offset = 0
while running:
vid = vids[(i % div) + offset]
running, cnt = extract_data(os.path.join(BGSUB_PATH, 'Out_'+str(vid)+'.MP4'),
os.path.join(MASKS_PATH, '/Mask_' + str(vid)+'.MP4'), cnt)
print(cnt)
i += 1
hf.attrs['elems'] = cnt
hf.close()