diff --git a/README.md b/README.md index a34b26a..7a2556a 100644 --- a/README.md +++ b/README.md @@ -36,10 +36,10 @@ You can unzip the folder of the language model into any directory, but it is rec ## Usage -Now you are able to run `vosk-cli -i -o -m `. +Now you are able to run `vosk-cli -i -o -m -p `. -For example, if there is a `video.mp4` file in your download folder and a model named `vosk-model-en-us-0.22` in the `./models` folder you created, you can run +For example, if there is a `video.mp4` file in your download folder and a model named `vosk-model-en-us-0.22` and `vosk-recasepunc-en-0.22/checkpoint` in the `./models` folder you created, you can run -`vosk-cli -i ~/Downloads/video.mp4 -o text -m vosk-model-en-us-0.22` +`vosk-cli -i ~/Downloads/video.mp4 -o text -m vosk-model-en-us-0.22 -p vosk-recasepunc-en-0.22` This will create a `text.vtt` file (which contains the transcribed captions) in your current directory. diff --git a/recasepunc/LICENSE b/recasepunc/LICENSE new file mode 100644 index 0000000..b6d4c82 --- /dev/null +++ b/recasepunc/LICENSE @@ -0,0 +1,11 @@ +Copyright 2021 Benoit Favre + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/recasepunc/__init__.py b/recasepunc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/recasepunc/recasepunc.py b/recasepunc/recasepunc.py new file mode 100644 index 0000000..cd0a106 --- /dev/null +++ b/recasepunc/recasepunc.py @@ -0,0 +1,742 @@ +import sys +import collections +import os +import regex as re +#from mosestokenizer import * +from tqdm import tqdm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import random +import unicodedata +import numpy as np +import argparse +from torch.utils.data import TensorDataset, DataLoader + +from transformers import AutoModel, AutoTokenizer, BertTokenizer + +default_config = argparse.Namespace( + seed=871253, + lang='de', + #flavor='flaubert/flaubert_base_uncased', + flavor=None, + max_length=256, + batch_size=16, + updates=24000, + period=1000, + lr=1e-5, + dab_rate=0.1, + device='cuda', + debug=False +) + +default_flavors = { + 'fr': 'flaubert/flaubert_base_uncased', + 'en': 'bert-base-uncased', + 'zh': 'ckiplab/bert-base-chinese', + 'tr': 'dbmdz/bert-base-turkish-uncased', + 'de': 'dbmdz/bert-base-german-uncased', + 'pt': 'neuralmind/bert-base-portuguese-cased' +} + +class Config(argparse.Namespace): + def __init__(self, **kwargs): + for key, value in default_config.__dict__.items(): + setattr(self, key, value) + for key, value in kwargs.items(): + setattr(self, key, value) + + assert self.lang in ['fr', 'en', 'zh', 'tr', 'pt', 'de'] + + if 'lang' in kwargs and ('flavor' not in kwargs or kwargs['flavor'] is None): + self.flavor = default_flavors[self.lang] + + #print(self.lang, self.flavor) + + +def init_random(seed): + # make sure everything is deterministic + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + #torch.use_deterministic_algorithms(True) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + random.seed(seed) + np.random.seed(seed) + +# NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label! + +punctuation = { + 'O': 0, + 'COMMA': 1, + 'PERIOD': 2, + 'QUESTION': 3, + 'EXCLAMATION': 4, +} + +punctuation_syms = ['', ',', '.', ' ?', ' !'] + +case = { + 'LOWER': 0, + 'UPPER': 1, + 'CAPITALIZE': 2, + 'OTHER': 3, +} + + +class Model(nn.Module): + def __init__(self, flavor, device): + super().__init__() + self.bert = AutoModel.from_pretrained(flavor) + # need a proper way of determining representation size + size = self.bert.dim if hasattr(self.bert, 'dim') else self.bert.config.pooler_fc_size if hasattr(self.bert.config, 'pooler_fc_size') else self.bert.config.emb_dim if hasattr(self.bert.config, 'emb_dim') else self.bert.config.hidden_size + self.punc = nn.Linear(size, 5) + self.case = nn.Linear(size, 4) + self.dropout = nn.Dropout(0.3) + self.to(device) + + def forward(self, x): + output = self.bert(x) + representations = self.dropout(F.gelu(output['last_hidden_state'])) + punc = self.punc(representations) + case = self.case(representations) + return punc, case + + +# randomly create sequences that align to punctuation boundaries +def drop_at_boundaries(rate, x, y, cls_token_id, sep_token_id, pad_token_id): + for i, dropped in enumerate(torch.rand((len(x),)) < rate): + if dropped: + # select all indices that are sentence endings + indices = (y[i,:,0] > 1).nonzero(as_tuple=True)[0] + if len(indices) < 2: + continue + start = indices[0] + 1 + end = indices[random.randint(1, len(indices) - 1)] + 1 + length = end - start + if length + 2 > len(x[i]): + continue + x[i, 0] = cls_token_id + x[i, 1: length + 1] = x[i, start: end].clone() + x[i, length + 1] = sep_token_id + x[i, length + 2:] = pad_token_id + y[i, 0] = 0 + y[i, 1: length + 1] = y[i, start: end].clone() + y[i, length + 1:] = 0 + + +def compute_performance(config, model, loader): + device = config.device + criterion = nn.CrossEntropyLoss() + model.eval() + total_loss = all_correct1 = all_correct2 = num_loss = num_perf = 0 + num_ref = collections.defaultdict(float) + num_hyp = collections.defaultdict(float) + num_correct = collections.defaultdict(float) + for x, y in loader: + x = x.long().to(device) + y = y.long().to(device) + y1 = y[:,:,0] + y2 = y[:,:,1] + with torch.no_grad(): + y_scores1, y_scores2 = model(x.to(device)) + loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1))) + loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1))) + loss = loss1 + loss2 + y_pred1 = torch.max(y_scores1, 2)[1] + y_pred2 = torch.max(y_scores2, 2)[1] + for label in range(1, 5): + ref = (y1 == label) + hyp = (y_pred1 == label) + correct = (ref * hyp == 1) + num_ref[label] += ref.sum() + num_hyp[label] += hyp.sum() + num_correct[label] += correct.sum() + num_ref[0] += ref.sum() + num_hyp[0] += hyp.sum() + num_correct[0] += correct.sum() + all_correct1 += (y_pred1 == y1).sum() + all_correct2 += (y_pred2 == y2).sum() + total_loss += loss.item() + num_loss += len(y) + num_perf += len(y) * config.max_length + recall = {} + precision = {} + fscore = {} + for label in range(0, 5): + recall[label] = num_correct[label] / num_ref[label] if num_ref[label] > 0 else 0 + precision[label] = num_correct[label] / num_hyp[label] if num_hyp[label] > 0 else 0 + fscore[label] = (2 * recall[label] * precision[label] / (recall[label] + precision[label])).item() if recall[label] + precision[label] > 0 else 0 + return total_loss / num_loss, all_correct2.item() / num_perf, all_correct1.item() / num_perf, fscore + + +def fit(config, model, checkpoint_path, train_loader, valid_loader, iterations, valid_period=200, lr=1e-5): + device = config.device + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=lr) + iteration = 0 + while True: + model.train() + total_loss = num = 0 + for x, y in tqdm(train_loader): + x = x.long().to(device) + y = y.long().to(device) + drop_at_boundaries(config.dab_rate, x, y, config.cls_token_id, config.sep_token_id, config.pad_token_id) + y1 = y[:,:,0] + y2 = y[:,:,1] + optimizer.zero_grad() + y_scores1, y_scores2 = model(x) + loss1 = criterion(y_scores1.view(y1.size(0) * y1.size(1), -1), y1.view(y1.size(0) * y1.size(1))) + loss2 = criterion(y_scores2.view(y2.size(0) * y2.size(1), -1), y2.view(y2.size(0) * y2.size(1))) + loss = loss1 + loss2 + loss.backward() + optimizer.step() + total_loss += loss.item() + num += len(y) + if iteration % valid_period == valid_period - 1: + train_loss = total_loss / num + valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore = compute_performance(config, model, valid_loader) + torch.save({ + 'iteration': iteration + 1, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'train_loss': train_loss, + 'valid_loss': valid_loss, + 'valid_accuracy_case': valid_accuracy_case, + 'valid_accuracy_punc': valid_accuracy_punc, + 'valid_fscore': valid_fscore, + 'config': config.__dict__, + }, '%s.%d' % (checkpoint_path, iteration + 1)) + print(iteration + 1, train_loss, valid_loss, valid_accuracy_case, valid_accuracy_punc, valid_fscore) + total_loss = num = 0 + + iteration += 1 + if iteration > iterations: + return + + sys.stderr.flush() + sys.stdout.flush() + + +def batchify(max_length, x, y): + print (x.shape) + print (y.shape) + x = x[:(len(x) // max_length) * max_length].reshape(-1, max_length) + y = y[:(len(y) // max_length) * max_length, :].reshape(-1, max_length, 2) + return x, y + + +def train(config, train_x_fn, train_y_fn, valid_x_fn, valid_y_fn, checkpoint_path): + X_train, Y_train = batchify(config.max_length, torch.load(train_x_fn), torch.load(train_y_fn)) + X_valid, Y_valid = batchify(config.max_length, torch.load(valid_x_fn), torch.load(valid_y_fn)) + + train_set = TensorDataset(X_train, Y_train) + valid_set = TensorDataset(X_valid, Y_valid) + + train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True) + valid_loader = DataLoader(valid_set, batch_size=config.batch_size) + + model = Model(config.flavor, config.device) + + fit(config, model, checkpoint_path, train_loader, valid_loader, config.updates, config.period, config.lr) + + +def run_eval(config, test_x_fn, test_y_fn, checkpoint_path): + X_test, Y_test = batchify(config.max_length, torch.load(test_x_fn), torch.load(test_y_fn)) + test_set = TensorDataset(X_test, Y_test) + test_loader = DataLoader(test_set, batch_size=config.batch_size) + + loaded = torch.load(checkpoint_path, map_location=config.device) + if 'config' in loaded: + config = Config(**loaded['config']) + init(config) + + model = Model(config.flavor, config.device) + model.load_state_dict(loaded['model_state_dict']) + + print(*compute_performance(config, model, test_loader)) + + +def recase(token, label): + if label == case['LOWER']: + return token.lower() + elif label == case['CAPITALIZE']: + return token.lower().capitalize() + elif label == case['UPPER']: + return token.upper() + else: + return token + + +class CasePuncPredictor: + def __init__(self, checkpoint_path, lang=default_config.lang, flavor=default_config.flavor, device=default_config.device): + loaded = torch.load(checkpoint_path, map_location=device if torch.cuda.is_available() else 'cpu') + if 'config' in loaded: + self.config = Config(**loaded['config']) + else: + self.config = Config(lang=lang, flavor=flavor, device=device) + init(self.config) + + self.model = Model(self.config.flavor, self.config.device) + self.model.load_state_dict(loaded['model_state_dict']) + self.model.eval() + self.model.to(self.config.device) + + self.rev_case = {b: a for a, b in case.items()} + self.rev_punc = {b: a for a, b in punctuation.items()} + + def tokenize(self, text): + return [self.config.cls_token] + self.config.tokenizer.tokenize(text) + [self.config.sep_token] + + def predict(self, tokens, getter=lambda x: x): + max_length = self.config.max_length + device = self.config.device + if type(tokens) == str: + tokens = self.tokenize(tokens) + previous_label = punctuation['PERIOD'] + for start in range(0, len(tokens), max_length): + instance = tokens[start: start + max_length] + if type(getter(instance[0])) == str: + ids = self.config.tokenizer.convert_tokens_to_ids(getter(token) for token in instance) + else: + ids = [getter(token) for token in instance] + if len(ids) < max_length: + ids += [0] * (max_length - len(ids)) + x = torch.tensor([ids]).long().to(device) + y_scores1, y_scores2 = self.model(x) + y_pred1 = torch.max(y_scores1, 2)[1] + y_pred2 = torch.max(y_scores2, 2)[1] + for i, id, token, punc_label, case_label in zip(range(len(instance)), ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]): + if id == self.config.cls_token_id or id == self.config.sep_token_id: + continue + if previous_label != None and previous_label > 1: + if case_label in [case['LOWER'], case['OTHER']]: # LOWER, OTHER + case_label = case['CAPITALIZE'] + if i + start == len(tokens) - 2 and punc_label == punctuation['O']: + punc_label = punctuation['PERIOD'] + yield (token, self.rev_case[case_label], self.rev_punc[punc_label]) + previous_label = punc_label + + def map_case_label(self, token, case_label): + if token.endswith(''): + token = token[:-4] + if token.startswith('##'): + token = token[2:] + return recase(token, case[case_label]) + + def map_punc_label(self, token, punc_label): + if token.endswith(''): + token = token[:-4] + if token.startswith('##'): + token = token[2:] + return token + punctuation_syms[punctuation[punc_label]] + + + +def generate_predictions(config, checkpoint_path): + loaded = torch.load(checkpoint_path, map_location=config.device if torch.cuda.is_available() else 'cpu') + if 'config' in loaded: + config = Config(**loaded['config']) + init(config) + + model = Model(config.flavor, config.device) + model.load_state_dict(loaded['model_state_dict']) + + rev_case = {b: a for a, b in case.items()} + rev_punc = {b: a for a, b in punctuation.items()} + + for line in sys.stdin: + # also drop punctuation that we may generate + line = ''.join([c for c in line if c not in mapped_punctuation]) + if config.debug: + print(line) + tokens = [config.cls_token] + config.tokenizer.tokenize(line) + [config.sep_token] + if config.debug: + print(tokens) + previous_label = punctuation['PERIOD'] + first_time = True + was_word = False + for start in range(0, len(tokens), config.max_length): + instance = tokens[start: start + config.max_length] + ids = config.tokenizer.convert_tokens_to_ids(instance) + #print(len(ids), file=sys.stderr) + if len(ids) < config.max_length: + ids += [config.pad_token_id] * (config.max_length - len(ids)) + x = torch.tensor([ids]).long().to(config.device) + y_scores1, y_scores2 = model(x) + y_pred1 = torch.max(y_scores1, 2)[1] + y_pred2 = torch.max(y_scores2, 2)[1] + for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]): + if config.debug: + print(id, token, punc_label, case_label, file=sys.stderr) + if id == config.cls_token_id or id == config.sep_token_id: + continue + if previous_label != None and previous_label > 1: + if case_label in [case['LOWER'], case['OTHER']]: + case_label = case['CAPITALIZE'] + previous_label = punc_label + # different strategy due to sub-lexical token encoding in Flaubert + if config.lang == 'fr': + if token.endswith(''): + cased_token = recase(token[:-4], case_label) + if was_word: + print(' ', end='') + print(cased_token + punctuation_syms[punc_label], end='') + was_word = True + else: + cased_token = recase(token, case_label) + if was_word: + print(' ', end='') + print(cased_token, end='') + was_word = False + else: + if token.startswith('##'): + cased_token = recase(token[2:], case_label) + print(cased_token, end='') + else: + cased_token = recase(token, case_label) + if not first_time: + print(' ', end='') + first_time = False + print(cased_token + punctuation_syms[punc_label], end='') + if previous_label == 0: + print('.', end='') + print() + + +def label_for_case(token): + token = re.sub('[^\p{Han}\p{Ll}\p{Lu}]', '', token) + if token == token.lower(): + return 'LOWER' + elif token == token.lower().capitalize(): + return 'CAPITALIZE' + elif token == token.upper(): + return 'UPPER' + else: + return 'OTHER' + + +def make_tensors(config, input_fn, output_x_fn, output_y_fn): + # count file lines without loading them + size = 0 + with open(input_fn) as fp: + for line in fp: + size += 1 + + with open(input_fn) as fp: + X = torch.IntTensor(size) + Y = torch.ByteTensor(size, 2) + + offset = 0 + for n, line in enumerate(fp): + word, case_label, punc_label = line.strip().split('\t') + id = config.tokenizer.convert_tokens_to_ids(word) + if config.debug: + assert word.lower() == tokenizer.convert_ids_to_tokens(id) + X[offset] = id + Y[offset, 0] = punctuation[punc_label] + Y[offset, 1] = case[case_label] + offset += 1 + + torch.save(X, output_x_fn) + torch.save(Y, output_y_fn) + + +mapped_punctuation = { + '.': 'PERIOD', + '...': 'PERIOD', + ',': 'COMMA', + ';': 'COMMA', + ':': 'COMMA', + '(': 'COMMA', + ')': 'COMMA', + '?': 'QUESTION', + '!': 'EXCLAMATION', + ',': 'COMMA', + '!': 'EXCLAMATION', + '?': 'QUESTION', + ';': 'COMMA', + ':': 'COMMA', + '(': 'COMMA', + '(': 'COMMA', + ')': 'COMMA', + '[': 'COMMA', + ']': 'COMMA', + '【': 'COMMA', + '】': 'COMMA', + '└': 'COMMA', + '└ ': 'COMMA', + '_': 'O', + '。': 'PERIOD', + '、': 'COMMA', # enumeration comma + '、': 'COMMA', + '…': 'PERIOD', + '—': 'COMMA', + '「': 'COMMA', + '」': 'COMMA', + '.': 'PERIOD', + '《': 'O', + '》': 'O', + ',': 'COMMA', + '“': 'O', + '”': 'O', + '"': 'O', + '-': 'O', + '-': 'O', + '〉': 'COMMA', + '〈': 'COMMA', + '↑': 'O', + '〔': 'COMMA', + '〕': 'COMMA', +} + +def preprocess_text(config, max_token_count=-1): + global num_tokens_output + max_token_count = int(max_token_count) + num_tokens_output = 0 + def process_segment(text, punctuation): + global num_tokens_output + text = text.replace('\t', ' ') + tokens = config.tokenizer.tokenize(text) + for i, token in enumerate(tokens): + case_label = label_for_case(token) + if i == len(tokens) - 1: + print(token.lower(), case_label, punctuation, sep='\t') + else: + print(token.lower(), case_label, 'O', sep='\t') + num_tokens_output += 1 + # a bit too ugly, but alternative is to throw an exception + if max_token_count > 0 and num_tokens_output >= max_token_count: + sys.exit(0) + + for line in sys.stdin: + line = line.strip() + if line != '': + line = unicodedata.normalize("NFC", line) + if config.debug: + print(line) + start = 0 + for i, char in enumerate(line): + if char in mapped_punctuation: + if i > start and line[start: i].strip() != '': + process_segment(line[start: i], mapped_punctuation[char]) + start = i + 1 + if start < len(line): + process_segment(line[start:], 'PERIOD') + + +def preprocess_text_old_fr(config): + assert config.lang == 'fr' + splitsents = MosesSentenceSplitter(lang) + tokenize = MosesTokenizer(lang, extra=['-no-escape']) + normalize = MosesPunctuationNormalizer(lang) + + for line in sys.stdin: + if line.strip() != '': + for sentence in splitsents([normalize(line)]): + tokens = tokenize(sentence) + previous_token = None + for token in tokens: + if token in mapped_punctuation: + if previous_token != None: + print(previous_token, mapped_punctuation[token], sep='\t') + previous_token = None + elif not re.search('[\p{Han}\p{Ll}\p{Lu}\d]', token): # remove non-alphanumeric tokens + continue + else: + if previous_token != None: + print(previous_token, 'O', sep='\t') + previous_token = token + if previous_token != None: + print(previous_token, 'PERIOD', sep='\t') + + +# modification of the wordpiece tokenizer to keep case information even if vocab is lower cased +# forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100, keep_case=True): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + self.keep_case = keep_case + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in text.strip().split(): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + # optionaly lowercase substring before checking for inclusion in vocab + if (self.keep_case and substr.lower() in self.vocab) or (substr in self.vocab): + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +# modification of XLM bpe tokenizer for keeping case information when vocab is lowercase +# forked from https://github.com/huggingface/transformers/blob/cd56f3fe7eae4a53a9880e3f5e8f91877a78271c/src/transformers/models/xlm/tokenization_xlm.py +def bpe(self, token): + def to_lower(pair): + #print(' ',pair) + return (pair[0].lower(), pair[1].lower()) + + from transformers.models.xlm.tokenization_xlm import get_pairs + + word = tuple(token[:-1]) + (token[-1] + "",) + if token in self.cache: + return self.cache[token] + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(to_lower(pair), float("inf"))) + #print(bigram) + if to_lower(bigram) not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + if word == "\n ": + word = "\n" + self.cache[token] = word + return word + + + +def init(config): + init_random(config.seed) + + if config.lang == 'fr': + config.tokenizer = tokenizer = AutoTokenizer.from_pretrained(config.flavor, do_lower_case=False) + + from transformers.models.xlm.tokenization_xlm import XLMTokenizer + assert isinstance(tokenizer, XLMTokenizer) + + # monkey patch XLM tokenizer + import types + tokenizer.bpe = types.MethodType(bpe, tokenizer) + else: + # warning: needs to be BertTokenizer for monkey patching to work + config.tokenizer = tokenizer = BertTokenizer.from_pretrained(config.flavor, do_lower_case=False) + + # warning: monkey patch tokenizer to keep case information + #from recasing_tokenizer import WordpieceTokenizer + config.tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab, unk_token=tokenizer.unk_token) + + if config.lang == 'fr': + config.pad_token_id = tokenizer.pad_token_id + config.cls_token_id = tokenizer.bos_token_id + config.cls_token = tokenizer.bos_token + config.sep_token_id = tokenizer.sep_token_id + config.sep_token = tokenizer.sep_token + else: + config.pad_token_id = tokenizer.pad_token_id + config.cls_token_id = tokenizer.cls_token_id + config.cls_token = tokenizer.cls_token + config.sep_token_id = tokenizer.sep_token_id + config.sep_token = tokenizer.sep_token + + if not torch.cuda.is_available() and config.device == 'cuda': + print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr) + config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu') + + +def main(config, action, args): + init(config) + + if action == 'train': + train(config, *args) + elif action == 'eval': + run_eval(config, *args) + elif action == 'predict': + generate_predictions(config, *args) + elif action == 'tensorize': + make_tensors(config, *args) + elif action == 'preprocess': + preprocess_text(config, *args) + else: + print('invalid action "%s"' % action) + sys.exit(1) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("action", help="train|eval|predict|tensorize|preprocess", type=str) + parser.add_argument("action_args", help="arguments for selected action", type=str, nargs='*') + parser.add_argument("--seed", help="random seed", default=default_config.seed, type=int) + parser.add_argument("--lang", help="language (fr, en, zh)", default=default_config.lang, type=str) + parser.add_argument("--flavor", help="bert flavor in transformers model zoo", default=default_config.flavor, type=str) + parser.add_argument("--max-length", help="maximum input length", default=default_config.max_length, type=int) + parser.add_argument("--batch-size", help="size of batches", default=default_config.batch_size, type=int) + parser.add_argument("--device", help="computation device (cuda, cpu)", default=default_config.device, type=str) + parser.add_argument("--debug", help="whether to output more debug info", default=default_config.debug, type=bool) + parser.add_argument("--updates", help="number of training updates to perform", default=default_config.updates, type=bool) + parser.add_argument("--period", help="validation period in updates", default=default_config.period, type=bool) + parser.add_argument("--lr", help="learning rate", default=default_config.lr, type=bool) + parser.add_argument("--dab-rate", help="drop at boundaries rate", default=default_config.dab_rate, type=bool) + config = Config(**parser.parse_args().__dict__) + + main(config, config.action, config.action_args) + + diff --git a/requirements.txt b/requirements.txt index b04e8ae..22f3fcd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,7 @@ webvtt-py vosk +numpy +regex +torch +tqdm +transformers diff --git a/setup.py b/setup.py index 630d58d..166ba66 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ def read(filename): author='Martin Wygas', author_email='mwygas@uos.de', license='Apache-2.0', - packages=['voskcli'], + packages=['voskcli', 'recasepunc'], license_files=('LICENSE'), include_package_data=True, install_requires=read('requirements.txt').split(), diff --git a/voskcli/__init__.py b/voskcli/__init__.py new file mode 100644 index 0000000..ca3f621 --- /dev/null +++ b/voskcli/__init__.py @@ -0,0 +1,2 @@ +# Have to be imported for punctuation in this file!! +from recasepunc.recasepunc import WordpieceTokenizer # noqa: F401 diff --git a/voskcli/transcribe.py b/voskcli/transcribe.py old mode 100644 new mode 100755 index 0d58c6a..8584f54 --- a/voskcli/transcribe.py +++ b/voskcli/transcribe.py @@ -17,6 +17,9 @@ from vosk import Model, KaldiRecognizer, SetLogLevel from webvtt import WebVTT, Caption from argparse import ArgumentParser +from recasepunc.recasepunc import CasePuncPredictor +import sys +import voskcli import os import subprocess import json @@ -106,7 +109,7 @@ def write_captions_paragraph(vtt, paragraph): vtt.captions.append(caption) -def write_webvtt_captions(rec_results): +def write_webvtt_captions(result_list): ''' Process transcription data. @@ -126,29 +129,25 @@ def write_webvtt_captions(rec_results): line = [] paragraph = [] char_count = 0 - for i, rec_result in enumerate(rec_results): - result = json.loads(rec_result).get('result') - if not result: - continue + for i, result in enumerate(result_list): # main logic for the captions "format" # (words per line and lines per paragraph) - for entry in result: - char_count += len(entry['word']) - if char_count > MAX_CHARS_PER_LINE and len(line) != 0: - if len(paragraph) == MAX_LINES_IN_PARAGRAPH: - write_captions_paragraph(vtt, paragraph) - paragraph = [line] - line = [entry] - char_count = len(entry['word']) - continue - else: - paragraph.append(line) - line = [entry] - char_count = len(entry['word']) + char_count += len(result['word']) + if char_count > MAX_CHARS_PER_LINE and len(line) != 0: + if len(paragraph) == MAX_LINES_IN_PARAGRAPH: + write_captions_paragraph(vtt, paragraph) + paragraph = [line] + line = [result] + char_count = len(result['word']) + continue else: - line.append(entry) - char_count += 1 # add 1 because of whitespace + paragraph.append(line) + line = [result] + char_count = len(result['word']) + else: + line.append(result) + char_count += 1 # add 1 because of whitespace # write the remaining words into the captions file if len(paragraph) != 0: @@ -164,10 +163,9 @@ def write_webvtt_captions(rec_results): return vtt -def transcribe(inputFile, outputFile, model): +def transcribe(inputFile, outputFile, model, punc): ''' - Produce transcription. - + Produce transcription. Create transcription data from inputFile, process data and save finished transcription to outputFile. @@ -189,6 +187,8 @@ def transcribe(inputFile, outputFile, model): process = subprocess.Popen(command, stdout=subprocess.PIPE) rec_results = [] + result_list = [] + case_result_list = [] while True: data = process.stdout.read(4000) if len(data) == 0: @@ -197,10 +197,59 @@ def transcribe(inputFile, outputFile, model): rec_results.append(rec.Result()) rec_results.append(rec.FinalResult()) - vtt = write_webvtt_captions(rec_results) + print('Finished transcribing...') + + if punc != "": + print(f'Start punctuating with model {punc}') + # Punctuation + # Load text from json + text = '' + for rec_result in rec_results: + result = json.loads(rec_result).get('result') + if not result: + continue + text += ' '.join([entry['word'] for entry in result])+' ' + + # Predicts Punctuation of text + # Manipulate main to be able to load model + old_main = sys.modules['__main__'] + sys.modules['__main__'] = voskcli + predictor = CasePuncPredictor(punc + '/checkpoint') + sys.modules['__main__'] = old_main + + # Beginning punctuation + tokens = list(enumerate(predictor.tokenize(text))) + case_result = "" + predicted = predictor.predict(tokens, lambda x: x[1]) + for token, case_label, punc_label in predicted: + map_label = predictor.map_case_label(token[1], case_label) + prediction = predictor.map_punc_label(map_label, punc_label) + if token[1][0] != '#'\ + and prediction != '-'\ + and not case_result.endswith('-'): + case_result = case_result + ' ' + prediction + else: + case_result = case_result + prediction + case_result_list = case_result.split(" ") + print('Finished punctuating...') + else: + print('No punctuating wished...') - # save webvtt - print('Finished transcribing. Saving WebVTT file...') + # Creating array for next function + word = 1 + for rec_result in rec_results: + result = json.loads(rec_result).get('result') + if not result: + continue + if punc != "": + for entry in result: + entry['word'] = case_result_list[word] + word += 1 + result_list += result + vtt = write_webvtt_captions(result_list) + + # save WebVTT + print('Finished writing. Saving WebVTT file...') vtt.save(outputFile) print('WebVTT saved.') # print(vtt.content) @@ -225,9 +274,15 @@ def main(): 'media file. Value will be checked in the following ' 'order: 1. value as system path. 2. Value in local ' './model folder. 3. Value in /usr/share/vosk/models/.') + parser.add_argument('-p', '--punctuation', type=str, dest='punc', + help='The punctuation model to use for punctuate the ' + 'media file. Value will be checked in the following ' + 'order: 1. value as system path. 2. Value in local ' + './model folder. 3. Value in /usr/share/vosk/models/.') args = parser.parse_args() inputFile = args.inputFile + puncuationFile = args.punc outputFile = args.outputFile if args.language: model = '/usr/share/vosk/language/' + args.language @@ -235,5 +290,8 @@ def main(): else: model = args.model model = model_path(model) + punc = "" + if puncuationFile: + punc = model_path(puncuationFile) - transcribe(inputFile, outputFile, model) + transcribe(inputFile, outputFile, model, punc)