-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathaudio_train.py
76 lines (56 loc) · 1.67 KB
/
audio_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#%% Setup.
import signal
import sys
import numpy as np
import scipy.io.wavfile
from keras.utils.visualize_util import plot
from keras.callbacks import TensorBoard, ModelCheckpoint
from keras.utils import np_utils
from eva.models.wavenet import Wavenet, compute_receptive_field
from eva.util.mutil import sparse_labels
#%% Data
RATE, DATA = scipy.io.wavfile.read('./data/undertale/undertale_001_once_upon_a_time.comp.wav')
#%% Train Config.
EPOCHS = 2000
BATCH = 8
#%% Model Config.
MODEL = Wavenet
FILTERS = 32
DEPTH = 10
STACKS = 5
BINS = 256
LAST = RATE
LENGTH = LAST + compute_receptive_field(RATE, DEPTH, STACKS)[0]
LOAD = False
#%% Model.
INPUT = (LENGTH, BINS)
ARGS = (INPUT, FILTERS, DEPTH, STACKS, LAST)
M = MODEL(*ARGS)
if LOAD:
M.load_weights('model.h5')
M.summary()
plot(M)
#%% Train.
padded_data = np.zeros(DATA.shape[0]+LENGTH-1)
padded_data[LENGTH-1:] = DATA
def train_gen():
while True:
i = np.random.randint(0, DATA.shape[0]-2, size=BATCH, dtype=int)
data = np.zeros((BATCH, LENGTH))
y = np.zeros((BATCH, LAST, 1))
x = np.zeros((BATCH, LENGTH, BINS))
for s in range(BATCH):
si = i[s]
data[s] = padded_data[si:si+LENGTH].astype(int)
y[s] = np.expand_dims(data[s, -LAST:], -1)
x[s, list(range(LENGTH)), data[s].astype(int)] = 1
yield x, y
def save():
M.save('sigint_model.h5')
def save_gracefully(signal, frame):
save()
sys.exit(0)
signal.signal(signal.SIGINT, save_gracefully)
# Fuck theano and its recursions <3
sys.setrecursionlimit(50000)
M.fit_generator(train_gen(), samples_per_epoch=RATE//4, nb_epoch=EPOCHS, callbacks=[ModelCheckpoint('model.h5')])