-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecode.py
108 lines (96 loc) · 4.09 KB
/
decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from util import *
from models import *
from train import *
from data import *
from logging.handlers import RotatingFileHandler
from tokenizers import Tokenizer
from speechbrain.processing.features import InputNormalization
from hydra import initialize, compose
from omegaconf import OmegaConf
import hydra
import torch.nn as nn
import torch
import pdb
import logging
import copy
import argparse
import time
import random
import sys
import resource
import torch.distributed as dist
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
class TestDataset(torch.utils.data.Dataset):
def __init__(self, cfg):
self.cfg = cfg
self.df = pd.read_csv(cfg.paths.test_path)
def __len__(self):
return self.df.shape[0]
def __getitem__(self, index):
row = self.df.iloc[index]
if self.cfg.corpus == 'librispeech':
path = row.audio_file.replace('/data/corpora2/librispeech/LibriSpeech/',
'/research/nfs_fosler_1/vishal/audio/libri/')
else:
path = row.audio_file
ground_truth = clean4asr(row.utterance)
return path, ground_truth
class Collator:
def __init__(self, cfg):
self.cfg = cfg
def __call__(self, lst):
path_, ground_truth_ = zip(*lst)
return list(path_), list(ground_truth_)
def main(cfg):
rank = int(os.environ['LOCAL_RANK'])
dist.init_process_group(backend='nccl', init_method='env://',
world_size=cfg.distributed.world_size,
rank=rank)
device = torch.device("cpu")
# Data init
data = TestDataset(cfg)
sampler = torch.utils.data.distributed.DistributedSampler(data,
num_replicas=cfg.distributed.world_size,
rank=rank)
collator = Collator(cfg)
loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False, num_workers=0, collate_fn=collator, sampler=sampler)
# Load model
print(f'Loading model.')
if cfg.model_name == 'ctc':
model = baseCTC.from_pretrained(cfg, cfg.paths.ckpt_path)
elif cfg.model_name == 'sc_ctc':
model = baseSCCTC.from_pretrained(cfg, cfg.paths.ckpt_path)
elif cfg.model_name == 'hc_ctc':
model = baseHCCTC.from_pretrained(cfg, cfg.paths.ckpt_path)
else:
raise ValueError(f"Model '{cfg.model_name}' is invalid. Current implementation includes: 'ctc', 'sc_ctc' and 'hc_ctc'")
# Decode
print(f'Decoding.')
if rank==0 and not os.path.exists(cfg.paths.decode_path):
os.makedirs(cfg.paths.decode_path)
if rank != 0:
while not os.path.exists(cfg.paths.decode_path):
continue
logger = logging.getLogger()
logger.setLevel(logging.INFO)
rfh = RotatingFileHandler(os.path.join(cfg.paths.decode_path, f'{rank}.log'), maxBytes=1000000, backupCount=10, encoding="UTF-8")
logger.addHandler(rfh)
print(f'Starting transcription.')
with open(os.path.join(cfg.paths.decode_path, f'{rank}.txt'), 'w') as dP:
for audio_path, text in tqdm(loader, disable=(rank!=0)):
hyp = model.transcribe(audio_path[0]).replace(" ' ", "'").replace("-", "")
gt = text[0].replace("-", "").replace(" '", "'")
logger.info(f'{audio_path[0]} ----> {gt} ----> {hyp}')
dP.write(f'{audio_path[0]} ----> {gt} ----> {hyp}\n')
if __name__ == "__main__":
# Parse config_path and config_name from the command line
parser = argparse.ArgumentParser()
parser.add_argument("--local-rank", type=int, default=0)
parser.add_argument("--config-path", type=str, default=".", help="Path to the config directory")
parser.add_argument("--config-name", type=str, default="config", help="Name of the config file (without .yaml extension)")
args, unknown = parser.parse_known_args()
# Use the provided config_path and config_name
with initialize(config_path=args.config_path):
cfg = compose(config_name=args.config_name, overrides=unknown)
main(cfg)