Skip to content

Commit 5c32494

Browse files
committed
🚀 Support MultiGPU/SiggleGPU in one base_trainer, Refactor project.
1 parent 621cada commit 5c32494

File tree

101 files changed

+3112
-3195
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+3112
-3195
lines changed

.gitattributes

100644100755
File mode changed.

.github/workflows/ci.yaml

100644100755
+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ jobs:
1616
strategy:
1717
max-parallel: 10
1818
matrix:
19-
python-version: [3.6]
20-
tensorflow-version: [2.1.0]
19+
python-version: [3.7]
20+
tensorflow-version: [2.2.0]
2121
steps:
2222
- uses: actions/checkout@master
2323
- uses: actions/setup-python@v1

.gitignore

100644100755
File mode changed.

LICENSE

100644100755
File mode changed.

README.md

100644100755
File mode changed.

docker-compose.yml

100644100755
File mode changed.

dockerfile

100644100755
File mode changed.

examples/fastspeech/README.md

100644100755
File mode changed.

examples/fastspeech/conf/fastspeech.v1.yaml

100644100755
File mode changed.

examples/fastspeech/conf/fastspeech.v3.yaml

100644100755
File mode changed.

examples/fastspeech/decode_fastspeech.py

100644100755
+6-8
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@
1818
import logging
1919
import os
2020
import sys
21+
2122
sys.path.append(".")
2223

2324
import numpy as np
24-
import yaml
2525
import tensorflow as tf
26-
26+
import yaml
2727
from tqdm import tqdm
2828

29-
from tensorflow_tts.configs import FastSpeechConfig
3029
from examples.fastspeech.fastspeech_dataset import CharactorDataset
30+
from tensorflow_tts.configs import FastSpeechConfig
3131
from tensorflow_tts.models import TFFastSpeech
3232

3333

@@ -111,7 +111,6 @@ def main():
111111
root_dir=args.rootdir,
112112
charactor_query=char_query,
113113
charactor_load_fn=char_load_fn,
114-
return_utt_id=True,
115114
)
116115
dataset = dataset.create(batch_size=args.batch_size)
117116

@@ -123,14 +122,13 @@ def main():
123122
fastspeech.load_weights(args.checkpoint)
124123

125124
for data in tqdm(dataset, desc="Decoding"):
126-
utt_ids = data[0]
127-
char_ids = data[1]
125+
utt_ids = data["utt_ids"]
126+
char_ids = data["input_ids"]
128127

129128
# fastspeech inference.
130129
masked_mel_before, masked_mel_after, duration_outputs = fastspeech.inference(
131130
char_ids,
132-
attention_mask=tf.math.not_equal(char_ids, 0),
133-
speaker_ids=tf.zeros(shape=[tf.shape(char_ids)[0]]),
131+
speaker_ids=tf.zeros(shape=[tf.shape(char_ids)[0]], dtype=tf.int32),
134132
speed_ratios=tf.ones(shape=[tf.shape(char_ids)[0]], dtype=tf.float32),
135133
)
136134

examples/fastspeech/fastspeech_dataset.py

100644100755
+41-187
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,15 @@
1414
# limitations under the License.
1515
"""Dataset modules."""
1616

17+
import itertools
1718
import logging
1819
import os
1920
import random
20-
import itertools
21-
import numpy as np
2221

22+
import numpy as np
2323
import tensorflow as tf
2424

2525
from tensorflow_tts.datasets.abstract_dataset import AbstractDataset
26-
2726
from tensorflow_tts.utils import find_files
2827

2928

