-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtrain.py
139 lines (108 loc) · 5.43 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
from smd.data.data_generator import DataGenerator
from smd.data.dataset_loader import DatasetLoader
from smd.models.model_loader import load_model, compile_model
from smd.data.data_augmentation import random_loudness_spec, random_filter_spec, block_mixing_spec, pitch_time_deformation_spec
from smd.data import preprocessing
import smd.utils as utils
import numpy as np
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping, ReduceLROnPlateau
import keras.models
import os
def training_data_processing(spec_file, annotation_file, mean, std, spec_file2=None, annotation_file2=None):
spec = np.load(spec_file)
spec, stretching_rate = pitch_time_deformation_spec(spec)
spec = random_filter_spec(spec)
spec = random_loudness_spec(spec)
label = preprocessing.get_label(
annotation_file, spec.shape[1], stretching_rate=stretching_rate)
if not(spec_file2 is None):
spec2 = np.load(spec_file2)
spec2, stretching_rate2 = pitch_time_deformation_spec(spec2)
spec2 = random_filter_spec(spec2)
spec2 = random_loudness_spec(spec2)
label2 = preprocessing.get_label(
annotation_file2, spec2.shape[1], stretching_rate=stretching_rate2)
spec, label = block_mixing_spec(spec, spec2, label, label2)
mels = preprocessing.get_scaled_mel_bands(spec)
mels = preprocessing.normalize(mels, mean, std)
return mels, label
def validation_data_processing(spec_file, annotation_file, mean, std):
spec = np.load(spec_file)
mels = preprocessing.get_scaled_mel_bands(spec)
mels = preprocessing.normalize(mels, mean, std)
n_frame = mels.shape[1]
label = preprocessing.get_label(
annotation_file, n_frame, stretching_rate=1)
return mels, label
def train(train_set, val_set, cfg, config_name, resume, model_path):
if not(model_path is None):
if resume:
print("Loading compiled model: " + model_path)
model = keras.models.load_model(model_path, compile=True)
else:
print("Loading uncompiled model: " + model_path)
model = keras.models.load_model(model_path, compile=False)
model = compile_model(model, cfg["model"])
else:
print("Loading the network..")
model = load_model(cfg["model"])
csv_logger = CSVLogger('checkpoint/' + config_name +
'-training.log', append=resume)
save_ckpt = ModelCheckpoint("checkpoint/weights.{epoch:02d}-{val_loss:.2f}" + config_name + ".hdf5", monitor='val_loss',
verbose=1,
save_best_only=True,
period=1)
early_stopping = EarlyStopping(monitor='val_loss',
min_delta=0,
patience=5,
verbose=0, mode='auto')
lr_schedule = ReduceLROnPlateau(
monitor='val_loss', factor=0.1, patience=3, verbose=1, mode='auto', min_lr=10e-7)
callback_list = [save_ckpt, early_stopping, lr_schedule, csv_logger]
print("Start the training..")
model.fit_generator(train_set,
epochs=cfg["nb_epoch"],
callbacks=callback_list,
validation_data=val_set,
workers=cfg["workers"],
use_multiprocessing=cfg["use_multiprocessing"],
shuffle=True
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Script to train a neural network for speech and music detection.")
parser.add_argument('--config', type=str, default="test1",
help='the configuration of the training')
parser.add_argument('--data_location', type=str, default="/Users/quentin/Computer/DataSet/Music/speech_music_detection/",
help='the location of the data')
parser.add_argument('--resume', type=bool, default=False,
help='set to true to restart a previous starning')
parser.add_argument('--model', type=str, default=None,
help='path of the model to load when the starting is resumed')
args = parser.parse_args()
experiments = utils.load_json('experiments.json')
cfg = experiments[args.config]
if not(os.path.isdir("checkpoint")):
os.makedirs("checkpoint")
print("Checkpoint folder created.")
print("Creating the dataset..")
datasets_config = utils.load_json("datasets.json")
dataset = DatasetLoader(
cfg["dataset"], args.data_location, datasets_config)
print("Creating the data generator..")
train_set = DataGenerator(dataset.get_train_set(),
cfg["batch_size"],
cfg["target_seq_length"],
training_data_processing,
dataset.get_training_mean(),
dataset.get_training_std(),
set_type="train")
val_set = DataGenerator(dataset.get_val_set(),
cfg["batch_size"],
cfg["target_seq_length"],
validation_data_processing,
dataset.get_training_mean(),
dataset.get_training_std(),
set_type="val")
train(train_set, val_set, cfg, args.config, args.resume, args.model)