-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_dataset.py
79 lines (64 loc) · 2.84 KB
/
load_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import os, re, glob, pdb
import hparams as hp
import torch
def collate_fn(batch):
#batch = zero_pad(batch)
batch_size = len(batch)
max_text_len = max([x[0].size(0) for x in batch])
max_mel_len = max([x[1].size(0) for x in batch])
num_mels = batch[0][1].size(1)
text_padded = torch.zeros(batch_size, max_text_len, dtype=torch.long)
mel_padded = torch.zeros(batch_size, max_mel_len, num_mels)
stop_padded = torch.zeros(len(batch), max_mel_len)
text_lengths, mel_lengths = torch.zeros(batch_size, dtype=torch.long), torch.zeros(batch_size, dtype=torch.long)
for i, (text, mel) in enumerate(batch):
text_padded[i, :text.size(0)] = text
text_lengths[i] = text.size(0)
mel_padded[i, :mel.size(0), :] = mel
mel_lengths[i] = mel.size(0)
stop_padded[i, mel.size(0)-1:] = 1
return text_padded, text_lengths, mel_padded, mel_lengths, stop_padded
class PrepareDataset(Dataset):
"""RUSLAN"""
def __init__(self, csv_file, wav_dir):
"""
Args:
csv_file (string): Path to the csv file with text.
wav_dir (string): Directory with all the wavs.
"""
self.dump_dir = wav_dir
self.unk = "<unk>"
self.eos = "<eos>"
df = pd.read_csv(csv_file, sep='|', header=None)
wav_files = [os.path.basename(x.replace('-feats.npy', '')) for x in glob.glob(self.dump_dir + '/*')]
self.csv_file = df[df.iloc[:, 0].isin(wav_files)]
self.token_list = {char.lower() for utt in list(self.csv_file.iloc[:, 1]) for char in utt}
self.token_list.add(self.unk)
self.token_list = sorted(self.token_list)
self.token2id: Dict[str, int] = {}
self.token2id = {t:i for i, t in enumerate(self.token_list)} # zero is reserved for pad
self.token2id[self.eos] = len(self.token2id)
def __len__(self):
return len(self.csv_file)
def whitespace_clean(self, text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def tokenize(self, text):
text = self.whitespace_clean(text).lower()
tokens = []
for token in text:
if token in self.token2id:
token_id = self.token2id[token]
else: token_id = self.token2id[self.unk]
tokens.append(token_id)
return tokens
def __getitem__(self, idx):
mel_path = os.path.join(self.dump_dir, self.csv_file.iloc[idx, 0]) + '-feats.npy'
text = self.csv_file.iloc[idx, 1]
tokenized_text = self.tokenize(text)
mel = np.load(mel_path)
return torch.tensor(tokenized_text), torch.tensor(mel, dtype=torch.float)