14
14
# limitations under the License.
15
15
"""Dataset modules."""
16
16
17
+ import itertools
17
18
import logging
18
19
import os
19
20
import random
20
- import itertools
21
- import numpy as np
22
21
22
+ import numpy as np
23
23
import tensorflow as tf
24
24
25
25
from tensorflow_tts .datasets .abstract_dataset import AbstractDataset
26
-
27
26
from tensorflow_tts .utils import find_files
28
27
29
28
@@ -39,8 +38,7 @@ def __init__(
39
38
charactor_load_fn = np .load ,
40
39
mel_load_fn = np .load ,
41
40
duration_load_fn = np .load ,
42
- mel_length_threshold = None ,
43
- return_utt_id = False ,
41
+ mel_length_threshold = 0 ,
44
42
):
45
43
"""Initialize dataset.
46
44
@@ -60,57 +58,6 @@ def __init__(
60
58
charactor_files = sorted (find_files (root_dir , charactor_query ))
61
59
mel_files = sorted (find_files (root_dir , mel_query ))
62
60
duration_files = sorted (find_files (root_dir , duration_query ))
63
- # filter by threshold
64
- if mel_length_threshold is not None :
65
- mel_lengths = [mel_load_fn (f ).shape [0 ] for f in mel_files ]
66
-
67
- idxs = [
68
- idx
69
- for idx in range (len (mel_files ))
70
- if mel_lengths [idx ] > mel_length_threshold
71
- ]
72
- if len (mel_files ) != len (idxs ):
73
- logging .warning (
74
- f"Some files are filtered by mel length threshold "
75
- f"({ len (mel_files )} -> { len (idxs )} )."
76
- )
77
- mel_files = [mel_files [idx ] for idx in idxs ]
78
- charactor_files = [charactor_files [idx ] for idx in idxs ]
79
- duration_files = [duration_files [idx ] for idx in idxs ]
80
- mel_lengths = [mel_lengths [idx ] for idx in idxs ]
81
-
82
- # bucket sequence length trick, sort based-on mel-length.
83
- idx_sort = np .argsort (mel_lengths )
84
-
85
- # sort
86
- mel_files = np .array (mel_files )[idx_sort ]
87
- charactor_files = np .array (charactor_files )[idx_sort ]
88
- duration_files = np .array (duration_files )[idx_sort ]
89
- mel_lengths = np .array (mel_lengths )[idx_sort ]
90
-
91
- # group
92
- idx_lengths = [
93
- [idx , length ]
94
- for idx , length in zip (np .arange (len (mel_lengths )), mel_lengths )
95
- ]
96
- groups = [
97
- list (g ) for _ , g in itertools .groupby (idx_lengths , lambda a : a [1 ])
98
- ]
99
-
100
- # group shuffle
101
- random .shuffle (groups )
102
-
103
- # get idxs affter group shuffle
104
- idxs = []
105
- for group in groups :
106
- for idx , _ in group :
107
- idxs .append (idx )
108
-
109
- # re-arange dataset
110
- mel_files = np .array (mel_files )[idxs ]
111
- charactor_files = np .array (charactor_files )[idxs ]
112
- duration_files = np .array (duration_files )[idxs ]
113
- mel_lengths = np .array (mel_lengths )[idxs ]
114
61
115
62
# assert the number of files
116
63
assert len (mel_files ) != 0 , f"Not found any mels files in ${ root_dir } ."
@@ -131,7 +78,7 @@ def __init__(
131
78
self .mel_load_fn = mel_load_fn
132
79
self .charactor_load_fn = charactor_load_fn
133
80
self .duration_load_fn = duration_load_fn
134
- self .return_utt_id = return_utt_id
81
+ self .mel_length_threshold = mel_length_threshold
135
82
136
83
def get_args (self ):
137
84
return [self .utt_ids ]
@@ -144,115 +91,16 @@ def generator(self, utt_ids):
144
91
mel = self .mel_load_fn (mel_file )
145
92
charactor = self .charactor_load_fn (charactor_file )
146
93
duration = self .duration_load_fn (duration_file )
147
- if self .return_utt_id :
148
- items = utt_id , charactor , duration , mel
149
- else :
150
- items = charactor , duration , mel
151
- yield items
152
-
153
- def create (
154
- self ,
155
- allow_cache = False ,
156
- batch_size = 1 ,
157
- is_shuffle = False ,
158
- map_fn = None ,
159
- reshuffle_each_iteration = True ,
160
- ):
161
- """Create tf.dataset function."""
162
- output_types = self .get_output_dtypes ()
163
- datasets = tf .data .Dataset .from_generator (
164
- self .generator , output_types = output_types , args = (self .get_args ())
165
- )
166
94
167
- if allow_cache :
168
- datasets = datasets .cache ()
95
+ items = {
96
+ "utt_ids" : utt_id ,
97
+ "input_ids" : charactor ,
98
+ "speaker_ids" : 0 ,
99
+ "duration_gts" : duration ,
100
+ "mel_gts" : mel ,
101
+ "mel_lengths" : len (mel ),
102
+ }
169
103
170
- if is_shuffle :
171
- datasets = datasets .shuffle (
172
- self .get_len_dataset (),
173
- reshuffle_each_iteration = reshuffle_each_iteration ,
174
- )
175
-
176
- datasets = datasets .padded_batch (
177
- batch_size , padded_shapes = ([None ], [None ], [None , None ])
178
- )
179
- datasets = datasets .prefetch (tf .data .experimental .AUTOTUNE )
180
- return datasets
181
-
182
- def get_output_dtypes (self ):
183
- output_types = (tf .int32 , tf .int32 , tf .float32 )
184
- if self .return_utt_id :
185
- output_types = (tf .dtypes .string , * output_types )
186
- return output_types
187
-
188
- def get_len_dataset (self ):
189
- return len (self .utt_ids )
190
-
191
- def __name__ (self ):
192
- return "CharactorDurationMelDataset"
193
-
194
-
195
- class CharactorDurationDataset (AbstractDataset ):
196
- """Tensorflow Charactor dataset."""
197
-
198
- def __init__ (
199
- self ,
200
- root_dir ,
201
- charactor_query = "*-ids.npy" ,
202
- duration_query = "*-durations.npy" ,
203
- charactor_load_fn = np .load ,
204
- duration_load_fn = np .load ,
205
- return_utt_id = False ,
206
- ):
207
- """Initialize dataset.
208
-
209
- Args:
210
- root_dir (str): Root directory including dumped files.
211
- charactor_query (str): Query to find charactor files in root_dir.
212
- duration_query (str): Query to find duration files in root_dir.
213
- charactor_load_fn (func): Function to load charactor file.
214
- duration_load_fn (func): Function to load duration file.
215
- return_utt_id (bool): Whether to return the utterance id with arrays.
216
-
217
- """
218
- # find all of charactor and mel files.
219
- charactor_files = sorted (find_files (root_dir , charactor_query ))
220
- duration_files = sorted (find_files (root_dir , duration_query ))
221
-
222
- # assert the number of files
223
- assert (
224
- len (charactor_files ) != 0 or len (duration_files ) != 0
225
- ), f"Not found any char or duration files in ${ root_dir } ."
226
-
227
- assert len (charactor_files ) == len (
228
- duration_files
229
- ), "number of charactor and duration files are different."
230
-
231
- if ".npy" in charactor_query :
232
- suffix = charactor_query [1 :]
233
- utt_ids = [os .path .basename (f ).replace (suffix , "" ) for f in charactor_files ]
234
-
235
- # set global params
236
- self .utt_ids = utt_ids
237
- self .charactor_files = charactor_files
238
- self .duration_files = duration_files
239
- self .charactor_load_fn = charactor_load_fn
240
- self .duration_load_fn = duration_load_fn
241
- self .return_utt_id = return_utt_id
242
-
243
- def get_args (self ):
244
- return [self .utt_ids ]
245
-
246
- def generator (self , utt_ids ):
247
- for i , utt_id in enumerate (utt_ids ):
248
- charactor_file = self .charactor_files [i ]
249
- duration_file = self .duration_files [i ]
250
- charactor = self .charactor_load_fn (charactor_file )
251
- duration = self .duration_load_fn (duration_file )
252
- if self .return_utt_id :
253
- items = utt_id , charactor , duration
254
- else :
255
- items = charactor , duration
256
104
yield items
257
105
258
106
def create (
@@ -269,6 +117,10 @@ def create(
269
117
self .generator , output_types = output_types , args = (self .get_args ())
270
118
)
271
119
120
+ datasets = datasets .filter (
121
+ lambda x : x ["mel_lengths" ] > self .mel_length_threshold
122
+ )
123
+
272
124
if allow_cache :
273
125
datasets = datasets .cache ()
274
126
@@ -278,36 +130,43 @@ def create(
278
130
reshuffle_each_iteration = reshuffle_each_iteration ,
279
131
)
280
132
281
- padded_shapes = ([None ], [None ])
282
- if self .return_utt_id :
283
- padded_shapes = ([], * padded_shapes )
133
+ # define padded_shapes
134
+ padded_shapes = {
135
+ "utt_ids" : [],
136
+ "input_ids" : [None ],
137
+ "speaker_ids" : [],
138
+ "duration_gts" : [None ],
139
+ "mel_gts" : [None , None ],
140
+ "mel_lengths" : [],
141
+ }
284
142
285
143
datasets = datasets .padded_batch (batch_size , padded_shapes = padded_shapes )
286
144
datasets = datasets .prefetch (tf .data .experimental .AUTOTUNE )
287
145
return datasets
288
146
289
147
def get_output_dtypes (self ):
290
- output_types = (tf .int32 , tf .int32 )
291
- if self .return_utt_id :
292
- output_types = (tf .dtypes .string , * output_types )
148
+ output_types = {
149
+ "utt_ids" : tf .string ,
150
+ "input_ids" : tf .int32 ,
151
+ "speaker_ids" : tf .int32 ,
152
+ "duration_gts" : tf .int32 ,
153
+ "mel_gts" : tf .float32 ,
154
+ "mel_lengths" : tf .int32 ,
155
+ }
293
156
return output_types
294
157
295
158
def get_len_dataset (self ):
296
159
return len (self .utt_ids )
297
160
298
161
def __name__ (self ):
299
- return "CharactorDurationDataset "
162
+ return "CharactorDurationMelDataset "
300
163
301
164
302
165
class CharactorDataset (AbstractDataset ):
303
166
"""Tensorflow Charactor dataset."""
304
167
305
168
def __init__ (
306
- self ,
307
- root_dir ,
308
- charactor_query = "*-ids.npy" ,
309
- charactor_load_fn = np .load ,
310
- return_utt_id = False ,
169
+ self , root_dir , charactor_query = "*-ids.npy" , charactor_load_fn = np .load ,
311
170
):
312
171
"""Initialize dataset.
313
172
@@ -333,7 +192,6 @@ def __init__(
333
192
self .utt_ids = utt_ids
334
193
self .charactor_files = charactor_files
335
194
self .charactor_load_fn = charactor_load_fn
336
- self .return_utt_id = return_utt_id
337
195
338
196
def get_args (self ):
339
197
return [self .utt_ids ]
@@ -342,10 +200,9 @@ def generator(self, utt_ids):
342
200
for i , utt_id in enumerate (utt_ids ):
343
201
charactor_file = self .charactor_files [i ]
344
202
charactor = self .charactor_load_fn (charactor_file )
345
- if self .return_utt_id :
346
- items = utt_id , charactor
347
- else :
348
- items = charactor
203
+
204
+ items = {"utt_ids" : utt_id , "input_ids" : charactor }
205
+
349
206
yield items
350
207
351
208
def create (
@@ -371,18 +228,15 @@ def create(
371
228
reshuffle_each_iteration = reshuffle_each_iteration ,
372
229
)
373
230
374
- padded_shapes = ([None ],)
375
- if self .return_utt_id :
376
- padded_shapes = ([], * padded_shapes )
231
+ # define padded shapes
232
+ padded_shapes = {"utt_ids" : [], "input_ids" : [None ]}
377
233
378
234
datasets = datasets .padded_batch (batch_size , padded_shapes = padded_shapes )
379
235
datasets = datasets .prefetch (tf .data .experimental .AUTOTUNE )
380
236
return datasets
381
237
382
238
def get_output_dtypes (self ):
383
- output_types = (tf .int32 ,)
384
- if self .return_utt_id :
385
- output_types = (tf .dtypes .string , * output_types )
239
+ output_types = {"utt_ids" : tf .string , "input_ids" : tf .int32 }
386
240
return output_types
387
241
388
242
def get_len_dataset (self ):
0 commit comments