@@ -39,8 +38,7 @@ def __init__(
3938
charactor_load_fn=np.load,
4039
mel_load_fn=np.load,
4140
duration_load_fn=np.load,
42-
mel_length_threshold=None,
43-
return_utt_id=False,
41+
mel_length_threshold=0,
4442
):
4543
"""Initialize dataset.
4644
@@ -60,57 +58,6 @@ def __init__(
6058
charactor_files = sorted(find_files(root_dir, charactor_query))
6159
mel_files = sorted(find_files(root_dir, mel_query))
6260
duration_files = sorted(find_files(root_dir, duration_query))
63-
# filter by threshold
64-
if mel_length_threshold is not None:
65-
mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files]
66-
67-
idxs = [
68-
idx
69-
for idx in range(len(mel_files))
70-
if mel_lengths[idx] > mel_length_threshold
71-
]
72-
if len(mel_files) != len(idxs):
73-
logging.warning(
74-
f"Some files are filtered by mel length threshold "
75-
f"({len(mel_files)} -> {len(idxs)})."
76-
)
77-
mel_files = [mel_files[idx] for idx in idxs]
78-
charactor_files = [charactor_files[idx] for idx in idxs]
79-
duration_files = [duration_files[idx] for idx in idxs]
80-
mel_lengths = [mel_lengths[idx] for idx in idxs]
81-
82-
# bucket sequence length trick, sort based-on mel-length.
83-
idx_sort = np.argsort(mel_lengths)
84-
85-
# sort
86-
mel_files = np.array(mel_files)[idx_sort]
87-
charactor_files = np.array(charactor_files)[idx_sort]
88-
duration_files = np.array(duration_files)[idx_sort]
89-
mel_lengths = np.array(mel_lengths)[idx_sort]
90-
91-
# group
92-
idx_lengths = [
93-
[idx, length]
94-
for idx, length in zip(np.arange(len(mel_lengths)), mel_lengths)
95-
]
96-
groups = [
97-
list(g) for _, g in itertools.groupby(idx_lengths, lambda a: a[1])
98-
]
99-
100-
# group shuffle
101-
random.shuffle(groups)
102-
103-
# get idxs affter group shuffle
104-
idxs = []
105-
for group in groups:
106-
for idx, _ in group:
107-
idxs.append(idx)
108-
109-
# re-arange dataset
110-
mel_files = np.array(mel_files)[idxs]
111-
charactor_files = np.array(charactor_files)[idxs]
112-
duration_files = np.array(duration_files)[idxs]
113-
mel_lengths = np.array(mel_lengths)[idxs]
11461

11562
# assert the number of files
11663
assert len(mel_files) != 0, f"Not found any mels files in ${root_dir}."
@@ -131,7 +78,7 @@ def __init__(
13178
self.mel_load_fn = mel_load_fn
13279
self.charactor_load_fn = charactor_load_fn
13380
self.duration_load_fn = duration_load_fn
134-
self.return_utt_id = return_utt_id
81+
self.mel_length_threshold = mel_length_threshold
13582

13683
def get_args(self):
13784
return [self.utt_ids]
@@ -144,115 +91,16 @@ def generator(self, utt_ids):
14491
mel = self.mel_load_fn(mel_file)
14592
charactor = self.charactor_load_fn(charactor_file)
14693
duration = self.duration_load_fn(duration_file)
147-
if self.return_utt_id:
148-
items = utt_id, charactor, duration, mel
149-
else:
150-
items = charactor, duration, mel
151-
yield items
152-
153-
def create(
154-
self,
155-
allow_cache=False,
156-
batch_size=1,
157-
is_shuffle=False,
158-
map_fn=None,
159-
reshuffle_each_iteration=True,
160-
):
161-
"""Create tf.dataset function."""
162-
output_types = self.get_output_dtypes()
163-
datasets = tf.data.Dataset.from_generator(
164-
self.generator, output_types=output_types, args=(self.get_args())
165-
)
16694

167-
if allow_cache:
168-
datasets = datasets.cache()
95+
items = {
96+
"utt_ids": utt_id,
97+
"input_ids": charactor,
98+
"speaker_ids": 0,
99+
"duration_gts": duration,
100+
"mel_gts": mel,
101+
"mel_lengths": len(mel),
102+
}
169103

170-
if is_shuffle:
171-
datasets = datasets.shuffle(
172-
self.get_len_dataset(),
173-
reshuffle_each_iteration=reshuffle_each_iteration,
174-
)
175-
176-
datasets = datasets.padded_batch(
177-
batch_size, padded_shapes=([None], [None], [None, None])
178-
)
179-
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
180-
return datasets
181-
182-
def get_output_dtypes(self):
183-
output_types = (tf.int32, tf.int32, tf.float32)
184-
if self.return_utt_id:
185-
output_types = (tf.dtypes.string, *output_types)
186-
return output_types
187-
188-
def get_len_dataset(self):
189-
return len(self.utt_ids)
190-
191-
def __name__(self):
192-
return "CharactorDurationMelDataset"
193-
194-
195-
class CharactorDurationDataset(AbstractDataset):
196-
"""Tensorflow Charactor dataset."""
197-
198-
def __init__(
199-
self,
200-
root_dir,
201-
charactor_query="*-ids.npy",
202-
duration_query="*-durations.npy",
203-
charactor_load_fn=np.load,
204-
duration_load_fn=np.load,
205-
return_utt_id=False,
206-
):
207-
"""Initialize dataset.
208-
209-
Args:
210-
root_dir (str): Root directory including dumped files.
211-
charactor_query (str): Query to find charactor files in root_dir.
212-
duration_query (str): Query to find duration files in root_dir.
213-
charactor_load_fn (func): Function to load charactor file.
214-
duration_load_fn (func): Function to load duration file.
215-
return_utt_id (bool): Whether to return the utterance id with arrays.
216-
217-
"""
218-
# find all of charactor and mel files.
219-
charactor_files = sorted(find_files(root_dir, charactor_query))
220-
duration_files = sorted(find_files(root_dir, duration_query))
221-
222-
# assert the number of files
223-
assert (
224-
len(charactor_files) != 0 or len(duration_files) != 0
225-
), f"Not found any char or duration files in ${root_dir}."
226-
227-
assert len(charactor_files) == len(
228-
duration_files
229-
), "number of charactor and duration files are different."
230-
231-
if ".npy" in charactor_query:
232-
suffix = charactor_query[1:]
233-
utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files]
234-
235-
# set global params
236-
self.utt_ids = utt_ids
237-
self.charactor_files = charactor_files
238-
self.duration_files = duration_files
239-
self.charactor_load_fn = charactor_load_fn
240-
self.duration_load_fn = duration_load_fn
241-
self.return_utt_id = return_utt_id
242-
243-
def get_args(self):
244-
return [self.utt_ids]
245-
246-
def generator(self, utt_ids):
247-
for i, utt_id in enumerate(utt_ids):
248-
charactor_file = self.charactor_files[i]
249-
duration_file = self.duration_files[i]
250-
charactor = self.charactor_load_fn(charactor_file)
251-
duration = self.duration_load_fn(duration_file)
252-
if self.return_utt_id:
253-
items = utt_id, charactor, duration
254-
else:
255-
items = charactor, duration
256104
yield items
257105

258106
def create(
@@ -269,6 +117,10 @@ def create(
269117
self.generator, output_types=output_types, args=(self.get_args())
270118
)
271119

120+
datasets = datasets.filter(
121+
lambda x: x["mel_lengths"] > self.mel_length_threshold
122+
)
123+
272124
if allow_cache:
273125
datasets = datasets.cache()
274126

@@ -278,36 +130,43 @@ def create(
278130
reshuffle_each_iteration=reshuffle_each_iteration,
279131
)
280132

281-
padded_shapes = ([None], [None])
282-
if self.return_utt_id:
283-
padded_shapes = ([], *padded_shapes)
133+
# define padded_shapes
134+
padded_shapes = {
135+
"utt_ids": [],
136+
"input_ids": [None],
137+
"speaker_ids": [],
138+
"duration_gts": [None],
139+
"mel_gts": [None, None],
140+
"mel_lengths": [],
141+
}
284142

285143
datasets = datasets.padded_batch(batch_size, padded_shapes=padded_shapes)
286144
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
287145
return datasets
288146

289147
def get_output_dtypes(self):
290-
output_types = (tf.int32, tf.int32)
291-
if self.return_utt_id:
292-
output_types = (tf.dtypes.string, *output_types)
148+
output_types = {
149+
"utt_ids": tf.string,
150+
"input_ids": tf.int32,
151+
"speaker_ids": tf.int32,
152+
"duration_gts": tf.int32,
153+
"mel_gts": tf.float32,
154+
"mel_lengths": tf.int32,
155+
}
293156
return output_types
294157

295158
def get_len_dataset(self):
296159
return len(self.utt_ids)
297160

298161
def __name__(self):
299-
return "CharactorDurationDataset"
162+
return "CharactorDurationMelDataset"
300163

301164

302165
class CharactorDataset(AbstractDataset):
303166
"""Tensorflow Charactor dataset."""
304167

305168
def __init__(
306-
self,
307-
root_dir,
308-
charactor_query="*-ids.npy",
309-
charactor_load_fn=np.load,
310-
return_utt_id=False,
169+
self, root_dir, charactor_query="*-ids.npy", charactor_load_fn=np.load,
311170
):
312171
"""Initialize dataset.
313172
@@ -333,7 +192,6 @@ def __init__(
333192
self.utt_ids = utt_ids
334193
self.charactor_files = charactor_files
335194
self.charactor_load_fn = charactor_load_fn
336-
self.return_utt_id = return_utt_id
337195

338196
def get_args(self):
339197
return [self.utt_ids]
@@ -342,10 +200,9 @@ def generator(self, utt_ids):
342200
for i, utt_id in enumerate(utt_ids):
343201
charactor_file = self.charactor_files[i]
344202
charactor = self.charactor_load_fn(charactor_file)
345-
if self.return_utt_id:
346-
items = utt_id, charactor
347-
else:
348-
items = charactor
203+
204+
items = {"utt_ids": utt_id, "input_ids": charactor}
205+
349206
yield items
350207

351208
def create(
@@ -371,18 +228,15 @@ def create(
371228
reshuffle_each_iteration=reshuffle_each_iteration,
372229
)
373230

374-
padded_shapes = ([None],)
375-
if self.return_utt_id:
376-
padded_shapes = ([], *padded_shapes)
231+
# define padded shapes
232+
padded_shapes = {"utt_ids": [], "input_ids": [None]}
377233

378234
datasets = datasets.padded_batch(batch_size, padded_shapes=padded_shapes)
379235
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
380236
return datasets
381237

382238
def get_output_dtypes(self):
383-
output_types = (tf.int32,)
384-
if self.return_utt_id:
385-
output_types = (tf.dtypes.string, *output_types)
239+
output_types = {"utt_ids": tf.string, "input_ids": tf.int32}
386240
return output_types
387241

388242
def get_len_dataset(self):

examples/fastspeech/fig/fastspeech.v1.png

100644100755
File mode changed.

0 commit comments

Comments
 (0)