diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index e5a82dfda2..d0c65377c2 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,30 @@ ## Results +### zapformer (zapformer + pruned-transducer w/ CTC) + +Note: --num-real-epochs 40 takes about the same time as 20 epochs with the zipformer CR-CTC recipe. +(each epoch is really 3 epochs due to speed-perturb). So the time for training will be roughly 40% +of the old zipformer recipe. The "--epoch 13" reported below is the last epoch, the smaller +number of epochs has to do with the --min-copies,--max-copies, we will add this into the +report later (later epochs take more real computation time because they make different SpecAug +copies of the data.) + +# (non-streaming) +./zapformer/train.py --world-size 4 \ + --min-copies 1 --max-copies 8 --num-real-epochs 40 \ + --base-lr=0.023 --batches-per-epoch 2400 --start-epoch 1 --use-fp16 1 \ + --exp-dir zapformer/exp \ + --use-ctc 1 --use-transducer 1 \ + --base-dim 64 --ctc-loss-scale 0.2 \ + --full-libri 1 --max-duration 1200 --master-port 43039 + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| greedy_search | 1.81 | 3.73 | --epoch 13 --avg 3 | + +Note on other results: dev-clean=1.73,dev-other,3.55, giga test=16.69 giga dev=1.733. (i.e. on the model trained with Libri only). + + ### zipformer (zipformer + pruned-transducer w/ CR-CTC) See for more details. diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index d2f6db8335..a9ce6b8d36 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -18,13 +18,21 @@ import argparse import inspect +import glob import logging +import re + from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional +import numpy as np # to set its random seed + import torch +import lhotse + from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy + from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, @@ -497,3 +505,107 @@ def gigaspeech_dev_cuts(self) -> CutSet: def gigaspeech_test_cuts(self) -> CutSet: logging.info("About to get Gigaspeech test cuts") return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") + + +class GigaSpeech: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files: + + - gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz + - gigaspeech_cuts_L.jsonl.gz + - gigaspeech_cuts_M.jsonl.gz + - gigaspeech_cuts_S.jsonl.gz + - gigaspeech_cuts_XS.jsonl.gz + - gigaspeech_cuts_DEV.jsonl.gz + - gigaspeech_cuts_TEST.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_XL_cuts_split(self) -> CutSet: + logging.info("About to get train-XL cuts") + + filenames = list( + glob.glob( + f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz" # noqa + ) + ) + + pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") + idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + + sorted_filenames = [f[1] for f in idx_filenames] + + logging.info(f"Loading {len(sorted_filenames)} splits") + + return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) + + def train_XL_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_XL.jsonl.gz" + logging.info(f"About to get train-XL cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_L_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_L.jsonl.gz" + logging.info(f"About to get train-L cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_M_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_M.jsonl.gz" + logging.info(f"About to get train-M cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_S_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_S.jsonl.gz" + logging.info(f"About to get train-S cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_XS_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_XS.jsonl.gz" + logging.info(f"About to get train-XS cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def test_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + logging.info(f"About to get TEST cuts from {f}") + return load_manifest_lazy(f) + + def dev_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" + logging.info(f"About to get DEV cuts from {f}") + return load_manifest_lazy(f) + + +class CommonVoice: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files:: + + - cv22-en_cuts_train.jsonl.gz + - cv22-en_cuts_dev.jsonl.gz + - cv22-en_cuts_test.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_cuts(self) -> CutSet: + logging.info("CommonVoice: About to get train cuts") + return load_manifest_lazy( + self.manifest_dir / "cv22-en_cuts_train.jsonl.gz" + ) + + def dev_cuts(self) -> CutSet: + logging.info("CommonVoice: About to get dev cuts") + return load_manifest_lazy( + self.manifest_dir / "cv22-en_cuts_dev.jsonl.gz" + ) + + def test_cuts(self) -> CutSet: + logging.info("CommonVoice: About to get test cuts") + return load_manifest_lazy( + self.manifest_dir / "cv22-en_cuts_test.jsonl.gz" + ) diff --git a/egs/librispeech/ASR/zapformer/.gitignore b/egs/librispeech/ASR/zapformer/.gitignore new file mode 100644 index 0000000000..e47ac15828 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/.gitignore @@ -0,0 +1 @@ +swoosh.pdf diff --git a/egs/librispeech/ASR/zapformer/alternating_spec_augment.py b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py new file mode 100644 index 0000000000..e9e2faa83d --- /dev/null +++ b/egs/librispeech/ASR/zapformer/alternating_spec_augment.py @@ -0,0 +1,400 @@ +import logging +import random +from typing import Any, Dict, Optional, Sequence, Tuple, TypeVar, Union + +import torch + + + +class AlternatingSpecAugment(torch.nn.Module): + """ + AlternatingSpecAugment is a different version of feature-masking and frame-masking + aspects of SpecAugment, without the time warping for now (we use code for time_warp + adapted from lhotse which is the same as the original SpecAugment). + + The main difference is in how it selects the regions to be masked, they are selected + in a way that usually ensures there is a good amount of space between successive masks. + We also use a relatively large temporal masked-fraction (max_frame_mask_fraction) + and have the number of masks be selected proportional to the utterance length. + """ + def __init__( + self, + max_feature_mask_fraction: float = 0.675, # max fraction that can possibly be masked; the expected masked-fraction is half of this. + num_feature_masks: int = 2, + max_frame_mask_fraction: float = 0.725, # the expected temporal masked-fraction is half of this. + max_frame_mask_size: float = 70, # max size in frames of temporal masks. + p=0.9, # probability of doing core SpecAug augmentation + time_warp_p=0.9, # probability of doing time warping. + time_warp_factor=80, # as in original SpecAug paper. + seed=None, # if you leave this as none it will use torch.randint(0, 100000, ()).item() + ): + super().__init__() + assert 0 <= p <= 1 + assert 0 <= max_feature_mask_fraction <= 1 + assert 0 <= max_frame_mask_fraction <= 1 + assert 0 <= max_frame_mask_size + assert 0 <= num_feature_masks + + self.max_feature_mask_fraction = max_feature_mask_fraction + self.num_feature_masks = num_feature_masks + self.max_frame_mask_fraction = max_frame_mask_fraction + self.max_frame_mask_size = max_frame_mask_size + self.p = p + + self.time_warp_p = time_warp_p + self.time_warp_factor = time_warp_factor + + self.seed = seed + self.device_to_generator = dict() + + def get_generator(self, device): + try: + return self.device_to_generator[str(device)] + except KeyError: + gen = torch.Generator(device) + gen.manual_seed(self.seed if self.seed is not None else torch.randint(0, 100000, ()).item()) + self.device_to_generator[str(device)] = gen + return gen + + + def forward( + self, + features: torch.Tensor, + feature_lens: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Do augmentaiton and return modified features. + features: (batch_size, seq_len, num_channels) + feature_lens: (batch_size,), contains sequence lengths 0 < feature_lens <= seq_len + """ + if self.time_warp_p > 0: + features = time_warp(features, + p=self.time_warp_p, + time_warp_factor=self.time_warp_factor, + feature_lens=feature_lens, + generator=self.get_generator(torch.device('cpu'))) + if self.p > 0: + features = self.forward_masking(features) + return features + + def forward_masking( + self, + features: torch.Tensor, + ) -> torch.Tensor: + """ + Computes ExpAugment for a batch of feature matrices. Caution: for time warping + the user should call self.time_warp() separately. It's a class member for purposes + of keeping track of generators. + + Since the batch will usually already be padded, the user can optionally + provide a ``supervision_segments`` tensor that will be used to apply SpecAugment + only to selected areas of the input. The format of this input is described below. + + :param features: a batch of feature matrices with shape ``(B, T, F)``. + + :return: an augmented tensor of shape ``(B, T, F)``. + """ + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of " "single-channel feature matrices." + ) + B, T, F = features.shape + features = features.clone() + + mean = features.mean() + + features_unmasked = features + + if self.num_feature_masks > 0: + num_masks = self.num_feature_masks + features = self._mask_on_axis(features, mean, axis=2, + max_mask_fraction=self.max_feature_mask_fraction, + num_masks=num_masks) + + + if self.max_frame_mask_fraction > 0: + num_masks = max(1, round((T * self.max_frame_mask_fraction) / self.max_frame_mask_size)) + features = self._mask_on_axis(features, mean, axis=1, + max_mask_fraction=self.max_frame_mask_fraction, + num_masks=num_masks) + + generator = self.get_generator(features.device) + features = torch.where(torch.rand(B, 1, 1, device=features.device, generator=generator).expand_as(features) < self.p, + features, features_unmasked) + + return features + + def _mask_on_axis(self, + features: torch.Tensor, + mean: torch.Tensor, + axis: int, + max_mask_fraction: float, + num_masks: int) -> torch.Tensor: + """ + Mask ``features`` on a particular axis by replacing masked segments of that sequence with + ``mean``. + + :param features: a batch of feature matrices with shape ``(B, T, F)``. + :param mean: the overall feature-matrix mean, a scalar. + :param axis: the axis to mask on, i.e. 1 for time, 2 for frequency/feature. + :param max_mask_fraction: the maximum fraction of the data to mask (expected value will be + close to half of this.) + :param num_masks: the number of masked regions. + """ + assert axis in [1,2] + # num_regions refers to regions including 'exterior' regions + device = features.device + shape = list(features.shape) + B = shape[0] + M = num_masks + N = shape[axis] # T or F + + mask_starts, mask_ends = self._sample_mask_starts_and_ends(B, N, num_masks, max_mask_fraction, device) + + mask_boundaries = torch.cat((mask_starts, mask_ends), dim=1) + + # round down to next integer. + mask_boundaries = mask_boundaries.to(torch.long).clamp(min=0, max=N-1) + + + # _masks: (B, M, N) + _masks = torch.logical_and(torch.arange(N, device=device) >= mask_starts[..., None], + torch.arange(N, device=device) <= mask_ends[..., None]).to(torch.float) + _masks = torch.sum(_masks, dim=1).clamp(max=1) + + is_mask_start = torch.cat((torch.ones(B, M, dtype=torch.bool, device=device), + torch.zeros(B, M, dtype=torch.bool, device=device)), + dim=1) + + mask_boundaries, indexes = mask_boundaries.sort(dim=1) + is_mask_start = torch.gather(is_mask_start, dim=1, index=indexes) + not_mask_start = torch.logical_not(is_mask_start) + + # is_not_repeat is 1 if this element of mask_boundaries is not a repeat of the same boundary + # type as the previous boundary, i.e. mask start or mask end. + + keep_boundary = torch.ones(B, 2 * M, device=device, dtype=torch.bool) + # the following says: set to False all elements of keep_boundary where both this and the previous + # element is a mask start. I.e. remove redundant mask-starts corresponding to overlapping masks. + keep_boundary[:, 1:][torch.logical_and(is_mask_start[:, :-1], is_mask_start[:, 1:])] = False + # the following says: set to False all elements of keep_boundary where both this and the next + # element are mask ends. I.e. remove redundant mask-ends corresponding to overlapping masks. + keep_boundary[:, :-1][torch.logical_and(not_mask_start[:, :-1], not_mask_start[:, 1:])] = False + + keep_boundary = keep_boundary.to(dtype=torch.long) + cumsum = torch.zeros(B, N, device=device, dtype=torch.long) + cumsum.scatter_add_(index=mask_boundaries, dim=1, src=keep_boundary) + + + cumsum = torch.cumsum(cumsum, dim=1) + + is_masked = (cumsum % 2) == 1 # (B, N), is True at spots to mask. + if axis == 1: + is_masked = is_masked.unsqueeze(-1) + else: + is_masked = is_masked.unsqueeze(1) + + return torch.where(is_masked.expand_as(features), mean[None, None, None].expand_as(features), features) + + + def _sample_mask_starts_and_ends(self, batch_size, seq_len, num_masks, max_mask_fraction, device) -> Tuple[Tuple,Tuple]: + generator = self.get_generator(device) + # we imagine there are "pairs of sequences" for historical reasons but one of each pair is not + # a real sequence. + B = batch_size + # M is the number of masks we sample for each "pair of sequences." (i.e. for each sequence and its + # imaginary twin) + M = 2 * num_masks + + # "rlength" means relative length of each mask, i.e. relative to seq_len. the + # lengths in mask_lengths are normalized lengths. + mask_rlengths = torch.rand(B, M, device=device, generator=generator) * (max_mask_fraction / num_masks) + #if (seq_len + batch_size) % 10 == 0: # pseudo-randomly print the random numbers. i want to test repeatability. + # logging.info(f"mask_rlengths: {mask_rlengths.flatten()[:10]}") + mask_tot_rlen = mask_rlengths.sum(dim=1, keepdim=True) # (batch_size, 1) + + # padding_tot_rlen is the total relative length of the padding segmnts. We clamp to min=0.25 + # so there is some randomness in the positions even if the selected masks are unusually large. + # (note: we expect the max_fraction values to be between about .5 to .7, so the expected-masked-fraction + # values would be about .25 to 0.35 (since we sample between 0 and maximum); and if we double + # it because we do the selection for pairs of masked regions, that gives us about .5 to .7. + # so definitely this clamping will happen for less than half of the pairs of sequences. + + padding_tot_rlen = (1. - mask_tot_rlen).clamp(min=0.2) # (batch_size, 1) + + # get padding lengths by randomly placing dividers on the line of length "padding_tot_rlen" + # P is the number of padding regions for each pair of sequences. + P = M + 1 + # rpositions means positions expressed in relative length, i.e. normalized so that + # seq_len is 1. + padding_rpositions = torch.rand(B, P - 1, device=device, generator=generator) * padding_tot_rlen + padding_rpositions = padding_rpositions.sort(dim=1)[0] + zero = torch.zeros(B, 1, device=device) + padding_rpositions = torch.cat((zero, padding_rpositions, padding_tot_rlen), dim=1) + padding_rlengths = padding_rpositions[:, 1:] - padding_rpositions[:, :-1] + + # 'rlengths' are the normalized lengths of the padding regions and the masks. + rlengths = torch.empty(B, 2 * M + 1, device=device) + rlengths[:, 1::2] = mask_rlengths + rlengths[:, 0::2] = padding_rlengths + + # lengths is the lengths of the masks and padding regions, converted to absolute + # length. We have to normalize before multiplying by seq_len because of the .clamp() + # operation above-- not all sequences will sum to one. + lengths = (rlengths / rlengths.sum(dim=1, keepdim=True)) * seq_len + + positions = torch.cumsum(lengths, dim=1) + # last element of 'positions' should be seq_len + assert torch.all((positions[:, -1] - seq_len).abs() < 0.0001 * seq_len) + + # positions does not have a leading zero, cumsum is inclusive; but do not treat final `seq_len` as a mask start position. + mask_starts = positions[:, 0:-1:2] + mask_ends = positions[:, 1::2] + assert mask_starts.shape == (B, M) and mask_ends.shape == (B, M) + + + # letting the start-position when we take alternating positions be + # randomly 0 or 1 avoids any overall bias towards the start or end of + # the sequence. + index = torch.randint(0, 2, (B,), device=device, generator=generator).unsqueeze(-1) + torch.arange(0, M, step=2, device=device) + mask_starts = torch.gather(mask_starts, dim=1, index=index) + mask_ends = torch.gather(mask_ends, dim=1, index=index) + + return mask_starts, mask_ends + + + +def time_warp_impl(features: torch.Tensor, factor: int, + generator: Optional[torch.Generator]) -> torch.Tensor: + """ + # modified from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py#L338C1-L369C1 + # to use torch rng rather than the numpy one, this has to do with which rngs + # are synchronized and which are not. (we keep the numpy and python rng's synchronized + # for the sake of lhotse's sampler code, where they need to be synchronized to avoid data + # overlap). + + Time warping as described in the SpecAugment paper. + Implementation based on Espresso: + https://github.com/freewym/espresso/blob/master/espresso/tools/specaug_interpolate.py#L51 + + :param features: input tensor of shape ``(T, F)`` + :param factor: time warping parameter. + :return: a warped tensor of shape ``(T, F)`` + """ + t = features.size(0) + if t - factor <= factor + 1: + return features + center = torch.randint(factor + 1, t - factor, (), generator=generator).item() + warped = torch.randint(center - factor, center + factor + 1, (), generator=generator).item() + if warped == center: + return features + features = features.unsqueeze(0).unsqueeze(0) + left = torch.nn.functional.interpolate( + features[:, :, :center, :], + size=(warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + right = torch.nn.functional.interpolate( + features[:, :, center:, :], + size=(t - warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) + + +# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py +# it does not differ substantively from that; only, it accepts feature_lens rather than supervision +# segments, and uses torch as the random number generator. +def time_warp( + features: torch.Tensor, + p: float = 0.9, + time_warp_factor: Optional[int] = 80, + feature_lens: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, # generator for CPU only +): + if time_warp_factor is None or time_warp_factor < 1: + return features + assert ( + len(features.shape) == 3 + ), f"SpecAugment only supports batches of single-channel feature matrices. {features.shape}" + features = features.clone() + + # we use torch.rand(1).item() instead of random.random() for easier control of generators + # that is more consistent with GPU generators. + if feature_lens is None: + # No feature_lens - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + if torch.rand(1, generator=generator).item() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx] = time_warp_impl( + features[sequence_idx], factor=time_warp_factor, generator=generator, + ) + else: + for sequence_idx, num_frames in enumerate(feature_lens): + if torch.rand(1, generator=generator).item() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx, :num_frames] = time_warp_impl( + features[sequence_idx, :num_frames], factor=time_warp_factor, + generator=generator, + ) + + return features + + + + +def _test_alternating_spec_augment(): + for n in [ 0, 1 ]: + #device = 'cuda' + B, T, F = 301, 600, 80 + device = 'cpu' + + if n == 0: + aspec_augment = AlternatingSpecAugment(time_warp_p=0.0) + else: + from lhotse.dataset import SpecAugment + time_mask_ratio = 3.5 + num_frame_masks = int(10 * time_mask_ratio) + max_frames_mask_fraction = 0.15 * time_mask_ratio + print( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + p=0.9, + ) + supervision_segments = torch.stack(( + torch.arange(B, device=device), # sequence_idx + torch.zeros(B, device=device, dtype=torch.long), # start_frame + T * torch.ones(B, device=device, dtype=torch.long) # num_frames + ), dim=-1) + aspec_augment = lambda x: spec_augment(x, supervision_segments) + + features = torch.randn(B, T, F, device=device) + + features = aspec_augment(features) + + frame_is_masked = features[:, :, 0] == features[:, :, -1] + print("mean frame_is_masked = ", frame_is_masked.to(torch.float).mean()) + + print("mean frame_is_masked[per-frame][::10] = ", frame_is_masked.to(torch.float).mean(dim=0)[::10]) + feature_is_masked = features[:, 0] == features[:, -1] + print("mean feature_is_masked = ", feature_is_masked.to(torch.float).mean()) + print("mean feature_is_masked[per-freq] = ", feature_is_masked.to(torch.float).mean(dim=0)) + + + + +# from lhotse.dataset import SpecAugment + +if __name__ == '__main__': + _test_alternating_spec_augment() diff --git a/egs/librispeech/ASR/zapformer/asr_datamodule.py b/egs/librispeech/ASR/zapformer/asr_datamodule.py new file mode 100755 index 0000000000..c4a628df01 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/asr_datamodule.py @@ -0,0 +1,591 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import glob +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional +import re +import random # to set its random seed +import numpy as np # to set its random seed + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, +) +import lhotse + +# MulticopyDataset is a modified version of K2SpeechRecognitionDataset from +# lhotse.dataset, modified to, in training mode, to return a batch that has multiple +# different copies of the same data having different Musan +# augmentations and the first having none; and also include the key "num_copies" +# in the batch which would be 1 for the validation data (no Musan) and 2 for the +# different copies of the training data with musan. +from multicopy_dataset import MulticopyDataset # interface like K2SpeechRecognitionDataset + +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + random_seed = self.seed + 9999 * worker_id + random.seed(random_seed) + np.random.seed(random_seed) + +class AsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""When enabled, use 960h LibriSpeech; and 10000 hour GigaSpeech if --use-giga. Otherwise, use 100h and if applicable 250h subsets.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=float, + default=800.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch, including multiple copies, so if num_copies " + "is larger the actual duration prior to making copies will be smaller." + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + parser.add_argument( + "--libri-copies", + type=int, + default=1, + help="The number of copies of librispeech data used per epoch, i.e. per epoch of gigaspeech, if --use-giga=True." + "(it is really libri_copies times 3, because of Librispeech using speed augmentation)." + ) + + parser.add_argument( + "--use-giga", + type=str2bool, + default=False, + help="If set to True, use gigaspeech in addition to librispeech. See also --libri-copies." + ) + + parser.add_argument( + "--use-cv", + type=str2bool, + default=False, + help="If set to True, use CommonVoice in addition to librispeech. See also --libri-copies." + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + num_copies: int = 1, + seed: int = 100, # lets us specify different seed if we create data loader on different epochs. + # note: the seed has to be the same across ranks, because the samplers need to be kept in sync + # so we can divide up the data accurately. + rank: int = 0, # the torch. distributed rank, affects the seed used for + + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create train dataset") + train = MulticopyDataset( + num_copies=num_copies, + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=[], + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = MulticopyDataset( + num_copies=num_copies, + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=[], + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info(f"Using DynamicBucketingSampler, num_copies={num_copies}") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration / num_copies, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + seed=seed, + ) + else: + logging.info(f"Using SimpleCutSampler, num_copies={num_copies}") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration / num_copies, + shuffle=self.args.shuffle, + seed=seed, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # the data-loader workers do not have to be synchronized across the process-group, + # we can give them rank-dependent seeds. (There may not actually be any randomization + # at this level in this zapformer recipe though, we do SpecAug in the main process + # and I think the musan-related stuff happens in the sampler. + worker_init_fn = _SeedWorkers(seed + 4321 * rank) + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = MulticopyDataset( + num_copies=1, + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = MulticopyDataset( + num_copies=1, + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = MulticopyDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + +class LibriSpeech: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files:: + + - librispeech_cuts_dev-clean.jsonl.gz + - librispeech_cuts_dev-other.jsonl.gz + - librispeech_cuts_test-clean.jsonl.gz + - librispeech_cuts_test-other.jsonl.gz + - librispeech_cuts_train-clean-100.jsonl.gz + - librispeech_cuts_train-clean-360.jsonl.gz + - librispeech_cuts_train-other-500.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + +class GigaSpeech: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files: + + - gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz + - gigaspeech_cuts_L.jsonl.gz + - gigaspeech_cuts_M.jsonl.gz + - gigaspeech_cuts_S.jsonl.gz + - gigaspeech_cuts_XS.jsonl.gz + - gigaspeech_cuts_DEV.jsonl.gz + - gigaspeech_cuts_TEST.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_XL_cuts_split(self) -> CutSet: + logging.info("About to get train-XL cuts") + + filenames = list( + glob.glob( + f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz" # noqa + ) + ) + + pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") + idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + + sorted_filenames = [f[1] for f in idx_filenames] + + logging.info(f"Loading {len(sorted_filenames)} splits") + + return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) + + def train_XL_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_XL.jsonl.gz" + logging.info(f"About to get train-XL cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_L_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_L.jsonl.gz" + logging.info(f"About to get train-L cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_M_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_M.jsonl.gz" + logging.info(f"About to get train-M cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_S_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_S.jsonl.gz" + logging.info(f"About to get train-S cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_XS_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_XS.jsonl.gz" + logging.info(f"About to get train-XS cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def test_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + logging.info(f"About to get TEST cuts from {f}") + return load_manifest_lazy(f) + + def dev_cuts(self) -> CutSet: + f = self.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" + logging.info(f"About to get DEV cuts from {f}") + return load_manifest_lazy(f) + + +class CommonVoice: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files:: + + - cv22-en_cuts_train.jsonl.gz + - cv22-en_cuts_dev.jsonl.gz + - cv22-en_cuts_test.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_cuts(self) -> CutSet: + logging.info("CommonVoice: About to get train cuts") + return load_manifest_lazy( + self.manifest_dir / "cv22-en_cuts_train.jsonl.gz" + ) + + def dev_cuts(self) -> CutSet: + logging.info("CommonVoice: About to get dev cuts") + return load_manifest_lazy( + self.manifest_dir / "cv22-en_cuts_dev.jsonl.gz" + ) + + def test_cuts(self) -> CutSet: + logging.info("CommonVoice: About to get test cuts") + return load_manifest_lazy( + self.manifest_dir / "cv22-en_cuts_test.jsonl.gz" + ) diff --git a/egs/librispeech/ASR/zapformer/attention_decoder.py b/egs/librispeech/ASR/zapformer/attention_decoder.py new file mode 100644 index 0000000000..648be4b1e0 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/attention_decoder.py @@ -0,0 +1,583 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import List, Optional + +import k2 +import torch +import torch.nn as nn +from label_smoothing import LabelSmoothingLoss +from zapformer_utils import penalize_abs_values_gt + +from icefall.utils import add_eos, add_sos, make_pad_mask + + +class AttentionDecoderModel(nn.Module): + """ + Args: + vocab_size (int): Number of classes. + decoder_dim: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + num_heads (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int = 512, + num_decoder_layers: int = 6, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + sos_id: int = 1, + eos_id: int = 1, + dropout: float = 0.1, + ignore_id: int = -1, + label_smoothing: float = 0.1, + ): + super().__init__() + self.eos_id = eos_id + self.sos_id = sos_id + self.ignore_id = ignore_id + + # For the segment of the warmup period, we let the Embedding + # layer learn something. Then we start to warm up the other encoders. + self.decoder = TransformerDecoder( + vocab_size=vocab_size, + d_model=decoder_dim, + num_decoder_layers=num_decoder_layers, + attention_dim=attention_dim, + num_heads=num_heads, + feedforward_dim=feedforward_dim, + memory_dim=memory_dim, + dropout=dropout, + ) + + # Used to calculate attention-decoder loss + self.loss_fun = LabelSmoothingLoss( + ignore_index=ignore_id, label_smoothing=label_smoothing, reduction="sum" + ) + + def _pre_ys_in_out(self, ys: k2.RaggedTensor, ys_lens: torch.Tensor): + """Prepare ys_in_pad and ys_out_pad.""" + ys_in = add_sos(ys, sos_id=self.sos_id) + # [B, S+1], start with SOS + ys_in_pad = ys_in.pad(mode="constant", padding_value=self.eos_id) + ys_in_lens = ys_lens + 1 + + ys_out = add_eos(ys, eos_id=self.eos_id) + # [B, S+1], end with EOS + ys_out_pad = ys_out.pad(mode="constant", padding_value=self.ignore_id) + + return ys_in_pad.to(torch.int64), ys_in_lens, ys_out_pad.to(torch.int64) + + def calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys: k2.RaggedTensor, + ys_lens: torch.Tensor, + ) -> torch.Tensor: + """Calculate attention-decoder loss. + Args: + encoder_out: (batch, num_frames, encoder_dim) + encoder_out_lens: (batch,) + token_ids: A list of token id list. + + Return: The attention-decoder loss. + """ + ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) + + # decoder forward + decoder_out = self.decoder( + x=ys_in_pad, + x_lens=ys_in_lens, + memory=encoder_out, + memory_lens=encoder_out_lens, + ) + + loss = self.loss_fun(x=decoder_out, target=ys_out_pad) + return loss + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + token_ids: List[List[int]], + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from attention-decoder. + Args: + encoder_out: (batch, num_frames, encoder_dim) + encoder_out_lens: (batch,) + token_ids: A list of token id list. + + Return: A tensor of shape (batch, num_tokens). + """ + ys = k2.RaggedTensor(token_ids).to(device=encoder_out.device) + row_splits = ys.shape.row_splits(1) + ys_lens = row_splits[1:] - row_splits[:-1] + + ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) + + # decoder forward + decoder_out = self.decoder( + x=ys_in_pad, + x_lens=ys_in_lens, + memory=encoder_out, + memory_lens=encoder_out_lens, + ) + + batch_size, _, num_classes = decoder_out.size() + nll = nn.functional.cross_entropy( + decoder_out.view(-1, num_classes), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction="none", + ) + nll = nll.view(batch_size, -1) + return nll + + +class TransformerDecoder(nn.Module): + """Transfomer decoder module. + + Args: + vocab_size: output dim + d_model: decoder dimension + num_decoder_layers: number of decoder layers + attention_dim: total dimension of multi head attention + num_heads: number of attention heads + feedforward_dim: hidden dimension of feed_forward module + dropout: dropout rate + """ + + def __init__( + self, + vocab_size: int, + d_model: int = 512, + num_decoder_layers: int = 6, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + dropout: float = 0.1, + ): + super().__init__() + self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model) + + # Absolute positional encoding + self.pos = PositionalEncoding(d_model, dropout_rate=0.1) + + self.num_layers = num_decoder_layers + self.layers = nn.ModuleList( + [ + DecoderLayer( + d_model=d_model, + attention_dim=attention_dim, + num_heads=num_heads, + feedforward_dim=feedforward_dim, + memory_dim=memory_dim, + dropout=dropout, + ) + for _ in range(num_decoder_layers) + ] + ) + + self.output_layer = nn.Linear(d_model, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + memory: Optional[torch.Tensor] = None, + memory_lens: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch, tgt_len). + x_lens: A tensor of shape (batch,) containing the number of tokens in `x` + before padding. + memory: + Memory sequence of shape (batch, src_len, memory_dim). + memory_lens: + A tensor of shape (batch,) containing the number of frames in + `memory` before padding. + + Returns: + Decoded token logits before softmax (batch, tgt_len, vocab_size) + """ + x = self.embed(x) # (batch, tgt_len, embed_dim) + x = self.pos(x) # (batch, tgt_len, embed_dim) + + x = x.permute(1, 0, 2) # (tgt_len, batch, embed_dim) + + # construct attn_mask for self-attn modules + padding_mask = make_pad_mask(x_lens) # (batch, tgt_len) + causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len) + attn_mask = torch.logical_or( + padding_mask.unsqueeze(1), # (batch, 1, seq_len) + torch.logical_not(causal_mask).unsqueeze(0), # (1, seq_len, seq_len) + ) # (batch, seq_len, seq_len) + + if memory is not None: + memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim) + # construct memory_attn_mask for cross-attn modules + memory_padding_mask = make_pad_mask(memory_lens) # (batch, src_len) + memory_attn_mask = memory_padding_mask.unsqueeze(1) # (batch, 1, src_len) + else: + memory_attn_mask = None + + for i, mod in enumerate(self.layers): + x = mod( + x, + attn_mask=attn_mask, + memory=memory, + memory_attn_mask=memory_attn_mask, + ) + + x = x.permute(1, 0, 2) # (batch, tgt_len, vocab_size) + x = self.output_layer(x) + + return x + + +class DecoderLayer(nn.Module): + """Single decoder layer module. + + Args: + d_model: equal to decoder_dim, total dimension of the decoder + attention_dim: total dimension of multi head attention + num_heads: number of attention heads + feedforward_dim: hidden dimension of feed_forward module + dropout: dropout rate + """ + + def __init__( + self, + d_model: int = 512, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + dropout: float = 0.1, + ): + """Construct an DecoderLayer object.""" + super(DecoderLayer, self).__init__() + + self.norm_self_attn = nn.LayerNorm(d_model) + self.self_attn = MultiHeadAttention( + d_model, attention_dim, num_heads, dropout=0.0 + ) + + self.norm_src_attn = nn.LayerNorm(d_model) + self.src_attn = MultiHeadAttention( + d_model, attention_dim, num_heads, memory_dim=memory_dim, dropout=0.0 + ) + + self.norm_ff = nn.LayerNorm(d_model) + self.feed_forward = nn.Sequential( + nn.Linear(d_model, feedforward_dim), + Swish(), + nn.Dropout(dropout), + nn.Linear(feedforward_dim, d_model), + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + memory: Optional[torch.Tensor] = None, + memory_attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: Input sequence of shape (seq_len, batch, embed_dim). + attn_mask: A binary mask for self-attention module indicating which + elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + memory: Memory sequence of shape (seq_len, batch, memory_dim). + memory_attn_mask: A binary mask for cross-attention module indicating which + elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + """ + # self-attn module + qkv = self.norm_self_attn(x) + self_attn_out = self.self_attn( + query=qkv, key=qkv, value=qkv, attn_mask=attn_mask + ) + x = x + self.dropout(self_attn_out) + + # cross-attn module + q = self.norm_src_attn(x) + src_attn_out = self.src_attn( + query=q, key=memory, value=memory, attn_mask=memory_attn_mask + ) + x = x + self.dropout(src_attn_out) + + # feed-forward module + x = x + self.dropout(self.feed_forward(self.norm_ff(x))) + + return x + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, but must be a multiple of num_heads. + num_heads: number of parallel attention heads. + memory_dim: dimension of memory embedding, optional. + dropout: a Dropout layer on attn_output_weights. + """ + + def __init__( + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + memory_dim: Optional[int] = None, + dropout: float = 0.0, + ): + super(MultiHeadAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.head_dim = attention_dim // num_heads + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True) + self.linear_k = nn.Linear( + embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True + ) + self.linear_v = nn.Linear( + embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True + ) + + self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute dot product attention. + + Args: + query: Query tensor of shape (tgt_len, batch, embed_dim). + key: Key tensor of shape (src_len, batch, embed_dim or memory_dim). + value: Value tensor of shape (src_len, batch, embed_dim or memory_dim). + key_padding_mask: A binary mask indicating which elements are padding. + Its shape is (batch, src_len). + attn_mask: A binary mask indicating which elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + + Returns: + Output tensor of shape (tgt_len, batch, embed_dim). + """ + num_heads = self.num_heads + head_dim = self.head_dim + + tgt_len, batch, _ = query.shape + src_len = key.shape[0] + + q = self.linear_q(query) # (tgt_len, batch, num_heads * head_dim) + k = self.linear_k(key) # (src_len, batch, num_heads * head_dim) + v = self.linear_v(value) # (src_len, batch, num_heads * head_dim) + + q = q.reshape(tgt_len, batch, num_heads, head_dim) + q = q.permute(1, 2, 0, 3) # (batch, head, tgt_len, head_dim) + k = k.reshape(src_len, batch, num_heads, head_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, head_dim, src_len) + v = v.reshape(src_len, batch, num_heads, head_dim) + v = v.reshape(src_len, batch * num_heads, head_dim).transpose(0, 1) + + # Note: could remove the scaling operation when using ScaledAdam + # (batch, head, tgt_len, src_len) + attn_weights = torch.matmul(q, k) / math.sqrt(head_dim) + + # From zipformer.py: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_weights = penalize_abs_values_gt(attn_weights, limit=50.0, penalty=1.0e-04) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + + if attn_mask is not None: + assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == ( + batch, + tgt_len, + src_len, + ), attn_mask.shape + attn_weights = attn_weights.masked_fill( + attn_mask.unsqueeze(1), float("-inf") + ) + + attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + # (batch * head, tgt_len, head_dim) + attn_output = torch.bmm(attn_weights, v) + assert attn_output.shape == ( + batch * num_heads, + tgt_len, + head_dim, + ), attn_output.shape + + attn_output = attn_output.transpose(0, 1).contiguous() + attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim) + + # (batch, tgt_len, embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class PositionalEncoding(nn.Module): + """Positional encoding. + Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35. + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def subsequent_mask(size, device="cpu", dtype=torch.bool): + """Create mask for subsequent steps (size, size). + + :param int size: size of mask + :param str device: "cpu" or "cuda" or torch.Tensor.device + :param torch.dtype dtype: result dtype + :rtype: torch.Tensor + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + ret = torch.ones(size, size, device=device, dtype=dtype) + return torch.tril(ret, out=ret) + + +def _test_attention_decoder_model(): + m = AttentionDecoderModel( + vocab_size=500, + decoder_dim=512, + num_decoder_layers=6, + attention_dim=512, + num_heads=8, + feedforward_dim=2048, + memory_dim=384, + dropout=0.1, + sos_id=1, + eos_id=1, + ignore_id=-1, + ) + + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of model parameters: {num_param}") + + m.eval() + encoder_out = torch.randn(2, 50, 384) + encoder_out_lens = torch.full((2,), 50) + token_ids = [[1, 2, 3, 4], [2, 3, 10]] + + nll = m.nll(encoder_out, encoder_out_lens, token_ids) + print(nll) + + +if __name__ == "__main__": + _test_attention_decoder_model() diff --git a/egs/librispeech/ASR/zapformer/batched_rubik.py b/egs/librispeech/ASR/zapformer/batched_rubik.py new file mode 100644 index 0000000000..ca0e92159b --- /dev/null +++ b/egs/librispeech/ASR/zapformer/batched_rubik.py @@ -0,0 +1,696 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import math +import logging +import random +from collections import defaultdict +from torch.optim.lr_scheduler import LambdaLR + +from typing import Dict, List, Optional, Tuple, Union +import torch +import torch.distributed as dist +from torch import Tensor +from torch.optim import Optimizer + +#COMPUTE_DTYPE = torch.float32 +COMPUTE_DTYPE = torch.bfloat16 + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager + def batched_params(self, param_list): + """ + This function returns (technically, yields) a list of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. + """ + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + + for p in param_list: + key = (str(p.dtype), *p.shape) + batches[key].append(p) + + + old_batches = batches.values() # a list of lists + # Now split up any batches that are too large. + batches = [ ] + for b in old_batches: + num_tensors = len(b) + num_bytes = num_tensors * b[0].numel() * 4 # total bytes in group of tensors, assuming float + max_bytes = 2 ** 30 # 1024**3 == one gigabyte + num_groups = min(num_tensors, (num_bytes + max_bytes - 1) // max_bytes) + group_size = (num_tensors + num_groups - 1) // num_groups + tot = 0 + for g in range(num_groups): + batches.append(b[g*group_size:(g+1)*group_size]) + tot += len(batches[-1]) + assert tot == num_tensors + + + # tuples will contain tuples of (stacked_param, state), + # one for each batch in `batches`. + tuples = [] + + for batch in batches: + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) + p_stacked.grad = grad + tuples.append((p_stacked, state)) + + yield tuples # <-- calling code will do the actual optimization here! + + for ((stacked_params, _state), batch) in zip(tuples, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + + +def three_way_product(x): + """ returns the 3-way matrix product x @ x.t() @ x """ + if x.shape[-2] <= x.shape[-1]: + x2 = torch.matmul(x, x.transpose(-2, -1)) + return torch.matmul(x2, x) + else: + x2 = torch.matmul(x.transpose(-2, -1), x) + return torch.matmul(x, x2) + + +def scaled_three_way_product(x): + """ + Returns alpha * (x @ x.t() @ x), + where alpha is computed from the 2-norm of x in such a way that if all the singular values of + x are the same, it will return x itself. (There is only one such formula.) If the singular + values of x differ from each other, the result will in general have a larger norm than x. + """ + rows, cols = x.shape[-2], x.shape[-1] + eps = 1.0e-40 + x_meansq = (x ** 2).mean(dim=(-2, -1), keepdim=True) + eps + x = x * (x_meansq * max(rows, cols)) ** (-1/3) + return three_way_product(x) + +def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: + """ + Solve the equation: ||x + alpha y||_2^2 == ||beta x||_2^2 + + x.x + 2 alpha y.x + alpha^2 y.y = beta^2 x.x + alpha^2 y.y + 2 alpha x.y + (1-beta^2) x.x = 0 + (a,b,c) = (y.y, 2 alpha x.y, x.x) + alpha = (-b + sqrt(b^2 - 4ac) ) / 2a # this is the solution closest to zero. + # treat the thing inside the sqrt as zero if + # negative, this + # factoring out 2 from the top and bottom we get: + so alpha = (-x.y + sqrt(x.y * y.x - (1-beta^2) x.x * y.y)) / y.y + ... we treat the thing inside the sqrt as zero if it is negative, + which gives us the closest real solution + """ + eps = 1.0e-40 + xx = x.square().mean(dim=(1, 2), keepdim=True) + xy = (x * y).mean(dim=(1, 2), keepdim=True) + yy = y.square().mean(dim=(1, 2), keepdim=True) + yyeps = yy + eps + + # this alpha is the value that solves exactly for the requested difference in norm. + # this will be negative. + alpha = (-xy + (xy**2 - (1-beta*beta) * xx * yy).clamp(min=0).sqrt()) / yyeps + + # min_sum_scale is the value of alpha that would minimize the norm of a + alpha y. + min_sum_scale = -xy / yyeps + # safety_factor = 0.5 means we are only willing to go halfway to that value that minimizes the norm, + # to avoid change of eigenvalue sign / overshoot, which can ultimately lead to certain + # parameter eigenvalues getting too large. + safety_factor = 0.5 + + return torch.maximum(safety_factor * min_sum_scale, alpha) # return the closet to zero of these two formulae. + + +def matrix_shape(shape): + """ + shape is expected to be a torch.Size or a list with at least two dimensions. + Returns (rows, cols) such that a tensor of shape `shape` can be reshaped + to size (rows, cols), by combining dimensions in a way that minimizes the + difference between rows and cols. e.g. matrix_shape([ 2, 3, 10 ]) = (6, 10) + """ + shape = list(shape) + cumprod = [ ] + numel = 1 + for k in shape: + numel = numel * k + cumprod.append(numel) + diffs = [ abs(k - numel // k) for k in cumprod ] + min_diff = min(diffs) + for i in range(len(shape)): + if diffs[i] == min_diff: + return cumprod[i], numel // cumprod[i] + assert False, shape + + + +def half_normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): + """ + Normalize the rms of x using row-wise and column-wise stats, while + updating the moving-average stats; return the normalized x. + Shapes: + x: (batch_size, rows, cols) +row_stats: (batch_size, rows, 1) +col_stats: (batch_size, 1, cols) + Returns: + normalized x, shape: (batch_size, rows, cols) + """ + row_stats.mul_(beta2).add_(x.abs().mean(dim=2, keepdim=True), alpha=(1 - beta2)) + row_denom = (row_stats + eps) + x = x / row_denom + col_stats.mul_(beta2).add_(x.abs().mean(dim=1, keepdim=True), alpha=(1 - beta2)) + col_denom = (col_stats + eps) + row_denom_sqrt = row_denom.sqrt() + col_denom_sqrt = col_denom.sqrt() + x_half_norm = (x * row_denom_sqrt) / col_denom_sqrt + x = x / col_denom + invP = row_denom * col_denom + return x, x_half_norm, invP + + + +def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): + """ + Normalize the rms of x using row-wise and column-wise stats, while + updating the moving-average stats; return the normalized x. + Shapes: + x: (batch_size, rows, cols) +row_stats: (batch_size, rows, 1) +col_stats: (batch_size, 1, cols) + Returns: + normalized x, shape: (batch_size, rows, cols) + """ + # use squared norm to save memory + row_stats.mul_(beta2).add_(x.square().mean(dim=2, keepdim=True), alpha=(1 - beta2)) + row_denom = (row_stats.sqrt() + eps) + x = x / row_denom + col_stats.mul_(beta2).add_(x.square().mean(dim=1, keepdim=True), alpha=(1 - beta2)) + col_denom = (col_stats.sqrt() + eps) + x = x / col_denom + return x + + + +# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2) +# From https://arxiv.org/pdf/2505.16932 +polar_express_coeffs = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +] + +#@torch.compile(dynamic=False, fullgraph=True) +def muon_step_fused( + stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients + momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer + second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment + momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient + lr_t: Tensor, # () - 0-D CPU tensor, learning rate + beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment + eps: Tensor, + ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations + red_dim: int, # -1 or -2 - reduction dimension for variance +) -> Tensor: + """ + Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update + All in one compiled graph to eliminate Python overhead between ops. + Some of the constants are 0-D CPU tensors to avoid recompilation when values change. + """ + + # Nesterov momentum + momentum = momentum_t.to(stacked_grads.dtype) + momentum_buffer.lerp_(stacked_grads, 1 - momentum) + g = stacked_grads.lerp_(momentum_buffer, momentum) + + # apply the same normalization both before and after + # the core muon step, the symmetry ensures it is a descent direction. + g = g / (second_momentum_buffer.sqrt() + eps).to(g.dtype) + + # Polar express + # Cast to bf16 for speed when available; skip cast otherwise (fp16 is unstable here due to limited exponent range) + X = g.bfloat16() if COMPUTE_DTYPE == torch.bfloat16 else g + X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6) + if g.size(-2) > g.size(-1): # Tall matrix + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X.mT @ X + B = b * A + c * (A @ A) + X = a * X + X @ B + else: # Wide matrix (original math) + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + g = X + + # Variance normalization + beta2 = beta2_t.to(second_momentum_buffer.dtype) + v_mean = g.float().square().mean(dim=red_dim, keepdim=True) + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + g = g / (second_momentum_buffer.sqrt() + eps).to(g.dtype) + lr = lr_t.to(g.dtype) + beta1 = momentum_t.to(g.dtype) + # assumed scale of step size if it arose from momentum decay from i.i.d. variance-1 grads. + assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) + return -lr * assumed_scale * g + + +def muon_core_step(group, state, grad): + lr = group["lr"] + eps = group["eps"] + step = state["step"] + beta1_ceil = 1. - 1. / (10. + 0.2 * step) + beta1 = min(group["beta1"], beta1_ceil) + beta2_ceil = step / (step + 1) + beta2 = min(group["beta2"], beta2_ceil) + + orig_shape = grad.shape + batch_size = orig_shape[0] + rows, cols = matrix_shape(orig_shape[1:]) + grad = grad.reshape(batch_size, rows, cols) + + if "momentum_buffer" not in state: + assert step < 2 + state["momentum_buffer"] = torch.zeros(batch_size, rows, cols, device=grad.device, dtype=COMPUTE_DTYPE) + if rows > cols: + state["second_momentum_buffer"] = torch.ones(batch_size, rows, 1, device=grad.device, dtype=torch.float) + else: + state["second_momentum_buffer"] = torch.ones(batch_size, 1, cols, device=grad.device, dtype=torch.float) + + + momentum_buffer = state["momentum_buffer"] + second_momentum_buffer = state["second_momentum_buffer"] + + if momentum_buffer.dtype == torch.float: # Error due to loading state dict; TODO put this in load_state_dict() + momentum_buffer = momentum_buffer.to(COMPUTE_DTYPE) + state["momentum_buffer"] = momentum_buffer + + def t(x): + return torch.tensor(x, device=grad.device, dtype=COMPUTE_DTYPE) + + + grad = grad.to(COMPUTE_DTYPE) if grad.dtype != COMPUTE_DTYPE else grad.clone() + step = muon_step_fused(grad, momentum_buffer, second_momentum_buffer, + t(beta1), t(lr), t(beta2), t(eps), 5, (-1 if rows > cols else -2)) + + return step.reshape(orig_shape) + + +def scaling_step(group, param, state, grad): + # we reach here for biases and weights but not scalars. + # This does three things things: + # (i) multiply the step from "cubic_decay" by an estimate of the parameter scale + # (ii) apply parameter decay + # (iii) update the parameter scale, which means shrinking or growing the whole tensor + lr = group["lr"] + momentum = group["scale_momentum"] # e.g. 0.95 + min_scale, max_scale = group["scale_limits"] + # the scaling factor is implicitly a scalar; apply scalar_scale to its + # learning rate. + scalar_scale = group["scalar_scale"] + + if grad.ndim >= 2 and grad.numel() != grad.shape[0] * max(grad.shape[1:]): + delta = muon_core_step(group, state, grad) + else: + # biases and similar-shaped tensors + delta = adam_step(group, state, grad) + + dims = list(range(1, param.ndim)) + + try: + scale = state["scale"] + scale_grad_buf = state["scale_grad_buffer"] + except KeyError: + scale = (param ** 2).mean(dim=dims, keepdim=True).sqrt().clamp( + min=min_scale, max=max_scale).to(torch.float) + scale_grad_buf = torch.zeros_like(scale) + state["scale"] = scale + state["scale_grad_buffer"] = scale_grad_buf + + scale_grad = (grad * param.detach()).sum(dim=dims, keepdim=True) + scale_grad_buf.mul_(momentum).add_(scale_grad) # simple momentum + + nesterov = True + if nesterov: + # simple interpretation of nesterov: do an extra step of + # moving-average on scale_grad_buf, with scale_grad, like double-counting + # it. + negative_update = (scale_grad_buf * momentum + scale_grad).sign() + else: + negative_update = scale_grad_buf.sign() + + old_scale = scale.clone() + + scale.mul_(1. - lr * scalar_scale * negative_update) + scale.clamp_(min=min_scale, max=max_scale) + + scale_ratio = scale / old_scale + + delta_scale = (scale_ratio * (1 - 0.5 * (lr ** 2))) - 1 + return param * delta_scale + scale * delta + + +def adam_step(group, state, grad): + # this is the adam update but with a slight modification / simplification on + # how "bias correction" (startup on small step counts) is dealt with. + lr = group["lr"] + step = state["step"] + eps = group["eps"] + beta1 = group["adam_beta1"] + # the following modification to beta2 makes it unnecessary to do bias correction; + # for small step values, we are just computing the mean over the steps so far + beta2 = min(group["adam_beta2"], step / (step + 1)) + + try: + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + except KeyError as e: + assert step < 2 + exp_avg = torch.zeros_like(grad) + exp_avg_sq = torch.zeros_like(grad) + state["exp_avg"] = exp_avg + state["exp_avg_sq"] = exp_avg_sq + + exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + denom = exp_avg_sq.sqrt() + eps + + nesterov = True + if nesterov: + # this is similar to double-counting grad + moving_grad = exp_avg * beta1 + grad * (1-beta1) + else: + moving_grad = exp_avg + + return -lr * (moving_grad / denom) + + +class BatchedRubik(BatchedOptimizer): + """ + Implements a batched version of the Rubik optimizer. + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + Unlike common optimizers, which accept model.parameters() or groups of parameters(), + this optimizer could accept model.named_parameters() or groups of named_parameters(). + See comments of function _get_names_of_parameters for its 4 possible cases. + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + beta2: beta2 is the momentum constant for moving-grad-squared as in Adam. + Must satisfy 0 < beta <= beta2 < 1. + betas: a list of decay constants for momentum on the parameter-change + scales: a list of scales corresponding to each of the betas, that we multiply + each momentum-update by. Implicitly there is also a beta=0, scale=1, + i.e. a non-decayed update. + scaling_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each non-scalar parameter tensor. If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + scale_decay: A constant similar to the weight_decay of AdamW, that applies on the scaling + factors, decaying them in log-space to scale_default. + scale_default: A constant that dictates the RMS value to which weight magnitudes decay. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update scalar tensors. + eps: A general-purpose epsilon to prevent division by zero +grad_aggregation: if None, no grad aggregation is done here (assume it is done in DDP if relevant); + set it to torch.distributed.ReduceOp.AVG or torch.distributed.ReduceOp.SUM to have it done by this class. + """ + def __init__( + self, + params, + lr=1.2e-02, + beta1=0.99, + beta2=0.98, + eps=1.0e-08, + scale_limits=(0.03, 0.15), + scalar_scale=0.05, + adam_beta1=0.98, + adam_beta2=0.98, + scale_momentum=0.95, + grad_aggregation=None, + ): + self.grad_aggregation = grad_aggregation + defaults = dict( + lr=lr, + beta1=beta1, + beta2=beta2, + eps=eps, + scale_limits=scale_limits, + scalar_scale=scalar_scale, + adam_beta1=adam_beta1, + adam_beta2=adam_beta2, + scale_momentum=scale_momentum, + ) + + super(BatchedRubik, self).__init__(params, defaults) + + def __setstate__(self, state): + super(BatchedRubik, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group in self.param_groups: + with self.batched_params(group["params"]) as batches: + for p, state in batches: + grad = p.grad + + if self.grad_aggregation is not None and dist.is_initialized(): + # sync grads. + dist.all_reduce(grad, op=self.grad_aggregation) + + + try: + cur_step = state["step"] + except KeyError: + state["step"] = 0 + cur_step = 0 + + if p.numel() == p.shape[0]: + # "scalar_scale" the assumed parameter scale used for + # scalars, in this case it just acts as a multiplier on + # the learning rate. + p += group["scalar_scale"] * adam_step(group, state, grad) + else: + p += scaling_step(group, p.detach(), state, grad) + + state["step"] = cur_step + 1 + + return loss + + + +def _test_batched_rubik(hidden_dim: int): + import timeit + + E = 100 + B = 4 + T = 2 + logging.info("in test_batched_rubik") + device = torch.device('cuda') + #device = torch.device("cpu") + dtype = torch.float32 + + torch.random.manual_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + if True: + Linear = torch.nn.Linear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + lr = 0.018 + optim = BatchedRubik(m.parameters(), lr=lr, beta1=0.998) + + num_epochs = 180 + + total_steps = num_epochs + def lr_lambda(current_step): + # a LR schedule similar to InterpCosineLRScheduler from combined_scheduler.py + progress = min(1, current_step / total_steps) + cos = math.cos(progress * math.pi / 2) + # the relatively small scale on cos means the linear cool-down phase + # is long/slow, as the loss of this easy task is dominated by + # parameter noise.. in practical scenarios we use larger scale on + # the cos term, e.g. as large as 0.66. + return 0.05 * cos + 0.95 * (cos ** 2) + + scheduler = LambdaLR(optim, lr_lambda) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + # if epoch == 100 and test in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 512 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + #scheduler.step_batch() + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + norm3 = '%.2e' % (m[4].weight**2).mean().sqrt().item() + + bias_norm1 = '%.2e' % (m[0].bias**2).mean().sqrt().item() + bias_norm2 = '%.2e' % (m[2].bias**2).mean().sqrt().item() + bias_norm3 = '%.2e' % (m[4].bias**2).mean().sqrt().item() + + lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm2,norm3}, bias_norms={bias_norm1,bias_norm2,bias_norm3}" + ) + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step() # step once per epoch + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Time taken: {stop - start}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + + +def _test_scaled_three_way_product(): + x = torch.randn(3, 16, 32) + _U, _S, V = torch.linalg.svd(x, full_matrices=False) + W = V * torch.randn(3, 1, 1) + # so now all the singular values of x will be identical (but arbitrary) + + X = scaled_three_way_product(W) + #print("X = ", X[0]) + #print("W = ", W[0]) + assert torch.allclose(W, X, atol=1.0e-03) + # but the result won't be identical to the input if the singular values are not all identical. + assert not torch.allclose(x, scaled_three_way_product(x), atol=1.0e-03) + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_scaled_three_way_product() + _test_batched_rubik(hidden_dim) diff --git a/egs/librispeech/ASR/zapformer/beam_search.py b/egs/librispeech/ASR/zapformer/beam_search.py new file mode 100644 index 0000000000..66c84b2a94 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/beam_search.py @@ -0,0 +1,3183 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Union + +import k2 +import sentencepiece as spm +import torch +from torch import nn + +from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost +from icefall.decode import Nbest, one_best_decoding +from icefall.lm_wrapper import LmScorer +from icefall.rnn_lm.model import RnnLmModel +from icefall.transformer_lm.model import TransformerLM +from icefall.utils import ( + DecodingResults, + KeywordResult, + add_eos, + add_sos, + get_texts, + get_texts_with_timestamp, +) + + +def fast_beam_search_one_best( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, + ilme_scale: float = 0.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ilme_scale=ilme_scale, + allow_partial=allow_partial, + blank_penalty=blank_penalty, + ) + + best_path = one_best_decoding(lattice) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest_LG( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, + temperature: float = 1.0, + blank_penalty: float = 0.0, + ilme_scale: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + allow_partial=allow_partial, + blank_penalty=blank_penalty, + ilme_scale=ilme_scale, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + path_to_utt_map = nbest.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, + log_semiring=True, # Note: we always use True + ) + # See https://github.com/k2-fsa/icefall/pull/420 for why + # we always use log_semiring=True + + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + blank_penalty=blank_penalty, + temperature=temperature, + allow_partial=allow_partial, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest_oracle( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + ref_texts: List[List[int]], + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + we select `num_paths` linear paths from the lattice. The path + that has the minimum edit distance with the given reference transcript + is used as the output. + + This is the best result we can achieve for any nbest based rescoring + methods. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + ref_texts: + A list-of-list of integers containing the reference transcripts. + If the decoding_graph is a trivial_graph, the integer ID is the + BPE token ID. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + allow_partial=allow_partial, + blank_penalty=blank_penalty, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + hyps = nbest.build_levenshtein_graphs() + refs = k2.levenshtein_graph(ref_texts, device=hyps.device) + + levenshtein_alignment = k2.levenshtein_alignment( + refs=refs, + hyps=hyps, + hyp_to_ref_map=nbest.shape.row_ids(1), + sorted_match_ref=True, + ) + + tot_scores = levenshtein_alignment.get_tot_scores( + use_double_scores=False, log_semiring=False + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + + max_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, + subtract_ilme: bool = False, + ilme_scale: float = 0.1, + allow_partial: bool = False, + blank_penalty: float = 0.0, +) -> k2.Fsa: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + Returns: + Return an FsaVec with axes [utt][state][arc] containing the decoded + lattice. Note: When the input graph is a TrivialGraph, the returned + lattice is actually an acceptor. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + log_probs = (logits / temperature).log_softmax(dim=-1) + + if ilme_scale != 0: + ilme_logits = model.joiner( + torch.zeros_like( + current_encoder_out, device=current_encoder_out.device + ).unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + ilme_logits = ilme_logits.squeeze(1).squeeze(1) + if blank_penalty != 0: + ilme_logits[:, 0] -= blank_penalty + ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) + log_probs -= ilme_scale * ilme_log_probs + + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output( + encoder_out_lens.tolist(), allow_partial=allow_partial + ) + + return lattice + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + max_sym_per_frame: int, + blank_penalty: float = 0.0, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """Greedy search for a single utterance. + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + max_sym_per_frame: + Maximum number of symbols per frame. If it is set to 0, the WER + would be 100%. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + unk_id = getattr(model, "unk_id", blank_id) + + device = next(model.parameters()).device + + decoder_input = torch.tensor( + [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + hyp = [blank_id] * context_size + + # timestamp[i] is the frame index after subsampling + # on which hyp[i] is decoded + timestamp = [] + + # Maximum symbols per utterance. + max_sym_per_utt = 1000 + + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + + while t < T and sym_per_utt < max_sym_per_utt: + if sym_per_frame >= max_sym_per_frame: + sym_per_frame = 0 + t += 1 + continue + + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits is (1, 1, 1, vocab_size) + + if blank_penalty != 0: + logits[:, :, :, 0] -= blank_penalty + + y = logits.argmax().item() + if y not in (blank_id, unk_id): + hyp.append(y) + timestamp.append(t) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sym_per_utt += 1 + sym_per_frame += 1 + else: + sym_per_frame = 0 + t += 1 + hyp = hyp[context_size:] # remove blanks + + if not return_timestamps: + return hyp + else: + return DecodingResults( + hyps=[hyp], + timestamps=[timestamp], + ) + + +def greedy_search_batch( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + blank_penalty: float = 0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = next(model.parameters()).device + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] + + # timestamp[n][i] is the frame index after subsampling + # on which hyp[n][i] is decoded + timestamps = [[] for _ in range(N)] + # scores[n][i] is the logits on which hyp[n][i] is decoded + scores = [[] for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out: (N, 1, decoder_out_dim) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v not in (blank_id, unk_id): + hyps[i].append(v) + timestamps[i].append(t) + scores[i].append(logits[i, v].item()) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + ans_timestamps = [] + ans_scores = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(timestamps[unsorted_indices[i]]) + ans_scores.append(scores[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + scores=ans_scores, + ) + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor + + ac_probs: Optional[List[float]] = None + + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] = field(default_factory=list) + + # the lm score for next token given the current ys + lm_score: Optional[torch.Tensor] = None + + # the RNNLM states (h and c in LSTM) + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # N-gram LM state + state_cost: Optional[NgramLmStateCost] = None + + # Context graph state + context_state: Optional[ContextState] = None + + num_tailing_blanks: int = 0 + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": + """Return the top-k hypothesis. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + """ + hyps = list(self._data.items()) + + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) + + +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def keywords_search( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + keywords_graph: ContextGraph, + beam: int = 4, + num_tailing_blanks: int = 0, + blank_penalty: float = 0, +) -> List[List[KeywordResult]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + keywords_graph: + A instance of ContextGraph containing keywords and their configurations. + beam: + Number of active paths during the beam search. + num_tailing_blanks: + The number of tailing blanks a keyword should be followed, this is for the + scenario that a keyword will be the prefix of another. In most cases, you + can just set it to 0. + blank_penalty: + The score used to penalize blank probability. + Returns: + Return a list of list of KeywordResult. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert keywords_graph is not None + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=keywords_graph.root, + timestamp=[], + ac_probs=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + sorted_ans = [[] for _ in range(N)] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + probs = logits.softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs = probs.log() + + probs = probs.reshape(-1) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_probs = k2.RaggedTensor(shape=log_probs_shape, value=probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + hyp_probs = ragged_probs[i].tolist() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + new_ac_probs = hyp.ac_probs[:] + context_score = 0 + new_context_state = hyp.context_state + new_num_tailing_blanks = hyp.num_tailing_blanks + 1 + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + new_ac_probs.append(hyp_probs[topk_indexes[k]]) + ( + context_score, + new_context_state, + _, + ) = keywords_graph.forward_one_step(hyp.context_state, new_token) + new_num_tailing_blanks = 0 + if new_context_state.token == -1: # root + new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id] + + new_log_prob = topk_log_probs[k] + context_score + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ac_probs=new_ac_probs, + context_state=new_context_state, + num_tailing_blanks=new_num_tailing_blanks, + ) + B[i].add(new_hyp) + + top_hyp = B[i].get_most_probable(length_norm=True) + matched, matched_state = keywords_graph.is_matched(top_hyp.context_state) + if matched: + ac_prob = ( + sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level + ) + if ( + matched + and top_hyp.num_tailing_blanks > num_tailing_blanks + and ac_prob >= matched_state.ac_threshold + ): + keyword = KeywordResult( + hyps=top_hyp.ys[-matched_state.level :], + timestamps=top_hyp.timestamp[-matched_state.level :], + phrase=matched_state.phrase, + ) + sorted_ans[i].append(keyword) + B[i] = HypothesisList() + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=keywords_graph.root, + timestamp=[], + ac_probs=[], + ) + ) + + B = B + finalized_B + + for i, hyps in enumerate(B): + top_hyp = hyps.get_most_probable(length_norm=True) + matched, matched_state = keywords_graph.is_matched(top_hyp.context_state) + if matched: + ac_prob = ( + sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level + ) + if matched and ac_prob >= matched_state.ac_threshold: + keyword = KeywordResult( + hyps=top_hyp.ys[-matched_state.level :], + timestamps=top_hyp.timestamp[-matched_state.level :], + phrase=matched_state.phrase, + ) + sorted_ans[i].append(keyword) + + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + return ans + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_graph: Optional[ContextGraph] = None, + beam: int = 4, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=None if context_graph is None else context_graph.root, + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + context_score = 0 + new_context_state = None if context_graph is None else hyp.context_state + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + if context_graph is not None: + ( + context_score, + new_context_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + + new_log_prob = topk_log_probs[k] + context_score + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + context_state=new_context_state, + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + finalized_B = [HypothesisList() for _ in range(len(B))] + for i, hyps in enumerate(B): + for hyp in list(hyps): + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + context_score, + timestamp=hyp.timestamp, + context_state=new_context_state, + ) + ) + B = finalized_B + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] + ans = [] + ans_timestamps = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + ) + + +def modified_beam_search_lm_rescore( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + # get the best hyp with different lm_scale + for lm_scale in lm_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}" + tot_scores = am_scores.values + lm_scores * lm_scale + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + +def modified_beam_search_lm_rescore_LODR( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + LODR_lm: NgramLm, + sp: spm.SentencePieceProcessor, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + # now LODR scores + import math + + LODR_scores = [] + for seq in candidate_seqs: + tokens = " ".join(sp.id_to_piece(seq)) + LODR_scores.append(LODR_lm.score(tokens)) + LODR_scores = torch.tensor(LODR_scores).to(device) * math.log( + 10 + ) # arpa scores are 10-based + assert lm_scores.shape == LODR_scores.shape + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + LODR_scale_list = [0.05 * i for i in range(1, 20)] + # get the best hyp with different lm_scale and lodr_scale + for lm_scale in lm_scale_list: + for lodr_scale in LODR_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}" + tot_scores = ( + am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale + ) + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + +def _deprecated_modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + beam: int = 4, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + return_timestamps: + Whether to return timestamps. + + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + device = next(model.parameters()).device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_timestamp = hyp.timestamp[:] + new_token = topk_token_indexes[i] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + if not return_timestamps: + return ys + else: + return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) + + +def beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + beam: int = 4, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """ + It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf + + espnet/nets/beam_search_transducer.py#L247 is used as a reference. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + device = next(model.parameters()).device + + decoder_input = torch.tensor( + [blank_id] * context_size, + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + + B = HypothesisList() + B.add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[] + ) + ) + + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + A = B + B = HypothesisList() + + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() + A.remove(y_star) + + cached_key = y_star.key + + if cached_key not in decoder_cache: + decoder_input = torch.tensor( + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + decoder_cache[cached_key] = decoder_out + else: + decoder_out = decoder_cache[cached_key] + + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner( + current_encoder_out, + decoder_out.unsqueeze(1), + project_input=False, + ) + + if blank_penalty != 0: + logits[:, :, :, 0] -= blank_penalty + + # TODO(fangjun): Scale the blank posterior + log_prob = (logits / temperature).log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol + skip_log_prob = log_prob[blank_id] + new_y_star_log_prob = y_star.log_prob + skip_log_prob + + # ys[:] returns a copy of ys + B.add( + Hypothesis( + ys=y_star.ys[:], + log_prob=new_y_star_log_prob, + timestamp=y_star.timestamp[:], + ) + ) + + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i in (blank_id, unk_id): + continue + new_ys = y_star.ys + [i] + new_log_prob = y_star.log_prob + v + new_timestamp = y_star.timestamp + [t] + A.add( + Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ) + ) + + # Check whether B contains more than "beam" elements more probable + # than the most probable in A + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) + break + + t += 1 + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + if not return_timestamps: + return ys + else: + return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) + + +def fast_beam_search_with_nbest_rescoring( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model. The shortest path within the + lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} + for s in ngram_lm_scale_list: + key = f"ngram_lm_scale_{s}" + tot_scores = am_scores.values + s * ngram_lm_scores + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) + + return ans + + +def fast_beam_search_with_nbest_rnn_rescoring( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + rnn_lm_model: torch.nn.Module, + rnn_lm_scale_list: List[float], + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model and a rnn-lm. + The shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + rnn_lm_model: + A rnn-lm model used for LM rescoring + rnn_lm_scale_list: + A list of floats representing RNN score scales. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + # Now RNN-LM + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("sos_id") + eos_id = sp.piece_to_id("eos_id") + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64) + y_tokens = y_tokens.to(torch.int64) + sentence_lengths = sentence_lengths.to(torch.int64) + + rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) + assert rnn_lm_nll.ndim == 2 + assert rnn_lm_nll.shape[0] == len(token_list) + rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) + + ans: Dict[str, List[List[int]]] = {} + for n_scale in ngram_lm_scale_list: + for rnn_scale in rnn_lm_scale_list: + key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" + tot_scores = ( + am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) + + return ans + + +def modified_beam_search_ngram_rescoring( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ngram_lm: NgramLm, + ngram_lm_scale: float, + beam: int = 4, + temperature: float = 1.0, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + lm_scale = ngram_lm_scale + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state_cost=NgramLmStateCost(ngram_lm), + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [ + hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale + for hyps in A + for hyp in hyps + ] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + vocab_size = log_probs.size(-1) + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + else: + state_cost = hyp.state_cost + + # We only keep AM scores in new_hyp.log_prob + new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, state_cost=state_cost + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +def modified_beam_search_LODR( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LODR_lm: NgramLm, + LODR_lm_scale: float, + LM: LmScorer, + beam: int = 4, + context_graph: Optional[ContextGraph] = None, +) -> List[List[int]]: + """This function implements LODR (https://arxiv.org/abs/2203.16776) with + `modified_beam_search`. It uses a bi-gram language model as the estimate + of the internal language model and subtracts its score during shallow fusion + with an external language model. This implementation uses a RNNLM as the + external language model. + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + LODR_lm: + A low order n-gram LM, whose score will be subtracted during shallow fusion + LODR_lm_scale: + The scale of the LODR_lm + LM: + A neural net LM, e.g an RNNLM or transformer LM + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert LM is not None + lm_scale = LM.lm_scale + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = getattr(LM, "sos_id", 1) + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, # state of the NN LM + lm_score=init_score.reshape(-1), + state_cost=NgramLmStateCost( + LODR_lm + ), # state of the source domain ngram + context_state=None if context_graph is None else context_graph.root, + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + LM will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) + + # forward NN LM to get new states and scores + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + # current score of hyp + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + + context_score = 0 + new_context_state = None if context_graph is None else hyp.context_state + if new_token not in (blank_id, unk_id): + if context_graph is not None: + ( + context_score, + new_context_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + + ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + + # calculate the score of the latest token + current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score + + assert current_ngram_score <= 0.0, ( + state_cost.lm_score, + hyp.state_cost.lm_score, + ) + # score = score + TDLM_score - LODR_score + # LODR_LM_scale should be a negative number here + hyp_log_prob += ( + lm_score[new_token] * lm_scale + + LODR_lm_scale * current_ngram_score + + context_score + ) # add the lm score + + lm_score = scores[count] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + else: + state_cost = hyp.state_cost + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + state_cost=state_cost, + context_state=new_context_state, + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + finalized_B = [HypothesisList() for _ in range(len(B))] + for i, hyps in enumerate(B): + for hyp in list(hyps): + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + context_score, + timestamp=hyp.timestamp, + context_state=new_context_state, + ) + ) + B = finalized_B + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +def modified_beam_search_lm_shallow_fusion( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + beam: int = 4, + return_timestamps: bool = False, +) -> List[List[int]]: + """Modified_beam_search + NN LM shallow fusion + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + LM (LmScorer): + A neural net LM, e.g RNN or Transformer + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert LM is not None + lm_scale = LM.lm_scale + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = getattr(LM, "sos_id", 1) + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, + lm_score=init_score.reshape(-1), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + lm_scores = torch.cat( + [hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + `LM` will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] # a list of list + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) + + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + ys.append(new_token) + new_timestamp.append(t) + + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score + + lm_score = scores[count] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + timestamp=new_timestamp, + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] + ans = [] + ans_timestamps = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + ) diff --git a/egs/librispeech/ASR/zapformer/combined_scheduler.py b/egs/librispeech/ASR/zapformer/combined_scheduler.py new file mode 100644 index 0000000000..f3eb6a7332 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/combined_scheduler.py @@ -0,0 +1,254 @@ +import torch +from torch import Tensor +from torch.optim import Optimizer +from typing import List +import math +import logging + +class CombinedLRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch; it estimates the "progress" for you based on the epoch you are + in and the estimated progress within the epoch based on + the number of steps within the epoch. The interface is as follows;; suppose + you're using CosineLRScheduler that inherits from this (below). batches_per_epoch + is your best guess at how many batches you will have per epoch; if you get this + wrong there will be a discontinuity in the learning rate as you start the second + epoch. + + num_epochs = 20 + scheduler = CosineLRScheduler(optimizer, batches_per_epoch=2512, num_epochs=num_epochs) + for epoch in range(1, num_epochs + 1): + scheduler.set_epoch(epoch) # caution: one-based epoch count + for batch_idx, batch in enumerate(train_dl): + scheduler.set_batch_idx(batch_idx) + + Args: + optimizer: optimizer that we will set the learning rates in; the initial learning rate(s) in + the optimizer is/are the base LRs and we set the LR as a fraction of those. + batches_per_epoch: the estimated number of batches per epoch; use your best guess. + num_epochs: the total number of epochs you will train for + """ + def __init__(self, + optimizer: Optimizer, + batches_per_epoch: int, + num_epochs: int, + verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.batches_per_epoch = batches_per_epoch + self.num_epochs = num_epochs # the number of epochs we plan to train for. + + self.epoch = -1 + self.batch = -1 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + # the user might try to override the base_lr, so don't include this in the state. + # previously they were included. + # "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + # Caution: storing batches_per_epoch with the state might not necessarily be what you want, + # it's good for interrupted training runs only as long as you continue to train with the + # same world-size. + "batches_per_epoch": self.batches_per_epoch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def set_batch(self, batch: int): + """ Sets the batch index within the epoch, with zero-based counting (not that this matters much).""" + # set the within-epoch batch index. + self.batch = batch + self._set_lrs() + + def set_epoch(self, epoch: int): + """ Sets the epoch with one-based counting, so the first epoch is 1; the epoch should not exceed the num_epochs used + in the constructor. """ + assert epoch > 0 and epoch <= self.num_epochs # Epoch numbers are assumed to be be 1-based indexes. + if epoch == self.epoch + 1 and self.batch > 0: + logging.info(f"Overriding batches_per_epoch from {self.batches_per_epoch} to {self.batch+1} based on observed batch count.") + self.batches_per_epoch = self.batch + 1 + + self.epoch = epoch + self._set_lrs() + + def get_progress(self): + if self.epoch <= 0: + return 0.0 + else: + assert self.epoch <= self.num_epochs + assert self.batches_per_epoch > 0 + whole_epoch_progress = (self.epoch - 1) / self.num_epochs + batch = self.batch + if batch < 0: + partial_epoch_progress = 0 + else: + partial_epoch_progress = min(1.0, batch / self.batches_per_epoch) / self.num_epochs + return whole_epoch_progress + partial_epoch_progress + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.warning( + f"Epoch={self.epoch}, batch={self.batch}, num_epochs={self.num_epochs}, batches_per_epoch={self.batches_per_epoch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class CosineLRScheduler(CombinedLRScheduler): + def __init__(self, + *args, + max_factor: float = 0.95, # it will start the cosine schedule from the point where it would have this, but renormalize so initial factor is 1; + min_factor: float = 0.05, # it will end the cosine schedule at where it's this value divided by max_factor + **kwargs): + """ + Cosine learning rate scheduler that inherits from CombinedLRScheduler (see its documentation + to understand general aspects of usage). + Args: + max_factor, min_factor: The conventional cosine factor goes from 1 to 0 based on the formula: + factor = 0.5 * (1 + cos(pi * progress)). + This scheduler selects the part of that graph from factor=max_factor + to factor=min_factor (imagine cropping the graph by selecting lines + that intersect the y-axis at hose values). It renormalizes so the initial + factor is one by dividing by max_factor; the last factor will actually + be min_factor / max_factor. + """ + super().__init__(*args, **kwargs) + self.max_factor = max_factor + def factor_to_progress(factor): + # inverse function of: factor = 0.5 * (1.0 + math.cos(math.pi * progress)) + cos = 2.0 * factor - 1.0 + return math.acos(cos) / math.pi + + # we'll divide the factors by max_factor in get_lr() after computing the cosine formula, + # so the initial and final factors will actually be 1.0 and min_factor respectively. + self.initial_progress = factor_to_progress(max_factor) + self.final_progress = factor_to_progress(min_factor) + + def get_lr(self): + progress = self.get_progress() + # map progress in [0..1] to a tighter range like [0.15..0.85] + progress = self.initial_progress + (self.final_progress - self.initial_progress) * progress + factor = 0.5 * (1.0 + math.cos(math.pi * progress)) + factor = factor / self.max_factor # make it so the initial factor is 1.0 despite limiting range of progress + return [x * factor for x in self.base_lrs] + + +class InterpCosineLRScheduler(CombinedLRScheduler): + def __init__(self, + *args, + min_factor: float = 0.0, + half_cosine_scale: float = 0.0, + linear_scale: float = 0.0, + **kwargs): + """ + This cosine LR scheduler encompasses the conventional cosine LR scheduler + that takes the cosine from 0 to pi (shifted to 0..1), the half-cosine LR + scheduler that takes the cosine from 0 to pi, and the linear LR scheduler + that takes the linear function from 1 to 0. + """ + self.min_factor = min_factor + self.half_cosine_scale = half_cosine_scale + self.linear_scale = linear_scale + super().__init__(*args, **kwargs) + + def get_lr(self): + progress = self.get_progress() + half_cos = math.cos((math.pi / 2) * progress) + cos = half_cos ** 2 + linear = 1. - progress + + linear_scale = self.linear_scale + half_cosine_scale = self.half_cosine_scale + cosine_scale = 1. - self.half_cosine_scale - linear_scale + assert cosine_scale >= 0.0 + + factor = linear_scale * linear + half_cosine_scale * half_cos + cosine_scale * cos + # apply min_factor via interpolation + factor = self.min_factor + factor * (1. - self.min_factor) + return [x * factor for x in self.base_lrs] + + + +class HalfCosineLRScheduler(CombinedLRScheduler): + def __init__(self, + *args, + min_factor: float = 0.05, + **kwargs): # takes also batches_per_epoch and num_epochs args. + """ + This cosine LR scheduler is the cosine from 0 to pi/2, with no offset of 1. + It inherits from CombinedLRScheduler (see its documentation + to understand general aspects of usage). + """ + self.min_factor = min_factor + super().__init__(*args, **kwargs) + + def get_lr(self): + progress = self.get_progress() + factor = math.cos((math.pi / 2) * progress) + # factor**2 would be the conventional cosine LR scheduler with cosine from 0 to pi, we interpolate + # between the two. + factor = self.min_factor + factor * (1. - self.min_factor) + return [x * factor for x in self.base_lrs] + + +class LinearLRScheduler(CombinedLRScheduler): + def __init__(self, + *args, + min_factor: float = 0.0, + **kwargs): # takes also batches_per_epoch and num_epochs args. + super().__init__(*args, **kwargs) + self.min_factor = min_factor + + def get_lr(self): + progress = self.get_progress() + # initially: factor is constant at 1.0 until progress==self.const_fraction, then decays to 0 + # at the end. + factor = 1.0 - progress + min_factor = self.min_factor + factor = min_factor + (1.0 - self.min_factor) * factor + return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/zapformer/ctc_decode.py b/egs/librispeech/ASR/zapformer/ctc_decode.py new file mode 100755 index 0000000000..dd1ec0c7e0 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/ctc_decode.py @@ -0,0 +1,1311 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-greedy-search +./zapformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-greedy-search + +(2) ctc-decoding +./zapformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(3) 1best +./zapformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method 1best + +(4) nbest +./zapformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method nbest + +(5) nbest-rescoring +./zapformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring + +(6) whole-lattice-rescoring +./zapformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring + +(7) attention-decoder-rescoring-no-ngram +./zapformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --decoding-method attention-decoder-rescoring-no-ngram + +(8) attention-decoder-rescoring-with-ngram +./zapformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method attention-decoder-rescoring-with-ngram +""" + + +import argparse +import logging +import math +import os +import re +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import CommonVoice, LibriSpeech, GigaSpeech, AsrDataModule +from lhotse import set_caching_enabled +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.context_graph import ContextGraph, ContextState +from icefall.decode import ( + ctc_greedy_search, + ctc_prefix_beam_search, + ctc_prefix_beam_search_attention_decoder_rescoring, + ctc_prefix_beam_search_shallow_fussion, + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder_no_ngram, + rescore_with_attention_decoder_with_ngram, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.lm_wrapper import LmScorer +from icefall.ngram_lm import NgramLm, NgramLmStateCost +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def giga_asr_text_post_processing(text: str) -> str: # only used for gigaspeech + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +def giga_post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = giga_asr_text_post_processing(" ".join(ref)).split() + new_hyp = giga_asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def cv_post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + def normalize(text): + return re.sub(r'[^\w\s]', '', text).upper() + new_results = [] + for key, ref, hyp in results: + new_ref = normalize(" ".join(ref)).split() + new_hyp = normalize(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zapformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--giga", + type=str2bool, + default=False, + help="If True, decode gigaspeech in addition to librispeech test sets." + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (3) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (4) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (5) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (6) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + - (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. + - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM + rescored lattice, rescore them with the attention decoder. + - (10) ctc-prefix-beam-search. Extract n paths with the given beam, the best + path of the n paths is the decoding result. + - (11) ctc-prefix-beam-search-attention-decoder-rescoring. Extract n paths with + the given beam, rescore them with the attention decoder. + - (12) ctc-prefix-beam-search-shallow-fusion. Use NNLM shallow fusion during + beam search, LODR and hotwords are also supported in this decoding method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--nnlm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--nnlm-scale", + type=float, + default=0, + help="""The scale of the neural network LM, 0 means don't use nnlm shallow fusion. + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.6, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--lodr-ngram", + type=str, + help="The path to the lodr ngram", + ) + + parser.add_argument( + "--lodr-lm-scale", + type=float, + default=0, + help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.", + ) + + parser.add_argument( + "--context-score", + type=float, + default=0, + help=""" + The bonus score of each token for the context biasing words/phrases. + 0 means don't use contextual biasing. + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + parser.add_argument( + "--cv", + type=str2bool, + default=False, + help="""If True, decode commonvoice in addition to librispeech test sets.""", + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, # for k2 fsa composition + "output_beam": 8, # for k2 fsa composition + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "beam": 4, # for prefix-beam-search + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + device = params.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zapformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)[:2] + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + if params.decoding_method == "ctc-greedy-search": + hyps = ctc_greedy_search(ctc_output, encoder_out_lens) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(hyps) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-greedy-search" + return {key: hyps} + + if params.decoding_method == "ctc-prefix-beam-search": + token_ids = ctc_prefix_beam_search( + ctc_output=ctc_output, encoder_out_lens=encoder_out_lens + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search" + return {key: hyps} + + if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring": + best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output=ctc_output, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + ans = dict() + for a_scale_str, token_ids in best_path_dict.items(): + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + + if params.decoding_method in [ "ctc-prefix-beam-search-shallow-fussion", "ctc-prefix-beam-search-shallow-fusion" ]: + token_ids = ctc_prefix_beam_search_shallow_fussion( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + NNLM=NNLM, + LODR_lm=LODR_lm, + LODR_lm_scale=params.lodr_lm_scale, + context_graph=context_graph, + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search-shallow-fussion" + return {key: hyps} + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} # note: returns words + + if params.decoding_method == "attention-decoder-rescoring-no-ngram": + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + ans = dict() + for a_scale_str, best_path in best_path_dict.items(): + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no-rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} # note: returns BPE tokens + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "attention-decoder-rescoring-with-ngram": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + best_path_dict = rescore_with_attention_decoder_with_ngram( + lattice=rescored_lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) + + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + if params.decoding_method in ( + "attention-decoder-rescoring-with-ngram", + "whole-lattice-rescoring", + ): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + + test_set_wers = dict() + for key, results in results_dict.items(): + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}_{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "ctc-decoding", + "ctc-greedy-search", + "ctc-prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + "attention-decoder-rescoring-no-ngram", + "attention-decoder-rescoring-with-ngram", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + if params.nnlm_scale != 0: + params.suffix += f"_nnlm-scale-{params.nnlm_scale}" + if params.lodr_lm_scale != 0: + params.suffix += f"_lodr-scale-{params.lodr_lm_scale}" + if params.context_score != 0: + params.suffix += f"_context_score-{params.context_score}" + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + params.device = device + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + params.eos_id = 1 + params.sos_id = 1 + + if params.decoding_method in [ + "ctc-decoding", + "ctc-greedy-search", + "ctc-prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", + "attention-decoder-rescoring-no-ngram", + ]: + HLG = None + H = None + if params.decoding_method in [ + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method in [ + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + # only load the neural network LM if required + NNLM = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.nnlm_scale != 0 + ): + NNLM = LmScorer( + lm_type=params.nnlm_type, + params=params, + device=device, + lm_scale=params.nnlm_scale, + ) + NNLM.to(device) + NNLM.eval() + + LODR_lm = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.lodr_lm_scale != 0 + ): + assert os.path.exists( + params.lodr_ngram + ), f"LODR ngram does not exists, given path : {params.lodr_ngram}" + logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}") + LODR_lm = NgramLm( + params.lodr_ngram, + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {LODR_lm.lm.num_states}") + + context_graph = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.context_score != 0 + ): + assert os.path.exists( + params.context_file + ), f"context_file does not exists, given path : {params.context_file}" + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(bpe_model.encode(line.strip())) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + + asr_datamodule = AsrDataModule(args) + test_sets = [] + test_dl = [] + if True: + librispeech = LibriSpeech(args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + dev_clean_dl = asr_datamodule.test_dataloaders(dev_clean_cuts) + dev_other_dl = asr_datamodule.test_dataloaders(dev_other_cuts) + + test_sets += ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl += [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + if args.giga: + gigaspeech = GigaSpeech(args.manifest_dir) + test_cuts = gigaspeech.test_cuts() + dev_cuts = gigaspeech.dev_cuts() + giga_test_dl = asr_datamodule.test_dataloaders(test_cuts) + giga_dev_dl = asr_datamodule.test_dataloaders(dev_cuts) + test_sets += ["giga-dev", "giga-test"] + test_dl += [giga_dev_dl, giga_test_dl] + + if args.cv: + commonvoice = CommonVoice(args.manifest_dir) + test_cuts = commonvoice.test_cuts() + dev_cuts = commonvoice.dev_cuts() + cv_test_dl = asr_datamodule.test_dataloaders(test_cuts) + cv_dev_dl = asr_datamodule.test_dataloaders(dev_cuts) + test_sets += ["cv-dev", "cv-test"] + test_dl += [cv_dev_dl, cv_test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer/decode.py b/egs/librispeech/ASR/zapformer/decode.py new file mode 100755 index 0000000000..d7cb11e752 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/decode.py @@ -0,0 +1,1210 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zapformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zapformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zapformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zapformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zapformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zapformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zapformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zapformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +import re +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import CommonVoice, LibriSpeech, GigaSpeech, AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from lhotse import set_caching_enabled +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def giga_asr_text_post_processing(text: str) -> str: # only used for gigaspeech + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +def giga_post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = giga_asr_text_post_processing(" ".join(ref)).split() + new_hyp = giga_asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def cv_post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + def normalize(text): + return re.sub(r'[^\w\s]', '', text).upper() + new_results = [] + for key, ref, hyp in results: + new_ref = normalize(" ".join(ref)).split() + new_hyp = normalize(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zapformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + parser.add_argument( + "--giga", + type=str2bool, + default=False, + help="""If True, decode gigaspeech in addition to librispeech test sets.""", + ) + + parser.add_argument( + "--cv", + type=str2bool, + default=False, + help="""If True, decode commonvoice in addition to librispeech test sets.""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zapformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)[:2] + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" + if "LG" in params.decoding_method: + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + + return {prefix: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix += f"_beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"_context-score-{params.context_score}" + return {prefix: hyps} + else: + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) + + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + asr_datamodule = AsrDataModule(args) + test_sets = [] + test_dl = [] + if True: # if not args.giga: + librispeech = LibriSpeech(args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + dev_clean_dl = asr_datamodule.test_dataloaders(dev_clean_cuts) + dev_other_dl = asr_datamodule.test_dataloaders(dev_other_cuts) + + test_sets += ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl += [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + if args.giga: + gigaspeech = GigaSpeech(args.manifest_dir) + test_cuts = gigaspeech.test_cuts() + dev_cuts = gigaspeech.dev_cuts() + giga_test_dl = asr_datamodule.test_dataloaders(test_cuts) + giga_dev_dl = asr_datamodule.test_dataloaders(dev_cuts) + test_sets += ["giga-dev", "giga-test"] + test_dl += [giga_dev_dl, giga_test_dl] + + if args.cv: + commonvoice = CommonVoice(args.manifest_dir) + test_cuts = commonvoice.test_cuts() + dev_cuts = commonvoice.dev_cuts() + cv_test_dl = asr_datamodule.test_dataloaders(test_cuts) + cv_dev_dl = asr_datamodule.test_dataloaders(dev_cuts) + test_sets += ["cv-dev", "cv-test"] + test_dl += [cv_dev_dl, cv_test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer/decode_stream.py b/egs/librispeech/ASR/zapformer/decode_stream.py new file mode 100644 index 0000000000..a1bf671bf5 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/decode_stream.py @@ -0,0 +1,147 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch +from beam_search import Hypothesis, HypothesisList + +from icefall.utils import AttributeDict + + +class DecodeStream(object): + def __init__( + self, + params: AttributeDict, + cut_id: str, + initial_states: List[torch.Tensor], + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + """ + Args: + initial_states: + Initial decode states of the model, e.g. the return value of + `get_init_state` in conformer.py + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Used only when decoding_method is fast_beam_search. + device: + The device to run this stream. + """ + if params.decoding_method == "fast_beam_search": + assert decoding_graph is not None + assert device == decoding_graph.device + + self.params = params + self.cut_id = cut_id + self.LOG_EPS = math.log(1e-10) + + self.states = initial_states + + # It contains a 2-D tensors representing the feature frames. + self.features: torch.Tensor = None + + self.num_frames: int = 0 + # how many frames have been processed. (before subsampling). + # we only modify this value in `func:get_feature_frames`. + self.num_processed_frames: int = 0 + + self._done: bool = False + + # The transcript of current utterance. + self.ground_truth: str = "" + + # The decoding result (partial or final) of current utterance. + self.hyp: List = [] + + # how many frames have been processed, at encoder output + self.done_frames: int = 0 + + # The encoder_embed subsample features (T - 7) // 2 + self.pad_length = 7 + + if params.decoding_method == "greedy_search": + self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] + elif params.decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[-1] * (params.context_size - 1) + [params.blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + elif params.decoding_method == "fast_beam_search": + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + @property + def done(self) -> bool: + """Return True if all the features are processed.""" + return self._done + + @property + def id(self) -> str: + return self.cut_id + + def set_features( + self, + features: torch.Tensor, + tail_pad_len: int = 0, + ) -> None: + """Set features tensor of current utterance.""" + assert features.dim() == 2, features.dim() + self.features = torch.nn.functional.pad( + features, + (0, 0, 0, self.pad_length + tail_pad_len), + mode="constant", + value=self.LOG_EPS, + ) + self.num_frames = self.features.size(0) + + def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: + """Consume chunk_size frames of features""" + chunk_length = chunk_size + self.pad_length + + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + + ret_features = self.features[ + self.num_processed_frames : self.num_processed_frames + ret_length # noqa + ] + + self.num_processed_frames += chunk_size + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_features, ret_length + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.params.decoding_method == "greedy_search": + return self.hyp[self.params.context_size :] # noqa + elif self.params.decoding_method == "modified_beam_search": + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.params.context_size :] # noqa + else: + assert self.params.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/librispeech/ASR/zapformer/decoder.py b/egs/librispeech/ASR/zapformer/decoder.py new file mode 100644 index 0000000000..fc6aec95e6 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/decoder.py @@ -0,0 +1,113 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + decoder_dim: + Dimension of the input embedding, and of the decoder output. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + + self.embedding = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=decoder_dim, + ) + with torch.no_grad(): + # and we will scale by 10 in forward. this is because with an optimizer that has weight decay, + # it's best if all the parameters have fairly similar dynamic range. + self.embedding.weight[:] *= 0.1 + + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + self.vocab_size = vocab_size + + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=decoder_dim, + out_channels=decoder_dim, + kernel_size=context_size, + padding=0, + groups=decoder_dim // 4, # group size == 4 + bias=False, + ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + y = y.to(torch.int64) + # this stuff about clamp() is a temporary fix for a mismatch + # at utterance start, we use negative ids in beam_search.py + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) * 20.0 + + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = F.relu(embedding_out) + + return embedding_out diff --git a/egs/librispeech/ASR/zapformer/encoder_interface.py b/egs/librispeech/ASR/zapformer/encoder_interface.py new file mode 100644 index 0000000000..257facce4f --- /dev/null +++ b/egs/librispeech/ASR/zapformer/encoder_interface.py @@ -0,0 +1,43 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.nn as nn + + +class EncoderInterface(nn.Module): + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (batch_size, input_seq_len, num_features) + containing the input features. + x_lens: + A tensor of shape (batch_size,) containing the number of frames + in `x` before padding. + Returns: + Return a tuple containing two tensors: + - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) + containing unnormalized probabilities, i.e., the output of a + linear layer. + - encoder_out_lens, a tensor of shape (batch_size,) containing + the number of frames in `encoder_out` before padding. + """ + raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/librispeech/ASR/zapformer/export-onnx-streaming.py b/egs/librispeech/ASR/zapformer/export-onnx-streaming.py new file mode 100755 index 0000000000..1a4e9bed37 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/export-onnx-streaming.py @@ -0,0 +1,783 @@ +#!/usr/bin/env python3 +# +# Copyright 2023-2026 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a streaming transducer model from PyTorch to ONNX. + +Usage: + +cd egs/librispeech/ASR + +./zapformer/export-onnx-streaming.py \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 9 \ + --avg 2 \ + --exp-dir zapformer/exp \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 128 + +It will generate the following 3 files inside exp-dir: + + - encoder-epoch-9-avg-2-chunk-32-left-128.onnx + - decoder-epoch-9-avg-2-chunk-32-left-128.onnx + - joiner-epoch-9-avg-2-chunk-32-left-128.onnx +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from train import add_model_arguments, get_model, get_params +from zapformer import Zapformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--dynamic-batch", + type=int, + default=1, + help="1 to support dynamic batch size. 0 to support only batch size == 1", + ) + + parser.add_argument( + "--enable-int8-quantization", + type=int, + default=1, + help="1 to also export int8 onnx models.", + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zapformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + + add_model_arguments(parser) + + return parser + + +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zapformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zapformer, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + def forward( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + N = x.size(0) + T = self.chunk_size * 2 + 7 + x_lens = torch.tensor([T] * N, device=x.device) + left_context_len = self.left_context_len + + embed_cache = states[-2] + x, x_lens, new_embed_cache = self.encoder_embed.streaming_forward( + x=x, + x_lens=x_lens, + cache=embed_cache, + ) + assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) + + src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool) + + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + new_processed_lens = processed_lens + x_lens + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) + encoder_caches = states[:-2] + logging.info(f"len_encoder_caches={len(encoder_caches)}") + ( + encoder_out, + encoder_out_lens, + new_encoder_caches, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + caches=encoder_caches, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + + new_states = new_encoder_caches + [ + new_embed_cache, + new_processed_lens, + ] + + return encoder_out, new_states + + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, + states[i*9:(i+1)*9] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len, cached_attn_wm_sum, + cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs). + states[-1] is processed_lens of shape (batch,). + """ + states = self.encoder.get_init_caches(batch_size, device) + + embed_cache = self.encoder_embed.get_init_cache(batch_size, device) + states.append(embed_cache) + + processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device) + states.append(processed_lens) + + return states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + logit = encoder_out + decoder_out + logit = 2.0 * self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, + feature_dim: int = 80, + dynamic_batch: bool = True, +) -> None: + encoder_model.encoder.__class__.forward = ( + encoder_model.encoder.__class__.streaming_forward + ) + + decode_chunk_len = encoder_model.chunk_size * 2 + T = decode_chunk_len + 7 + + x = torch.rand(1, T, feature_dim, dtype=torch.float32) + init_state = encoder_model.get_init_states() + logging.info(f"len(init_state): {len(init_state)}") + + # Warm up angular freq bases for tracing + left_context_len = encoder_model.left_context_len + ds_factors = encoder_model.encoder.downsampling_factor + max_seq_len = left_context_len + encoder_model.chunk_size + encoder_model.encoder.warmup_angular_freq_bases( + seq_len=max_seq_len, left_context_len=left_context_len, device=x.device + ) + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + # Count total number of layers across all encoder stacks + total_layers = sum(encoder_model.encoder.num_encoder_layers) + logging.info(f"total encoder layers: {total_layers}") + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 9, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + logging.info(f"{name}.shape: {tensors[0].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_value_{i}" + logging.info(f"{name}.shape: {tensors[1].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv_{i}" + logging.info(f"{name}.shape: {tensors[2].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # cached_norm_stats: (batch_size,) + name = f"cached_norm_stats_{i}" + logging.info(f"{name}.shape: {tensors[3].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # cached_norm_len: (batch_size,) + name = f"cached_norm_len_{i}" + logging.info(f"{name}.shape: {tensors[4].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # cached_attn_wm_sum: (1, batch_size, attn_value_dim) + name = f"cached_attn_wm_sum_{i}" + logging.info(f"{name}.shape: {tensors[5].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # cached_attn_wm_num_frames: (batch_size,) + name = f"cached_attn_wm_num_frames_{i}" + logging.info(f"{name}.shape: {tensors[6].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # cached_conv_wm_sum: (1, batch_size, embed_dim) + name = f"cached_conv_wm_sum_{i}" + logging.info(f"{name}.shape: {tensors[7].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # cached_conv_wm_num_frames: (batch_size,) + name = f"cached_conv_wm_num_frames_{i}" + logging.info(f"{name}.shape: {tensors[8].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + num_encoder_layers = encoder_model.encoder.num_encoder_layers + encoder_dims = encoder_model.encoder.encoder_dim + conv_params = encoder_model.encoder.conv_params + ds = encoder_model.encoder.downsampling_factor + left_context_len_per_stack = [left_context_len // k for k in ds] + query_head_dims = encoder_model.encoder.query_head_dim + value_head_dims = encoder_model.encoder.value_head_dim + num_heads = encoder_model.encoder.num_heads + + meta_data = { + "model_type": "zapformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "streaming zapformer", + "decode_chunk_len": str(decode_chunk_len), + "T": str(T), + "num_encoder_layers": ",".join(map(str, num_encoder_layers)), + "encoder_dims": ",".join(map(str, encoder_dims)), + "conv_params": ",".join(map(str, conv_params)), + "left_context_len": ",".join(map(str, left_context_len_per_stack)), + "query_head_dims": ",".join(map(str, query_head_dims)), + "value_head_dims": ",".join(map(str, value_head_dims)), + "num_heads": ",".join(map(str, num_heads)), + } + + logging.info(f"meta_data: {meta_data}") + + # 9 tensors per layer + for i in range(len(init_state[:-2]) // 9): + build_inputs_outputs(init_state[i * 9 : (i + 1) * 9], i) + + # (batch_size, channels, left_pad, freq) + embed_cache = init_state[-2] + name = "embed_cache" + logging.info(f"{name}.shape: {embed_cache.shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size,) + processed_lens = init_state[-1] + name = "processed_lens" + logging.info(f"{name}.shape: {processed_lens.shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + logging.info(f"input_names: {input_names}") + logging.info(f"output_names: {output_names}") + + torch.onnx.export( + encoder_model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "encoder_out": {0: "N"}, + **inputs, + **outputs, + } + if dynamic_batch + else {}, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, + dynamic_batch: bool = True, +) -> None: + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(1, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + } + if dynamic_batch + else {}, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, + dynamic_batch: bool = True, +) -> None: + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + } + if dynamic_batch + else {}, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + suffix += f"-chunk-{params.chunk_size}" + suffix += f"-left-{params.left_context_frames}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + str(encoder_filename), + opset_version=opset_version, + feature_dim=params.feature_dim, + dynamic_batch=params.dynamic_batch == 1, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + str(decoder_filename), + opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + str(joiner_filename), + opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + if params.fp16: + logging.info("Generate fp16 models") + + encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" + export_onnx_fp16(encoder_filename, encoder_filename_fp16) + + decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" + export_onnx_fp16(decoder_filename, decoder_filename_fp16) + + joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" + export_onnx_fp16(joiner_filename, joiner_filename_fp16) + + # Generate int8 quantization models + if params.enable_int8_quantization: + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/export-onnx.py b/egs/librispeech/ASR/zapformer/export-onnx.py new file mode 100755 index 0000000000..3823c66e20 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/export-onnx.py @@ -0,0 +1,647 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2026 Xiaomi Corporation (Author: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a transducer model from PyTorch to ONNX. + +Usage: + +cd egs/librispeech/ASR + +./zapformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 13 \ + --avg 2 \ + --exp-dir zapformer/exp \ + --fp16 True +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-13-avg-2.onnx + - decoder-epoch-13-avg-2.onnx + - joiner-epoch-13-avg-2.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from train import add_model_arguments, get_model, get_params +from zapformer import Zapformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zapformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zapformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zapformer, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + """ + Args: + encoder: + A Zapformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zapformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]) + x = x.permute(1, 0, 2) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + # see comment in joiner.py for the scale of 2.0 + logit = 2.0 * self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + # Use a large dummy input (3000 fbank frames ≈ 30s audio) so that the + # relative position basis baked into the ONNX graph is large enough for + # any input encountered at inference time. The basis becomes a constant + # in the ONNX graph, and the GatherElements indices at runtime must not + # exceed its size. + T_max = 3000 + x = torch.zeros(1, T_max, 80, dtype=torch.float32) + x_lens = torch.tensor([T_max], dtype=torch.int64) + + # Pre-compute angular frequency bases so tracing uses cached values + # instead of recomputing with varying constants per layer. + # After Conv2dSubsampling, T_max → ~(T_max-7)//2 ≈ 1496 frames. + # Each encoder stack further downsamples, so the max seq_len seen by + # any stack is ~1496. We use 1500 to be safe. + encoder_model.encoder.warmup_angular_freq_bases( + seq_len=1500, left_context_len=0, device=x.device + ) + + import traceback + + try: + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=True, + enable_onnx_checker=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + except Exception as e: + logging.error(f"Failed to export the encoder model to ONNX: {e}") + logging.error(traceback.format_exc()) + raise e + + meta_data = { + "model_type": "zapformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming zapformer", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + if params.fp16: + logging.info("Generate fp16 models") + + encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" + export_onnx_fp16(encoder_filename, encoder_filename_fp16) + + decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" + export_onnx_fp16(decoder_filename, decoder_filename_fp16) + + joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" + export_onnx_fp16(joiner_filename, joiner_filename_fp16) + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/export.py b/egs/librispeech/ASR/zapformer/export.py new file mode 100755 index 0000000000..b1bda25d10 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/export.py @@ -0,0 +1,522 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2026 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zapformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zapformer/decode.py \ + --exp-dir ./zapformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +- For streaming model: + +To use the generated file with `zapformer/decode.py` and `zapformer/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + + # simulated streaming decoding + ./zapformer/decode.py \ + --exp-dir ./zapformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + + # chunk-wise streaming decoding + ./zapformer/streaming_decode.py \ + --exp-dir ./zapformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer + +- streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zapformer + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zapformer + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zapformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed (streaming)""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors. + states[:-2] are the encoder caches (9 tensors per layer). + states[-2] is the cached left padding for ConvNeXt module. + states[-1] is processed_lens of shape (batch,). + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed = states[-2] + x, x_lens, new_cached_embed = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cache=cached_embed, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_caches = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_caches, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + caches=encoder_caches, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_caches + [ + new_cached_embed, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, + states[i*9:(i+1)*9] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len, cached_attn_wm_sum, + cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs). + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_caches(batch_size, device) + + embed_cache = self.encoder_embed.get_init_cache(batch_size, device) + states.append(embed_cache) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.sos_id = params.eos_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained.py b/egs/librispeech/ASR/zapformer/jit_pretrained.py new file mode 100755 index 0000000000..201204b7a4 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/jit_pretrained.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# Copyright 2021-2026 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +./zipformer/jit_pretrained.py \ + --nn-model-filename ./zipformer/exp/cpu_jit.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = model.decoder.blank_id + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + logging.info("Constructing Fbank computer") + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = [] + for w in waves: + feat = torchaudio.compliance.kaldi.fbank( + w.unsqueeze(0), + num_mel_bins=80, + sample_frequency=16000, + dither=0, + snip_edges=False, + high_freq=-400, + ) # (num_frames, 80) + features.append(feat.to(device)) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + features=features, + feature_lengths=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + + s = "\n" + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + for filename, hyp in zip(args.sound_files, hyps): + words = token_ids_to_words(hyp) + s += f"{filename}:\n{words}\n" + + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py new file mode 100755 index 0000000000..1430b97109 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/jit_pretrained_ctc.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +- For non-streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +- For streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +(1) ctc-decoding +./zapformer/jit_pretrained_ctc.py \ + --model-filename ./zapformer/exp/jit_script.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./zapformer/jit_pretrained_ctc.py \ + --model-filename ./zapformer/exp/jit_script.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./zapformer/jit_pretrained_ctc.py \ + --model-filename ./zapformer/exp/jit_script.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) whole-lattice-rescoring +./zapformer/jit_pretrained_ctc.py \ + --model-filename ./zapformer/exp/jit_script.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from ctc_decode import get_decoding_params +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the torchscript model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a token table, + i.e., lang_dir/token.txt, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + nbest n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + whole-lattice n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.model_filename) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(features, feature_lengths) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + batch_size = ctc_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i].item() // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + max_token_id = params.vocab_size - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = [[token_table[i] for i in ids] for ids in token_ids] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + if params.method == "ctc-decoding": + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py b/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py new file mode 100755 index 0000000000..705a964a8d --- /dev/null +++ b/egs/librispeech/ASR/zapformer/jit_pretrained_streaming.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +./zapformer/jit_pretrained_streaming.py \ + --nn-model-filename ./zapformer/exp-causal/jit_script_chunk_16_left_128.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + /path/to/foo.wav \ +""" + +import argparse +import logging +from typing import List, Optional + +import k2 +import torch +import torchaudio + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model jit_script.pt", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, + device: torch.device = torch.device("cpu"), +): + assert encoder_out.ndim == 2 + context_size = decoder.context_size + blank_id = decoder.blank_id + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor( + decoder_input, dtype=torch.int32, device=device + ).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + + return hyp, decoder_out + + +def compute_fbank(waveform: torch.Tensor, sample_rate: int) -> torch.Tensor: + """Compute fbank features for the entire waveform at once. + + Args: + waveform: + A 1-D float32 tensor of audio samples. + sample_rate: + The sample rate of the audio. + Returns: + Return a 2-D tensor of shape (num_frames, feature_dim). + """ + feat = torchaudio.compliance.kaldi.fbank( + waveform.unsqueeze(0), + num_mel_bins=80, + sample_frequency=sample_rate, + dither=0, + snip_edges=False, + high_freq=-400, + ) + return feat + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + model.eval() + model.to(device) + + encoder = model.encoder + decoder = model.decoder + joiner = model.joiner + + token_table = k2.SymbolTable.from_file(args.tokens) + context_size = decoder.context_size + + logging.info("Computing fbank features") + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + # Compute all fbank features at once + features = compute_fbank(wave_samples, args.sample_rate) + logging.info(f"features shape: {features.shape}") + + logging.info("Decoding started") + + chunk_length = encoder.chunk_size * 2 + T = chunk_length + 7 # Conv2dSubsampling pad_length is a fixed constant + + logging.info(f"chunk_length: {chunk_length}") + logging.info(f"T: {T}") + + states = encoder.get_init_states(device=device) + + num_frames = features.size(0) + num_processed_frames = 0 + + hyp = None + decoder_out = None + + while num_processed_frames + T <= num_frames: + frames = features[num_processed_frames : num_processed_frames + T].to(device).unsqueeze(0) + x_lens = torch.tensor([T], dtype=torch.int32, device=device) + encoder_out, out_lens, states = encoder( + features=frames, + feature_lengths=x_lens, + states=states, + ) + num_processed_frames += chunk_length + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device + ) + + text = "" + for i in hyp[context_size:]: + text += token_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/joiner.py b/egs/librispeech/ASR/zapformer/joiner.py new file mode 100644 index 0000000000..5cf7b42bd2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/joiner.py @@ -0,0 +1,69 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from zapformer_modules import ScaledLinear + + +class Joiner(nn.Module): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + super().__init__() + + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) + self.output_linear = nn.Linear(joiner_dim, vocab_size) + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + project_input: bool = True, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + assert encoder_out.ndim == decoder_out.ndim, ( + encoder_out.shape, + decoder_out.shape, + ) + + if project_input: + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + else: + logit = encoder_out + decoder_out + + # the scale of 2.0 is arbitrary, it is intended to modulate the speed at which joiner.output_linear trains, + # make it train faster by reducing its scale. + logit = 2.0 * self.output_linear(torch.tanh(logit)) + + return logit diff --git a/egs/librispeech/ASR/zapformer/label_smoothing.py b/egs/librispeech/ASR/zapformer/label_smoothing.py new file mode 100644 index 0000000000..52d2eda3bb --- /dev/null +++ b/egs/librispeech/ASR/zapformer/label_smoothing.py @@ -0,0 +1,109 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class LabelSmoothingLoss(torch.nn.Module): + """ + Implement the LabelSmoothingLoss proposed in the following paper + https://arxiv.org/pdf/1512.00567.pdf + (Rethinking the Inception Architecture for Computer Vision) + + """ + + def __init__( + self, + ignore_index: int = -1, + label_smoothing: float = 0.1, + reduction: str = "sum", + ) -> None: + """ + Args: + ignore_index: + ignored class id + label_smoothing: + smoothing rate (0.0 means the conventional cross entropy loss) + reduction: + It has the same meaning as the reduction in + `torch.nn.CrossEntropyLoss`. It can be one of the following three + values: (1) "none": No reduction will be applied. (2) "mean": the + mean of the output is taken. (3) "sum": the output will be summed. + """ + super().__init__() + assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}" + assert reduction in ("none", "sum", "mean"), reduction + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.ignore_index of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.ndim == 3 + assert target.ndim == 2 + assert x.shape[:2] == target.shape + num_classes = x.size(-1) + x = x.reshape(-1, num_classes) + # Now x is of shape (N*T, C) + + # We don't want to change target in-place below, + # so we make a copy of it here + target = target.clone().reshape(-1) + + ignored = target == self.ignore_index + + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use target[ignored] = 0 here + target = torch.where(ignored, torch.zeros_like(target), target) + + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + + true_dist = ( + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + ) + + # Set the value of ignored indexes to 0 + # + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use true_dist[ignored] = 0 here + true_dist = torch.where( + ignored.unsqueeze(1).repeat(1, true_dist.shape[1]), + torch.zeros_like(true_dist), + true_dist, + ) + + loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) + if self.reduction == "sum": + return loss.sum() + elif self.reduction == "mean": + return loss.sum() / (~ignored).sum() + else: + return loss.sum(dim=-1) diff --git a/egs/librispeech/ASR/zapformer/model.py b/egs/librispeech/ASR/zapformer/model.py new file mode 100755 index 0000000000..f64528b8ef --- /dev/null +++ b/egs/librispeech/ASR/zapformer/model.py @@ -0,0 +1,411 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from torch import Tensor +from encoder_interface import EncoderInterface + +from zapformer_modules import ScaledLinear +from icefall.utils import add_sos, make_pad_mask + + +class AsrModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + decoder: Optional[nn.Module] = None, + joiner: Optional[nn.Module] = None, + attention_decoder: Optional[nn.Module] = None, + encoder_dim: int = 384, + decoder_dim: int = 512, + vocab_size: int = 500, + use_transducer: bool = True, + use_ctc: bool = False, + use_attention_decoder: bool = False, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + use_attention_decoder: + Whether use attention-decoder head. Default: False. + """ + super().__init__() + + assert ( + use_transducer or use_ctc + ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" + + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder_embed = encoder_embed + self.encoder = encoder + + self.use_transducer = use_transducer + if use_transducer: + # Modules for Transducer head + assert decoder is not None + assert hasattr(decoder, "blank_id") + assert joiner is not None + + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_scale=0.1, + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, vocab_size, initial_scale=0.1, + ) + + else: + assert decoder is None + assert joiner is None + + self.use_ctc = use_ctc + if use_ctc: + # Modules for CTC head + self.ctc_output = nn.Sequential( + ScaledLinear(encoder_dim, vocab_size, initial_scale=0.1), + nn.LogSoftmax(dim=-1), + ) + + self.use_attention_decoder = use_attention_decoder + if use_attention_decoder: + self.attention_decoder = attention_decoder + else: + assert attention_decoder is None + + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + aux_loss_scale: + auxiliary-loss scale, for scaling cosine losses in the encoders. + sc_prob: + stochastic-depth probability: not a layer skipping probabilty but the probabibilty + of taking the output of a randomly chosen layer, instead of the last layer. + + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + x, x_lens = self.encoder_embed(x, x_lens, aux_loss_scale=aux_loss_scale) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) # (N, T) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask, + aux_loss_scale=aux_loss_scale) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + + return encoder_out, encoder_out_lens + + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + + # the calls to .long() were added as a workaround for a problem with + # torch.nn.functional.ctc_loss() on newer torch versions. Previously + # instead of .long() we had .cpu(). This activates the use of CUDNN + # because it only uses CUDNN if integer inputs are in int32 and on CPU. + # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) + # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: + # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. + # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss + # seems to have a bug with "int32" integer arguments (it returns infinity), so we call + # .long() to use the torch implementation and avoid that bug. + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets.long(), + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), + reduction="sum", + ) + return ctc_loss + + + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.amp.autocast('cuda', enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.amp.autocast('cuda', enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return simple_loss, pruned_loss + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + aux_loss_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + aux_loss_scale: + auxiliary-loss scale, for scaling cosine losses in the encoders. + sc_prob: + stochastic-depth probability: not a layer skipping probabilty but the probabibilty + of taking the output of a randomly chosen layer, instead of the last layer. + + Returns: + Return the transducer losses, CTC loss, AED loss, + and consistency-regularization loss in form of + (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + + device = x.device + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens, + aux_loss_scale=aux_loss_scale) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + if self.use_transducer: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + else: + simple_loss = torch.empty(0) + pruned_loss = torch.empty(0) + + if self.use_ctc: + targets = y.values + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + else: + ctc_loss = torch.empty(0) + + if self.use_attention_decoder: + attention_decoder_loss = self.attention_decoder.calc_att_loss( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ys=y.to(device), + ys_lens=y_lens.to(device), + ) + else: + attention_decoder_loss = torch.empty(0) + + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss diff --git a/egs/librispeech/ASR/zapformer/multicopy_dataset.py b/egs/librispeech/ASR/zapformer/multicopy_dataset.py new file mode 100755 index 0000000000..a41e9b4a1a --- /dev/null +++ b/egs/librispeech/ASR/zapformer/multicopy_dataset.py @@ -0,0 +1,226 @@ +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate + +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + + +class MulticopyDataset(torch.utils.data.Dataset): + """ + This is slightly modified from lhotse's K2SpeechRecognitionDataset, but + to support multiple parallel copies of the data, with augmentation applied + differently. + It uses ideas from Piotr in this thread: + https://github.com/k2-fsa/icefall/pull/1975 + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + def __init__( + self, + return_cuts: bool = False, + num_copies: int = 1, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + ): + """ + k2 ASR IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + self.num_copies = num_copies + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_for_asr(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + if self.num_copies > 1: + cuts = cuts.repeat(times=self.num_copies) + + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we successfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "num_copies": self.num_copies, + "supervisions": default_collate( + [ + { + "text": supervision.text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + has_word_alignments = all( + s.alignment is not None and "word" in s.alignment + for c in cuts + for s in c.supervisions + ) + if has_word_alignments: + # TODO: might need to refactor BatchIO API to move the following conditional logic + # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), + # that returns either num_frames or num_samples depending on the strategy). + words, starts, ends = [], [], [] + frame_shift = cuts[0].frame_shift + sampling_rate = cuts[0].sampling_rate + if frame_shift is None: + try: + frame_shift = self.input_strategy.extractor.frame_shift + except AttributeError: + raise ValueError( + "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " + ) + for c in cuts: + for s in c.supervisions: + words.append([aliword.symbol for aliword in s.alignment["word"]]) + starts.append( + [ + compute_num_frames( + aliword.start, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + ends.append( + [ + compute_num_frames( + aliword.end, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + batch["supervisions"]["word"] = words + batch["supervisions"]["word_start"] = starts + batch["supervisions"]["word_end"] = ends + + return batch + + +def validate_for_asr(cuts: CutSet) -> None: + validate(cuts) + tol = 2e-3 # 1ms + for cut in cuts: + for supervision in cut.supervisions: + assert supervision.start >= -tol, ( + f"Supervisions starting before the cut are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + # + # 'supervision.end' is end of supervision inside the Cut + assert supervision.end <= cut.duration + tol, ( + f"Supervisions ending after the cut " + f"are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) diff --git a/egs/librispeech/ASR/zapformer/my_profile.py b/egs/librispeech/ASR/zapformer/my_profile.py new file mode 100755 index 0000000000..458a759694 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/my_profile.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: ./zapformer/my_profile.py +""" + +import argparse +import logging +from typing import Tuple + +import sentencepiece as spm +import torch +from torch import Tensor, nn +from train import ( + add_model_arguments, + get_encoder_embed, + get_encoder_model, + get_joiner_model, + get_params, +) + +from icefall.profiler import get_model_profile +from icefall.utils import make_pad_mask + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + add_model_arguments(parser) + + return parser + +class Model(nn.Module): + """A Wrapper for encoder, encoder_embed, and encoder_proj""" + + def __init__( + self, + encoder: nn.Module, + encoder_embed: nn.Module, + encoder_proj: nn.Module, + ) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + + def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]: + x, x_lens = self.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + logits = self.encoder_proj(encoder_out) + + return logits, encoder_out_lens + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + + # We only profile the encoder part + model = Model( + encoder=get_encoder_model(params), + encoder_embed=get_encoder_embed(params), + encoder_proj=get_joiner_model(params).encoder_proj, + ) + model.eval() + model.to(device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # for 30-second input + B, T, D = 1, 3000, 80 + feature = torch.ones(B, T, D, dtype=torch.float32).to(device) + feature_lens = torch.full((B,), T, dtype=torch.int64).to(device) + + flops, params = get_model_profile( + model=model, + args=(feature, feature_lens), + #module_hoop_mapping=MODULE_HOOK_MAPPING, + ) + logging.info(f"For the encoder part, params: {params}, flops: {flops}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_check.py b/egs/librispeech/ASR/zapformer/onnx_check.py new file mode 100755 index 0000000000..daca7d81bd --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_check.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model via torchscript (torch.jit.script()) + +./zapformer/export.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit 1 + +It will generate the following file in $repo/exp: + - jit_script.pt + +3. Export the model to ONNX + +./zapformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +4. Run this file + +./zapformer/onnx_check.py \ + --jit-filename $repo/exp/jit_script.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx +""" + +import argparse +import logging + + +import torch +from onnx_pretrained import OnnxModel + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-filename", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the onnx encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the onnx decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the onnx joiner model", + ) + + return parser + + +def test_encoder( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + C = 80 + for i in range(3): + N = torch.randint(low=1, high=20, size=(1,)).item() + T = torch.randint(low=30, high=50, size=(1,)).item() + logging.info(f"test_encoder: iter {i}, N={N}, T={T}") + + x = torch.rand(N, T, C) + x_lens = torch.randint(low=30, high=T + 1, size=(N,)) + x_lens[0] = T + + torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) + torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) + + onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] + decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) + projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) + + torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_model = torch.jit.load(args.jit_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + test_encoder(torch_model, onnx_model) + + logging.info("Test decoder") + test_decoder(torch_model, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_model, onnx_model) + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20220727) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_decode.py b/egs/librispeech/ASR/zapformer/onnx_decode.py new file mode 100755 index 0000000000..075474a6bf --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_decode.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zapformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +2. Run this file + +./zapformer/onnx_decode.py \ + --exp-dir $repo/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from k2 import SymbolTable +from onnx_pretrained import OnnxModel, greedy_search + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zapformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, token_table: SymbolTable, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + token_table: + The token table. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + hyps = [token_ids_to_words(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + token_table: SymbolTable, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + token_table: + The token table. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + token_table = SymbolTable.from_file(args.tokens) + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py new file mode 100755 index 0000000000..8d2cebb54c --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming-ctc.py @@ -0,0 +1,427 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script loads ONNX models exported by ./export-onnx-streaming-ctc.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zapformer-ctc-streaming-2023-11-05 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zapformer-ctc-streaming-2023-11-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zapformer/export-onnx-streaming-ctc.py \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 128 \ + --use-ctc 1 + +It will generate the following 2 files inside $repo/exp: + + - ctc-epoch-99-avg-1-chunk-16-left-128.int8.onnx + - ctc-epoch-99-avg-1-chunk-16-left-128.onnx + +You can use either the ``int8.onnx`` model or just the ``.onnx`` model. + +3. Run this file with the exported ONNX models + +./zapformer/onnx_pretrained-streaming-ctc.py \ + --model-filename $repo/exp/ctc-epoch-99-avg-1-chunk-16-left-128.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Tuple + +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(model_filename) + + def init_model(self, model_filename: str): + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + self.init_states() + + def init_states(self, batch_size: int = 1): + meta = self.model.get_modelmeta().custom_metadata_map + logging.info(f"meta={meta}") + + model_type = meta["model_type"] + assert model_type == "zapformer2", model_type + + decode_chunk_len = int(meta["decode_chunk_len"]) + T = int(meta["T"]) + + num_encoder_layers = meta["num_encoder_layers"] + encoder_dims = meta["encoder_dims"] + cnn_module_kernels = meta["cnn_module_kernels"] + left_context_len = meta["left_context_len"] + query_head_dims = meta["query_head_dims"] + value_head_dims = meta["value_head_dims"] + num_heads = meta["num_heads"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + query_head_dims = to_int_list(query_head_dims) + value_head_dims = to_int_list(value_head_dims) + num_heads = to_int_list(num_heads) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + logging.info(f"query_head_dims: {query_head_dims}") + logging.info(f"value_head_dims: {value_head_dims}") + logging.info(f"num_heads: {num_heads}") + + num_encoders = len(num_encoder_layers) + + self.states = [] + for i in range(num_encoders): + num_layers = num_encoder_layers[i] + key_dim = query_head_dims[i] * num_heads[i] + embed_dim = encoder_dims[i] + nonlin_attn_head_dim = 3 * embed_dim // 4 + value_dim = value_head_dims[i] * num_heads[i] + conv_left_pad = cnn_module_kernels[i] // 2 + + for layer in range(num_layers): + cached_key = torch.zeros( + left_context_len[i], batch_size, key_dim + ).numpy() + cached_nonlin_attn = torch.zeros( + 1, batch_size, left_context_len[i], nonlin_attn_head_dim + ).numpy() + cached_val1 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_val2 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + self.states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() + self.states.append(embed_states) + processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() + self.states.append(processed_lens) + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def _build_model_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + model_input = {"x": x.numpy()} + model_output = ["log_probs"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + model_input[name] = tensors[0] + model_output.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + model_input[name] = tensors[1] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + model_input[name] = tensors[2] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + model_input[name] = tensors[3] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + model_input[name] = tensors[4] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + model_input[name] = tensors[5] + model_output.append(f"new_{name}") + + for i in range(len(self.states[:-2]) // 6): + build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + name = "embed_states" + embed_states = self.states[-2] + model_input[name] = embed_states + model_output.append(f"new_{name}") + + # (batch_size,) + name = "processed_lens" + processed_lens = self.states[-1] + model_input[name] = processed_lens + model_output.append(f"new_{name}") + + return model_input, model_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size) + where T' is usually equal to ((T-7)//2 - 3)//2 + """ + model_input, model_output_names = self._build_model_input_output(x) + + out = self.model.run(model_output_names, model_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + return OnlineFbank(opts) + + +def greedy_search( + log_probs: torch.Tensor, +) -> List[int]: + """Greedy search for a single utterance. + Args: + log_probs: + A 3-D tensor of shape (1, T, vocab_size) + Returns: + Return the decoded result. + """ + assert log_probs.ndim == 3, log_probs.shape + assert log_probs.shape[0] == 1, log_probs.shape + + max_indexes = log_probs[0].argmax(dim=1) + unique_indexes = torch.unique_consecutive(max_indexes) + + blank_id = 0 + unique_indexes = unique_indexes[unique_indexes != blank_id] + return unique_indexes.tolist() + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel(model_filename=args.model_filename) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + hyp = [] + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + log_probs = model(frames) + + hyp += greedy_search(log_probs) + + # To handle byte-level BPE, we convert string tokens to utf-8 encoded bytes + id2token = {} + with open(args.tokens, encoding="utf-8") as f: + for line in f: + token, idx = line.split() + if token[:3] == "<0x" and token[-1] == ">": + token = int(token[1:-1], base=16) + assert 0 <= token < 256, token + token = token.to_bytes(1, byteorder="little") + else: + token = token.encode(encoding="utf-8") + + id2token[int(idx)] = token + + text = b"" + for i in hyp: + text += id2token[i] + text = text.decode(encoding="utf-8") + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py new file mode 100755 index 0000000000..0e297a5d30 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained-streaming.py @@ -0,0 +1,571 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script loads ONNX models exported by ./export-onnx-streaming.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zapformer-2023-05-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zapformer-2023-05-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zapformer/export-onnx-streaming.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 64 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./zapformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + logging.info(f"encoder_meta={encoder_meta}") + + model_type = encoder_meta["model_type"] + assert model_type == "zapformer", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + + num_encoder_layers = encoder_meta["num_encoder_layers"] + encoder_dims = encoder_meta["encoder_dims"] + conv_params = encoder_meta["conv_params"] + left_context_len = encoder_meta["left_context_len"] + query_head_dims = encoder_meta["query_head_dims"] + value_head_dims = encoder_meta["value_head_dims"] + num_heads = encoder_meta["num_heads"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + conv_params = to_int_list(conv_params) + left_context_len = to_int_list(left_context_len) + query_head_dims = to_int_list(query_head_dims) + value_head_dims = to_int_list(value_head_dims) + num_heads = to_int_list(num_heads) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"conv_params: {conv_params}") + logging.info(f"left_context_len: {left_context_len}") + logging.info(f"query_head_dims: {query_head_dims}") + logging.info(f"value_head_dims: {value_head_dims}") + logging.info(f"num_heads: {num_heads}") + + num_encoders = len(num_encoder_layers) + + self.states = [] + for i in range(num_encoders): + num_layers = num_encoder_layers[i] + key_dim = query_head_dims[i] * num_heads[i] + embed_dim = encoder_dims[i] + value_dim = value_head_dims[i] * num_heads[i] + conv_left_pad = conv_params[i] - 1 + + for layer in range(num_layers): + # (left_context_len, batch, key_dim) + cached_key = torch.zeros( + left_context_len[i], batch_size, key_dim + ).numpy() + # (left_context_len, batch, value_dim) + cached_value = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + # (batch, embed_dim, conv_left_pad) + cached_conv = torch.zeros( + batch_size, embed_dim, conv_left_pad + ).numpy() + # cached_norm_stats: (batch,) + cached_norm_stats = torch.zeros(batch_size).numpy() + # cached_norm_len: (batch,) + cached_norm_len = torch.zeros(batch_size).numpy() + # cached_attn_wm_sum: (1, batch, value_dim) + cached_attn_wm_sum = torch.zeros( + 1, batch_size, value_dim + ).numpy() + # cached_attn_wm_num_frames: (batch,) + cached_attn_wm_num_frames = torch.zeros( + batch_size, dtype=torch.int64 + ).numpy() + # cached_conv_wm_sum: (1, batch, embed_dim) + cached_conv_wm_sum = torch.zeros( + 1, batch_size, embed_dim + ).numpy() + # cached_conv_wm_num_frames: (batch,) + cached_conv_wm_num_frames = torch.zeros( + batch_size, dtype=torch.int64 + ).numpy() + + self.states += [ + cached_key, + cached_value, + cached_conv, + cached_norm_stats, + cached_norm_len, + cached_attn_wm_sum, + cached_attn_wm_num_frames, + cached_conv_wm_sum, + cached_conv_wm_num_frames, + ] + + # embed_cache: (batch, channels, left_pad, freq) + embed_cache = torch.zeros(batch_size, 128, 6, 19).numpy() + self.states.append(embed_cache) + # processed_lens: (batch,) + processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() + self.states.append(processed_lens) + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = {"x": x.numpy()} + encoder_output = ["encoder_out"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 9, len(tensors) + + # (left_context_len, batch_size, key_dim) + name = f"cached_key_{i}" + encoder_input[name] = tensors[0] + encoder_output.append(f"new_{name}") + + # (left_context_len, batch_size, value_dim) + name = f"cached_value_{i}" + encoder_input[name] = tensors[1] + encoder_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv_{i}" + encoder_input[name] = tensors[2] + encoder_output.append(f"new_{name}") + + # (batch_size,) + name = f"cached_norm_stats_{i}" + encoder_input[name] = tensors[3] + encoder_output.append(f"new_{name}") + + # (batch_size,) + name = f"cached_norm_len_{i}" + encoder_input[name] = tensors[4] + encoder_output.append(f"new_{name}") + + # (1, batch_size, value_dim) + name = f"cached_attn_wm_sum_{i}" + encoder_input[name] = tensors[5] + encoder_output.append(f"new_{name}") + + # (batch_size,) + name = f"cached_attn_wm_num_frames_{i}" + encoder_input[name] = tensors[6] + encoder_output.append(f"new_{name}") + + # (1, batch_size, embed_dim) + name = f"cached_conv_wm_sum_{i}" + encoder_input[name] = tensors[7] + encoder_output.append(f"new_{name}") + + # (batch_size,) + name = f"cached_conv_wm_num_frames_{i}" + encoder_input[name] = tensors[8] + encoder_output.append(f"new_{name}") + + for i in range(len(self.states[:-2]) // 9): + build_inputs_outputs(self.states[i * 9 : (i + 1) * 9], i) + + # (batch_size, channels, left_pad, freq) + name = "embed_cache" + embed_cache = self.states[-2] + encoder_input[name] = embed_cache + encoder_output.append(f"new_{name}") + + # (batch_size,) + name = "processed_lens" + processed_lens = self.states[-1] + encoder_input[name] = processed_lens + encoder_output.append(f"new_{name}") + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-7)//2-3)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def compute_fbank(waveform: torch.Tensor) -> torch.Tensor: + """Compute fbank features for the entire waveform at once. + + Args: + waveform: + A 1-D float32 tensor of audio samples. + Returns: + Return a 2-D tensor of shape (num_frames, feature_dim). + """ + feat = torchaudio.compliance.kaldi.fbank( + waveform.unsqueeze(0), + num_mel_bins=80, + sample_frequency=16000, + dither=0, + snip_edges=False, + high_freq=-400, + ) + return feat + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + # Compute all fbank features at once + logging.info("Computing fbank features") + features = compute_fbank(waves) + logging.info(f"features shape: {features.shape}") + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + num_frames = features.size(0) + + while num_processed_frames + segment <= num_frames: + frames = features[num_processed_frames : num_processed_frames + segment] + num_processed_frames += offset + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + token_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += token_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained.py b/egs/librispeech/ASR/zapformer/onnx_pretrained.py new file mode 100755 index 0000000000..39b5a70fd2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained.py @@ -0,0 +1,422 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zapformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file + +./zapformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def run_encoder( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 2-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, its shape is (N, T', joiner_dim) + - encoder_out_lens, its shape is (N,) + """ + out = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + ], + { + self.encoder.get_inputs()[0].name: x.numpy(), + self.encoder.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, joiner_dim) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.run_decoder(decoder_input) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + # current_encoder_out's shape: (batch_size, joiner_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + logits = model.run_joiner(current_encoder_out, decoder_out) + + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = model.run_decoder(decoder_input) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + logging.info("Constructing Fbank computer") + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = [] + for w in waves: + feat = torchaudio.compliance.kaldi.fbank( + w.unsqueeze(0), + num_mel_bins=80, + sample_frequency=args.sample_rate, + dither=0, + snip_edges=False, + high_freq=-400, + ) # (num_frames, 80) + features.append(feat) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + for filename, hyp in zip(args.sound_files, hyps): + words = token_ids_to_words(hyp) + s += f"{filename}:\n{words}\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py new file mode 100755 index 0000000000..457e2370bc --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zapformer/onnx_pretrained_ctc.py \ + --nn-model /path/to/model.onnx \ + --tokens /path/to/data/lang_bpe_500/tokens.txt \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + blank_id = 0 + s = "\n" + for i in range(log_probs.size(0)): + # greedy search + indexes = log_probs[i, : log_probs_len[i]].argmax(dim=-1) + token_ids = torch.unique_consecutive(indexes) + + token_ids = token_ids[token_ids != blank_id] + words = token_ids_to_words(token_ids.tolist()) + s += f"{args.sound_files[i]}:\n{words}\n\n" + + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py new file mode 100755 index 0000000000..7472c61c5e --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_H.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zapformer/onnx_pretrained_ctc_H.py \ + --nn-model /path/to/model.onnx \ + --tokens /path/to/data/lang_bpe_500/tokens.txt \ + --H /path/to/H.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zapformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import Dict, List, Tuple + +import k2 +import kaldifeat +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--H", + type=str, + help="""Path to H.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + H: kaldifst, + id2token: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + H: + The H graph. + id2word: + A map mapping token ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(H, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # tokens are incremented during graph construction + # are shifted by 1 during graph construction + hyps = [id2token[i - 1] for i in osymbols_out if i != 1] + hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁ + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + logging.info(f"Loading H from {args.H}") + H = kaldifst.StdVectorFst.read(args.H) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + token_table = k2.SymbolTable.from_file(args.tokens) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + H=H, + id2token=token_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py new file mode 100755 index 0000000000..9e11535b2b --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HL.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zapformer/onnx_pretrained_ctc_HL.py \ + --nn-model /path/to/model.onnx \ + --words /path/to/data/lang_bpe_500/words.txt \ + --HL /path/to/HL.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zapformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import Dict, List, Tuple + +import k2 +import kaldifeat +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HL", + type=str, + help="""Path to HL.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + HL: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HL: + The HL graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + logging.info(f"Loading HL from {args.HL}") + HL = kaldifst.StdVectorFst.read(args.HL) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + word_table = k2.SymbolTable.from_file(args.words) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + HL=HL, + id2word=word_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py new file mode 100755 index 0000000000..3d757386cb --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zapformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zapformer/onnx_pretrained_ctc_HLG.py \ + --nn-model /path/to/model.onnx \ + --words /path/to/data/lang_bpe_500/words.txt \ + --HLG /path/to/HLG.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zapformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import Dict, List, Tuple + +import k2 +import kaldifeat +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + HLG: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HLG: + The HLG graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HLG, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + logging.info(f"Loading HLG from {args.HLG}") + HLG = kaldifst.StdVectorFst.read(args.HLG) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + word_table = k2.SymbolTable.from_file(args.words) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + HLG=HLG, + id2word=word_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py new file mode 100755 index 0000000000..e823c8d5a2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/onnx_pretrained_ctc_HLG_streaming.py @@ -0,0 +1,439 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script loads ONNX models exported by ./export-onnx-streaming-ctc.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zapformer-small-2024-03-18 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zapformer-small-2024-03-18 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp-ctc-rnnt-small/*.pt" +git lfs pull --include "data/lang_bpe_500/words.txt" +git lfs pull --include "data/lang_bpe_500/HLG.fst" +popd + +2. Export the model to ONNX + +./zapformer/export-onnx-streaming-ctc.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 3 \ + --exp-dir $repo/exp-ctc-rnnt-small \ + --causal 1 \ + --use-ctc 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 + +It will generate the following 2 files inside $repo/exp-ctc-rnnt-small: + + - ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx + - ctc-epoch-30-avg-3-chunk-16-left-128.onnx + +You can use either the ``int8.onnx`` model or just the ``.onnx`` model. + +3. Run this file with the exported ONNX models + +python3 ./zapformer/onnx_pretrained_ctc_HLG_streaming.py \ + --nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \ + --words $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.fst \ + $repo/test_wavs/0.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. + +Note: HLG.fst is generated directly from ../local/prepare_lang_fst.py +""" + +import argparse +import logging +from typing import Dict, List, Tuple + +import k2 +import kaldifst +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HLG", + type=str, + required=True, + help="""Path to HLG.fst.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(model_filename) + + def init_model(self, model_filename: str): + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + self.init_states() + + def init_states(self, batch_size: int = 1): + meta = self.model.get_modelmeta().custom_metadata_map + logging.info(f"meta={meta}") + + model_type = meta["model_type"] + assert model_type == "zapformer2", model_type + + decode_chunk_len = int(meta["decode_chunk_len"]) + T = int(meta["T"]) + + num_encoder_layers = meta["num_encoder_layers"] + encoder_dims = meta["encoder_dims"] + cnn_module_kernels = meta["cnn_module_kernels"] + left_context_len = meta["left_context_len"] + query_head_dims = meta["query_head_dims"] + value_head_dims = meta["value_head_dims"] + num_heads = meta["num_heads"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + query_head_dims = to_int_list(query_head_dims) + value_head_dims = to_int_list(value_head_dims) + num_heads = to_int_list(num_heads) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + logging.info(f"query_head_dims: {query_head_dims}") + logging.info(f"value_head_dims: {value_head_dims}") + logging.info(f"num_heads: {num_heads}") + + num_encoders = len(num_encoder_layers) + + self.states = [] + for i in range(num_encoders): + num_layers = num_encoder_layers[i] + key_dim = query_head_dims[i] * num_heads[i] + embed_dim = encoder_dims[i] + nonlin_attn_head_dim = 3 * embed_dim // 4 + value_dim = value_head_dims[i] * num_heads[i] + conv_left_pad = cnn_module_kernels[i] // 2 + + for layer in range(num_layers): + cached_key = torch.zeros( + left_context_len[i], batch_size, key_dim + ).numpy() + cached_nonlin_attn = torch.zeros( + 1, batch_size, left_context_len[i], nonlin_attn_head_dim + ).numpy() + cached_val1 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_val2 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + self.states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() + self.states.append(embed_states) + processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() + self.states.append(processed_lens) + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def _build_model_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + model_input = {"x": x.numpy()} + model_output = ["log_probs"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + model_input[name] = tensors[0] + model_output.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + model_input[name] = tensors[1] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + model_input[name] = tensors[2] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + model_input[name] = tensors[3] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + model_input[name] = tensors[4] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + model_input[name] = tensors[5] + model_output.append(f"new_{name}") + + for i in range(len(self.states[:-2]) // 6): + build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + name = "embed_states" + embed_states = self.states[-2] + model_input[name] = embed_states + model_output.append(f"new_{name}") + + # (batch_size,) + name = "processed_lens" + processed_lens = self.states[-1] + model_input[name] = processed_lens + model_output.append(f"new_{name}") + + return model_input, model_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size) + where T' is usually equal to ((T-7)//2 - 3)//2 + """ + model_input, model_output_names = self._build_model_input_output(x) + + out = self.model.run(model_output_names, model_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + logging.info(f"Resample {sample_rate} to {expected_sample_rate}") + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + word_table = k2.SymbolTable.from_file(args.words) + model = OnnxModel(model_filename=args.nn_model) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + logging.info(f"Loading HLG from {args.HLG}") + HLG = kaldifst.StdVectorFst.read(args.HLG) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HLG, decoder_opts) + decoder.init_decoding() + + chunk = int(1 * sample_rate) # 1 second + start = 0 + + n = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + + # simulate streaming + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + + log_probs = model(frames) + log_probs = log_probs.squeeze(0).cpu().numpy() + + decodable = DecodableCtc(log_probs, offset=n) + n += log_probs.shape[0] + + num_processed_frames += offset + decoder.advance_decoding(decodable) + + if not decoder.reached_final(): + logging.info(f"Failed to decode {args.sound_file}") + return + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + + if not ok: + logging.info(f"Failed to get linear symbol sequence for {args.sound_file}") + return + + hyps = " ".join([word_table[i] for i in osymbols_out]).lower() + logging.info(f"\n{args.sound_file}\n{hyps}") + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/pretrained.py b/egs/librispeech/ASR/zapformer/pretrained.py new file mode 100755 index 0000000000..9e859332f8 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/pretrained.py @@ -0,0 +1,380 @@ +#!/usr/bin/env python3 +# Copyright 2021-2026 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +- For non-streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +- For non-streaming model: + +(1) greedy search +./zapformer/pretrained.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zapformer/pretrained.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zapformer/pretrained.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +- For streaming model: + +(1) greedy search +./zapformer/pretrained.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zapformer/pretrained.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zapformer/pretrained.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +You can also use `./zapformer/exp/epoch-xx.pt`. + +Note: ./zapformer/exp/pretrained.pt is generated by ./zapformer/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import torch +import torchaudio +from beam_search import ( + fast_beam_search_one_best, + greedy_search_batch, + modified_beam_search, +) +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = [] + for w in waves: + feat = torchaudio.compliance.kaldi.fbank( + w.unsqueeze(0), + num_mel_bins=params.feature_dim, + sample_frequency=params.sample_rate, + dither=0, + snip_edges=False, + high_freq=-400, + ) # (num_frames, feature_dim) + features.append(feat.to(device)) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + else: + raise ValueError(f"Unsupported method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/pretrained_ctc.py b/egs/librispeech/ASR/zapformer/pretrained_ctc.py new file mode 100755 index 0000000000..2cbd4098a9 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/pretrained_ctc.py @@ -0,0 +1,480 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +- For non-streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zapformer/export.py \ + --exp-dir ./zapformer/exp \ + --use-ctc 1 \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +(1) ctc-decoding +./zapformer/pretrained_ctc.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./zapformer/pretrained_ctc.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./zapformer/pretrained_ctc.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./zapformer/pretrained_ctc.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(5) attention-decoder-rescoring-no-ngram +./zapformer/pretrained_ctc.py \ + --checkpoint ./zapformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method attention-decoder-rescoring-no-ngram \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from ctc_decode import get_decoding_params +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_attention_decoder_no_ngram, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a token table, + i.e., lang_dir/tokens.txt, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + nbest n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + whole-lattice n-gram LM rescoring. + (4) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + params.vocab_size = num_tokens(token_table) + 1 # +1 for blank + params.blank_id = token_table[""] + params.sos_id = params.eos_id = token_table[""] + assert params.blank_id == 0 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + batch_size = ctc_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i].item() // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: + max_token_id = params.vocab_size - 1 + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + else: + logging.info("Use attention decoder rescoring without ngram") + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + + token_ids = get_texts(best_path) + hyps = [[token_table[i] for i in ids] for ids in token_ids] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zapformer/rubik.py b/egs/librispeech/ASR/zapformer/rubik.py new file mode 100644 index 0000000000..d9d352642a --- /dev/null +++ b/egs/librispeech/ASR/zapformer/rubik.py @@ -0,0 +1,552 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import math +import logging +import random +from collections import defaultdict +from torch.optim.lr_scheduler import LambdaLR + +from typing import Dict, List, Optional, Tuple, Union +import torch +import torch.distributed as dist +from torch import Tensor +from torch.optim import Optimizer + + +def three_way_product(x): + """ returns the 3-way matrix product x @ x.t() @ x """ + assert x.ndim == 2 + if x.shape[0] <= x.shape[1]: + x2 = torch.matmul(x, x.t()) + return torch.matmul(x2, x) + else: + x2 = torch.matmul(x.t(), x) + return torch.matmul(x, x2) + +def scaled_three_way_product(x): + """ + Returns alpha * (x @ x.t() @ x), + where alpha is computed from the 2-norm of x in such a way that if all the singular values of + x are the same, it will return x itself. (There is only one such formula.) If the singular + values of x differ from each other, the result will in general have a larger norm than x. + """ + rows, cols = x.shape + eps = 1.0e-40 + x_meansq = (x ** 2).mean(dim=(-2, -1), keepdim=True) + eps + x = x * (x_meansq * max(rows, cols)) ** (-1/3) + return three_way_product(x) + +def compute_alpha(x: Tensor, y: Tensor, beta: float) -> Tensor: + """ + Computes the amount of cubic decay to do for each parameter tensor in the batch, as + a scalar. + + First compute alpha that solves the equation: ||x - alpha y||_2^2 == ||beta x||_2^2 + + + x.x - 2 alpha y.x + alpha^2 y.y = beta^2 x.x + alpha^2 y.y - 2 alpha x.y + (1-beta^2) x.x = 0 + (a,b,c) = (y.y, -2 alpha x.y, x.x) + alpha = (-b - sqrt(b^2 - 4ac) ) / 2a # this is the solution closest to zero. + + # factoring out 2 from the top and bottom we get: + so alpha = (x.y - sqrt(x.y * y.x - (1-beta^2) x.x * y.y)) / y.y + ... we treat the thing inside the sqrt as zero if it is negative, + which gives us the closest real solution + + We then apply a formula that you can see at the bottom, which chooses the + smallest (closest to zero) of two formulae, see the comments. This is basically + heuristic; the safety_factor * min_sum_scale is a safety thing to reduce the + chance of eigenvalues flipping sign. + """ + eps = 1.0e-40 + xx = x.square().mean() + xy = (x * y).mean() + yy = y.square().mean() + yyeps = yy + eps + + # this alpha is the value that solves exactly for the requested difference in norm. + # this will be negative. + alpha = (xy - (xy**2 - (1-beta*beta) * xx * yy).clamp(min=0).sqrt()) / yyeps + + # min_sum_scale is the value of alpha that would minimize the norm of a - alpha y. + min_sum_scale = xy / yyeps + # safety_factor = 0.5 means we are only willing to go halfway to that value that minimizes the norm, + # to avoid change of eigenvalue sign / overshoot, which can ultimately lead to certain + # parameter eigenvalues getting too large. + safety_factor = 0.5 + + # alpha_power is a heuristic value that interpolates between the computed alpha, and alpha=(1-beta). + # the intention is that if the singular values are quite peaky (hence alpha << 1), + # we want to make sure that we're doing an adequate amount of decay for the smaller singular values. + alpha_power = 0.75 + + # return the closest to zero of the two formulae below. + return torch.minimum(safety_factor * min_sum_scale, + ((1-beta) ** (1-alpha_power)) * (alpha.clamp(min=1.0e-10) ** alpha_power)) + + + +def matrix_shape(shape): + """ + shape is expected to be a torch.Size with at least two dimensions. + Returns (rows, cols) such that a tensor of shape `shape` can be reshaped + to size (rows, cols), by combining dimensions in a way that minimizes the + difference between rows and cols. e.g. matrix_shape([ 2, 3, 10 ]) = (6, 10) + """ + shape = list(shape) + cumprod = [ ] + numel = 1 + for k in shape: + cumprod.append(k) + numel = numel * k + diffs = [ abs(k - numel // k) for k in cumprod ] + min_diff = min(diffs) + for i in range(len(shape)): + if diffs[i] == min_diff: + return cumprod[i], numel // cumprod[i] + assert False, shape + + +def half_normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): + """ + Normalize the rms of x using row-wise and column-wise stats, while + updating the moving-average stats; return the normalized x. + Shapes: + x: (rows, cols) +row_stats: (rows, 1) +col_stats: (1, cols) + Returns: + normalized x, shape: (rows, cols) + """ + row_stats.mul_(beta2).add_(x.abs().mean(dim=1, keepdim=True), alpha=(1 - beta2)) + row_denom = (row_stats + eps) + x = x / row_denom + col_stats.mul_(beta2).add_(x.abs().mean(dim=0, keepdim=True), alpha=(1 - beta2)) + col_denom = (col_stats + eps) + x_half_norm = (x * row_denom.sqrt()) / col_denom.sqrt() + x = x / col_denom + return x, x_half_norm + + +def normalize_and_update_stats(x, row_stats, col_stats, beta2, eps): + """ + Normalize the rms of x using row-wise and column-wise stats, while + updating the moving-average stats; return the normalized x. + Shapes: + x: (batch_size, rows, cols) +row_stats: (batch_size, rows, 1) +col_stats: (batch_size, 1, cols) + Returns: + normalized x, shape: (batch_size, rows, cols) + """ + row_stats.mul_(beta2).add_((x ** 2).mean(dim=1, keepdim=True), alpha=(1 - beta2)) + row_denom = (row_stats.sqrt() + eps) + x = x / row_denom + col_stats.mul_(beta2).add_((x ** 2).mean(dim=0, keepdim=True), alpha=(1 - beta2)) + col_denom = (col_stats.sqrt() + eps) + x = x / col_denom + return x + + + + +def cubic_decay_step(group, state, grad): + lr = group["lr"] + eps = group["eps"] + step = state["step"] + beta1_ceil = 1. - 1. / (10. + 0.2 * step) + beta1 = min(group["beta1"], beta1_ceil) + beta2_ceil = step / (step + 1) + beta2 = min(group["beta2"], beta2_ceil) + + orig_shape = grad.shape + rows, cols = matrix_shape(orig_shape) + grad = grad.reshape(rows, cols) + + if "moving_grad" not in state: + assert step < 2 + state["moving_grad"] = torch.zeros(rows, cols, device=grad.device) + state["row_stats"] = torch.ones(rows, 1, device=grad.device) + state["col_stats"] = torch.ones(1, cols, device=grad.device) + + moving_grad = state["moving_grad"] + row_stats = state["row_stats"] + col_stats = state["col_stats"] + + # we half update the stats here, half update them later. + norm_grad, norm_grad_precon = half_normalize_and_update_stats(grad, row_stats, col_stats, beta2, eps) + + # add the grad to the moving-average grad; the scaling factor used here + # doesn't matter as it all gets normalized later. + moving_grad.add_(norm_grad_precon, alpha=(1-beta1)) + + # prod3 would have the same value as moving_grad_precon if moving_grad_precon's singular values were + # all equal, but in general its 2-norm is >= the 2-norm of moving_grad_precon. + prod3 = scaled_three_way_product(moving_grad) + + cubic_alpha = compute_alpha(moving_grad, prod3, beta1) + # cubic_alpha shape: scalar + + moving_grad.add_(prod3 * cubic_alpha, alpha=-1) + + # assumed_scale is just a scalar factor to account for the fact that the moving-average "moving_grad" + # will have a smaller variance than the grad itself because of being a mean over independent elements. + # we rescale before getting the stats, to have the same variance as if it were the grad. + # The actual variance of moving_grad also depends on the variance of the original grads; this is just + # a scalar component in the variance to accountn for averaging-over-time effects. + assumed_scale = (1 - beta1) * ((1 - beta1**2)**-0.5) + + # use a beta2 that is much closer to 1 so we update the stats more slowly at this point; this will # make the stats update more dominated by grad rather than moving_grad. + beta2b_scale = 0.1 + beta2b = beta2b_scale * beta2 + (1 - beta2b_scale) + delta = assumed_scale * normalize_and_update_stats(moving_grad / assumed_scale, row_stats, col_stats, beta2b, eps) + + nesterov = True + if nesterov: + delta = torch.lerp(delta, norm_grad, weight=(1-beta1)) # beta1 * delta + (1 - beta1) * norm_grad # not in-place. + + debug = (step < 500 and (step % 50 == 0)) or (step % 500 == 0) + if debug: + cubic_alpha_ratio = cubic_alpha / (1-beta1) + scale = (assumed_scale / ((delta ** 2).mean().sqrt() + eps)) + logging.info(f"shape={prod3.shape}, scale={scale} [not applied], alpha_ratio={cubic_alpha_ratio}, delta-max={delta.abs().max()}") + + delta.mul_(-lr) + + return delta.reshape(orig_shape) + + + +def scaling_step(group, param, state, grad): + # we reach here for biases and weights but not scalars. + # This does three things things: + # (i) multiply the step from "cubic_decay" by an estimate of the parameter scale + # (ii) apply parameter decay + # (iii) update the parameter scale, which means shrinking or growing the whole tensor + lr = group["lr"] + momentum = group["scale_momentum"] # e.g. 0.95 + min_scale, max_scale = group["scale_limits"] + # the scaling factor is implicitly a scalar; apply scalar_scale to its + # learning rate. + scalar_scale = group["scalar_scale"] + + if grad.ndim >= 2 and grad.numel() != max(grad.shape): + delta = cubic_decay_step(group, state, grad) + else: + # biases and similar-shaped tensors + delta = adam_step(group, state, grad) + + try: + scale = state["scale"] + scale_grad_buf = state["scale_grad_buffer"] + except KeyError: + scale = (param ** 2).mean().sqrt().clamp(min=min_scale, + max=max_scale).to(torch.float) + scale_grad_buf = torch.zeros_like(scale) + state["scale"] = scale + state["scale_grad_buffer"] = scale_grad_buf + + + scale_grad = (grad * param.detach()).sum() + scale_grad_buf.mul_(momentum).add_(scale_grad) # simple momentum + + old_scale = scale.clone() + + nesterov = True + if nesterov: + # simple interpretation of nesterov: do an extra step of + # moving-average on scale_grad_buf, with scale_grad, like double-counting + # it. + negative_update = (scale_grad_buf * momentum + scale_grad).sign() + else: + negative_update = scale_grad_buf.sign() + + scale.mul_(1. - lr * scalar_scale * negative_update) + scale.clamp_(min=min_scale, max=max_scale) + + scale_ratio = scale / old_scale + + delta_scale = (scale_ratio * (1 - 0.5 * (lr ** 2))) - 1 + return param * delta_scale + scale * delta + + +def adam_step(group, state, grad): + # this is the adam update but with a slight modification / simplification on + # how "bias correction" (startup on small step counts) is dealt with. + lr = group["lr"] + step = state["step"] + eps = group["eps"] + beta1 = group["adam_beta1"] + # the following modification to beta2 makes it unnecessary to do bias correction; + # for small step values, we are just computing the mean over the steps so far + beta2 = min(group["adam_beta2"], step / (step + 1)) + + try: + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + except KeyError as e: + assert step < 2 + exp_avg = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + exp_avg_sq = torch.zeros(*grad.shape, device=grad.device, dtype=torch.float) + state["exp_avg"] = exp_avg + state["exp_avg_sq"] = exp_avg_sq + + exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + denom = exp_avg_sq.sqrt() + eps + + nesterov = True + if nesterov: + # this is similar to double-counting grad + moving_grad = exp_avg * beta1 + grad * (1-beta1) + else: + moving_grad = exp_avg + + return -lr * (moving_grad / denom) + + + + +class Rubik(Optimizer): + """ + Version of TransformedAdam that doesn't do the batching or gradient clipping (may be easier to integrate + into other frameworks). + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses). + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + beta2: beta2 is the momentum constant for moving-grad-squared as in Adam. + Must satisfy 0 < beta <= beta2 < 1. + betas: a list of decay constants for momentum on the parameter-change + scales: a list of scales corresponding to each of the betas, that we multiply + each momentum-update by. Implicitly there is also a beta=0, scale=1, + i.e. a non-decayed update. + """ + def __init__( + self, + params, + lr=1.2e-02, + beta1=0.99, + beta2=0.98, + eps=1.0e-08, + scale_limits=(0.03, 0.15), + scalar_scale=0.05, + adam_beta1=0.98, + adam_beta2=0.98, + scale_momentum=0.95, + ): + defaults = dict( + lr=lr, + beta1=beta1, + beta2=beta2, + eps=eps, + scale_limits=scale_limits, + scalar_scale=scalar_scale, + adam_beta1=adam_beta1, + adam_beta2=adam_beta2, + scale_momentum=scale_momentum, + ) + super().__init__(params, defaults) + + + def __setstate__(self, state): + super(Rubik, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group in self.param_groups: + + for p in group['params']: + state = self.state[p] + grad = p.grad + + try: + cur_step = state["step"] + except KeyError: + state["step"] = 0 + cur_step = 0 + + if p.numel() == 1: + # "scalar_scale" the assumed parameter scale used for + # scalars, in this case it just acts as a multiplier on + # the learning rate. + p += group["scalar_scale"] * adam_step(group, state, grad) + else: + p += scaling_step(group, p.detach(), state, grad) + + state["step"] = cur_step + 1 + + return loss + + + +def _test_rubik(hidden_dim: int): + import timeit + + E = 100 + B = 4 + T = 2 + logging.info("in test_rubik") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + torch.random.manual_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + if True: + Linear = torch.nn.Linear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + lr = 0.018 + optim = Rubik(m.parameters(), lr=lr, beta1=0.998) + + num_epochs = 180 + + total_steps = num_epochs + def lr_lambda(current_step): + # a LR schedule similar to InterpCosineLRScheduler from combined_scheduler.py + progress = min(1, current_step / total_steps) + cos = math.cos(progress * math.pi / 2) + # the relatively small scale on cos means the linear cool-down phase + # is long/slow, as the loss of this easy task is dominated by + # parameter noise.. in practical scenarios we use larger scale on + # the cos term, e.g. as large as 0.66. + return 0.05 * cos + 0.95 * (cos ** 2) + + scheduler = LambdaLR(optim, lr_lambda) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + # if epoch == 100 and test in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 512 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + #scheduler.step_batch() + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + norm3 = '%.2e' % (m[4].weight**2).mean().sqrt().item() + + bias_norm1 = '%.2e' % (m[0].bias**2).mean().sqrt().item() + bias_norm2 = '%.2e' % (m[2].bias**2).mean().sqrt().item() + bias_norm3 = '%.2e' % (m[4].bias**2).mean().sqrt().item() + + lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm2,norm3}, bias_norms={bias_norm1,bias_norm2,bias_norm3}" + ) + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step() # step once per epoch + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Time taken: {stop - start}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + + +def _test_scaled_three_way_product(): + x = torch.randn(16, 32) + _U, _S, V = torch.linalg.svd(x, full_matrices=False) + W = V * torch.randn(1, 1) + # so now all the singular values of x will be identical (but arbitrary) + + X = scaled_three_way_product(W) + #print("X = ", X[0]) + #print("W = ", W[0]) + assert torch.allclose(W, X, atol=1.0e-03) + # but the result won't be identical to the input if the singular values are not all identical. + assert not torch.allclose(x, scaled_three_way_product(x), atol=1.0e-03) + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_scaled_three_way_product() + _test_rubik(hidden_dim) diff --git a/egs/librispeech/ASR/zapformer/speech_recognition.py b/egs/librispeech/ASR/zapformer/speech_recognition.py new file mode 100755 index 0000000000..dd069cf3da --- /dev/null +++ b/egs/librispeech/ASR/zapformer/speech_recognition.py @@ -0,0 +1,229 @@ +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate + +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + + +class K2SpeechRecognitionDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the speech recognition task using k2 library. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + ): + """ + k2 ASR IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_for_asr(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + if self.cut_transforms: + orig_cuts = cuts + + cuts = cuts.repeat(times=2) + + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + cuts = orig_cuts + cuts + num_copies = 3 + else: + num_copies = 1 + + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we successfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "num_copies": num_copies, + "supervisions": default_collate( + [ + { + "text": supervision.text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + has_word_alignments = all( + s.alignment is not None and "word" in s.alignment + for c in cuts + for s in c.supervisions + ) + if has_word_alignments: + # TODO: might need to refactor BatchIO API to move the following conditional logic + # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), + # that returns either num_frames or num_samples depending on the strategy). + words, starts, ends = [], [], [] + frame_shift = cuts[0].frame_shift + sampling_rate = cuts[0].sampling_rate + if frame_shift is None: + try: + frame_shift = self.input_strategy.extractor.frame_shift + except AttributeError: + raise ValueError( + "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " + ) + for c in cuts: + for s in c.supervisions: + words.append([aliword.symbol for aliword in s.alignment["word"]]) + starts.append( + [ + compute_num_frames( + aliword.start, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + ends.append( + [ + compute_num_frames( + aliword.end, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + batch["supervisions"]["word"] = words + batch["supervisions"]["word_start"] = starts + batch["supervisions"]["word_end"] = ends + + return batch + + +def validate_for_asr(cuts: CutSet) -> None: + validate(cuts) + tol = 2e-3 # 1ms + for cut in cuts: + for supervision in cut.supervisions: + assert supervision.start >= -tol, ( + f"Supervisions starting before the cut are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + # + # 'supervision.end' is end of supervision inside the Cut + assert supervision.end <= cut.duration + tol, ( + f"Supervisions ending after the cut " + f"are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) diff --git a/egs/librispeech/ASR/zapformer/streaming_beam_search.py b/egs/librispeech/ASR/zapformer/streaming_beam_search.py new file mode 100644 index 0000000000..d5a475627a --- /dev/null +++ b/egs/librispeech/ASR/zapformer/streaming_beam_search.py @@ -0,0 +1,295 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List + +import k2 +import torch +import torch.nn as nn +from beam_search import Hypothesis, HypothesisList, get_hyps_shape +from decode_stream import DecodeStream + +from icefall.decode import one_best_decoding +from icefall.utils import get_texts + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], + blank_penalty: float = 0.0, +) -> None: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, 1, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], + num_active_paths: int = 4, + blank_penalty: float = 0.0, +) -> None: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The RNN-T model. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + streams: + A list of stream objects. + num_active_paths: + Number of active paths during the beam search. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + batch_size = len(streams) + T = encoder_out.size(1) + + B = [stream.hyps for stream in streams] + + for t in range(T): + current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + # logits is of shape (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + for i in range(batch_size): + streams[i].hyps = B[i] + + +def fast_beam_search_one_best( + model: nn.Module, + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + streams: List[DecodeStream], + beam: float, + max_states: int, + max_contexts: int, + blank_penalty: float = 0.0, +) -> None: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first generated by Fsa-based beam search, then we get the + recognition by applying shortest path on the lattice. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + processed_lens: + A tensor of shape (N,) containing the number of processed frames + in `encoder_out` before padding. + streams: + A list of stream objects. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + """ + assert encoder_out.ndim == 3 + B, T, C = encoder_out.shape + assert B == len(streams) + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(streams[i].rnnt_decoding_stream) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + best_path = one_best_decoding(lattice) + hyp_tokens = get_texts(best_path) + + for i in range(B): + streams[i].hyp = hyp_tokens[i] diff --git a/egs/librispeech/ASR/zapformer/streaming_decode.py b/egs/librispeech/ASR/zapformer/streaming_decode.py new file mode 100755 index 0000000000..b7d2ed7af0 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/streaming_decode.py @@ -0,0 +1,952 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zapformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zapformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import CommonVoice, LibriSpeech, GigaSpeech, AsrDataModule +from decode import cv_post_processing, giga_post_processing +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet, set_caching_enabled +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--label", + type=str, + default="", + help="""Extra label of the decoding run.""", + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zapformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""" + ) + + parser.add_argument( + "--giga", + type=str2bool, + default=False, + help="""If True, decode gigaspeech in addition to librispeech test sets.""", + ) + + parser.add_argument( + "--cv", + type=str2bool, + default=False, + help="""If True, decode commonvoice in addition to librispeech test sets.""", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*9:(i+1)*9] + is (cached_key, cached_value, cached_conv, cached_norm_stats, cached_norm_len, + cached_attn_wm_sum, cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_caches(batch_size, device) + + embed_states = model.encoder_embed.get_init_cache(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zapformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zapformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zapformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*9:(i+1)*9] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len, cached_attn_wm_sum, + cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 9 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 9 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 9 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_value: (left_context_len, batch_size, value_dim) + cached_value = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_conv: (batch_size, channels, left_pad) + cached_conv = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=0 + ) + # cached_norm_stats: (batch_size, ...) + cached_norm_stats = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=0 + ) + # cached_norm_len: (batch_size, ...) + cached_norm_len = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_attn_wm_sum: (1, batch_size, channels) + cached_attn_wm_sum = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=1 + ) + # cached_attn_wm_num_frames: (batch_size,) + cached_attn_wm_num_frames = torch.cat( + [state_list[i][layer_offset + 6] for i in range(batch_size)], dim=0 + ) + # cached_conv_wm_sum: (1, batch_size, channels) + cached_conv_wm_sum = torch.cat( + [state_list[i][layer_offset + 7] for i in range(batch_size)], dim=1 + ) + # cached_conv_wm_num_frames: (batch_size,) + cached_conv_wm_num_frames = torch.cat( + [state_list[i][layer_offset + 8] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_value, + cached_conv, + cached_norm_stats, + cached_norm_len, + cached_attn_wm_sum, + cached_attn_wm_num_frames, + cached_conv_wm_sum, + cached_conv_wm_num_frames, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zapformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zapformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 9 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 9 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 9 + # chunk dim=1 for attention maps + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + cached_value_list = batch_states[layer_offset + 1].chunk(chunks=batch_size, dim=1) + + # chunk dim=0 for conv and norm stats + cached_conv_list = batch_states[layer_offset + 2].chunk(chunks=batch_size, dim=0) + cached_norm_stats_list = batch_states[layer_offset + 3].chunk(chunks=batch_size, dim=0) + cached_norm_len_list = batch_states[layer_offset + 4].chunk(chunks=batch_size, dim=0) + + # chunk dim=1 for attn wm sum + cached_attn_wm_sum_list = batch_states[layer_offset + 5].chunk(chunks=batch_size, dim=1) + # chunk dim=0 for attn wm num frames + cached_attn_wm_num_frames_list = batch_states[layer_offset + 6].chunk(chunks=batch_size, dim=0) + # chunk dim=1 for conv wm sum + cached_conv_wm_sum_list = batch_states[layer_offset + 7].chunk(chunks=batch_size, dim=1) + # chunk dim=0 for conv wm num frames + cached_conv_wm_num_frames_list = batch_states[layer_offset + 8].chunk(chunks=batch_size, dim=0) + + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_value_list[i], + cached_conv_list[i], + cached_norm_stats_list[i], + cached_norm_len_list[i], + cached_attn_wm_sum_list[i], + cached_attn_wm_num_frames_list[i], + cached_conv_wm_sum_list[i], + cached_conv_wm_num_frames_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cache=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + caches=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + tail_length = chunk_size * 2 + 7 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + recogs_filename = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) + store_transcripts(filename=recogs_filename, texts=results) + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + if params.label: + params.suffix += f"-{params.label}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeech(args.manifest_dir) + + test_sets = [] + test_cuts = [] + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() + + test_sets += ["dev-clean", "dev-other", "test-clean", "test-other"] + test_cuts += [dev_clean_cuts, dev_other_cuts, test_clean_cuts, test_other_cuts] + + if args.giga: + gigaspeech = GigaSpeech(args.manifest_dir) + giga_test_cuts = gigaspeech.test_cuts() + giga_dev_cuts = gigaspeech.dev_cuts() + test_sets += ["giga-dev", "giga-test"] + test_cuts += [giga_dev_cuts, giga_test_cuts] + + if args.cv: + commonvoice = CommonVoice(args.manifest_dir) + cv_test_cuts = commonvoice.test_cuts() + cv_dev_cuts = commonvoice.dev_cuts() + test_sets += ["cv-dev", "cv-test"] + test_cuts += [cv_dev_cuts, cv_test_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer/subsampling.py b/egs/librispeech/ASR/zapformer/subsampling.py new file mode 100644 index 0000000000..c4cb90ea61 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/subsampling.py @@ -0,0 +1,402 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import warnings +from typing import Tuple, Optional + +import torch +from zapformer_modules import ( + ScaledLinear, + SwashL, + SwashR, +) +from torch import Tensor, nn + + +class AddNoise(nn.Module): + # assume Conv2d-style input: (N, C, H, W) + def __init__(self, rel_noise_scale: float): + super().__init__() + self.rel_noise_scale = rel_noise_scale + + def forward(self, x: Tensor) -> Tensor: + if not self.training: + return x + eps = 3.0e-08 + noise_scale = ((x ** 2).mean(dim=(1,2,3), keepdim=True) + eps).sqrt() * self.rel_noise_scale + return x + noise_scale * torch.randn_like(x) + + +class ConvNeXt(nn.Module): + """ + Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf + """ + + def __init__( + self, + channels: int, + hidden_ratio: int = 3, + kernel_size: Tuple[int, int] = (7, 7), + causal: bool = False, + ): + super().__init__() + assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1 + self.causal = causal + hidden_channels = channels * hidden_ratio + + if not causal: + padding = (kernel_size[0] // 2, kernel_size[1] // 2) + self.left_pad = 0 + else: + padding = (0, kernel_size[1] // 2) + self.left_pad = kernel_size[0] - 1 + + self.depthwise_conv = nn.Conv2d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=padding, + ) + + self.pointwise_conv1 = nn.Conv2d( + in_channels=channels, out_channels=hidden_channels, kernel_size=1, + ) + + self.activation = SwashL() + + self.pointwise_conv2 = nn.Conv2d( + in_channels=hidden_channels, + out_channels=channels, + kernel_size=1, + ) + + def forward( + self, x: Tensor, + ) -> Tensor: + """ + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + + The returned value has the same shape as x. + """ + bypass = x + + if self.causal: + x = nn.functional.pad(x, (0, 0, self.left_pad, 0)) + x = self.depthwise_conv(x) + assert x.shape == bypass.shape, (x.shape, bypass.shape) + + x = self.pointwise_conv1(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + x = bypass + x + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + cache: (batch_size, num_channels, left_pad, num_freqs) + + Returns: + - The returned value has the same shape as x. + - Updated cache. + """ + bypass = x + + # Pad left side with cache, and update cache + assert cache.size(2) == self.left_pad + x = torch.cat([cache, x], dim=2) + cache = x[:, :, -self.left_pad :, :] + + x = self.depthwise_conv(x) + assert x.shape == bypass.shape, (x.shape, bypass.shape) + + x = self.pointwise_conv1(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + x = bypass + x + + return x, cache + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 16, + layer2_channels: int = 64, + layer3_channels: int = 128, + causal: bool = False, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-3)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + bottleneck: + bottleneck dimension for 1d squeeze-excite + """ + assert in_channels >= 7 + self.in_channels = in_channels + super().__init__() + # The AddNoise module is there to prevent the gradients + # w.r.t. the weight or bias of the first Conv2d module in self.conv from + # getting too large. The justification in my mind for why this might work + # is that the first Conv2d module increases the dimension of the input quite a bit, + # so its output lives in a linear subspace; and there may in principle be quite large gradients + # in directions not in this subspace, without affecting the model quality. + # so by adding a little noise we force the model to "ignore" directions not in this subspace, + # as much as possible, which will tend to avoid very large gradients. The reason the + # large gradients are a problem is because of float16 training with GradScaler, the infinities will + # be detected and will make it scale the grads by a smaller amount.. + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + AddNoise(rel_noise_scale=5.0e-03), # this AddNoise + SwashR(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + SwashR(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + SwashR(), + ) + + # just one convnext layer + self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7), causal=causal) + + # (in_channels-3)//4 + self.out_width = (((in_channels - 1) // 2) - 1) // 2 + self.layer3_channels = layer3_channels + + # scale it up a bit, else the output is quite small. + self.out = ScaledLinear(self.out_width * layer3_channels, out_channels) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor, aux_loss_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + + Returns: + - a tensor of shape (N, (T-7)//2, odim) + - output lengths, of shape (batch_size,) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + x = self.convnext(x) + + # Now x is of shape (N, odim, (T-7)//2, (idim-3)//4) + b, c, t, f = x.size() + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, (T-7)//2, out_width * layer3_channels)) + + x = self.out(x) + # Now x is of shape (N, (T-7)//2, odim) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + x_lens = (x_lens - 7) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = (x_lens - 7) // 2 + + key_padding_mask = torch.arange(0, x.shape[1], device=x.device) >= x_lens.unsqueeze(-1) + + assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) + + return 0.15 * x, x_lens + + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + cache: Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + cache: + The cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs) + + Returns: + - a tensor of shape (N, (T-7)//2, odim) + - output lengths, of shape (batch_size,) + - updated cache + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + + # T' = (T-7)//2 + x = self.conv(x) + + x, cache = self.convnext.streaming_forward(x, cache=cache) + + # Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, T', out_width * layer3_channels)) + + x = self.out(x) + # Now x is of shape (N, T', odim) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + x_lens = (x_lens - 7) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = (x_lens - 7) // 2 + + assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max()) + + return 0.15 * x, x_lens, cache + + @torch.jit.export + def get_init_cache( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> Tensor: + """Get initial states for Conv2dSubsampling module. + It is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + """ + left_pad = self.convnext.left_pad + freq = self.out_width + channels = self.layer3_channels + cache = torch.zeros(batch_size, channels, left_pad, freq, device=device) + + return cache + + +def _test_conv2d_subsampling_streaming(): + logging.info("Testing Conv2dSubsampling streaming equivalence...") + + batch_size = 2 + idim = 80 + odim = 256 + + model = Conv2dSubsampling( + in_channels=idim, + out_channels=odim, + causal=True + ) + + model.eval() + + out_chunk_size = 32 + in_chunk_size = out_chunk_size * 2 + 7 + in_shift = out_chunk_size * 2 + + num_chunks = 10 + + seq_len = num_chunks * in_shift + 7 + + x_full = torch.randn(batch_size, seq_len, idim) + x_lens_full = torch.full((batch_size,), seq_len, dtype=torch.int64) + + with torch.no_grad(): + out_full, out_lens_full = model(x_full, x_lens_full) + + cache = model.get_init_cache(batch_size=batch_size) + + out_chunks = [] + out_offset = 0 + + for i in range(num_chunks): + start = i * in_shift + end = start + in_chunk_size + x_chunk = x_full[:, start:end, :] + x_lens_chunk = torch.full((batch_size,), in_chunk_size, dtype=torch.int64) + + out_chunk, out_lens_chunk, cache = model.streaming_forward( + x_chunk, x_lens_chunk, cache + ) + out_chunks.append(out_chunk) + + out_chunk_len = out_chunk.shape[1] + expected_out = out_full[:, out_offset : out_offset + out_chunk_len, :] + + diff_chunk = torch.max(torch.abs(expected_out - out_chunk)) + logging.info(f"Chunk {i+1} | Input: {x_chunk.shape} -> Output: {out_chunk.shape} | Max diff: {diff_chunk}") + + assert torch.allclose(expected_out, out_chunk, atol=1e-4), f"Chunk {i+1} mismatch! max diff: {diff_chunk}" + out_offset += out_chunk_len + + out_stream_cat = torch.cat(out_chunks, dim=1) + diff_total = torch.max(torch.abs(out_full - out_stream_cat)) + logging.info(f"Total Max Diff between full forward and streaming: {diff_total}") + assert torch.allclose(out_full, out_stream_cat, atol=1e-4), "Total outputs do not match!" + + logging.info("Passed") + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_conv2d_subsampling_streaming() diff --git a/egs/librispeech/ASR/zapformer/train.py b/egs/librispeech/ASR/zapformer/train.py new file mode 100755 index 0000000000..642285ddf6 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/train.py @@ -0,0 +1,1796 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zapformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zapformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zapformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zapformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default) + - ctc loss + - attention decoder loss +""" + + +import argparse +import copy +import logging +import warnings +import math +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AsrDataModule, CommonVoice, LibriSpeech, GigaSpeech +from attention_decoder import AttentionDecoderModel +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +# the try-pass blocks around imports are to reduce the chance of failures when running multiple code +# versions in parallel; later, these can be removed. +try: + from batched_rubik import BatchedRubik + # could also have done: + # from rubik import Rubik +except: + pass + + +from variable_combined_scheduler import VariableCombinedLRScheduler +try: + from variable_combined_scheduler import InterpCosineLRScheduler + LRSchedulerType = VariableCombinedLRScheduler +except: + pass + +SchedulerType = "VariableCombinedLRScheduler" + +from torch.optim.lr_scheduler import LambdaLR +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zapformer import Zapformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +import torch.distributed as dist +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from alternating_spec_augment import AlternatingSpecAugment + +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) +try: + from icefall.utils import dist_barrier +except: + pass + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def lookup(params: AttributeDict, name: str): + """ + Interprets numerical arguments in `params` by taking into account base-dim; + also parses comma-separated lists of integers, turning them into tuples. + If a particular attribute ending in "dim" is not present we look up + the same name but ending in "factor", and multiply the elements by base_dim. + """ + try: + attr = getattr(params, name) + try: + attr = tuple(map(int, attr.split(","))) # tuple of comma-separated ints + if len(attr) == 1: + attr = attr[0] + except: + pass # leave attr as it is, e.g. a string. + return attr + except AttributeError as e: + if name[-3:] != "dim": + raise e + try: + attr = getattr(params, name[:-3] + "multiple") + if isinstance(attr, str): + attr = tuple(map(int, attr.split(","))) # tuple of ints + base_dim = params.base_dim + attr = tuple([i * base_dim for i in attr]) + if len(attr) == 1: + attr = attr[0] + else: # assume int. + assert isinstance(attr, (int, float)), (name, attr) + attr = attr * params.base_dim + return attr + except AttributeError as e: + raise RuntimeError(f"cannot find or infer attribute {name} in params: {e}") + + + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="6,8,14,8", + help="Number of zapformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--base-dim", + type=int, + default=64, + help="Dimension that, via multiples, defines the dimensions of the model." + ) + + parser.add_argument( + "--embed-multiple", + type=int, + default=6, + help="Output dimension of frontend, as multiple of base-dim; determines bypass dimensions in zapformer stacks and zapformer output dim.", + ) + + parser.add_argument( + "--feedforward-multiple", + type=str, + default="4,4,3,4", + help="Factor by which the feedforward hidden dim is greater than the encoder-dim, per stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4", + help="Number of attention heads in the zapformer encoder layers, per stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-multiple", + type=str, + default="5,8,12,8", + help="Factor by which encoder-dim is larger then base-dim, per encoder stack.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="64", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="96", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Position encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--conv-params", + type=str, + default="31,31,15,31", + help="Parameters per channel of convolution kernels", + ) + + parser.add_argument( + "--decoder-multiple", + type=int, + default=8, + help="Factor by which embedding dimension in the decoder model is larger than base-dim.", + ) + + parser.add_argument( + "--joiner-multiple", + type=int, + default=12, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--attention-decoder-multiple", + type=int, + default=8, + help="""Factor by which attention decoder dim is larger than base-dim""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-multiple", + type=int, + default=8, + help="""Determines attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-multiple", + type=int, + default=4, + help="""Factor by which feedforward hidden dim in attention decoder is larger than attention-decoder-dim""" + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=True, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-real-epochs", + type=int, + default=30, + help="Number of epochs to train, including number of copies; num-epochs will be <= this.", + ) + + parser.add_argument( + "--max-copies", + type=int, + default=1, + help="The num_copies to use in the dataloader on the last epoch (it rises linearly with step count from --min-copies)" + ) + + parser.add_argument( + "--min-copies", + type=int, + default=1, + help="The num_copies to use in the dataloader on the first epoch (it rises linearly with step count to --max-copies)" + ) + + parser.add_argument( + "--batches-per-epoch", + type=int, + default=2200, + help="Assumed number of batches per epoch for purposes of setting learning rate; only " + "makes a difference during the first batch, after which an observed value is used. This " + "is the num batches where num_copies==1, i.e. on the first epoch" + ) + + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zapformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.02, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=17500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--aux-loss-scale", + type=float, + default=0.05, + help="Scale on auxiliary losses that are defined in the model, such " + "as cosine loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - warm_step: The warmup period that dictates the decay of the + scale on pruned loss (for transducer). + Expressed in terms of the "adjusted batch count", i.e. the + normalized batch count after adjusting for changes in batch size. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 10000, + # parameters for zapformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +class ParamPlotter: + def __init__(self, + model: Union[nn.Module, DDP], + tb_writer: Optional[SummaryWriter], + period: int = 50): + if isinstance(model, DDP): + model = model.module + self.model = model + self.tb_writer = tb_writer + device = next(model.parameters()).device + self.device = device + self.period = period + self.grad_proj = torch.tensor(0.0, device=device) + + def step(self, batch_idx_train: int): + # in addition to plotting param_proj and grad_proj and grad_proj_sign every "period" steps, + # plot grad_proj for the first 50 out of every 1000 steps; this will give us a sense of how + # stable the oscillations are. + dense_period = 1000 + dense_length = 50 + if batch_idx_train % self.period > 1 and batch_idx_train % dense_period > dense_length: + return + + generator = torch.Generator(device=self.device) + generator.manual_seed(1) + + + with torch.no_grad(): + param_proj = torch.tensor(0.0, device=self.device) + grad_proj = torch.tensor(0.0, device=self.device) + for p in self.model.parameters(): + proj = torch.randn(p.shape, generator=generator, device=self.device) + param_proj = param_proj + (p * proj).sum() + try: + grad_proj = grad_proj + (p.grad * proj).sum() + except AttributeError: + pass + + tb_writer = self.tb_writer + def dump(proj: Tensor, name: str): + proj_min = proj.clone() + proj_max = proj.clone() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(proj_min, op=dist.ReduceOp.MIN) + dist.all_reduce(proj_max, op=dist.ReduceOp.MAX) + dist.all_reduce(proj, op=dist.ReduceOp.SUM) + proj = proj / dist.get_world_size() + proj_diff = proj_max - proj_min + if tb_writer is not None: + tb_writer.add_scalar(name + '_diff', proj_diff.item(), batch_idx_train) + if tb_writer is not None: + tb_writer.add_scalar(name, proj.item(), batch_idx_train) + if batch_idx_train % self.period == 0: + dump(param_proj, f'train/param_proj') + self.grad_proj = grad_proj + if batch_idx_train % self.period == 1 and tb_writer is not None: + tb_writer.add_scalar('train/grad_same_sign', (grad_proj * self.grad_proj).sign(), batch_idx_train) + if (batch_idx_train % dense_period < dense_length or batch_idx_train % self.period == 0) and tb_writer is not None: + tb_writer.add_scalar('train/grad_proj', grad_proj, batch_idx_train) + + + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=lookup(params, "embed_dim"), + causal=params.causal, + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zapformer( + input_dim=lookup(params, "embed_dim"), + output_downsampling_factor=2, + downsampling_factor=lookup(params, "downsampling_factor"), + num_encoder_layers=lookup(params, "num_encoder_layers"), + encoder_dim=lookup(params, "encoder_dim"), + query_head_dim=lookup(params, "query_head_dim"), + value_head_dim=lookup(params, "value_head_dim"), + pos_head_dim=lookup(params, "pos_head_dim"), + num_heads=lookup(params, "num_heads"), + feedforward_multiple=lookup(params, "feedforward_multiple"), + conv_params=lookup(params, "conv_params"), + causal=params.causal, + chunk_size=lookup(params, "chunk_size"), + left_context_frames=lookup(params, "left_context_frames"), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=lookup(params, "decoder_dim"), + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + output_downsampling_factor = 2 + joiner = Joiner( + encoder_dim=lookup(params, "embed_dim") * output_downsampling_factor, + decoder_dim=lookup(params, "decoder_dim"), + joiner_dim=lookup(params, "joiner_dim"), + vocab_size=params.vocab_size, + ) + return joiner + + +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + output_downsampling_factor = 2 + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=lookup(params, "attention_decoder_dim"), + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=lookup(params, "attention_decoder_attention_dim"), + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_multiple * lookup(params, "attention_decoder_attention_dim"), + memory_dim=lookup(params, "embed_dim") * output_downsampling_factor, + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + + output_downsampling_factor = 2 + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + attention_decoder=attention_decoder, + encoder_dim=output_downsampling_factor * lookup(params, "embed_dim"), + decoder_dim=lookup(params, "decoder_dim"), + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[SchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[SchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + aux_loss_scale: float = 0.0, + specaug: Optional[nn.Module] = None, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zapformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + features = batch["inputs"] + # at entry, features is (N, T, C) + assert features.ndim == 3 + features = features.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + if specaug is not None: + with torch.amp.autocast('cuda', enabled=False): + features = specaug(features.to(torch.float), feature_lens) + + + if batch_idx_train % 50 == 0: + logging.info( + f"rng_state={torch.cuda.get_rng_state()}, features-sum={features.sum()}" + ) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( + x=features, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + aux_loss_scale=aux_loss_scale, + ) + + loss = 0.0 + + adjusted_batch_count = params.batch_idx_train + warm_step = params.warm_step + def warmup_schedule(scale, initial_factor): + # geometric warmup schedules. + warmup_factor = (1. if adjusted_batch_count >= warm_step else + initial_factor + (adjusted_batch_count / warm_step) * (1 - initial_factor)) + return scale * warmup_factor + + if params.use_transducer: + simple_loss_scale = params.simple_loss_scale + pruned_loss_scale = warmup_schedule(1.0, 0.05) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + nframes = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = nframes + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: SchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + specaug: Optional[nn.Module] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + param_plotter = ParamPlotter(model, tb_writer, period=50) + + def get_scaler_scale(): + if params.use_autocast and scaler._scale is not None: + return scaler._scale.item() + else: + return 1.0 + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.amp.autocast('cuda', + enabled=params.use_autocast, dtype=params.dtype + ): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + specaug=specaug, + aux_loss_scale=get_scaler_scale() * params.aux_loss_scale * (0.25 if params.batch_idx_train > 2000 else 1.0), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.set_batch(batch_idx) # sets batch-count within the epoch, and sets the LRs. + scaler.step(optimizer) + scaler.update() + param_plotter.step(params.batch_idx_train) + optimizer.zero_grad() + except Exception as e: + logging.info(f"Caught exception: {e}.") + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if params.use_autocast: + cur_grad_scale = get_scaler_scale() + + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) + logging.warning(f"Grad scale is small: {cur_grad_scale}") + + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if (batch_idx % 25 == 0 and cur_grad_scale < 2.0 or + batch_idx % 100 == 0 and cur_grad_scale < 8.0 or + batch_idx % 400 == 0 and cur_grad_scale < 32.0): + scaler.update(cur_grad_scale * 2.0) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = get_scaler_scale() + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + # synchronize seeds. important for parameter initialization to be consistent. + fix_random_seed(params.seed) + + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + # need torch.distributed.barrier() after fix_random_seed() as it fixes + # random seeds of all GPUs, not just the GPU of this process. + dist_barrier() + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, + params.attention_decoder_loss_scale, + ) + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = BatchedRubik( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=False), + lr=params.base_lr, + beta1=0.99, + ) + + if True: + # Work out copies_per_epoch + copies_per_epoch = [ ] + cur_real_epochs = 0 + progress_increment = 1.0 / (params.max_copies + 1 - params.min_copies) + cur_progress = 0.0 + # go in backwards order to minimize rounding errors. + for n in reversed(range(params.min_copies, params.max_copies + 1)): + cur_progress += progress_increment + target_real_epochs = int(0.5 + cur_progress * params.num_real_epochs) # + 0.5 to round up. + while cur_real_epochs < target_real_epochs: + copies_per_epoch.append(n) + cur_real_epochs += n + copies_per_epoch = list(reversed(copies_per_epoch)) + + num_epochs = len(copies_per_epoch) + logging.info(f"Num epochs = {len(copies_per_epoch)}, num-real-epochs={sum(copies_per_epoch)} vs target {params.num_real_epochs}") + logging.info(f"Copies per epoch: {copies_per_epoch}") + + # this InterpCosineLRScheduler inherits from VariableCombinedLRScheduler. + # this configuration is halfway between a linear function (1 to 0) and the conventional + # cosine LR scheduler. It decays to a minimum of 0.025. + scheduler = InterpCosineLRScheduler(optimizer, + min_factor=0.025, + linear_scale=0.5, + batches_per_epoch=[params.batches_per_epoch * n for n in copies_per_epoch]) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + asr_datamodule = AsrDataModule(args) + + librispeech = LibriSpeech(args.manifest_dir) + gigaspeech = GigaSpeech(args.manifest_dir) # gigaspeech will only be used if the --use-giga=True option is set + commonvoice = CommonVoice(args.manifest_dir) # commonvoice will only be used if the --use-cv=True option is set + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + # train_cuts_len = 960.0 * 3 # 960 hours times 3 for augmentation + train_cuts_len = 843723 # includes 3x speed perturbation + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + train_cuts_len = 85617 # includes 3x speed perturbation + + if params.use_giga or params.use_cv: + if params.libri_copies > 1: + train_cuts = train_cuts.repeat(params.libri_copies) + train_cuts_len = train_cuts_len * params.libri_copies + datasets_and_weights = [(train_cuts, train_cuts_len)] + + if params.use_giga: + if params.full_libri: + gigaspeech_cuts = gigaspeech.train_XL_cuts() + gigaspeech_cuts_len = 8277188 # 10000.0 + else: + gigaspeech_cuts = gigaspeech.train_S_cuts() # e.g. for debugging + gigaspeech_cuts_len = 229394 # 250.0 + datasets_and_weights.append((gigaspeech_cuts, gigaspeech_cuts_len)) + + if params.use_cv: + import re + def normalize_text(c): + c.supervisions[0].text = re.sub(r'[^\w\s]', '', c.supervisions[0].text).upper() + return c + commonvoice_cuts = commonvoice.train_cuts().map(normalize_text) + commonvoice_cuts_len = 1822817 # 2600.0 + datasets_and_weights.append((commonvoice_cuts, commonvoice_cuts_len)) + + cuts, weights = zip(*datasets_and_weights) + train_cuts = CutSet.mux(*cuts, weights=weights) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zapformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics and False: + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + num_copies=1, + seed=params.seed, + rank=rank, + ) + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + + for epoch in range(params.start_epoch, num_epochs + 1): + # fix all random seeds before starting the dataloaders, as they require + # all seeds to be synchronized, in particular for the sampler, which + # uns in the main process and relies on the currently-set random seed + # (in practice it's just the random module's + # seed and possibly the numpy seed that really matter here. + dist_barrier() + fix_random_seed(params.seed + epoch - 1) + dist_barrier() + + num_copies = copies_per_epoch[epoch - 1] + logging.info(f"On epoch {epoch}, for dataloader: num_copies={num_copies}, this will affect num batches.") + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + num_copies=num_copies, + seed=params.seed + 500 * epoch, + rank=rank, + ) + + sampler_state_dict=None + # we don't do : + # train_dl.sampler.set_epoch(epoch) + # because we just created the sampler and its seed already depends on the epoch. + + (model.module if isinstance(model, DDP) else model).encoder.compute_projection_overlap(verbose=True) # for diagnostics + + seed = params.seed + 50 * epoch + 512 * rank + + specaug = AlternatingSpecAugment( + seed=seed, + ) # otherwise use all default settings. + + if torch.cuda.is_available(): + with torch.cuda.device(rank): + # set CUDA seed for "my GPU" in a rank-and-epoch-dependent way. + # This is not not very important, it should just affect the + # AddNoise() module in subsampling.py + torch.cuda.manual_seed(seed) + else: + torch.manual_seed(seed) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + tb_writer.add_scalar("train/num_copies", num_copies, params.batch_idx_train) + + params.cur_epoch = epoch + scheduler.set_epoch(epoch) + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + specaug=specaug, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + d = diagnostic.print_diagnostics() + filename = params.exp_dir / f"diagnostics-epoch-{params.cur_epoch}.pt" + torch.save(d, filename) + logging.info(f"Saved detailed diagnostics to {filename}") + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.amp.autocast('cuda', + enabled=params.use_autocast, dtype=params.dtype + ): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py new file mode 100644 index 0000000000..b3a9fb262f --- /dev/null +++ b/egs/librispeech/ASR/zapformer/variable_combined_scheduler.py @@ -0,0 +1,193 @@ +import torch +from torch import Tensor +from torch.optim import Optimizer +from typing import List +import math +import logging + +class VariableCombinedLRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch; in this version the expected number of batches can be different + for different epochs. + + + base_batches = 3100 + multiples = [ 1, 1, 1, 2, 2, 2, 3, 3, 3 ] + batches_per_epoch = [ m * base_batches for m in multiples ] + + scheduler = InterpCosineLRScheduler(optimizer, batches_per_epoch=batches_per_epoch) + for epoch in range(len(multiples)): + scheduler.set_epoch(epoch+1) # caution: one-based epoch count + train_dl = f(multiples[epoch]) # num batches propto multiples. + for batch_idx, batch in enumerate(train_dl): # train_dl expected + scheduler.set_batch_idx(batch_idx) + + Args: + optimizer: optimizer that we will set the learning rates in; the initial learning rate(s) in + the optimizer is/are the base LRs and we set the LR as a fraction of those. + batches_per_epoch: the estimated number of batches per epoch; use your best guess. + num_epochs: the total number of epochs you will train for + """ + def __init__(self, + optimizer: Optimizer, + batches_per_epoch: List[int], + verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.batches_per_epoch = list(batches_per_epoch) # copy the list in case it's modified + self.tot_batches = sum(self.batches_per_epoch) + self.adjust_factor = 1.0 + + self.epoch = -1 + self.batch = -1 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + # the user might try to override the base_lr, so don't include this in the state. + # previously they were included. + # "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + "adjust_factor": self.adjust_factor, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def set_batch(self, batch: int): + """ Sets the batch index within the epoch, with zero-based counting (not that this matters much).""" + # set the within-epoch batch index. + self.batch = batch + self._set_lrs() + + def set_epoch(self, epoch: int): + """ Sets the epoch with one-based counting, so the first epoch is 1; the epoch should not exceed the num_epochs used + in the constructor. """ + assert epoch > 0 and epoch <= len(self.batches_per_epoch) # Epoch numbers are assumed to be be 1-based indexes. + if epoch == self.epoch + 1 and self.batch > 0 and self.epoch > 0: + self.adjust_factor = self.batches_per_epoch[self.epoch-1] / self.batch + logging.info(f"Setting self.adjust_factor = {self.adjust_factor} = expected/observed batches {self.batches_per_epoch[self.epoch-1]}/{self.batch} on epoch {self.epoch}") + + self.epoch = epoch + self.past_batches = sum(self.batches_per_epoch[:epoch-1], start=0) + self._set_lrs() + + def get_progress(self): + if self.epoch <= 0: + return 0.0 + else: + # epoch indexes start from 1 so we have to subtract 1 before indexing self.batches_per_epoch + past_batches = self.past_batches # sum of batches on previous eopchs + tot_batches = self.tot_batches # anticipated total batches + cur_max_batches = self.batches_per_epoch[self.epoch - 1] + cur_batches = min(cur_max_batches, self.adjust_factor * self.batch) + + progress = (past_batches + cur_batches) / tot_batches + assert progress <= 1.0 + return progress + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.warning( + f"Epoch={self.epoch}, batch={self.batch}, num_epochs={self.num_epochs}, batches_per_epoch={self.batches_per_epoch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + + +class InterpCosineLRScheduler(VariableCombinedLRScheduler): + def __init__(self, + *args, + min_factor: float = 0.0, + half_cosine_scale: float = 0.0, + linear_scale: float = 0.0, + **kwargs): + """ + This cosine LR scheduler encompasses the conventional cosine LR scheduler + that takes the cosine from 0 to pi (shifted to 0..1), the half-cosine LR + scheduler that takes the cosine from 0 to pi, and the linear LR scheduler + that takes the linear function from 1 to 0. + """ + self.min_factor = min_factor + self.half_cosine_scale = half_cosine_scale + self.linear_scale = linear_scale + super().__init__(*args, **kwargs) + + def get_lr(self): + progress = self.get_progress() + half_cos = math.cos((math.pi / 2) * progress) + cos = half_cos ** 2 + linear = 1. - progress + + linear_scale = self.linear_scale + half_cosine_scale = self.half_cosine_scale + cosine_scale = 1. - self.half_cosine_scale - linear_scale + assert cosine_scale >= 0.0 + + factor = linear_scale * linear + half_cosine_scale * half_cos + cosine_scale * cos + # apply min_factor via interpolation + factor = self.min_factor + factor * (1. - self.min_factor) + return [x * factor for x in self.base_lrs] + + +class LinearLRScheduler(VariableCombinedLRScheduler): + def __init__(self, + *args, + min_factor: float = 0.05, + **kwargs): + """ + This LR scheduler decreases linearly from 1 to min_factor. + It inherits from VariableCombinedLRScheduler (see its documentation + to understand general aspects of usage). + """ + self.min_factor = min_factor + super().__init__(*args, **kwargs) + + def get_lr(self): + progress = self.get_progress() + factor = 1.0 - progress # linearly decreasing + factor = self.min_factor + factor * (1. - self.min_factor) # apply min_factor via interpolation + return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/zapformer/zapformer.py b/egs/librispeech/ASR/zapformer/zapformer.py new file mode 100644 index 0000000000..f96dbcf2c2 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/zapformer.py @@ -0,0 +1,2338 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from zapformer_modules import ( + ActivationAndLinear, + CausalSequenceNorm, + CorrelationLimiter, + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + OrthogonalLinear, + RmsNorm, + SequenceNorm, + ScaledLinear, # just an initializer for Linear + SwashR, + ScaleLimiter, +) +from zapformer_utils import ( + limit_param_value, + penalize_abs_values_gt, + softmax, + with_loss, +) + + +from torch import Tensor, nn + +from icefall.utils import make_pad_mask + + +class Zapformer(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + value_head_dim (int or Tuple[int]): dimension of value in each attention head + pos_head_dim (int or Tuple[int]): dimension of position-embedding in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_multiple (int or Tuple[int]): determines hidden dimension in feedforward modules + conv_params (int or Tuple[int])): Kernel size of convolution module + + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. + """ + def __init__( + self, + input_dim: int, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + query_head_dim: Union[int, Tuple[int]] = 64, + value_head_dim: Union[int, Tuple[int]] = 12, + pos_head_dim: Union[int, Tuple[int]] = 4, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_multiple: Union[int, Tuple[int]] = 4, + conv_params: Union[int, Tuple[int]] = 31, + num_freqs: int = 64, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ) -> None: + super(Zapformer, self).__init__() + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + self.pos_head_dim = pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_multiple = _to_tuple(feedforward_multiple) + self.conv_params = conv_params = _to_tuple(conv_params) + self.num_freqs = num_freqs + + self.causal = causal + self.chunk_size = (chunk_size,) if isinstance(chunk_size, int) else chunk_size + self.left_context_frames = (left_context_frames,) if isinstance(left_context_frames, int) else left_context_frames + + # each one will be ZapformerEncoder or OrthogonalDownsample or OrthogonalUpsample + encoders = [] + + num_encoders = len(downsampling_factor) + + # caution: some changes we made for this break the streaming, later we'll try to fix this. + encoders_downsampling_factors = [ ] + + # make it so large the limit is never reached. + max_proj_dim = max(downsampling_factor) * max(encoder_dim) + + + for i in range(num_encoders): + encoder_layer = ZapformerEncoderLayer( + embed_dim=encoder_dim[i], + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + value_head_dim=value_head_dim[i], + pos_head_dim=pos_head_dim[i], + feedforward_multiple=feedforward_multiple[i], + conv_params=conv_params[i], + num_freqs=num_freqs, + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = ZapformerEncoder( + encoder_layer, + num_encoder_layers[i], + dim=downsampling_factor[i]*input_dim, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + # Share a single AngularFreqBasis instance across all layers within each encoder stack + for encoder in self.encoders: + shared_basis = encoder.layers[0].self_attn.rel_pos.angular_freq_basis + for layer in encoder.layers[1:]: + layer.self_attn.rel_pos.angular_freq_basis = shared_basis + + self.out_norm = RmsNorm() + + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) + Returns: + Return (embeddings_lengths), where: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + chunk_size, left_context_chunks = self.get_chunk_info() + orig_seq_len = x.shape[0] + + pad = (-orig_seq_len) % max(self.downsampling_factor) + # pad sequence length to be multiple of max(self.downsampling_factor) + x = torch.cat((x, x[-1:].repeat(pad, 1, 1)), + dim=0) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + src_key_padding_mask = pad_mask(src_key_padding_mask, x.shape[0]) + + num_stacks = len(self.downsampling_factor) + + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = downsample_by(x, ds) + T = x.shape[0] + x = module( + x, + chunk_size=chunk_size // ds if chunk_size > 0 else -1, + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=(None + if attn_mask is None + else attn_mask[::ds, ::ds] + ), + aux_loss_scale=aux_loss_scale * ds / (self.output_downsampling_factor * num_stacks) + ) + x = upsample_by(x, ds) + + od = self.output_downsampling_factor + x = downsample_by(x, od) + x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + + if od > 1: + x_lens = (x_lens + od - 1) // od + + x = self.out_norm(x) + + # disable the projection-overlap loss + #if self.training: + # # all of our losses and aux losses are proportional to the number of frames of data, so + # # we multiply by that factor. + # x = with_loss(x, aux_loss_scale * x.shape[0] * x.shape[1] * self.compute_projection_overlap()) + + return x, x_lens + + + def compute_projection_overlap(self, verbose: bool = False): + # This is currently just used for some diagnostics. + + # It also computes an auxiliary loss (currently unused) that + # ensures that the projections from more-subsampled sequences "contain" enough of the + # projections from the less-subsampled sequences-- specifically the direction where + # all the less-subsampled projections co-vary in the same way, e.g. if there are + # two frames, that the two frames are identical. + + min_overlap = 0.6 # we can tune this. CAUTION: I turned off this aux loss by commenting + # it out in forward(), + + tot_loss = 0.0 + # between pairs of encoders + N = len(self.encoders) + + covs = [] + ranks = [] + for i in range(N): + proj_i = self.encoders[i].proj.get_weight() + cov_i = torch.matmul(proj_i.t(), proj_i) + covs.append(cov_i) + ranks.append(proj_i.shape[0]) + + for i in range(N): + for j in range(i): + cov_i, cov_j = covs[i], covs[j] + rank_i, rank_j = ranks[i], ranks[j] + if cov_i.shape[0] > cov_j.shape[0]: + cov_i, cov_j = cov_j, cov_i + rank_i, rank_j = rank_j, rank_i + dim_i = cov_i.shape[0] # now this is <= proj_j.shape[0] + dim_j = cov_j.shape[0] + assert dim_i <= dim_j + assert dim_j % dim_i == 0 # dims must be multiples of each other (these are the + # feature dimension prior to project, i.e. the larger dimensions.) + R = dim_j // dim_i # e.g. 1, 2, 4 + assert R in [1, 2, 4, 8, 16] + cov_i = cov_i.repeat(R, R) * (1. / R) + # denominator is the minimum of the two ranks, + # because due to the orthogonal constraint, the maximum possible value of (cov_i * cov_j).sum() would be the + # smaller of the two ranks. + cosine = (cov_i * cov_j).sum() / min(rank_i, rank_j) + + loss = (min_overlap - cosine).relu() + tot_loss = tot_loss + loss + if verbose: + logging.info(f"overlap[{i}, {j}] = {cosine}") + return tot_loss + + def warmup_angular_freq_bases(self, seq_len: int, left_context_len: int, device: torch.device): + """Pre-compute angular frequency bases for all encoder layers. + Call this before torch.jit.trace to avoid tracer issues.""" + for module in self.encoders: + for layer in module.layers: + layer.self_attn.rel_pos.angular_freq_basis(seq_len, left_context_len, device) + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.conv_params[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) # TODO: could test remove this + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + caches: List[Tensor], + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + caches: list of cached tensors of all encoder layers. For layer-i, + caches[i*9:(i+1)*9] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len, cached_attn_wm_sum, + cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated caches: an updated list of cache tensors. + """ + orig_seq_len = x.shape[0] + pad = (-orig_seq_len) % max(self.downsampling_factor) + # pad sequence length to be multiple of max(self.downsampling_factor) + x = torch.cat((x, x[-1:].repeat(pad, 1, 1)), dim=0) + + if src_key_padding_mask is not None: + left_context_frames = src_key_padding_mask.shape[1] - orig_seq_len + assert left_context_frames == self.left_context_frames[0] + if pad > 0: + padded_mask = pad_mask(src_key_padding_mask[:, left_context_frames:], x.shape[0]) + assert padded_mask is not None + src_key_padding_mask = torch.cat( + [src_key_padding_mask[:, :left_context_frames], padded_mask], + dim=1, + ) + + new_caches = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + + x = downsample_by(x, ds) + + # Slice out the specific caches for the current module + module_caches = caches[layer_offset * 9 : (layer_offset + num_layers) * 9] + + x, new_module_caches = module.streaming_forward( + src=x, + caches=module_caches, + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + ) + + layer_offset += num_layers + new_caches.extend(new_module_caches) + + x = upsample_by(x, ds) + + # Output downsampling and normalization + od = self.output_downsampling_factor + x = downsample_by(x, od) + + x = x[:(orig_seq_len + od - 1) // od] # truncate so seq len not affected by padding + + if od > 1: + x_lens = (x_lens + od - 1) // od + + x = self.out_norm(x) + + return x, x_lens, new_caches + + @torch.jit.export + def get_init_caches( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial caches. + + A list of cached tensors of all encoder layers. For layer-i, caches[i*9:(i+1)*9] + is (cached_key, cached_value, cached_conv, cached_norm_stats, cached_norm_len, + cached_attn_wm_sum, cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). + """ + caches = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + conv_left_pad = self.conv_params[i] - 1 + + for layer_idx, enc_layer in enumerate(module.layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim, device=device) + cached_value = torch.zeros(downsample_left, batch_size, value_dim, device=device) + cached_conv = torch.zeros(batch_size, embed_dim, conv_left_pad, device=device) + cached_norm_stats, cached_norm_len = enc_layer.norm.get_init_cache(batch_size) + cached_norm_stats = cached_norm_stats.to(device) + cached_norm_len = cached_norm_len.to(device) + + attn_value_dim = self.value_head_dim[i] * num_heads + cached_attn_wm_sum = torch.zeros(1, batch_size, attn_value_dim, device=device) + cached_attn_wm_num_frames = torch.zeros(batch_size, dtype=torch.int64, device=device) + cached_conv_wm_sum = torch.zeros(1, batch_size, embed_dim, device=device) + cached_conv_wm_num_frames = torch.zeros(batch_size, dtype=torch.int64, device=device) + + caches.extend([ + cached_key, + cached_value, + cached_conv, + cached_norm_stats, + cached_norm_len, + cached_attn_wm_sum, + cached_attn_wm_num_frames, + cached_conv_wm_sum, + cached_conv_wm_num_frames, + ]) + + return caches + + +def pad_mask(mask: Optional[Tensor], seq_len: int): + # mask: (batch_size, old_seq_len) + # if mask is not None, returns mask: (batch_size, seq_len); pads with True (i.e., masked). + if mask is None: + return None + (batch_size, old_seq_len) = mask.shape + pad = seq_len - old_seq_len + if pad == 0: + return mask + else: + return torch.cat((mask, torch.ones(batch_size, pad, device=mask.device, dtype=torch.bool)), + dim=1) + + +def downsample_by(x: Tensor, downsampling_factor: int) -> Tensor: + # x: (seq_len, batch_size, num_channels) + # Returns: (seq_len // downsampling_factor, batch_size, num_channels * downsampling_factor) + if downsampling_factor == 1: + return x + (seq_len, batch_size, num_channels) = x.shape + x = x.reshape(seq_len // downsampling_factor, downsampling_factor, batch_size, num_channels) + x = x.permute(0, 2, 1, 3) + x = x.reshape(seq_len // downsampling_factor, batch_size, downsampling_factor * num_channels) + return x + +def upsample_by(x: Tensor, upsampling_factor: int) -> Tensor: + # x: (seq_len, batch_size, num_channels) + # Returns: (seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + if upsampling_factor == 1: + return x + (seq_len, batch_size, num_channels) = x.shape + x = x.reshape(seq_len, batch_size, upsampling_factor, num_channels // upsampling_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(seq_len * upsampling_factor, batch_size, num_channels // upsampling_factor) + return x + + +def get_dct_matrix(N): + """ + Generates an orthonormal DCT-II matrix for a given size N. + Args: + N (int): The size of the square matrix. + Returns: + torch.Tensor: The N x N orthonormal DCT-II matrix. + """ + # Create the base matrix with dimensions (N, N) + mat = torch.zeros(N, N) + # Create a tensor for the indices k (rows) and n (columns) + k = torch.arange(N).unsqueeze(1) + n = torch.arange(N).unsqueeze(0) + # Fill the matrix using the DCT-II formula + mat = math.sqrt(2 / N) * torch.cos(math.pi / (2 * N) * (2 * n + 1) * k) + # Adjust the first row (k=0) with a special normalization factor + mat[0] *= (2 ** -0.5) + return mat + + +class ZapformerEncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_multiple: determines the hidden dimension of the feedforward module + + conv_params (int): params per channel of convolution module + + Examples:: + >>> encoder_layer = ZapformerEncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + def __init__( + self, + embed_dim: int, + num_heads: int, + query_head_dim: int, + value_head_dim: int, + pos_head_dim: int, + feedforward_multiple: int, + conv_params: int, + num_freqs: int = 64, + causal: bool = False, + ) -> None: + super(ZapformerEncoderLayer, self).__init__() + self.embed_dim = embed_dim + self.name = None # will be set from training loop + + self.offset_scale_limiter = ScaleLimiter(max_rms=1.0) + + #power = 0.35 # power should be between 0 and 1. 1 would mean cov == I (unattainable) + #limit = (1. / (embed_dim ** power))) + limit = 0.25 # this is very enormous limit on correlations, it's just to prevent divergence + # and bad parameter locations from which it's impossible for the optimizer to escape. i.e. + # it should impose no real limitation on "normal" training runs. + self.correlation_limiter = CorrelationLimiter(limit=limit) + + self.self_attn = MultiheadRelPosGatedSelfAttention( + embed_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + value_head_dim=value_head_dim, + pos_head_dim=pos_head_dim, + num_freqs=num_freqs, + causal=causal, + ) + + feedforward_dim = embed_dim * feedforward_multiple + self.feed_forward1 = FeedforwardModule(embed_dim, feedforward_dim) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim) + + self.conv_module = ConvolutionModule(embed_dim, conv_params, causal=causal) + + self.norm = CausalSequenceNorm() if causal else SequenceNorm() + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + aux_loss_scale: + If supplied, auxiliary losses such as CosineSimilarityLoss will be + applied with this scale on the loss (note, these aux losses are + reduced via summation over frames.) + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + src = with_loss(src, self.correlation_limiter(src.permute(1, 0, 2), + 2. * aux_loss_scale, mask=src_key_padding_mask)) + + src_pre_ff1 = src + + src = src + self.feed_forward1(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + # may try changing src_pre_ff1 to src or vice versa. + src = src + self.self_attn(src_pre_ff1, src, attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + aux_loss_scale=0.1 * aux_loss_scale) + + src = src + self.conv_module(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask, aux_loss_scale=0.1 * aux_loss_scale) + + src = src + self.feed_forward2(src, aux_loss_scale=0.1 * aux_loss_scale, src_key_padding_mask=src_key_padding_mask) + + residual_scale = 0.25 + offset = (src - src_orig) * residual_scale + + offset = self.offset_scale_limiter(offset, aux_loss_scale) + + src = src_orig + offset + + src = self.norm(src, src_key_padding_mask) + + src = src.clamp(min=-5, max=5) + + return src + + def streaming_forward( + self, + src: Tensor, + cached_key: Tensor, + cached_value: Tensor, + cached_conv: Tensor, + cached_norm_stats: Tensor, + cached_norm_len: Tensor, + cached_attn_wm_sum: Tensor, + cached_attn_wm_num_frames: Tensor, + cached_conv_wm_sum: Tensor, + cached_conv_wm_num_frames: Tensor, + left_context_len: int, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + cached_key: cached attention key tensor, of shape (left_context_len, batch_size, key_dim) + cached_value: cached attention value tensor, of shape (left_context_len, batch_size, value_dim) + cached_conv: cached left context for the convolution module, of shape (batch_size, channels, left_pad) + cached_norm_stats: cached SequenceNorm stats, of shape (batch_size,) + cached_norm_len: cached SequenceNorm length, scalar. + cached_attn_wm_sum: (1, batch, channels), cumulative sum for attention weighted_mean + cached_attn_wm_num_frames: (batch,), number of frames for attention weighted_mean + cached_conv_wm_sum: (1, batch, channels), cumulative sum for conv weighted_mean + cached_conv_wm_num_frames: (batch,), number of frames for conv weighted_mean + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len + seq_len); + True means masked position. May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_value + - updated cached_conv + - updated cached_norm_stats + - updated cached_norm_len + - updated cached_attn_wm_sum + - updated cached_attn_wm_num_frames + - updated cached_conv_wm_sum + - updated cached_conv_wm_num_frames + """ + src_orig = src + + src_pre_ff1 = src + + chunk_mask = None if src_key_padding_mask is None else src_key_padding_mask[:, left_context_len:] + + src = src + self.feed_forward1(src, src_key_padding_mask=chunk_mask) + + # may try changing src_pre_ff1 to src or vice versa. + self_attn_out, cached_key, cached_value, cached_attn_wm_sum, cached_attn_wm_num_frames = self.self_attn.streaming_forward( + x_qkp=src_pre_ff1, + x_vg=src, + left_context_len=left_context_len, + cached_key=cached_key, + cached_value=cached_value, + cached_wm_sum=cached_attn_wm_sum, + cached_wm_num_frames=cached_attn_wm_num_frames, + key_padding_mask=src_key_padding_mask, + ) + src = src + self_attn_out + + src_conv, cached_conv, cached_conv_wm_sum, cached_conv_wm_num_frames = self.conv_module.streaming_forward( + src, + cached_conv=cached_conv, + cached_wm_sum=cached_conv_wm_sum, + cached_wm_num_frames=cached_conv_wm_num_frames, + src_key_padding_mask=chunk_mask, + ) + src = src + src_conv + + src = src + self.feed_forward2(src, src_key_padding_mask=chunk_mask) + + residual_scale = 0.25 + offset = (src - src_orig) * residual_scale + + src = src_orig + offset + + src, cached_norm_stats, cached_norm_len = self.norm.streaming_forward( + src, + cached_stats_sum=cached_norm_stats, + cached_len=cached_norm_len, + ) + + src = src.clamp(min=-5, max=5) + + return ( + src, + cached_key, + cached_value, + cached_conv, + cached_norm_stats, + cached_norm_len, + cached_attn_wm_sum, + cached_attn_wm_num_frames, + cached_conv_wm_sum, + cached_conv_wm_num_frames, + ) + + +class ZapformerEncoder(nn.Module): + r"""ZapformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ZapformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + dim: the dimension of the input and output (layer dim may be less than this). + + Examples:: + >>> encoder_layer = ZapformerEncoderLayer(embed_dim=512, nhead=8) + >>> zapformer_encoder = ZapformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zapformer_encoder(src) + """ + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dim: int, + ) -> None: + super().__init__() + + # self.downsample will also reverse the downsampling operation for us afterward. + self.proj = OrthogonalLinear(dim, + encoder_layer.embed_dim, + bias=False) + + self.name = None + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + self.residual_scales = nn.Parameter( + torch.cat([ -1.0 * torch.ones(1), + (1. / num_layers) * torch.ones(num_layers) ], + dim=0)) + + self.input_scale = nn.Parameter(torch.tensor([1.0])) + + self.copy_bypass = Identity() + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim), + but embed_dim is allowed to exceed the modules' embed_dim; we will bypass + any extra dimensions. + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + out, of the same shape as src. + """ + src_orig_fulldim = src + + src = self.proj(src) # project to layer dim. + + num_layers = len(self.layers) + src_orig = src + + residual_scale = limit_param_value(self.residual_scales[0], + min=-1.0, max=-0.5) + input_scale = limit_param_value(self.input_scale, + min=0.5, max=2.0) + + src_with_bypass = residual_scale * src + src = input_scale * src + + for i, mod in enumerate(self.layers): + + src = mod( + src, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + aux_loss_scale=aux_loss_scale/num_layers, + ) + residual_scale = limit_param_value(self.residual_scales[i + 1], + min=0.0 if i + 1 < num_layers else min(0.5, 1. / num_layers), + max=1.0) + src_with_bypass = src_with_bypass + residual_scale * src + + + offset = src_with_bypass + + src = src_orig_fulldim + self.proj(offset, transpose=True) + # in effect src_orig_fulldim already contains src_orig with a scale of 1 for the missing dims, + # because of some identities involving orthogonal matrices. + + return src + + + def streaming_forward( + self, + src: Tensor, + caches: List[Tensor], + left_context_len: int, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn in streaming mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embed_dim). + caches: list of cached tensors of N encoder layers. For layer-i, + caches[i*9:(i+1)*9] is (cached_key, cached_value, cached_conv, + cached_norm_stats, cached_norm_len, cached_attn_wm_sum, + cached_attn_wm_num_frames, cached_conv_wm_sum, cached_conv_wm_num_frames). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated caches + """ + src_orig_fulldim = src + + # project to layer dim. + src = self.proj(src) + + num_layers = len(self.layers) + assert len(caches) == num_layers * 9 + + residual_scale = self.residual_scales[0] + input_scale = self.input_scale + + src_with_bypass = residual_scale * src + src = input_scale * src + + new_caches = [] + for i, mod in enumerate(self.layers): + ( + cached_key, + cached_value, + cached_conv, + cached_norm_stats, + cached_norm_len, + cached_attn_wm_sum, + cached_attn_wm_num_frames, + cached_conv_wm_sum, + cached_conv_wm_num_frames, + ) = caches[i * 9 : (i + 1) * 9] + + ( + src, + new_cached_key, + new_cached_value, + new_cached_conv, + new_cached_norm_stats, + new_cached_norm_len, + new_cached_attn_wm_sum, + new_cached_attn_wm_num_frames, + new_cached_conv_wm_sum, + new_cached_conv_wm_num_frames, + ) = mod.streaming_forward( + src, + cached_key=cached_key, + cached_value=cached_value, + cached_conv=cached_conv, + cached_norm_stats=cached_norm_stats, + cached_norm_len=cached_norm_len, + cached_attn_wm_sum=cached_attn_wm_sum, + cached_attn_wm_num_frames=cached_attn_wm_num_frames, + cached_conv_wm_sum=cached_conv_wm_sum, + cached_conv_wm_num_frames=cached_conv_wm_num_frames, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + + layer_residual_scale = self.residual_scales[i + 1] + + src_with_bypass = src_with_bypass + layer_residual_scale * src + + new_caches.extend([ + new_cached_key, + new_cached_value, + new_cached_conv, + new_cached_norm_stats, + new_cached_norm_len, + new_cached_attn_wm_sum, + new_cached_attn_wm_num_frames, + new_cached_conv_wm_sum, + new_cached_conv_wm_num_frames, + ]) + + offset = src_with_bypass + src = src_orig_fulldim + self.proj(offset, transpose=True) + + return src, new_caches + + + + +class MultiheadRelPosGatedSelfAttention(nn.Module): + r""" + Module that computes multi-head attention weights with additive relative-position + scores that are kept separate from the regular scores. The values have gating. + An RMSNorm module is used to pre-normalize the input embedding only as it is + input to the queries and keys, not the values. + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + """ + def __init__( + self, + embed_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int , + value_head_dim: int, + num_freqs: int = 64, + causal: bool = False, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.name = None # will be overwritten in training code; for diagnostics. + + self.in_norm = RmsNorm() + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.qkp_in_proj = ScaledLinear( + embed_dim, in_proj_dim, + bias=True, initial_scale=0.125, + ) + + self.rel_pos = RelPosScores(num_heads, pos_head_dim, num_freqs=num_freqs) + + self.copy_query = Identity() + self.copy_pos_query = Identity() + + # value and gating in_proj. + self.vg_in_proj = ScaledLinear(embed_dim, 3 * num_heads * value_head_dim, + initial_scale=0.1, bias=True) + + self.copy_v = nn.Identity() # diagnostics. + self.sigmoid_in = nn.Sigmoid() + self.sigmoid_out = nn.Sigmoid() + + # out proj for the value times gating. + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.1 + ) + + self.weighted_mean = WeightedMean(num_heads * value_head_dim, causal) # TODO: fix causal option + + def forward( + self, + x_qkp: Tensor, + x_vg: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + aux_loss_scale: float = 0.0, + ) -> Tensor: + r""" + Args: + x_qkp: input of shape (seq_len, batch_size, embed_dim), that is used for the queries, + keys and positions. + x_vg: input of shape (seq_len, batch_size, embed_dim), that is used for the values + and gates. May be the same as x_qk. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + query_head_dim = self.query_head_dim + num_heads = self.num_heads + x_qkp = self.in_norm(x_qkp) + x_qkp = self.qkp_in_proj(x_qkp) + + seq_len, batch_size, _ = x_qkp.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x_qkp[..., 0:query_dim] + k = x_qkp[..., query_dim : 2 * query_dim] * (query_head_dim ** -0.5) + p = x_qkp[..., 2 * query_dim:] + + q = self.copy_query(q) # for diagnostics only, does nothing. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, -1) + + #q = self.rope(q.permute(1, 0, 2, 3)) # (batch, seq, head, channel) + #k = self.rope(k.permute(1, 0, 2, 3)) # (batch, seq, head, channel) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) # (head, batch, time1, time2) + + p = p.permute(1, 2, 0, 3) + pos_scores = self.rel_pos(p) # (batch, head, time1, time2) + attn_scores = attn_scores + pos_scores.permute(1, 0, 2, 3) + + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and self.training: + attn_scores_limit = 8.0 # limit on our metric that affects how much grad we are likely to backpropagate. + attn_scores = PenalizeLargeAttentionScores.apply(attn_scores, attn_scores_limit, + 0.1 * aux_loss_scale, + key_padding_mask, self.name) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + vg = self.vg_in_proj(x_vg) + N = vg.shape[-1] // 3 + v = vg[..., :N] + g = vg[..., N:] + if self.training: + # don't let the sigmoid values get too extreme, limit to -2..2. + g = penalize_abs_values_gt(g, 2.0, penalty=0.02*aux_loss_scale) + + g_in, g_out = g.chunk(2, dim=-1) + v = v * self.sigmoid_in(g_in) + + wm = self.weighted_mean(v, key_padding_mask, apply_mask=True) + + v = v.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + v = self.copy_v(v) + value_head_dim = v.shape[-1] + # now v: (num_heads, batch_size, seq_len, value_head_dim) + + # todo: see whether there is benefit in overriding matmul + v = torch.matmul(attn_weights, v) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + v = ( + v.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + v = v + wm + v = v * self.sigmoid_out(g_out) + v = self.out_proj(v) + return v + + def streaming_forward( + self, + x_qkp: Tensor, + x_vg: Tensor, + left_context_len: int, + cached_key: Tensor, + cached_value: Tensor, + cached_wm_sum: Tensor, + cached_wm_num_frames: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x_qkp: input of shape (seq_len, batch_size, embed_dim), that is used for the queries, + keys and positions. + x_vg: input of shape (seq_len, batch_size, embed_dim), that is used for the values + and gates. May be the same as x_qk. + left_context_len: length of the cached left context. + cached_key: cached attention key tensor, of shape (left_context_len, batch_size, key_dim). + cached_value: cached attention value tensor, of shape (left_context_len, batch_size, value_dim). + cached_wm_sum: (1, batch, channels), cumulative sum for weighted_mean + cached_wm_num_frames: (batch,), number of frames seen so far + key_padding_mask: a bool tensor of shape (batch_size, left_context_len + seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention output, of shape (seq_len, batch_size, embed_dim) + - updated cached_key, of shape (left_context_len, batch_size, key_dim) + - updated cached_value, of shape (left_context_len, batch_size, value_dim) + - Updated cached_wm_sum (1, batch, channels) + - Updated cached_wm_num_frames (batch,) + """ + query_head_dim = self.query_head_dim + num_heads = self.num_heads + x_qkp = self.in_norm(x_qkp) + x_qkp = self.qkp_in_proj(x_qkp) + + seq_len, batch_size, _ = x_qkp.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x_qkp[..., 0:query_dim] + k = x_qkp[..., query_dim : 2 * query_dim] + p = x_qkp[..., 2 * query_dim:] + + # append the cached key to the current key, and update the cache + assert cached_key.shape[0] == left_context_len, (cached_key.shape, left_context_len) + k = torch.cat([cached_key, k], dim=0) + kv_len = k.shape[0] + cached_key = k[kv_len - left_context_len:] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + k = k.reshape(kv_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, -1) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, query_head_dim, time2) + + attn_scores = torch.matmul(q, k) # (head, batch, time1, time2) + + p = p.permute(1, 2, 0, 3) + pos_scores = self.rel_pos(p, left_context_len) # (batch, head, time1, time2) + attn_scores = attn_scores + pos_scores.permute(1, 0, 2, 3) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, kv_len) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, kv_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill(key_padding_mask.unsqueeze(1), -1000) + + attn_weights = attn_scores.softmax(dim=-1) + + vg = self.vg_in_proj(x_vg) + N = vg.shape[-1] // 3 + v = vg[..., :N] + g = vg[..., N:] + g_in, g_out = g.chunk(2, dim=-1) + v = v * self.sigmoid_in(g_in) + + wm, cached_wm_sum, cached_wm_num_frames = self.weighted_mean.streaming_forward( + v, cached_wm_sum, cached_wm_num_frames + ) + + # append the cached value to the current value, and update the cache + assert cached_value.shape[0] == left_context_len, (cached_value.shape, left_context_len) + v = torch.cat([cached_value, v], dim=0) + cached_value = v[kv_len - left_context_len:] + + v = v.reshape(kv_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + value_head_dim = v.shape[-1] + # now v: (num_heads, batch_size, kv_len, value_head_dim) + + # todo: see whether there is benefit in overriding matmul + v = torch.matmul(attn_weights, v) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + v = ( + v.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + v = v + wm + v = v * self.sigmoid_out(g_out) + v = self.out_proj(v) + + return v, cached_key, cached_value, cached_wm_sum, cached_wm_num_frames + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class PenalizeLargeAttentionScores(torch.autograd.Function): + @staticmethod + def forward( + ctx, + attn_scores: Tensor, + limit: float, + aux_loss_scale: float, + key_padding_mask: Optional[Tensor], + name: str): + # attn_scores: (head, batch, query_time, key_time) + ctx.save_for_backward(attn_scores) + ctx.mask = key_padding_mask # has no grad + ctx.limit = limit + ctx.aux_loss_scale = aux_loss_scale + ctx.name = name + return attn_scores + + @staticmethod + def backward( + ctx, + attn_scores_grad): + attn_scores, = ctx.saved_tensors + mask = ctx.mask + (num_heads, batch_size, seq_len, _) = attn_scores.shape + with torch.amp.autocast('cuda', enabled=False): + attn_scores = attn_scores.to(torch.float) + attn_scores = attn_scores.detach() + # attn_scores: (head, batch, query_time, key_time) + attn_scores.requires_grad = True + with torch.enable_grad(): + probs = attn_scores.softmax(dim=-1) + scaled_scores = attn_scores.abs() * probs + avg_scores = scaled_scores.sum(dim=-1) # (head, batch, query_time) + if mask is not None: + avg_scores = avg_scores * (~mask) # mask: (batch, time) + query_scores = (avg_scores - ctx.limit).relu() + + if random.random() < 0.0005: + query_excess = query_scores.mean(dim=(1,2)).to('cpu') + avg_scores_mean = avg_scores.mean(dim=(1,2)).to('cpu') + logging.info(f"PenalizeLargeAttentionScores: {ctx.name}, limit={ctx.limit}, avg_scores={avg_scores_mean}, query_excess={query_excess}") + # all these losses have a "per-frame" scaling, i.e. scaled proportional to the total number + # of frames which is batch_size * seq_len. normalize by dividing by num heads. + # also divide by ctx.limit so it's like penalizing a relative excess. + query_scores.backward(gradient=torch.full_like(query_scores, ctx.aux_loss_scale / (num_heads * ctx.limit))) + + return attn_scores_grad + attn_scores.grad, None, None, None, None + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zapformer model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int): + super(FeedforwardModule, self).__init__() + # try to get in the useful range of the activation function, i.e. not too small. + self.in_proj = ScaledLinear(embed_dim, feedforward_dim) + # weight_min_rms will be interpreted by get_parameter_groups_with_lrs() and passed + # to the TransformedAdam optimizer. + self.in_proj.weight_min_rms = 0.02 + + self.out_proj = ActivationAndLinear( + feedforward_dim, + embed_dim, + activation="SwashR", + initial_scale=0.5, + bias=True, + ) + + + def forward(self, x: Tensor, aux_loss_scale: float = 0.0, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + x = self.in_proj(x) + x = self.out_proj(x) + return x + +def round_up_to_power_of_two(x): + x = x - 1 + x = x | x >> 1 + x = x | x >> 2 + x = x | x >> 4 + x = x | x >> 8 + x = x | x >> 16 + x = x + 1 + return x + + +# wolfram alpha: +# the right part of the triangular bin, from 0 to +W. +# definite integral from omega = 0 to W of (1 - omega/W) exp(-i x \omega) d\omega +# = -(i W x + e^(-i W x) - 1)/(W x^2) +# Re[definite integral from omega = 0 to W of (1 - omega/W) exp(-i x \omega) d\omega] +# = (1 - cos(W x))/(W x^2) +# Im[definite integral from omega = 0 to W of (1 - omega/W) exp(-i x \omega) d\omega] +# = (sin(W x) - W x)/(W x^2) + +# the left part of the triangular bin, from -W to 0. +# definite integral from omega = -W to 0 of (omega/W + 1) exp(-i x \omega) d\omega +# (i W x - e^(i W x) + 1)/(W x^2) +# +# Let the center frequency be C. +# right side: +# = e^(i C x) * -(i W x + e^(-i W x) - 1)/(W x^2) +# "alternate form including W, C and x are real": [note, this is left hand width, W_l] +# (W x sin(C x) - cos(x (C - W)) + cos(C x))/(W x^2) - (i (sin(x (C - W)) + W x cos(C x) - sin(C x)))/(W x^2) +# +# left side: +# e^(i C x) * (i W x - e^(i W x) + 1)/(W x^2) +# "alternate form including W, C and x are real": [note, this is right hand width, W_r] +# -(W x sin(C x) + cos(x (C + W)) - cos(C x))/(W x^2) + (i (-sin(x (C + W)) + W x cos(C x) + sin(C x)))/(W x^2) +# +# summing the left and right sides: +# Real part: +# +# (W_r x sin(C x) - cos(x (C - W_r)) + cos(C x))/(W_r x^2) +# -(W_l x sin(C x) + cos(x (C + W_l)) - cos(C x))/(W_l x^2) +# = (cos(C x) - cos((C - W_r)x)) / W_r x^2 +# + (cos(C x) - cos((C + W_l)x)) / W_l x^2 + +# Imaginary part: +# -(sin(x (C - W_r)) + W_r x cos(C x) - sin(C x))) / (W_r x^2) +# +(-sin(x (C + W_l)) + W_l x cos(C x) + sin(C x)) / (W_l x^2) +# = ( sin(C x) - sin((C - W_r)x) ) / (W_r x^2) +# + ( sin(C x) - sin((C + W_l)x) ) / (W_l x^2) + +def compute_angular_freq_basis_triangular(freqs: Tensor, + t: Tensor, + scale: bool) -> Tensor: + """ + This function computes a set of windowed sinusoidal functions + corresponding to the real and imaginary parts of possibly-asymmetrical + triangular angular-frequency bins in frequency space. This basis + allows you to approximate functions whose fourier spectrum is + a piecewise linear function of frequency, with the x-axis values of + the inflection points of the piecewise linear function corresponding + to the supplied "freqs". + + Args: + freqs: the frequencies of the triangular-bin centers; the left and + right parts of the widths of the triangular bins correspond to the + distances to the two adjacent bins; for the "edge" bins, the + "edge" distances are duplicated. + t: the "t" (or x) values for which we want to evaluate the basis; this + will normally be some kind of arange expression e.g. arange(100). + scale: if True, the returned basis will contain the "natural" scaling + factors that arise from the bin widths; if False, it will be + normalized so that the maximum absolute value of the real + functions (attained at t==0) is 1. + + + Returns: + Returns the real and imaginary parts of the basis functions, with + shape (t.size(), freqs.size(), 2) + """ + dtype = freqs.dtype + freqs = freqs.to(torch.double) + t = t.to(torch.double) + + t = t.unsqueeze(-1) + + + C = freqs # Center frequencies of bins. + W = freqs[1:] - freqs[:-1] # the differences between the frequencies + W_l = torch.cat((W[:1], W)) # the difference between each center freq and the freq to the left + W_r = torch.cat((W, W[-1:])) # the difference between each center freq and the freq to the right + + angles = C * t + angles_r = (C - W_r) * t + angles_l = (C + W_l) * t + t2 = t**2 + scale_factor = 0.5 * (W_r + W_l) + + re = torch.where(t == 0., scale_factor, + (angles.cos() - angles_r.cos()) / (W_r * t2) + (angles.cos() - angles_l.cos()) / (W_l * t2)) + im = torch.where(t == 0., 0.0, + (angles.sin() - angles_r.sin()) / (W_r * t2) + (angles.sin() - angles_l.sin()) / (W_l * t2)) + + + if not scale: + re = re / scale_factor + im = im / scale_factor + + return torch.stack((re, im), dim=-1).to(dtype) + + +class AngularFreqBasis(nn.Module): + """ + Computes and caches the angular-frequency basis used in relative position scoring. + + num_freqs: the number of frequencies of the sin and cos functions + low_freq_factor: this is approximately the amount by which the lowest frequency will be + less than the highest frequency, the highest frequency being the Nyquist (pi). + The frequencies are close to a geometric series at higher frequency but linear + at low frequency. + """ + def __init__(self, num_freqs: int, low_freq_factor: float = 0.001): + super().__init__() + log_freqs = torch.linspace(math.log(low_freq_factor), math.log(1 + low_freq_factor), num_freqs) + freqs = math.pi * (log_freqs.exp() - low_freq_factor) # range from 0 to pi. + freqs[0] = 0.0 # in case of roundoff + self.register_buffer('freqs', freqs, persistent=False) + + self._cached_basis: Tensor = torch.empty(0) + self._cached_seq_len: int = -1 + self._cached_left_context_len: int = -1 + + def forward(self, seq_len: int, left_context_len: int, device: torch.device) -> Tensor: + """ + Returns basis of shape (2 * seq_len + left_context_len - 1, 2 * num_freqs). + + The result is cached; if the requested (seq_len, left_context_len) fits + within the cached range, the cached tensor is sliced rather than + recomputed. + """ + S = self._cached_seq_len + L = self._cached_left_context_len + if (self._cached_basis.numel() > 0 + and seq_len <= S + and seq_len + left_context_len <= S + L): + start = S + L - seq_len - left_context_len + end = start + 2 * seq_len + left_context_len - 1 + return self._cached_basis[start:end] + + if torch.jit.is_tracing(): + raise RuntimeError( + "AngularFreqBasis: cache miss during tracing. " + "Call warmup_angular_freq_bases() before tracing." + ) + + t = torch.arange(-(seq_len + left_context_len - 1), seq_len, dtype=torch.double, device=device) + basis = compute_angular_freq_basis_triangular(self.freqs, t, scale=False) + # basis: (2 * seq_len + left_context_len - 1, num_freqs, 2) + basis = basis.permute(0, 2, 1) + # permute it because of how we did the low-pass initialization of weight, we want + # the cos and sin parts to each be continuous ranges, not interleaved. + basis = basis.reshape(basis.shape[0], -1) + # basis: (2 * seq_len + left_context_len - 1, 2 * num_freqs) + + self._cached_basis = basis + self._cached_seq_len = seq_len + self._cached_left_context_len = left_context_len + return basis + + +class RelPosScores(nn.Module): + def __init__(self, + num_heads: int, + pos_head_dim: int, + num_freqs: int): + """ + Implementation of relative position scores; where conventional relative position scores + would use sinusoids, we treat each sinusoid frequency as the central frequency of a + triangular "bucket" (like the buckets in mel bins) of frequencies. What this amounts + to is that instead of a sinusoid we get something a bit like a sinusoid times a + sinc-squared function (the sinc-squared function is the fourier transform of a triangular + function). Actually it's not the sinc-squared funtion, it's a slightly more complicated + function than that because the "triangles" have uneven shapes, due to the center frequencies + of the triangles not being evenly spaced. + + Args: + num_heads: the number of heads + pos_head_dim: the dimension of the head; in a conventionally structured model this would + be identical to the query-dim but we make the "position query" independent of + the main query and with a smaller dimension. + num_freqs: the number of frequencies of the sin and cos functions + """ + super().__init__() + self.weight = nn.Parameter(0.04 * torch.randn(num_heads, pos_head_dim, 2 * num_freqs)) + with torch.no_grad(): + # initialize the weight in a low-pass way. I think this is not so critical + # actually, it may not matter. + for _ in range(10): + self.weight[:] = (2 ** -0.5) * (self.weight + self.weight.roll(1, dims=2)) + + self.angular_freq_basis = AngularFreqBasis(num_freqs) + + def forward(self, p: Tensor, left_context_len: int = 0) -> Tensor: + """ + Compute and return unnormalized log scores for relative position. + Args: + p: these are the position-queries, of shape (batch_size, num_heads, seq_len, pos_head_dim) + (they are obtained via projection, just like the queries). + left_context_len: length of left context, must be 0 for non-streaming forward and > 0 for streaming forward. + Returns: + scores: (batch_size, num_heads, dest_seq_len, src_seq_len), where dest_seq_len relates to the + query and src_seq_len to the key. + In non-streaming forward, dest_seq_len and src_seq_len are numerically equal to seq_len; + in streaming forward, dest_seq_len is seq_len and src_seq_len is seq_len + left_context_len. + """ + (batch_size, num_heads, seq_len, pos_head_dim) = p.shape + + basis = self.angular_freq_basis(seq_len, left_context_len, p.device) + # basis: (2 * seq_len + left_context_len - 1, 2 * num_freqs) + + x = torch.matmul(self.weight, basis.t()) + assert x.shape == (num_heads, pos_head_dim, 2 * seq_len + left_context_len - 1) + + # with seq_len2 = 2 * seq_len + left_context_len - 1, + # (batch, head, seq_len, pos_head_dim) x (1, head, pos_head_dim, seq_len2) -> (batch, head, seq_len, seq_len2) + pos_weights = torch.matmul(p, x) + + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. This is all copied from our old conformer/zapformer code. + if torch.jit.is_tracing(): + seq_len2 = pos_weights.shape[-1] + rows = torch.arange(start=seq_len - 1, end=-1, step=-1) + cols = torch.arange(left_context_len + seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, seq_len2) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, seq_len, left_context_len + seq_len) + else: + pos_weights = pos_weights.as_strided( + (batch_size, num_heads, seq_len, left_context_len + seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + return pos_weights + + +def round_up_to_power_of_two(x): + x = x - 1 + x = x | x >> 1 + x = x | x >> 2 + x = x | x >> 4 + x = x | x >> 8 + x = x | x >> 16 + x = x + 1 + return x + + + +# FftConv was formerly used as the depthwise_conv module in ConvolutionModule. +# CAUTION: this is not used right now, we use BasisConv plus WeightedMean in +# parallel for the depthwise convolution in the ConvModule. Using FftConv is +# actually just as good in WER terms and is also more efficient, versus (BasisConv +# plus WeightedMean); FftConv itself, should be about twice faster because it +# operates on a twice-shorter length than BasisConv since BasisConv pads for +# exactness. For the overall training the speed difference is about 10%. +# The reason we use BasisConv is because it is properly invariant to +# how we pad different-length sequences into a batch, while FftConv cannot +# be made to give exactly the same results independent of the batch size, because +# it treats the signal as repeating in time which depends on the FFT size which +# depends on the longest sequence in the batch. Unfortunately, we don't know +# exactly how the model is going to be used and we don't want it to become +# deal-breaker that batching very-different-length sequences together in inference +# time could significantly affect the model results. For image tasks, +# FftConv may still be useful (after suitable adaptation), because +# you wouldn't normally try to inference different size images in a batch. +class FftConv(nn.Module): + def __init__(self, + num_channels: int, + params_per_channel: int, + bias: bool = True): + super().__init__() + self.weight = nn.Parameter(0.1 * torch.randn(num_channels, params_per_channel)) + # one factor of 2 is for (sin, cos); the other is to double the num representable freqs + self.weight_proj = nn.Linear(params_per_channel, 4 * params_per_channel) + if bias: + self.bias = nn.Parameter(0.01 * torch.randn(num_channels)) + + + def forward(self, + x: Tensor) -> Tensor: + (seq_len, batch_size, num_channels) = x.shape + + # select a power of two that's >= seq_len // 8 and round up seq_len + # to a multiple of that power. This means that rounded_seq_len + # will be of the form (2**n) * k where k <= 8, so it won't contain + # many factors other than two; this will make the FFT more efficient + # without adding an excessive amount of padding. + power_of_two = max(1, round_up_to_power_of_two(seq_len // 8)) + rounded_seq_len = power_of_two * ((seq_len + power_of_two - 1) // power_of_two) + + + with torch.amp.autocast('cuda', enabled=False): + # do it in float32 because non power of two seq_len is not supported in half precision. + x = torch.fft.rfft(x.to(torch.float32), dim=0, n=rounded_seq_len) + # x: (num_freqs, batch_size, num_channels) + N = x.shape[0] # num freqs + weight = 4. * self.weight + weight = self.weight_proj(weight).reshape(num_channels, 2, -1) # (num_channels, 2, 2 * params_per_channel) + # this scale of 10 times is because of interactions with commonly + # used optimizers, it's to help this module learn faster than it + # otherwise would. + weight = torch.nn.functional.interpolate(weight, N, mode='linear', align_corners=True) + weight = torch.view_as_complex(weight.permute(2, 0, 1).contiguous()) + # weight: (N, num_channels) + weight = weight.unsqueeze(1) # (N, 1, num_channels) + x = x * weight + x = torch.fft.irfft(x, n=rounded_seq_len, dim=0) + + x = x[:seq_len] + + try: + x = x + self.bias + except AttributeError: + pass + + return x + + +# convolution where we convolve with a combination of basis functions, the basis functions +# being based on linear interpolation in Fourier space-- in effect, each pair of basis functions +# corresponds to the real and imaginary coefficients for one triangular bin in Fourier space; +# in the time domain the triangular bin becomes a sinc^2 function and the frequency offset +# is just a complex exponential of which the real and imaginary coefficients give us sines and +# cosines. +def get_basis_funcs(seq_len: int, + num_freqs: int, + **kwargs +): + """ + seq_len: the sequence length to which the basis functions are truncated; this is expected to + be even + num_freqs: the number of frequencies; the number of basis functions will be 2 * num_freqs, + and note that the first pair of basis functions are special, because they are the + (zero-freq; nyquist-freq) ones. + kwargs: can be used for device + + Returns: + basis functions of shape: (2 * num_freqs, seq_len) + """ + assert seq_len % 2 == 0 + t = torch.cat((torch.arange(seq_len // 2, **kwargs), + torch.arange(-seq_len // 2, 0, **kwargs)), dim=0) # e.g. tensor([ 0, 1, 2, 3, -4, -3, -2, -1]) + # the second half of the "t" values are interpreted as the "negative half" of the time range-- + # the time range representing t values from -seq_len // 2 to seq_len // 2 - 1. + # The way we use this will be to convolve it with a signal of size seq_len // 2 that + # has been padded with zeroes of length seq_len // 2, and we want the result to be as if we padded with the basis + # functions from -infinity to infinity. + + + scaled_t = t * math.pi / num_freqs + + # "freqs" are the t values multiplied by the basis frequencies + t_freqs = scaled_t * torch.arange(num_freqs + 1, **kwargs).unsqueeze(-1) + # t_freqs: (num_freqs + 1, seq_len) + + # it's a sinc-squared envelope, as the frequency domain envelope is a + # triangular, not a rectangular, function. the factor of 0.5 comes + # from the math + sinc_arg = 0.5 * scaled_t + envelope = torch.where(sinc_arg != 0.0, sinc_arg.sin() / sinc_arg, torch.ones_like(sinc_arg)) ** 2 + + + cos, sin = t_freqs.cos() * envelope, t_freqs.sin() * envelope + #plt.plot(envelope) + + # the factor of 0.5 is because the other freqs would get "counted twice" due + # to having two symmetric versions, the freqs at zero and the nyquist only have + # one copy. This ensures that if we give a coeff of all ones on all the + # cos terms, we get (a scaled version of) the delta function. + sin[0] = 0.5 * cos[-1] + cos[0] = 0.5 * cos[0] + # the sin coefficient of freq 0 and nyquist gives us nothing, so we use the cos + # at the nyquist in this position. + cos = cos[:num_freqs] + sin = sin[:num_freqs] + #scale = num_freqs ** -0.5 # scale to make the funcs have a value around 1. + #cos = cos * scale + #sin = sin * scale + + basis = torch.cat((cos, sin), dim=0) + # basis: (2 * num_freqs, seq_len) + + #for i in range(num_freqs + 1): + # plt.plot(cos[i]) + # plt.plot(sin[i]) + # plt.show() + return basis + + +def fourier_conv(x: Tensor, y: Tensor): + # fourier based convolution of x and y, returns + # something with the same sequence length as the shorter of + # the two. + # x, y: (seq_len, [1 or batch_size], num_channels) + T = max(x.shape[0], y.shape[0]) + T_out = min(x.shape[0], y.shape[0]) + + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float) + y = y.to(torch.float) + X = torch.fft.rfft(x, dim=0, n=T) + Y = torch.fft.rfft(y, dim=0, n=T) + return torch.fft.irfft(X * Y, dim=0, n=T)[:T_out] + +# fourier-based convolution, mem-efficient wrapper for fourier_conv. +class FourierConv(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return fourier_conv(x, y) + + @staticmethod + def backward(ctx, ans_grad): + # we could probably do a bit better than this by doing it manually + x, y = ctx.saved_tensors + with torch.enable_grad(): + x = x.detach() + y = y.detach() + x.requires_grad = True + y.requires_grad = True + ans = fourier_conv(x, y) + ans.backward(gradient=ans_grad) + return x.grad, y.grad + + +class WeightedMean(nn.Module): + # this is like the core part of squeeze-and-excite: it computes a mean over time, + # and then multiplies it by a learnable channel-specific weight. + # we add this to a more conventional convolution; we found this was helpful because + # normal convolution cannot do averaging-over-time since it does not know the + # sequence length. + def __init__(self, + num_channels: int, + causal: bool = False): + super().__init__() + self.causal = causal + self.weights = nn.Parameter(0.1 * torch.randn(num_channels)) + + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + apply_mask: bool = True) -> Tensor: + """ + Compute weighted mean. + x: (time, batch, channel) + src_key_padding_mask: (batch, time), True for masked positions + + Returned shape: (time, batch, channel) if causal else (batch, channel) + """ + T = x.shape[0] + if self.causal: + num_frames = torch.arange(1, T + 1, device=x.device) + x_cumsum = torch.cumsum(x, dim=0) + return x_cumsum / num_frames[:, None, None] * self.weights + + + # assume x already masked, if mask is in use. + if src_key_padding_mask is not None: + mask = src_key_padding_mask.logical_not().to(torch.float) + num_frames = mask.sum(dim=1) + num_frames = num_frames.unsqueeze(-1).to(torch.float) + + if apply_mask: + x = x * mask.t().unsqueeze(-1) + + # num_frames: (batch_size, 1) + return x.mean(dim=0) * (T / num_frames) * self.weights + else: + return x.mean(dim=0) * self.weights + + def streaming_forward( + self, + x: Tensor, + cached_sum: Tensor, + cached_num_frames: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Streaming forward for causal weighted mean. + + Args: + x: (time, batch, channel), the current chunk + cached_sum: (1, batch, channel), cumulative sum from previous chunks + cached_num_frames: (batch,), number of frames seen so far + + Returns: + - output: (time, batch, channel) + - new_cached_sum: (1, batch, channel) + - new_cached_num_frames: (batch,) + """ + T = x.shape[0] + # cumsum within this chunk, then add the historical sum + x_cumsum = torch.cumsum(x, dim=0) + cached_sum # (T, batch, channel) + + # num_frames for each position in this chunk: (T, batch) + num_frames = cached_num_frames.unsqueeze(0) + torch.arange( + 1, T + 1, device=x.device + ).unsqueeze(1) # (T, batch) + + output = x_cumsum / num_frames.unsqueeze(-1) * self.weights + + new_cached_sum = x_cumsum[-1:, :, :] # (1, batch, channel) + new_cached_num_frames = cached_num_frames + T # (batch,) + + return output, new_cached_sum, new_cached_num_frames + + +class BasisConv(nn.Module): + def __init__(self, + num_channels: int, + num_freqs: int, + params_per_channel: int): + super().__init__() + self.weight_proj = nn.Linear(params_per_channel, 2 * num_freqs) + + self.weight = nn.Parameter(0.05 * torch.randn(num_channels, + params_per_channel)) + + + def forward(self, + x: Tensor) -> Tensor: + (seq_len, batch_size, num_channels) = x.shape + + + # round seq_len to a multiple of "round" to help ensure the FFT dimension + # has plenty of powers of two; this will tend to make it more efficient. + round = min(16, round_up_to_power_of_two(seq_len)) + seq_len_rounded = round * ((seq_len + round - 1) // round) + + # to ensure the answer is the same regardless of the amount of padding, we + # pad the sequence to at least twice its initial length for purposes of + # the FFT-based convolution. Because we will view the basis functions + # as going from t=-seq_len_rounded to t=seq_len_rounded - 1, this will + # ensure that we never see "wrap-around" effects. + T = 2 * seq_len_rounded + + num_freqs = self.weight_proj.weight.shape[0] // 2 + basis_funcs = get_basis_funcs(T, num_freqs, device=x.device) + # basis_funcs: (2 * num_freqs, T) + + scale = num_freqs ** -0.5 + + weight = scale * self.weight_proj(self.weight) + # weight: (num_channels, 2 * num_freqs) + channel_funcs = torch.matmul(weight, basis_funcs) + # channel_funcs: (num_channels, T) + + + # channel_funcs: (num_channels, T) + channel_funcs = channel_funcs.t().unsqueeze(1) + # channel_funcs: (T, 1, num_channels) + + return FourierConv.apply(channel_funcs, x) + + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zapformer model. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + + if not causal: + assert kernel_size % 2 == 1 + self.depthwise_conv = nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=False, + ) + + else: + self.depthwise_conv = nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=0, # will pad manually, on one side. + bias=False, + ) + self.left_pad = kernel_size - 1 + + with torch.no_grad(): + # make the non-central convolution weights much smaller. + k = kernel_size // 2 + self.depthwise_conv.weight[..., :k] *= 0.1 + self.depthwise_conv.weight[..., -k:] *= 0.1 + + # add average-of-all-frames to the "convolution."; it has extra power vs the convolution + # because the num frames differs between utterances. + self.weighted_mean = WeightedMean(bottleneck_dim, + causal=causal) + + self.out_proj = ActivationAndLinear( + bottleneck_dim, + channels, + activation="SwashR", + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + aux_loss_scale: float = 0.0, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + """ + input_scale = 3. + x = self.in_proj(x * input_scale) # (time, batch, 3*bottleneck_dim) + + x, y = x.chunk(2, dim=2) + y = self.sigmoid(y) + x = self.activation1(x) # identity. + + # x: (time, batch, channels) + # Caution: this module is not completely + # invariant to the number of frames each sequence is padded with, since + # the FFT-based convolution treats the signal as repeating. + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) + + + wm = self.weighted_mean(x, + src_key_padding_mask, + apply_mask=False) # just applied it. + x = x.permute(1, 2, 0) # (batch, bottleneck_dim, time) + if self.causal: + # Not support exporting a model for simulated streaming decoding + assert not torch.jit.is_scripting() and not torch.jit.is_tracing() + x_shape = x.shape + x = torch.nn.functional.pad(x, (self.left_pad, 0)) + x = self.depthwise_conv(x) + assert x.shape == x_shape, (x.shape, x_shape) + else: + x = self.depthwise_conv(x) # x: (time, batch, bottleneck_dim) + x = x.permute(2, 0, 1) # (time, batch, bottleneck_dim) + x = x + wm # Add in the weighted-mean to the convolution; this adds extra power + # because the utterances differ in length. + + x = x * y + x = self.out_proj(x) # (time, batch, channels) + + return x + + def streaming_forward( + self, + x: Tensor, + cached_conv: Tensor, + cached_wm_sum: Tensor, + cached_wm_num_frames: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Compute convolution module in streaming mode. + + Args: + x: Input tensor (#time, batch, channels). + cached_conv: cached left context for depthwise_conv, of shape + (#batch, channels, left_pad) + cached_wm_sum: (1, batch, channels), cumulative sum for weighted_mean + cached_wm_num_frames: (batch,), number of frames seen so far + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cached_conv (#batch, channels, left_pad) + - Updated cached_wm_sum (1, batch, channels) + - Updated cached_wm_num_frames (batch,) + """ + input_scale = 3. + x = self.in_proj(x * input_scale) # (time, batch, 3*bottleneck_dim) + + x, y = x.chunk(2, dim=2) + y = self.sigmoid(y) + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.t().unsqueeze(-1).expand_as(x), 0.0) + + wm, cached_wm_sum, cached_wm_num_frames = self.weighted_mean.streaming_forward( + x, cached_wm_sum, cached_wm_num_frames + ) + + x = x.permute(1, 2, 0) # (batch, bottleneck_dim, time) + + x_shape = x.shape + assert cached_conv.shape[-1] == self.left_pad, (cached_conv.shape[-1], self.left_pad) + x = torch.cat([cached_conv, x], dim=2) + cached_conv = x[..., -self.left_pad:] + + x = self.depthwise_conv(x) + assert x.shape == x_shape, (x.shape, x_shape) + + x = x.permute(2, 0, 1) # (time, batch, bottleneck_dim) + x = x + wm + + x = x * y + x = self.out_proj(x) # (time, batch, channels) + + return x, cached_conv, cached_wm_sum, cached_wm_num_frames + + +def _test_zapformer_main(causal: bool = False): + seq_len = 20 + # Just make sure the forward pass runs. + + input_dim = 50 + + c = Zapformer( + input_dim=input_dim, + encoder_dim=(64, 96), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + ) + + batch_size = 6 + seq_len = 21 + # Just make sure the forward pass runs. + f, lengths = c( + torch.randn(seq_len, batch_size, input_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + aux_loss_scale=1.0, + ) + f.sum().backward() + c.eval() + x_ = c( + torch.randn(seq_len, batch_size, input_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + aux_loss_scale=1.0, + ) + x_ # to remove flake8 warnings + + logging.info(f"Zapformer forward test passed, causal={causal}") + + +def _test_zapformer_streaming(): + input_dim = 50 + batch_size = 2 + chunk_size = 32 + num_chunks = 10 + tail_chunk_size = 8 + seq_len = chunk_size * num_chunks + tail_chunk_size + left_context_frames = 128 + + model = Zapformer( + input_dim=input_dim, + encoder_dim=(64, 96, 128, 96), + num_heads=(4, 4, 4, 4), + conv_params=(31, 31, 15, 31), # it may be better to make these even if not in causal mode. + downsampling_factor=(1, 2, 4, 2), + causal=True, + chunk_size=(chunk_size,), + left_context_frames=(left_context_frames,), + ) + + model.compute_projection_overlap(verbose=True) + + model.eval() + + x_full = torch.randn(seq_len, batch_size, input_dim) + x_lens_full = torch.full((batch_size,), seq_len, dtype=torch.int64) + + with torch.no_grad(): + out_full, out_lens_full = model(x_full, x_lens_full) + + caches = model.get_init_caches(batch_size=batch_size) + + out_chunks = [] + out_offset = 0 + processed_lens = torch.full((batch_size,), 0, dtype=torch.int64) + + for i in range(num_chunks): + start = i * chunk_size + end = start + chunk_size + x_chunk = x_full[start:end] + x_lens = torch.full((batch_size,), chunk_size, dtype=torch.int64) + + src_key_padding_mask = make_pad_mask(x_lens) + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_frames).expand(batch_size, left_context_frames) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + out_chunk, out_lens, caches = model.streaming_forward( + x=x_chunk, + x_lens=x_lens, + caches=caches, + src_key_padding_mask=src_key_padding_mask, + ) + out_chunks.append(out_chunk) + + out_chunk_len = out_chunk.shape[0] + expected_out = out_full[out_offset : out_offset + out_chunk_len] + diff_chunk = torch.max(torch.abs(expected_out - out_chunk)) + logging.info(f"Chunk {i+1} | Input: {x_chunk.shape} -> Output: {out_chunk.shape} | Max diff: {diff_chunk}") + assert torch.allclose(expected_out, out_chunk, atol=2e-5), f"Chunk {i+1} outputs do not match! Max diff: {diff_chunk}" + + out_offset += out_chunk_len + + x_tail = x_full[num_chunks * chunk_size:] + x_lens_tail = torch.full((batch_size,), tail_chunk_size, dtype=torch.int64) + src_key_padding_mask = make_pad_mask(x_lens_tail) + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_frames).expand(batch_size, left_context_frames) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + processed_lens = processed_lens + x_lens_tail + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + out_tail, out_lens_tail, caches = model.streaming_forward( + x=x_tail, + x_lens=x_lens_tail, + caches=caches, + src_key_padding_mask=src_key_padding_mask, + ) + out_chunks.append(out_tail) + + out_tail_len = out_tail.shape[0] + expected_out_tail = out_full[out_offset : out_offset + out_tail_len] + diff_tail = torch.max(torch.abs(expected_out_tail - out_tail)) + logging.info(f"Tail Chunk | Input: {x_tail.shape} -> Output: {out_tail.shape} | Max diff: {diff_tail}") + assert torch.allclose(expected_out_tail, out_tail, atol=2e-5), f"Tail Chunk outputs do not match! Max diff: {diff_tail}" + out_offset += out_tail_len + + out_stream_cat = torch.cat(out_chunks, dim=0) + + diff = torch.max(torch.abs(out_full - out_stream_cat)) + logging.info(f"Max abs diff between full forward and streaming forward: {diff}") + + assert torch.allclose(out_full, out_stream_cat, atol=2e-5), f"Outputs do not match! Max diff: {diff}" + + logging.info("Passed") + + + +def _test_basis_conv(): + num_channels = 11 + f = BasisConv(num_channels=num_channels, + num_freqs=4, + params_per_channel=2) + + seq_len = 100 + subseq_len = 10 # will help visualize the effect + batch_size = 2 + x = torch.cat((torch.randn(subseq_len, batch_size, num_channels), + torch.zeros(seq_len - subseq_len, batch_size, num_channels)), + dim=0) + + y = f(x) + + #plt.plot(x[:, 0, 0].detach()) + #plt.plot(y[:, 0, 0].detach()) + #plt.show() + + + def rms(a): + return (a**2).mean().item() + print(f"rms(x)={rms(x)}, rms(y)={rms(y)}") + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + # _test_basis_conv() + # _test_zapformer_main(False) + # _test_zapformer_main(True) + _test_zapformer_streaming() diff --git a/egs/librispeech/ASR/zapformer/zapformer_modules.py b/egs/librispeech/ASR/zapformer/zapformer_modules.py new file mode 100644 index 0000000000..aaf31378b8 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/zapformer_modules.py @@ -0,0 +1,1002 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import math +import copy +import random +from typing import Optional, Tuple, Union, Any + +import torch +import torch.nn as nn +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd +from zapformer_utils import limit_param_value + + +def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: + max_value = torch.max(x, y) + diff = torch.abs(x - y) + return max_value + torch.log1p(torch.exp(-diff)) + + +# RuntimeError: Exporting the operator logaddexp to ONNX opset version +# 14 is not supported. Please feel free to request support or submit +# a pull request on PyTorch GitHub. +# +# The following function is to solve the above error when exporting +# models to ONNX via torch.jit.trace() +def logaddexp(x: Tensor, y: Tensor) -> Tensor: + # Caution(fangjun): Put torch.jit.is_scripting() before + # torch.onnx.is_in_onnx_export(); + # otherwise, it will cause errors for torch.jit.script(). + # + # torch.logaddexp() works for both torch.jit.script() and + # torch.jit.trace() but it causes errors for ONNX export. + # + if torch.jit.is_scripting(): + # Note: We cannot use torch.jit.is_tracing() here as it also + # matches torch.onnx.export(). + return torch.logaddexp(x, y) + elif torch.onnx.is_in_onnx_export(): + return logaddexp_onnx(x, y) + else: + # for torch.jit.trace() + return torch.logaddexp(x, y) + + +# all arg tensors except x are scalars. +def _sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, mask: Optional[Tensor]): + stats = (x ** 2).mean(dim=2, keepdim=True) + T = x.shape[0] # time + if mask is None: + stats = stats.sum(dim=0) + lengths = torch.tensor(T, dtype=stats.dtype, device=stats.device) + else: + mask_f = (~mask).to(torch.float).t().unsqueeze(-1) + stats = stats * mask_f + stats = stats.sum(dim=0) + lengths = mask_f.sum(dim=0) + + scales = (lengths / stats).sqrt() + assert scales.shape == (x.shape[1], 1) + return x * ((scale * scales) + offset) + +# all arg tensors except x are scalars. +def _causal_sequence_norm(x: Tensor, offset: Tensor, scale: Tensor, ballast_rms: Tensor, ballast_frames: Tensor): + stats = (x ** 2).mean(dim=2, keepdim=True) + + # no need for mask in causal mode. + # ballast_frames should normally be positive due to limit_param_value, but there can be small excursions, so + # make absolutely sure using abs(). + ballast_frames = 100.0 * ballast_frames.abs() + ballast = ballast_frames * (ballast_rms ** 2) + T = x.shape[0] # time + + stats = stats.cumsum(dim=0) + ballast + lengths = ballast_frames + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] + + scales = (lengths / stats).sqrt() + assert scales.shape == (T, x.shape[1], 1) + return x * ((scale * scales) + offset) + + +# all arg tensors are scalars +def _causal_sequence_norm_streaming( + x: Tensor, + offset: Tensor, + scale: Tensor, + cached_stats_sum: Tensor, + cached_len: Tensor, +) -> Tuple[Tensor, Tensor, Tensor]: + """Streaming inference forward for _sequence_norm. We assume that ballast_frames and ballast_rms + are already included in cached_stats_sum and cached_len. + + Args: + x: (seq_len, batch_size, channels) + offset: scalar + scale: scalar + cached_stats_sum: (batch_size,) + cached_len: (batch_size,) + + Returns: + - normalized x, (seq_len, batch_size, channels) + - updated cached_stats_sum, (batch_size,) + - updated cached_len, (batch_size,) + """ + stats = (x ** 2).mean(dim=2, keepdim=True) # (seq_len, batch_size, 1) + + T = x.shape[0] # time + + stats = stats.cumsum(dim=0) + cached_stats_sum.unsqueeze(-1) + lengths = cached_len[:, None] + torch.arange(1, T + 1, dtype=x.dtype, device=x.device)[:, None, None] + + # update cached_stats_sum and cached_len for the next chunk + cached_stats_sum = stats[-1].squeeze(-1) # (batch_size,) + cached_len = cached_len + T + + scales = (lengths / stats).sqrt() # (T, batch_size, 1) + assert scales.shape == (T, x.shape[1], 1) + return x * ((scale * scales) + offset), cached_stats_sum, cached_len + + +class CausalSequenceNormFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + offset: Tensor, + scale: Tensor, + ballast_rms: Tensor, + ballast_frames: Tensor, + ) -> Tensor: + ctx.save_for_backward(x, offset, scale, ballast_rms, ballast_frames) + return _causal_sequence_norm(x, offset, scale, ballast_rms, ballast_frames) + + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + x, offset, scale, ballast_rms, ballast_frames = ctx.saved_tensors + + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float32).detach().requires_grad_() + offset = offset.to(torch.float32).detach().requires_grad_() + scale = scale.to(torch.float32).detach().requires_grad_() + ballast_rms = ballast_rms.to(torch.float32).detach().requires_grad_() + ballast_frames = ballast_frames.to(torch.float32).detach().requires_grad_() + + with torch.enable_grad(): + ans = _causal_sequence_norm(x, offset, scale, ballast_rms, ballast_frames) + ans.backward(gradient=ans_grad.to(torch.float32)) + + def c(x): + # this is to replace infinities that might be thrown up + # in autocast mode: scalars will tend to have larger grads than non-scalars, + # this code is to reduce the probabilities that any infinities could crash the + # training (it may still happen if the world-size is so large that these + # infinities get added together though). + return x.clamp_(min=-30000.0, max=30000.0) + + return x.grad, c(offset.grad), c(scale.grad), c(ballast_rms.grad), c(ballast_frames.grad) + + +class SequenceNormFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + offset: Tensor, + scale: Tensor, + mask: Optional[Tensor], + ) -> Tensor: + ctx.save_for_backward(x, offset, scale) + ctx.mask = mask + return _sequence_norm(x, offset, scale, mask) + + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + x, offset, scale = ctx.saved_tensors + + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float32).detach().requires_grad_() + offset = offset.to(torch.float32).detach().requires_grad_() + scale = scale.to(torch.float32).detach().requires_grad_() + + with torch.enable_grad(): + ans = _sequence_norm(x, offset, scale, ctx.mask) + ans.backward(gradient=ans_grad.to(torch.float32)) + + def c(x): + # this is to replace infinities that might be thrown up + # in autocast mode: scalars will tend to have larger grads than non-scalars, + # this code is to reduce the probabilities that any infinities could crash the + # training (it may still happen if the world-size is so large that these + # infinities get added together though). + return x if x is None else x.clamp_(min=-30000.0, max=30000.0) + + return x.grad, c(offset.grad), c(scale.grad), None + + +class CausalSequenceNorm(torch.nn.Module): + """ + This is like RMSNorm but the stats for the RMS value of x are aggregated over the whole sequence + up to the current point as well as the channels, with some padding of the stats with "default values" + determined by ballast_frames, ballast_rms for robustness near the beginning of the sequence. + + There is also a learnable scalar scale, multiplicatively applied to the output, and a learnable + "offset" value that acts multiplicatively on the input without taking into account the rms values. + """ + def __init__( + self, + ) -> None: + super().__init__() + self.scale = nn.Parameter(torch.tensor(0.5)) + self.offset = nn.Parameter(torch.tensor(0.0001)) + + # ballast_mean: assumed rms value of ballast frames used to pad stats + self.ballast_rms = nn.Parameter(torch.tensor(0.1)) + # ballast_frames: number of ballast frames, in hundreds (will be multiplied by 100) + self.ballast_frames = nn.Parameter(torch.tensor(0.05)) # number of ballast frames, will be multiplied by 100 + self.name = None + + def forward(self, x: Tensor, _mask: Optional[Tensor] = None) -> Tensor: + # x: (seq, batch, channel) + # The mask is ignored, it is allowed only for consistency of interface with SequenceNorm. + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return _causal_sequence_norm(x, self.offset, self.scale, self.ballast_rms, self.ballast_frames) + + scale = limit_param_value( + self.scale, min=0.05, max=2.0, training=self.training) + + offset = limit_param_value( + self.offset, min=0.0, max=10.0, training=self.training) + + ballast_rms = limit_param_value( + self.ballast_rms, min=0.0, max=10.0, training=self.training) + + ballast_frames = limit_param_value( + self.ballast_frames, min=0.0, max=5.0, training=self.training) # max of 5.0 would be 500 frames + + ans = CausalSequenceNormFunction.apply( + x, offset, scale, ballast_rms, ballast_frames, + ) + + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, offset={self.offset.item()}, ballast_rms={self.ballast_rms.item()}, ballast_frames*100={100*self.ballast_frames.item()}") + + return ans + + @torch.jit.export + def get_init_cache(self, batch_size: int): + """Get initial cache for streaming inference. We first include the ballast stats in the initial cache. + """ + # ballast_frames should normally be positive due to limit_param_value, but there can be small excursions, so + # make absolutely sure using abs(). + ballast_frames = 100.0 * self.ballast_frames.abs() + ballast = ballast_frames * (self.ballast_rms ** 2) + + cached_stats_sum = ballast.unsqueeze(0).repeat(batch_size) # (batch_size,) + cached_len = ballast_frames.unsqueeze(0).repeat(batch_size) # (batch_size,) + + return cached_stats_sum, cached_len + + def streaming_forward( + self, + x: Tensor, + cached_stats_sum: Tensor, + cached_len: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + + x, cached_stats_sum, cached_len = _causal_sequence_norm_streaming( + x, self.offset, self.scale, cached_stats_sum, cached_len) + return x, cached_stats_sum, cached_len + + +class SequenceNorm(torch.nn.Module): + """ + This is like RMSNorm but the stats for the RMS value of x are aggregated over the whole sequence + as well as the channels; and a padding mask is used for irregular length sequences (actually, + the mask is applied multiplicatively as well.) + + There is also a learnable scalar scale and a learnable "offset" value. + """ + def __init__( + self, + ) -> None: + super().__init__() + self.scale = nn.Parameter(torch.tensor(0.5)) + self.offset = nn.Parameter(torch.tensor(0.0001)) + self.name = None + + def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: + # x: (seq, batch, channel) + # mask: bool, (batch_size, seq_len) + # Note: mask is ignored in causal mode. + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return _sequence_norm(x, self.offset, self.scale, mask) + + scale = limit_param_value( + self.scale, min=0.05, max=2.0, training=self.training) + + offset = limit_param_value( + self.offset, min=0.0, max=10.0, training=self.training) + + ans = SequenceNormFunction.apply( + x, offset, scale, mask, + ) + + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, scale={self.scale.item()}, offset={self.offset.item()}") + + return ans + + @torch.jit.export + def get_init_cache(self, batch_size: int): + """Get initial cache for streaming inference.""" + cached_stats_sum = torch.zeros(batch_size) + cached_len = torch.zeros(batch_size) + return cached_stats_sum, cached_len + + +# assume layout: (time, batch, channel) +def _rms_norm(x: Tensor, eps: Tensor, scale: Tensor): + x_sq = torch.mean(x ** 2, dim=2, keepdim=True) + (eps * eps) + scales = scale / x_sq.sqrt() + return x * scales + + +class RmsNormFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + eps: Tensor, + scale: Tensor, + ) -> Tensor: + ctx.save_for_backward(x, eps, scale) + return _rms_norm(x, eps, scale) + + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + x, eps, scale = ctx.saved_tensors + + with torch.amp.autocast('cuda', enabled=False): + x, eps, scale = x.to(torch.float32), eps.to(torch.float32), scale.to(torch.float32) + x, eps, scale = x.detach(), eps.detach(), scale.detach() + + x.requires_grad = True + eps.requires_grad = True + scale.requires_grad = True + + with torch.enable_grad(): + ans = _rms_norm(x, eps, scale) + ans.backward(gradient=ans_grad.to(torch.float32)) + + def c(x): + # this is to replace infinities that might be thrown up + # in autocast mode. + return x.clamp_(min=-30000.0, max=30000.0) + + return x.grad, c(eps.grad), c(scale.grad) + + +class RmsNorm(torch.nn.Module): + """ + This is RMSNorm with a trainable scale and trainable epsilon. + """ + def __init__( + self, + ) -> None: + super(RmsNorm, self).__init__() + self.scale = nn.Parameter(torch.tensor(0.2)) # output scale + self.eps = nn.Parameter(torch.tensor(0.1)) + self.name = None + + + def forward(self, x: Tensor) -> Tensor: + # Assumes layout is (time, batch, channel) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return _rms_norm(x, self.eps, self.scale) + + scale = limit_param_value( + self.scale, min=0.05, max=1.0, training=self.training) + + eps = limit_param_value( + self.eps, min=0.0, max=10.0, training=self.training) + + ans = RmsNormFunction.apply( + x, eps, scale, + ) + + if random.random() < 0.002: + x_rms = (x ** 2).mean().sqrt() + ans_rms = (ans ** 2).mean().sqrt() + logging.info(f"name={self.name}: x_rms={x_rms}, ans_rms={ans_rms}, eps={eps.item()}, scale={scale.item()}") + + return ans + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.01 * initial_scale, 0.01 * initial_scale) + return ans + + +class OrthogonalPenaltyFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, weight: Tensor, penalty_scale: float, name: str): + ctx.save_for_backward(weight) + ctx.name = name + ctx.penalty_scale = penalty_scale + return weight + + @staticmethod + @custom_bwd + def backward(ctx, weight_grad): + weight, = ctx.saved_tensors + + if weight.requires_grad and ctx.penalty_scale != 0.0: + penalty_scale = ctx.penalty_scale * weight_grad.abs().mean() + + with torch.enable_grad(): + weight = weight.detach() + weight.requires_grad = True + + # Compute symmetric matrix-product prod with the smallest + # dimension possible given the shape of w. This is not just for + # efficiency; if we computed it the wrong way round, the product + # would have deficient rank and could never be the identity. + if (weight.shape[0] > weight.shape[1]): + prod = torch.matmul(weight.t(), weight) + else: + prod = torch.matmul(weight, weight.t()) + + # we'll try to enforce that for any i, prod[i] is any constant times the identity. + + # in the loss-function: + # orthogonality_loss = ((prod - I) ** 2).sum(), + + # note, prod_diag shares memory with prod, this will matter later on. + (r, c) = prod.shape + (r_stride, c_stride) = prod.stride() + + def diag_inplace(z): + return torch.as_strided(z, size=(r,), stride=(r_stride+c_stride,)) + + diag_inplace(prod)[:] -= 1. + + # that loss that we want to backprop would be 0.5 * (prod ** + # 2).sum() * penalty_scale. we can backprop this without doing + # any reductions as follows: + prod.backward(gradient=prod * penalty_scale) + + + do_print = random.random() < 0.002 + if do_print: + # we print a normalized version of the loss, by dividing by the + # number of rows. + loss = (prod ** 2).mean() + logging.info(f"OrthogonalLinear: name={ctx.name}, loss={loss.detach().cpu()}, penalty_scale={penalty_scale}, grad_abs_mean={weight_grad.abs().mean()}") + + + # add the extra gradient term from the orthogonality loss. + weight_grad = weight_grad + weight.grad + return weight_grad, None, None + +class OrthogonalLinear(nn.Linear): + """ + Like nn.Linear but can enforce that the weight matrix is orthogonal; in the non-square + case this is interpreted as either M^T M == I or M M^T == I, whichever would give a smaller + dimension. + (If M is square, these definitions are equivalent and is equivalent to the normal + definition of orthogonal). + + Args: + in_channels: number of input channels + out_channels: number of output channels + weight_rms: the rms value of the physical weights in self.weights; we rescale the weights + to achieve this while respecting the orthogonal constraint, as a way + of reducing the relative learning speed of this module. (larger weight_rms -> + slower learning, in general). + bias: if True, include a bias term. + penalty_scale: a scale on the penalty on non-orthogonality (this will + be multiplied by the average-absolute-value of the + backpropagated gradient). + """ + # if in_groups or out_groups are set to >1, the orthogonal constraint + # will be set per group. both of them cannot be >1. + def __init__(self, + in_channels: int, + out_channels: int, + weight_rms: float = 0.3, + bias: bool = True, + penalty_scale: float = 20.0, + ): + super().__init__(in_channels, out_channels, bias=bias) + self.name = None + self.penalty_scale = copy.deepcopy(penalty_scale) + + self.weight_scale = (in_channels ** -0.5) / weight_rms + with torch.no_grad(): + self.weight[:] = torch.randn(out_channels, in_channels) * weight_rms + if self.bias is not None: + torch.nn.init.uniform_(self.bias, -0.01, 0.01) + + def get_weight(self): + return self.weight * self.weight_scale + + def forward(self, x: Tensor, transpose: bool = False): + # you can only use transpose=True if you used bias=False in initialization + weight = self.get_weight() + if self.training and not torch.jit.is_scripting() and not torch.jit.is_tracing(): + weight = OrthogonalPenaltyFunction.apply(weight, float(self.penalty_scale), self.name) + if transpose: + weight = weight.t() + return torch.nn.functional.linear(x, weight, self.bias) + + +class ScaleLimiterFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, max_rms: float, aux_loss_scale: float, name: str): + ctx.save_for_backward(x) + ctx.max_rms = max_rms + ctx.aux_loss_scale = aux_loss_scale + ctx.name = name + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + x, = ctx.saved_tensors + with torch.enable_grad(): + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float) + x = x.detach() + x.requires_grad = True + rms = (x ** 2).mean(dim=-1).sqrt() + numel = rms.numel() + + excess = (rms / ctx.max_rms - 1.).relu().mean() + + if random.random() < 0.002: + logging.info( + f"ScaleLimiter: name={ctx.name}, max_rms={ctx.max_rms}, " + f"rms={rms.mean().item()}, excess={excess.item()}, " + f"loss_scale={ctx.aux_loss_scale}" + ) + excess.backward(gradient=torch.full_like(excess, ctx.aux_loss_scale * numel)) + return x_grad + x.grad, None, None, None + + +class ScaleLimiter(torch.nn.Module): + """ + Adds a penalty in backprop if the norm of any activation vector is less than min_rms + or more than max_rms. + + Assumes channel dim is -1 and the input shape has >1 dimension. + """ + def __init__(self, max_rms: float): + super().__init__() + self.name = None + self.max_rms = max_rms + + + def forward(self, x: Tensor, aux_loss_scale: float) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return _no_op(x) + else: + return ScaleLimiterFunction.apply(x, float(self.max_rms), + aux_loss_scale, self.name) + + +class CorrelationLimiterFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, aux_loss_scale: float, limit: float, mask: Optional[Tensor], name: str): + ctx.save_for_backward(x) + ctx.mask = mask + ctx.limit = limit + ctx.aux_loss_scale = aux_loss_scale + ctx.name = name + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): # assume ans_grad is 1.0 + x, = ctx.saved_tensors + mask = ctx.mask + aux_loss_scale = ctx.aux_loss_scale + (batch_size, seq_len, num_channels) = x.shape + + with torch.enable_grad(): + with torch.amp.autocast('cuda', enabled=False): + x = x.to(torch.float) + x = x.detach() + x.requires_grad = True + x_orig = x + + def norm(x: Tensor): + eps = 1.0e-20 + return x / ((x ** 2).mean(dim=-1, keepdim=True) + eps).sqrt() + x = norm(x) + + if mask is not None: + mask = (~mask).to(x.dtype).unsqueeze(-1) + x = x * mask + + half_batch = batch_size // 2 + if half_batch <= 1: + # the reason we also return None if half_batch==1 is because of CR-CTC + # where they may really be duplicates + return None, None, None, None, None + + + #x = torch.cat((x, y), dim=-1) + C = x.shape[-1] # num_channels + x1, x2 = x[0::2], x[1::2] + x1 = x1.reshape(-1, C) + x2 = x2.reshape(-1, C) + + if mask is not None: + numel1 = mask[0::2].sum() + numel2 = mask[1::2].sum() + else: + numel1 = x1.shape[0] + numel2 = x2.shape[0] + + S1 = torch.matmul(x1.t(), x1) * (1. / numel1) + S2 = torch.matmul(x2.t(), x2) * (1. / numel2) + + # S1, S2: (N, N) where N = min(num_channels, max_channels) + correlation = (S1 * S2).mean() + loss = (correlation - ctx.limit).relu() + + if random.random() < 0.001: + logging.info( + f"CorrelationLimiter: name={ctx.name}, loss_scale={aux_loss_scale}, correlation={correlation}, limit={ctx.limit}, loss={loss}" + ) + + loss.backward(gradient=torch.tensor(aux_loss_scale * batch_size * seq_len, device=loss.device)) + + + return x_orig.grad, None, None, None, None + + +class CorrelationLimiter(torch.nn.Module): + """ + Adds a penalty in backprop if the input feature has a covariance matrix that is + too different from the identity matrix. limit=1/num_channels is the + smallest limit you can provide but the limit should be much larger than + this, like 1/sqrt(num_channels). + + Assumes input is (batch, seq, channel) + """ + def __init__(self, limit: float = 0.03): + super().__init__() + self.name = None + self.limit = limit + + + def forward(self, x: Tensor, aux_loss_scale: float, mask: Optional[Tensor]) -> Tensor: + # x should be: (batch, seq, channel) + # returns a scalar tensor that should be included in the loss function with: + # z = with_loss(z, ret, None) + # where z is any quantity that will be used in calculating the main loss. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return torch.tensor(0.0, device=x.device) + else: + return CorrelationLimiterFunction.apply(x, + aux_loss_scale, + float(self.limit), + mask, + self.name) + + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +def torch_compile(fn, *args, **kwargs): + if hasattr(torch, 'compile'): + fn = torch.compile(fn, *args, **kwargs, dynamic=True, options={"shape_padding": True, "force_shape_pad": True}) + return fn + +def swashl(x: Tensor) -> Tensor: + zero = torch.zeros_like(x) + return 0.25 * logaddexp(zero, 4 * x - 4.0) - 0.08 * x - 0.00875 + +def swashr(x: Tensor) -> Tensor: + zero = torch.zeros_like(x) + return 0.25 * logaddexp(zero, 4 * x - 1.0) - 0.08 * x - 0.07831542175 + + +def swashl_and_deriv(x: Tensor): + x_offset = 4. * x - 4. + denom = 1. + x_offset.exp() + inv_denom = 1. / denom # note: 1 / infinity = 0. + deriv = 0.92 - inv_denom; + log_denom = denom.log() + log_denom = torch.where(torch.isinf(log_denom), x_offset, log_denom) + y = 0.25 * log_denom - 0.08 * x - 0.00875 + return y, deriv + +def swashr_and_deriv(x: Tensor): + x_offset = 4. * x - 1. + denom = 1. + x_offset.exp() + inv_denom = 1. / denom # note: 1 / infinity = 0. + deriv = 0.92 - inv_denom; + log_denom = denom.log() + log_denom = torch.where(torch.isinf(log_denom), x_offset, log_denom) + y = 0.25 * log_denom - 0.08 * x - 0.07831542175 + return y, deriv + + +class SwashL(torch.nn.Module): + def __init__(self): + super().__init__() + self.func = torch_compile(swashl) + def forward(self, x: Tensor) -> Tensor: + """Return Swash-L activation, which is the same as SwooshL but with a factor of 4 + on the input and 0.25 on the output..""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return swashl(x) + return self.func(x) + +class SwashR(torch.nn.Module): + def __init__(self): + super().__init__() + self.func = torch_compile(swashr) + def forward(self, x: Tensor) -> Tensor: + """Return Swash-R activation, which is the same as SwooshL but with a factor of 4 + on the input and 0.25 on the output..""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return swashr(x) + return self.func(x) + + +class ActivationAndLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + forward_func: Any, + backward_func: Any, + ): + ctx.save_for_backward(x, weight, bias) + + ctx.backward_func = backward_func + + x = forward_func(x) + x = torch.nn.functional.linear(x, weight, bias) + return x + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor): + saved = ctx.saved_tensors + (x, weight, bias) = saved + + y, func_deriv = ctx.backward_func(x) + # now compute derivative of y w.r.t. weight and bias.. + # y: (..., in_channels), ans_grad: (..., out_channels), + (out_channels, in_channels) = weight.shape + + in_channels = y.shape[-1] + g = ans_grad.reshape(-1, out_channels) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) + y_deriv = torch.matmul(ans_grad, weight) + bias_deriv = None if bias is None else g.sum(dim=0) + x_deriv = y_deriv * func_deriv + return x_deriv, weight_deriv, bias_deriv, None, None + + +class ActivationAndLinear(torch.nn.Module): + """ + This merges an activation function followed by a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwashL, this will be + equivalent to: + nn.Sequential(SwashL(), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwashL, SwashR. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwashL", + initial_scale: float = 1.0, + ): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + l = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) + + self.weight = l.weight + # register_parameter properly handles making it a parameter when l.bias + # is None. I think there is some reason for doing it this way rather + # than just setting it to None but I don't know what it is, maybe + # something to do with exporting the module.. + self.register_parameter("bias", l.bias) + + self.activation = activation + + assert activation in ["SwashL", "SwashR"] + if activation == "SwashL": + self.forward_func = torch_compile(swashl) + self.backward_func = torch_compile(swashl_and_deriv) + else: + self.forward_func = torch_compile(swashr) + self.backward_func = torch_compile(swashr_and_deriv) + + + def forward(self, x: Tensor): + if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + if self.activation == "SwashL": + x = swashl(x) + else: + x = swashr(x) + else: + x = self.forward_func(x) + return torch.nn.functional.linear(x, self.weight, self.bias) + + return ActivationAndLinearFunction.apply( + x, + self.weight, + self.bias, + self.forward_func, + self.backward_func, + ) + + +def _test_swashl_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwashL() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swashr_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwashR() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_activation_and_linear(): + in_channels = 20 + out_channels = 30 + + for bias in [True, False]: + if True: + for activation in ["SwashL", "SwashR"]: + m1 = nn.Sequential( + SwashL() if activation == "SwashL" else SwashR(), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + ) + with torch.no_grad(): + m2.weight[:] = m1[1].weight + if bias: + m2.bias[:] = m1[1].bias + # make sure forward gives same result. + x1 = torch.randn(10, in_channels) + x1.requires_grad = True + + + x2 = x1.clone().detach() + x2.requires_grad = True + seed = 10 + torch.manual_seed(seed) + y1 = m1(x1) + y_grad = torch.randn_like(y1) + y1.backward(gradient=y_grad) + torch.manual_seed(seed) + y2 = m2(x2) + y2.backward(gradient=y_grad) + + print( + f"bias = {bias}, activation = {activation}" + ) + print("y1 = ", y1) + print("y2 = ", y2) + assert torch.allclose(y1, y2, atol=0.02) + print("grad1 = ", m1[1].weight.grad) + print("grad2 = ", m2.weight.grad) + + assert torch.allclose(m1[1].weight.grad, m2.weight.grad, atol=1.0e-05) + if bias: + assert torch.allclose(m1[1].bias.grad, m2.bias.grad, atol=1.0e-05) + print("x1.grad = ", x1.grad) + print("x2.grad = ", x2.grad) + + def isclose(a, b): + # return true if cosine similarity is > 0.9. + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + + # the SwashL() implementation has a noisy gradient due to 1-byte + # storage of it. + assert isclose(x1.grad, x2.grad) + + +def _test_orthogonal_linear(): + m = OrthogonalLinear(128, 128) + m(torch.randn(30, 2, 128)) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_swashr_deriv() + _test_swashl_deriv() + _test_activation_and_linear() + _test_orthogonal_linear() diff --git a/egs/librispeech/ASR/zapformer/zapformer_utils.py b/egs/librispeech/ASR/zapformer/zapformer_utils.py new file mode 100644 index 0000000000..e7db94b884 --- /dev/null +++ b/egs/librispeech/ASR/zapformer/zapformer_utils.py @@ -0,0 +1,181 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import math +import copy +import random +from typing import Optional, Tuple, Union, Any + +import k2 +import torch +import torch.nn as nn +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + + + +class SoftmaxFunction(torch.autograd.Function): + """ + A memory-efficient implementation of softmax that does not require + storing anything as fp32 in autocast mode. + """ + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.get_autocast_gpu_dtype()) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans,) = ctx.saved_tensors + with torch.amp.autocast('cuda', enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): + return x.softmax(dim=dim) + + return SoftmaxFunction.apply(x, dim) + + +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = "" +) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss, name) + # you must use x for something, or this will be ineffective. + return x + + + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.y_shape = y.shape + ctx.dtype = y.dtype + if random.random() < 0.002 and name is not None: + loss_sum = y.sum().item() + logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ctx.dtype, device=ans_grad.device), + None, + ) + + +def with_loss(x: Tensor, y: Tensor, name: str = "") -> Tensor: + if torch.jit.is_scripting(): + return x + return WithLoss.apply(x, y, name) + + +class LimitParamValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, min: float, max: float): + ctx.save_for_backward(x) + assert max >= min + ctx.min = min + ctx.max = max + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x,) = ctx.saved_tensors + # where x < ctx.min, ensure all grads are negative (this will tend to make + # x more positive). + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) + # where x > ctx.max, ensure all grads are positive (this will tend to make + # x more negative). + x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + return x_grad, None, None + + +def limit_param_value( + x: Tensor, min: float, max: float, training: bool = True +): + # You apply this to (typically) an nn.Parameter during training to ensure that its + # (elements mostly) stays within a supplied range. This is done by modifying the + # gradients in backprop. + if torch.jit.is_scripting(): + return x + if training: + return LimitParamValue.apply(x, min, max) + else: + return x + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:, 0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:, 0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_softmax() diff --git a/egs/librispeech/ASR/zapformer2/.gitignore b/egs/librispeech/ASR/zapformer2/.gitignore new file mode 100644 index 0000000000..e47ac15828 --- /dev/null +++ b/egs/librispeech/ASR/zapformer2/.gitignore @@ -0,0 +1 @@ +swoosh.pdf diff --git a/egs/librispeech/ASR/zipformer/batched_rubik.py b/egs/librispeech/ASR/zipformer/batched_rubik.py new file mode 120000 index 0000000000..5c024cfd72 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/batched_rubik.py @@ -0,0 +1 @@ +../zapformer/batched_rubik.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer/combined_scheduler.py b/egs/librispeech/ASR/zipformer/combined_scheduler.py new file mode 120000 index 0000000000..04a0322459 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/combined_scheduler.py @@ -0,0 +1 @@ +../zapformer/combined_scheduler.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 6462d22f86..ac6a44ae67 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -98,6 +98,7 @@ import logging import math import os +import re from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -106,7 +107,7 @@ import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import CommonVoice, GigaSpeech, LibriSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -142,6 +143,80 @@ LOG_EPS = math.log(1e-10) +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def giga_asr_text_post_processing(text: str) -> str: # only used for gigaspeech + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +def giga_post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = giga_asr_text_post_processing(" ".join(ref)).split() + new_hyp = giga_asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def cv_post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + def normalize(text): + return re.sub(r'[^\w\s]', '', text).upper() + new_results = [] + for key, ref, hyp in results: + new_ref = normalize(" ".join(ref)).split() + new_hyp = normalize(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + def get_parser(): parser = argparse.ArgumentParser( @@ -378,6 +453,20 @@ def get_parser(): help="""Skip scoring, but still save the ASR output (for eval sets).""", ) + parser.add_argument( + "--giga", + type=str2bool, + default=False, + help="""If True, decode gigaspeech in addition to librispeech test sets.""", + ) + + parser.add_argument( + "--cv", + type=str2bool, + default=False, + help="""If True, decode commonvoice in addition to librispeech test sets.""", + ) + add_model_arguments(parser) return parser @@ -732,6 +821,10 @@ def save_asr_output( recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) store_transcripts(filename=recogs_filename, texts=results) logging.info(f"The transcripts are stored in {recogs_filename}") @@ -759,6 +852,10 @@ def save_wer_results( logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + if 'giga' in test_set_name: + results = giga_post_processing(results) + if 'cv' in test_set_name: + results = cv_post_processing(results) wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" @@ -1044,12 +1141,34 @@ def main(): test_clean_cuts = librispeech.test_clean_cuts() test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts) + dev_other_dl = librispeech.test_dataloaders(dev_other_cuts) + + test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"] + test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl] + + if args.giga: + gigaspeech = GigaSpeech(args.manifest_dir) + test_cuts = gigaspeech.test_cuts() + dev_cuts = gigaspeech.dev_cuts() + giga_test_dl = librispeech.test_dataloaders(test_cuts) + giga_dev_dl = librispeech.test_dataloaders(dev_cuts) + test_sets += ["giga-dev", "giga-test"] + test_dl += [giga_dev_dl, giga_test_dl] + + if args.cv: + commonvoice = CommonVoice(args.manifest_dir) + test_cuts = commonvoice.test_cuts() + dev_cuts = commonvoice.dev_cuts() + cv_test_dl = librispeech.test_dataloaders(test_cuts) + cv_dev_dl = librispeech.test_dataloaders(dev_cuts) + test_sets += ["cv-dev", "cv-test"] + test_dl += [cv_dev_dl, cv_test_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 6ef2508192..c7ee0e5a6d 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -173,11 +173,22 @@ def forward_ctc( # Compute CTC log-prob ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + # the calls to .long() were added as a workaround for a problem with + # torch.nn.functional.ctc_loss() on newer torch versions. Previously + # instead of .long() we had .cpu(). .cpu() activates the use of CUDNN + # because it only uses CUDNN if integer inputs are in int32 and on CPU. + # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) + # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: + # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. + # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss + # seems to have a bug with "int32" integer arguments (it returns infinity), so we call + # .long() to use the torch implementation and avoid that bug. ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) - targets=targets.cpu(), - input_lengths=encoder_out_lens.cpu(), - target_lengths=target_lengths.cpu(), + targets=targets.long(), + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), reduction="sum", ) return ctc_loss @@ -200,12 +211,22 @@ def forward_cr_ctc( to be un-padded and concatenated within 1 dimension. """ # Compute CTC loss + # the calls to .long() were added as a workaround for a problem with + # torch.nn.functional.ctc_loss() on newer torch versions. Previously + # instead of .long() we had .cpu(). .cpu() activates the use of CUDNN + # because it only uses CUDNN if integer inputs are in int32 and on CPU. + # (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp#L501) + # But on more recent torch/cuda versions we were getting "RuntimeError: cuDNN error: + # CUDNN_STATUS_EXECUTION_FAILED" if we use the CUDNN implementation. + # We can't use (int32, CUDA) for the integer inputs because the torch implementation of ctc_loss + # seems to have a bug with "int32" integer arguments (it returns infinity), so we call + # .long() to use the torch implementation and avoid that bug. ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) - targets=targets.cpu(), - input_lengths=encoder_out_lens.cpu(), - target_lengths=target_lengths.cpu(), + targets=targets.long(), + input_lengths=encoder_out_lens.long(), + target_lengths=target_lengths.long(), reduction="sum", ) diff --git a/egs/librispeech/ASR/zipformer/speech_recognition.py b/egs/librispeech/ASR/zipformer/speech_recognition.py new file mode 120000 index 0000000000..cb33e61085 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/speech_recognition.py @@ -0,0 +1 @@ +../zapformer/speech_recognition.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer/train_newoptim.py b/egs/librispeech/ASR/zipformer/train_newoptim.py new file mode 100755 index 0000000000..d20d996194 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/train_newoptim.py @@ -0,0 +1,1612 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Amir Hussein +# Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default) + - ctc loss + - attention decoder loss + - cr-ctc loss (should use half the max-duration compared to regular ctc) +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from attention_decoder import AttentionDecoderModel +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset import SpecAugment +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from batched_rubik import BatchedRubik +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, + torch_autocast, +) + + +from combined_scheduler import CombinedLRScheduler +from combined_scheduler import InterpCosineLRScheduler + + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler, CombinedLRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--attention-decoder-dim", + type=int, + default=512, + help="""Dimension used in the attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-dim", + type=int, + default=512, + help="""Attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-dim", + type=int, + default=2048, + help="""Feedforward dimension used in attention decoder""", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--batches-per-epoch", + type=int, + default=2200, + help="Assumed number of batches per epoch for purposes of setting learning rate; only " + "makes a difference during the first batch, after which an observed value is used. This " + "is the num batches where num_copies==1, i.e. on the first epoch" + ) + + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.02, help="The base learning rate." + ) + + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.5, + help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", + ) + + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=params.attention_decoder_dim, + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=params.attention_decoder_attention_dim, + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_dim, + memory_dim=max(_to_int_tuple(params.encoder_dim)), + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + attention_decoder=attention_decoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, + ) + return model + + +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional["GradScaler"] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + spec_augment: Optional[SpecAugment] = None, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + # linear warmup + cr_loss_scale = min(batch_idx_train / warm_step, 1.0) * params.cr_loss_scale + loss += cr_loss_scale * cr_loss + + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: "GradScaler", + spec_augment: Optional[SpecAugment] = None, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.set_batch(batch_idx) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: + logging.info(f"Caught exception: {e}.") + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if params.use_autocast: + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) + logging.warning(f"Grad scale is small: {cur_grad_scale}") + + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if ( + batch_idx % 25 == 0 + and cur_grad_scale < 2.0 + or batch_idx % 100 == 0 + and cur_grad_scale < 8.0 + or batch_idx % 400 == 0 + and cur_grad_scale < 32.0 + ): + scaler.update(cur_grad_scale * 2.0) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, + params.attention_decoder_loss_scale, + ) + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + if params.use_cr_ctc: + assert params.use_ctc + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = BatchedRubik( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=False), + lr=params.base_lr, + beta1=0.99, + ) + + + # this InterpCosineLRScheduler inherits from VariableCombinedLRScheduler. + # this configuration is halfway between a linear function (1 to 0) and the conventional + # cosine LR scheduler. It decays to a minimum of 0.025. + scheduler = InterpCosineLRScheduler(optimizer, + min_factor=0.025, + linear_scale=0.5, + batches_per_epoch=params.batches_per_epoch, + num_epochs=params.num_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + # For CTC `(T - 2) < len(tokens)` is needed. otherwise inf. in loss appears. + # For Transducer `T < len(tokens)` was okay. + if (T - 2) < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training (too many supervision tokens). " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + spec_augment=spec_augment, + ) + + scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.set_epoch(epoch) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + spec_augment=spec_augment, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + spec_augment=spec_augment, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index d923e88424..bdf3e02dcc 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -222,12 +222,17 @@ def accumulate(self, x, class_name: Optional[str] = None): else: this_dim_stats[stats_type].append(TensorAndCount(stats, count)) - def print_diagnostics(self): - """Print diagnostics for each dimension of the tensor.""" + def print_diagnostics(self) -> dict: + """Print diagnostics for each dimension of the tensor. Returns a dict containing more specific stats, as tensors, that can be used for further + analysis if needed""" if self.stats is None: print(f"Warning: the stats of {self.name} is None.") return + + ans_dict = dict() + for dim, this_dim_stats in enumerate(self.stats): + ans_dict[dim] = dict() if "rms" in this_dim_stats and "value" in this_dim_stats: # produce "stddev" stats, which is centered RMS. rms_stats_list = this_dim_stats["rms"] @@ -286,6 +291,8 @@ def get_count(count): # we stored the square; after aggregation we need to take sqrt. stats = stats.sqrt() + ans_dict[dim][stats_type] = stats + # if `summarize` we print percentiles of the stats; else, # we print out individual elements. summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized( @@ -314,6 +321,7 @@ def get_count(count): # can be attributed to the mean of the distribution. norm = (stats**2).sum().sqrt().item() ans += f", norm={norm:.2g}" + mean = stats.mean().item() rms = (stats**2).mean().sqrt().item() ans += f", mean={mean:.3g}, rms={rms:.3g}" @@ -331,184 +339,9 @@ def get_count(count): print( f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}" ) + return ans_dict -class ScalarDiagnostic(object): - """This class is not directly used by the user, it is responsible for - collecting diagnostics for a single module (subclass of torch.nn.Module) that - represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc. - """ - - def __init__(self, opts: TensorDiagnosticOptions, name: str): - self.opts = opts - self.name = name - self.class_name = None # will assign in accumulate() - self.is_forward_pass = True - - self.tick_scale = None - - self.saved_inputs = [] - self.is_ok = True - - self.counts = None - self.sum_grad = None - self.sum_gradsq = None - self.sum_abs_grad = None - - def accumulate_input(self, x: Tensor, class_name: Optional[str] = None): - """ - Called in forward pass. - """ - if not self.is_forward_pass: - # in case we did a forward pass without a backward pass, for some reason. - self.saved_inputs = [] - self.is_forward_pass = True - - if class_name is not None: - self.class_name = class_name - if not self.is_ok: - return - - limit = 10 - if len(self.saved_inputs) > limit: - print( - f"ERROR: forward pass called for this module over {limit} times with no backward pass. " - f" Will not accumulate scalar stats." - ) - self.is_ok = False - return - self.saved_inputs.append(x) - - def accumulate_output_grad(self, grad: Tensor): - if not self.is_ok: - return - if self.is_forward_pass: - self.is_forward_pass = False - - last_shape = ( - "n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape - ) - if len(self.saved_inputs) == 0 or grad.shape != last_shape: - print( - f"ERROR: shape mismatch or no forward activation present when backward " - f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}" - f", shape-of-last-saved-input={last_shape}" - ) - self.is_ok = False - return - - x = self.saved_inputs.pop() - self.process_input_and_grad(x, grad) - - def process_input_and_grad(self, x: Tensor, grad: Tensor): - assert x.shape == grad.shape - x = x.flatten() - grad = grad.flatten() - - num_ticks_per_side = 256 - - if self.tick_scale is None: - x_abs_sorted = x.abs().sort()[0] - # take the 98th percentile as the largest value we count separately. - index = int(x.numel() * 0.98) - self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side) - - # integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1] - self.counts = torch.zeros( - 2 * num_ticks_per_side, dtype=torch.long, device=x.device - ) - self.sum_grad = torch.zeros( - 2 * num_ticks_per_side, dtype=torch.double, device=x.device - ) - # sum_gradsq is for getting error bars. - self.sum_gradsq = torch.zeros( - 2 * num_ticks_per_side, dtype=torch.double, device=x.device - ) - self.sum_abs_grad = torch.zeros( - 2 * num_ticks_per_side, dtype=torch.double, device=x.device - ) - - # this will round down. - x = (x / self.tick_scale).to(torch.long) - x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1) - x = x + num_ticks_per_side - - self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x)) - self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double)) - self.sum_gradsq.index_add_( - dim=0, index=x, source=(grad * grad).to(torch.double) - ) - self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double)) - - def print_diagnostics(self): - """Print diagnostics.""" - if self.is_ok is False or self.counts is None: - print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}") - return - - counts = self.counts.to("cpu") - sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32) - sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32) - sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32) - - counts_cumsum = counts.cumsum(dim=0) - counts_tot = counts_cumsum[-1] - - # subdivide the distribution up into `num_bins` intervals for analysis, for greater - # statistical significance. each bin corresponds to multiple of the original 'tick' intervals. - num_bins = 20 - - # integer division - counts_per_bin = (counts_tot // num_bins) + 1 - bin_indexes = counts_cumsum // counts_per_bin - bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long) - - bin_counts = torch.zeros(num_bins, dtype=torch.long) - bin_counts.index_add_(dim=0, index=bin_indexes, source=counts) - bin_grad = torch.zeros(num_bins) - bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad) - bin_gradsq = torch.zeros(num_bins) - bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq) - bin_abs_grad = torch.zeros(num_bins) - bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad) - - avg_grad = bin_grad / bin_counts - avg_grad_stddev = (bin_gradsq / bin_counts).sqrt() - - bin_boundary_counts = ( - torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin - ) - bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts) - # boundaries are the "x" values between the bins, e.g. corresponding to the - # locations of percentiles of the distribution. - num_ticks_per_side = counts.numel() // 2 - bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale - - bin_grad = bin_grad / (bin_counts + 1) - bin_conf_interval = bin_gradsq.sqrt() / ( - bin_counts + 1 - ) # consider this a standard deviation. - # bin_grad / bin_abs_grad will give us a sense for how important in a practical sense, - # the gradients are. - bin_abs_grad = bin_abs_grad / (bin_counts + 1) - - bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20) - bin_conf = bin_grad / (bin_conf_interval + 1.0e-20) - - def tensor_to_str(x: Tensor): - x = ["%.2g" % f for f in x] - x = "[" + " ".join(x) + "]" - return x - - maybe_class_name = ( - f" type={self.class_name}," if self.class_name is not None else "" - ) - - print( - f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, " - f"rel_grad={tensor_to_str(bin_rel_grad)}, grad_conf={tensor_to_str(bin_conf)}" - ) - class ModelDiagnostic(object): """This class stores diagnostics for all tensors in the torch.nn.Module. @@ -528,15 +361,17 @@ def __init__(self, opts: Optional[TensorDiagnosticOptions] = None): self.diagnostics = dict() def __getitem__(self, name: str): - T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic if name not in self.diagnostics: - self.diagnostics[name] = T(self.opts, name) + self.diagnostics[name] = TensorDiagnostic(self.opts, name) return self.diagnostics[name] - def print_diagnostics(self): - """Print diagnostics for each tensor.""" + def print_diagnostics(self) -> dict: + """Print diagnostics for each tensor. Returns dict with more detailed per-dimension info + that could be further analyzed.""" + ans = dict() for k in sorted(self.diagnostics.keys()): - self.diagnostics[k].print_diagnostics() + ans[k] = self.diagnostics[k].print_diagnostics() + return ans def get_class_name(module: nn.Module): @@ -636,42 +471,6 @@ def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): else: module.register_backward_hook(backward_hook) - if type(module).__name__ in [ - "Sigmoid", - "Tanh", - "ReLU", - "TanSwish", - "Swish", - "DoubleSwish", - "Swoosh", - ]: - # For these specific module types, accumulate some additional diagnostics - # that can help us improve the activation function. These require a lot of memory, - # to save the forward activations, so limit this to some select classes. - # Note: this will not work correctly for all model types. - def scalar_forward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name - ): - if isinstance(_input, tuple): - (_input,) = _input - assert isinstance(_input, Tensor) - _model_diagnostic[f"{_name}.scalar"].accumulate_input( - _input, class_name=get_class_name(_module) - ) - - def scalar_backward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name - ): - if isinstance(_output, tuple): - (_output,) = _output - assert isinstance(_output, Tensor) - _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) - - module.register_forward_hook(scalar_forward_hook) - if hasattr(module, "register_full_backward_hook"): - module.register_full_backward_hook(scalar_backward_hook) - else: - module.register_backward_hook(scalar_backward_hook) for name, parameter in model.named_parameters(): diff --git a/icefall/utils.py b/icefall/utils.py index 0d4e24db53..6f3bdd17e4 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -43,7 +43,6 @@ import torch import torch.distributed as dist import torch.nn as nn -from lhotse.dataset.signal_transforms import time_warp as time_warp_impl from packaging import version from pypinyin import lazy_pinyin, pinyin from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials @@ -159,6 +158,12 @@ def str2bool(v): raise argparse.ArgumentTypeError("Boolean value expected.") +def dist_barrier() -> None: + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size() + if world_size > 1: + dist.barrier() + def setup_logger( log_filename: Pathlike, log_level: str = "info", @@ -648,7 +653,7 @@ def store_translations( hyp_list = [] ref_list = [] dir_ = os.path.dirname(filename) - reftgt = os.path.join(dir_, "reftgt-" + str(os.path.basename(filename))) + reftgt = os.path.join(dir_, "reftgt-" + str(os.path.basename(filename))) refsrc = os.path.join(dir_, "refsrc-"+str(os.path.basename(filename))) hyp = os.path.join(dir_, "hyp-"+str( os.path.basename(filename))) bleu_file = os.path.join(dir_, "bleu-"+str( os.path.basename(filename))) @@ -661,7 +666,7 @@ def store_translations( print(f"{cut_id}: ref_tgt {ref_tgt}", file=f) print(f"{cut_id}: hyp {hyp}", file=f) print("\n", file=f) - + print(f"{ref}", file=f_src) print(f"{ref_tgt}", file=f_tgt) @@ -673,7 +678,7 @@ def store_translations( with open(bleu_file, 'w') as b: print(str(bleu.corpus_score(hyp_list, [ref_list])), file=b) print(f"BLEU signiture: {str(bleu.get_signature())}", file=b) - + logging.info( f"[{bleu.corpus_score(hyp_list, [ref_list])}] " f"BLEU signiture: {str(bleu.get_signature())}" @@ -2419,33 +2424,80 @@ def num_tokens( return num_tokens +def time_warp_impl(features: torch.Tensor, factor: int) -> torch.Tensor: + """ + # modified from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py#L338C1-L369C1 + # to use torch rng rather than the numpy one, this has to do with which rngs + # are synchronized and which are not. (we keep the numpy and python rng's synchronized + # for the sake of lhotse's sampler code, where they need to be synchronized to avoid data + # overlap). + + Time warping as described in the SpecAugment paper. + Implementation based on Espresso: + https://github.com/freewym/espresso/blob/master/espresso/tools/specaug_interpolate.py#L51 + + :param features: input tensor of shape ``(T, F)`` + :param factor: time warping parameter. + :return: a warped tensor of shape ``(T, F)`` + """ + t = features.size(0) + if t - factor <= factor + 1: + return features + center = torch.randint(factor + 1, t - factor, ()).item() + warped = torch.randint(center - factor, center + factor + 1, ()).item() + if warped == center: + return features + features = features.unsqueeze(0).unsqueeze(0) + left = torch.nn.functional.interpolate( + features[:, :, :center, :], + size=(warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + right = torch.nn.functional.interpolate( + features[:, :, center:, :], + size=(t - warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) + + # Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py def time_warp( features: torch.Tensor, p: float = 0.9, time_warp_factor: Optional[int] = 80, supervision_segments: Optional[torch.Tensor] = None, + feature_lens: Optional[torch.Tensor] = None, ): - """Apply time warping on a batch of features""" + """Apply time warping on a batch of features + supervision_segments and feature_lens are two alternative ways of specifying the parts of the feature matrix to + warp, see the code for details. + """ if time_warp_factor is None or time_warp_factor < 1: return features assert ( len(features.shape) == 3 ), f"SpecAugment only supports batches of single-channel feature matrices. {features.shape}" features = features.clone() - if supervision_segments is None: + + # we use torch.rand(1).item() instead of random.random() because for lhotse reasons we keep the + # python RNG synchronized across ranks, but we keep the torch RNG desynchronized. + if supervision_segments is None and feature_lens is None: # No supervisions - apply spec augment to full feature matrices. for sequence_idx in range(features.size(0)): - if random.random() > p: + if torch.rand(1).item() > p: # Randomly choose whether this transform is applied continue features[sequence_idx] = time_warp_impl( features[sequence_idx], factor=time_warp_factor ) - else: + elif supervision_segments is not None: + assert feature_lens is None # Supervisions provided - we will apply time warping only on the supervised areas. for sequence_idx, start_frame, num_frames in supervision_segments: - if random.random() > p: + if torch.rand(1).item() > p: # Randomly choose whether this transform is applied continue end_frame = start_frame + num_frames @@ -2453,4 +2505,13 @@ def time_warp( features[sequence_idx, start_frame:end_frame], factor=time_warp_factor ) + else: + for sequence_idx, num_frames in enumerate(feature_lens): + if torch.rand(1).item() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx, :num_frames] = time_warp_impl( + features[sequence_idx, :num_frames], factor=time_warp_factor + ) + return features