Skip to content

Commit 7aacce4

Browse files
author
David Foster
committed
adding utils files
1 parent 9880aa2 commit 7aacce4

8 files changed

+827
-212
lines changed

06_02_qa_train.ipynb

+50-81
Large diffs are not rendered by default.

06_03_qa_analysis.ipynb

+91-129
Large diffs are not rendered by default.

model.png

-88 Bytes
Loading

utils/.gitignore

-2
This file was deleted.

utils/callbacks.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from keras.callbacks import Callback, LearningRateScheduler
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
import os
5+
6+
#### CALLBACKS
7+
class CustomCallback(Callback):
8+
9+
def __init__(self, run_folder, print_every_n_batches, initial_epoch, vae):
10+
self.epoch = initial_epoch
11+
self.run_folder = run_folder
12+
self.print_every_n_batches = print_every_n_batches
13+
self.vae = vae
14+
15+
def on_batch_end(self, batch, logs={}):
16+
if batch % self.print_every_n_batches == 0:
17+
z_new = np.random.normal(size = (1,self.vae.z_dim))
18+
reconst = self.vae.decoder.predict(np.array(z_new))[0].squeeze()
19+
20+
filepath = os.path.join(self.run_folder, 'images/img_' + str(self.epoch).zfill(3) + '_' + str(batch) + '.jpg')
21+
if len(reconst.shape) == 2:
22+
plt.imsave(filepath, reconst, cmap='gray_r')
23+
else:
24+
plt.imsave(filepath, reconst)
25+
26+
def on_epoch_begin(self, epoch, logs={}):
27+
self.epoch += 1
28+
29+
30+
31+
def step_decay_schedule(initial_lr, decay_factor=0.5, step_size=1):
32+
'''
33+
Wrapper function to create a LearningRateScheduler with step decay schedule.
34+
'''
35+
def schedule(epoch):
36+
new_lr = initial_lr * (decay_factor ** np.floor(epoch/step_size))
37+
38+
return new_lr
39+
40+
return LearningRateScheduler(schedule)

utils/loaders.py

