Skip to content

Commit

Permalink
Fix audio inference; improve namings of files
Browse files Browse the repository at this point in the history
  • Loading branch information
israelg99 committed Feb 22, 2017
1 parent b200856 commit 7153d74
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions audio_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@

#%% Save.
print('Saving.')
samples = np.save(type(M).__name__ + '_samples.npy', samples)
audio = np.save(type(M).__name__ + '_audio.npy', audio)
np.save(type(M).__name__ + '_samples.npy', samples)
np.save(type(M).__name__ + '_audio.npy', audio)

for i in tqdm(range(BATCH_SIZE)):
scipy.io.wavfile.write(type(M).__name__ + '_audio.wav', RATE, audio[i])
scipy.io.wavfile.write('audio' + str(i) + '.wav', RATE, audio[i])
2 changes: 1 addition & 1 deletion audio_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@
TRAIN = TRAIN[:TRAIN.shape[0]//LENGTH*LENGTH].reshape(TRAIN.shape[0]//LENGTH, LENGTH, BINS)

M.fit(TRAIN, sparse_labels(TRAIN), nb_epoch=2000, batch_size=8,
callbacks=[TensorBoard(), ModelCheckpoint(type(M).__name__ + '_model.h5')])
callbacks=[TensorBoard(), ModelCheckpoint('model.h5')])
2 changes: 1 addition & 1 deletion img_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ARGS += (1,)

M = MODEL(*ARGS)
M.load_weights(type(M).__name__ + '_model.h5')
M.load_weights('model.h5')


#%% Choice (Probabilistic).
Expand Down
2 changes: 1 addition & 1 deletion img_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@
M.fit(DATA if LABELS is None else [DATA, LABELS],
[(np.expand_dims(DATA[:, :, :, c].reshape(DATA.shape[0], DATA.shape[1]*DATA.shape[2]), -1)*255).astype(int) for c in range(DATA.shape[3])],
batch_size=32, nb_epoch=200,
verbose=1, callbacks=[TensorBoard(), ModelCheckpoint(type(M).__name__ + '_model.h5', save_weights_only=True)]) # Only weights because Keras is a bitch.
verbose=1, callbacks=[TensorBoard(), ModelCheckpoint('model.h5', save_weights_only=True)]) # Only weights because Keras is a bitch.

0 comments on commit 7153d74

Please sign in to comment.