+315
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
import pickle
2+
import os
3+
4+
from keras.datasets import mnist, cifar100,cifar10
5+
from keras.preprocessing.image import ImageDataGenerator, load_img, save_img, img_to_array
6+
7+
import pandas as pd
8+
9+
import numpy as np
10+
from os import walk, getcwd
11+
import h5py
12+
13+
import scipy
14+
from glob import glob
15+
16+
from keras.applications import vgg19
17+
from keras import backend as K
18+
from keras.utils import to_categorical
19+
20+
import pdb
21+
22+
23+
class ImageLabelLoader():
24+
def __init__(self, image_folder, target_size):
25+
self.image_folder = image_folder
26+
self.target_size = target_size
27+
28+
def build(self, att, batch_size, label = None):
29+
30+
data_gen = ImageDataGenerator(rescale=1./255)
31+
if label:
32+
data_flow = data_gen.flow_from_dataframe(
33+
att
34+
, self.image_folder
35+
, x_col='image_id'
36+
, y_col=label
37+
, target_size=self.target_size
38+
, class_mode='other'
39+
, batch_size=batch_size
40+
, shuffle=True
41+
)
42+
else:
43+
data_flow = data_gen.flow_from_dataframe(
44+
att
45+
, self.image_folder
46+
, x_col='image_id'
47+
, target_size=self.target_size
48+
, class_mode='input'
49+
, batch_size=batch_size
50+
, shuffle=True
51+
)
52+
53+
return data_flow
54+
55+
56+
57+
58+
class DataLoader():
59+
def __init__(self, dataset_name, img_res=(256, 256)):
60+
self.dataset_name = dataset_name
61+
self.img_res = img_res
62+
63+
def load_data(self, domain, batch_size=1, is_testing=False):
64+
data_type = "train%s" % domain if not is_testing else "test%s" % domain
65+
path = glob('./data/%s/%s/*' % (self.dataset_name, data_type))
66+
67+
batch_images = np.random.choice(path, size=batch_size)
68+
69+
imgs = []
70+
for img_path in batch_images:
71+
img = self.imread(img_path)
72+
if not is_testing:
73+
img = scipy.misc.imresize(img, self.img_res)
74+
75+
if np.random.random() > 0.5:
76+
img = np.fliplr(img)
77+
else:
78+
img = scipy.misc.imresize(img, self.img_res)
79+
imgs.append(img)
80+
81+
imgs = np.array(imgs)/127.5 - 1.
82+
83+
return imgs
84+
85+
def load_batch(self, batch_size=1, is_testing=False):
86+
data_type = "train" if not is_testing else "val"
87+
path_A = glob('./data/%s/%sA/*' % (self.dataset_name, data_type))
88+
path_B = glob('./data/%s/%sB/*' % (self.dataset_name, data_type))
89+
90+
self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
91+
total_samples = self.n_batches * batch_size
92+
93+
# Sample n_batches * batch_size from each path list so that model sees all
94+
# samples from both domains
95+
path_A = np.random.choice(path_A, total_samples, replace=False)
96+
path_B = np.random.choice(path_B, total_samples, replace=False)
97+
98+
for i in range(self.n_batches-1):
99+
batch_A = path_A[i*batch_size:(i+1)*batch_size]
100+
batch_B = path_B[i*batch_size:(i+1)*batch_size]
101+
imgs_A, imgs_B = [], []
102+
for img_A, img_B in zip(batch_A, batch_B):
103+
img_A = self.imread(img_A)
104+
img_B = self.imread(img_B)
105+
106+
img_A = scipy.misc.imresize(img_A, self.img_res)
107+
img_B = scipy.misc.imresize(img_B, self.img_res)
108+
109+
if not is_testing and np.random.random() > 0.5:
110+
img_A = np.fliplr(img_A)
111+
img_B = np.fliplr(img_B)
112+
113+
imgs_A.append(img_A)
114+
imgs_B.append(img_B)
115+
116+
imgs_A = np.array(imgs_A)/127.5 - 1.
117+
imgs_B = np.array(imgs_B)/127.5 - 1.
118+
119+
yield imgs_A, imgs_B
120+
121+
def load_img(self, path):
122+
img = self.imread(path)
123+
img = scipy.misc.imresize(img, self.img_res)
124+
img = img/127.5 - 1.
125+
return img[np.newaxis, :, :, :]
126+
127+
def imread(self, path):
128+
return scipy.misc.imread(path, mode='RGB').astype(np.float)
129+
130+
131+
132+
133+
def load_model(model_class, folder):
134+
135+
with open(os.path.join(folder, 'params.pkl'), 'rb') as f:
136+
params = pickle.load(f)
137+
138+
model = model_class(*params)
139+
140+
model.load_weights(os.path.join(folder, 'weights/weights.h5'))
141+
142+
return model
143+
144+
145+
def load_mnist():
146+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
147+
148+
x_train = x_train.astype('float32') / 255.
149+
x_train = x_train.reshape(x_train.shape + (1,))
150+
x_test = x_test.astype('float32') / 255.
151+
x_test = x_test.reshape(x_test.shape + (1,))
152+
153+
return (x_train, y_train), (x_test, y_test)
154+
155+
def load_mnist_gan():
156+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
157+
158+
x_train = (x_train.astype('float32') - 127.5) / 127.5
159+
x_train = x_train.reshape(x_train.shape + (1,))
160+
x_test = (x_test.astype('float32') - 127.5) / 127.5
161+
x_test = x_test.reshape(x_test.shape + (1,))
162+
163+
return (x_train, y_train), (x_test, y_test)
164+
165+
166+
167+
def load_fashion_mnist(input_rows, input_cols, path='./data/fashion/fashion-mnist_train.csv'):
168+
#read the csv data
169+
df = pd.read_csv(path)
170+
#extract the image pixels
171+
X_train = df.drop(columns = ['label'])
172+
X_train = X_train.values
173+
X_train = (X_train.astype('float32') - 127.5) / 127.5
174+
X_train = X_train.reshape(X_train.shape[0], input_rows, input_cols, 1)
175+
#extract the labels
176+
y_train = df['label'].values
177+
178+
return X_train, y_train
179+
180+
def load_safari(folder):
181+
182+
mypath = os.path.join("./data", folder)
183+
txt_name_list = []
184+
for (dirpath, dirnames, filenames) in walk(mypath):
185+
for f in filenames:
186+
if f != '.DS_Store':
187+
txt_name_list.append(f)
188+
break
189+
190+
slice_train = int(80000/len(txt_name_list)) ###Setting value to be 80000 for the final dataset
191+
i = 0
192+
seed = np.random.randint(1, 10e6)
193+
194+
for txt_name in txt_name_list:
195+
txt_path = os.path.join(mypath,txt_name)
196+
x = np.load(txt_path)
197+
x = (x.astype('float32') - 127.5) / 127.5
198+
# x = x.astype('float32') / 255.0
199+
200+
x = x.reshape(x.shape[0], 28, 28, 1)
201+
202+
y = [i] * len(x)
203+
np.random.seed(seed)
204+
np.random.shuffle(x)
205+
np.random.seed(seed)
206+
np.random.shuffle(y)
207+
x = x[:slice_train]
208+
y = y[:slice_train]
209+
if i != 0:
210+
xtotal = np.concatenate((x,xtotal), axis=0)
211+
ytotal = np.concatenate((y,ytotal), axis=0)
212+
else:
213+
xtotal = x
214+
ytotal = y
215+
i += 1
216+
217+
return xtotal, ytotal
218+
219+
220+
221+
def load_cifar(label, num):
222+
if num == 10:
223+
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
224+
else:
225+
(x_train, y_train), (x_test, y_test) = cifar100.load_data(label_mode = 'fine')
226+
227+
train_mask = [y[0]==label for y in y_train]
228+
test_mask = [y[0]==label for y in y_test]
229+
230+
x_data = np.concatenate([x_train[train_mask], x_test[test_mask]])
231+
y_data = np.concatenate([y_train[train_mask], y_test[test_mask]])
232+
233+
x_data = (x_data.astype('float32') - 127.5) / 127.5
234+
235+
return (x_data, y_data)
236+
237+
238+
def load_celeb(data_name, image_size, batch_size):
239+
data_folder = os.path.join("./data", data_name)
240+
241+
data_gen = ImageDataGenerator(preprocessing_function=lambda x: (x.astype('float32') - 127.5) / 127.5)
242+
243+
x_train = data_gen.flow_from_directory(data_folder
244+
, target_size = (image_size,image_size)
245+
, batch_size = batch_size
246+
, shuffle = True
247+
, class_mode = 'input'
248+
, subset = "training"
249+
)
250+
251+
return x_train
252+
253+
254+
def load_music(data_name, filename, n_bars, n_steps_per_bar):
255+
file = os.path.join("./data", data_name, filename)
256+
257+
with np.load(file, encoding='bytes') as f:
258+
data = f['train']
259+
260+
data_ints = []
261+
262+
for x in data:
263+
counter = 0
264+
cont = True
265+
while cont:
266+
if not np.any(np.isnan(x[counter:(counter+4)])):
267+
cont = False
268+
else:
269+
counter += 4
270+
271+
if n_bars * n_steps_per_bar < x.shape[0]:
272+
data_ints.append(x[counter:(counter + (n_bars * n_steps_per_bar)),:])
273+
274+
275+
data_ints = np.array(data_ints)
276+
277+
n_songs = data_ints.shape[0]
278+
n_tracks = data_ints.shape[2]
279+
280+
data_ints = data_ints.reshape([n_songs, n_bars, n_steps_per_bar, n_tracks])
281+
282+
max_note = 83
283+
284+
where_are_NaNs = np.isnan(data_ints)
285+
data_ints[where_are_NaNs] = max_note + 1
286+
max_note = max_note + 1
287+
288+
data_ints = data_ints.astype(int)
289+
290+
num_classes = max_note + 1
291+
292+
293+
data_binary = np.eye(num_classes)[data_ints]
294+
data_binary[data_binary==0] = -1
295+
data_binary = np.delete(data_binary, max_note,-1)
296+
297+
data_binary = data_binary.transpose([0,1,2, 4,3])
298+
299+
300+
301+
302+
303+
return data_binary, data_ints, data
304+
305+
306+
def preprocess_image(data_name, file, img_nrows, img_ncols):
307+
308+
image_path = os.path.join('./data', data_name, file)
309+
310+
img = load_img(image_path, target_size=(img_nrows, img_ncols))
311+
img = img_to_array(img)
312+
img = np.expand_dims(img, axis=0)
313+
img = vgg19.preprocess_input(img)
314+
return img
315+

0 commit comments

Comments
 (0)