From 4856ae0919aafbb239ed67d6727ec8ed8b163a79 Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Sun, 21 Jun 2026 19:59:36 +0800 Subject: [PATCH 1/4] fix comment for gigaspeech recipe --- egs/gigaspeech/ASR/prepare.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh index ef6a667f96..9006fe7b09 100755 --- a/egs/gigaspeech/ASR/prepare.sh +++ b/egs/gigaspeech/ASR/prepare.sh @@ -165,7 +165,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "State 3: Preprocess GigaSpeech manifest" + log "Stage 3: Preprocess GigaSpeech manifest" if [ ! -f data/fbank/.preprocess_complete ]; then python3 ./local/preprocess_gigaspeech.py touch data/fbank/.preprocess_complete @@ -173,7 +173,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute features for DEV, TEST, L, M, S, and XS subsets of GigaSpeech." + log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech." python3 ./local/compute_fbank_gigaspeech.py fi From 612806874789791e61bdacc105c7f1eddfdd489f Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Sun, 21 Jun 2026 22:34:40 +0800 Subject: [PATCH 2/4] Enable on-the-fly features and skip musan by default for gigaspeech - Add `--on-the-fly` (default true) to compute_fbank_gigaspeech.py and compute_fbank_gigaspeech_splits.py: skip storing fbank features and only produce the trimmed cut manifests, since zipformer extracts features on-the-fly during training. - Parallelize the on-the-fly split trimming with a process pool (the GPU feature-compute path stays serial). - Add `use_musan` (default false) to prepare.sh, skipping the musan download/manifest/fbank stages (0, 2, 7) unless `--use-musan true`. - Default `--enable-musan` to false in zipformer/asr_datamodule.py to match. - Tune train dataloader for on-the-fly: num_workers default 8, add `--prefetch-factor` (default 4, was hardcoded 16), guard num_workers=0. - Fix stage log typo and the stage 4 subset description. Co-Authored-By: Claude Opus 4.8 --- .../ASR/local/compute_fbank_gigaspeech.py | 51 +++++++++++++++---- .../local/compute_fbank_gigaspeech_splits.py | 48 ++++++++++++++++- egs/gigaspeech/ASR/prepare.sh | 34 ++++++++----- .../ASR/zipformer/asr_datamodule.py | 18 +++++-- 4 files changed, 123 insertions(+), 28 deletions(-) diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py index 14353008ce..0f6017ce3f 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py @@ -16,12 +16,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import logging from pathlib import Path import torch from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig +from icefall.utils import str2bool + # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -30,7 +33,22 @@ torch.set_num_interop_threads(1) -def compute_fbank_gigaspeech(): +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--on-the-fly", + type=str2bool, + default=True, + help="When True, do not compute and store fbank features; only " + "produce the trimmed cut manifests so that features are extracted " + "on-the-fly during training.", + ) + return parser.parse_args() + + +def compute_fbank_gigaspeech(on_the_fly: bool = True): in_out_dir = Path("data/fbank") # number of workers in dataloader @@ -51,7 +69,11 @@ def compute_fbank_gigaspeech(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + + # on-the-fly mode does not need the extractor (and kaldifeat may be absent) + extractor = None + if not on_the_fly: + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") @@ -66,15 +88,21 @@ def compute_fbank_gigaspeech(): logging.info(f"Loading {raw_cuts_path}") cut_set = CutSet.from_file(raw_cuts_path) - logging.info("Computing features") + if on_the_fly: + logging.info( + "on-the-fly is enabled - skipping feature extraction, " + "only saving the trimmed cut manifest" + ) + else: + logging.info("Computing features") + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}", + num_workers=num_workers, + batch_duration=batch_duration, + overwrite=True, + ) - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}", - num_workers=num_workers, - batch_duration=batch_duration, - overwrite=True, - ) cut_set = cut_set.trim_to_supervisions( keep_overlapping=False, min_duration=None ) @@ -88,7 +116,8 @@ def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_gigaspeech() + args = get_args() + compute_fbank_gigaspeech(on_the_fly=args.on_the_fly) if __name__ == "__main__": diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index c1645f7cc9..5ed78015c0 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -19,11 +19,14 @@ import argparse import logging import os +from concurrent.futures import ProcessPoolExecutor from pathlib import Path import torch from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig +from icefall.utils import str2bool + # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -71,9 +74,36 @@ def get_args(): default=-1, help="Stop processing pieces until this number (exclusive).", ) + + parser.add_argument( + "--on-the-fly", + type=str2bool, + default=True, + help="When True, do not compute and store fbank features; only " + "produce the trimmed cut manifests so that features are extracted " + "on-the-fly during training.", + ) return parser.parse_args() +def trim_split(idx: str, output_dir: Path) -> None: + """Trim one raw split to supervisions and save it (no feature extraction).""" + cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + return + + raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz" + if not raw_cuts_path.is_file(): + logging.info(f"{raw_cuts_path} does not exist - skipping it") + return + + cut_set = CutSet.from_file(raw_cuts_path) + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False, min_duration=None) + cut_set.to_file(cuts_path) + logging.info(f"Saved to {cuts_path}") + + def compute_fbank_gigaspeech_splits(args): num_splits = args.num_splits output_dir = "data/fbank/gigaspeech_XL_split" @@ -87,13 +117,29 @@ def compute_fbank_gigaspeech_splits(args): stop = min(stop, num_splits) + num_digits = 8 # num_digits is fixed by lhotse split-lazy + + if args.on_the_fly: + # on-the-fly does not compute features, so the per-split work is pure + # CPU and independent -- fan it out across processes instead of the + # serial loop the GPU path needs. + logging.info( + f"on-the-fly is enabled - trimming splits in parallel " + f"with {args.num_workers} workers" + ) + idxs = [f"{i}".zfill(num_digits) for i in range(start, stop)] + with ProcessPoolExecutor(max_workers=args.num_workers) as ex: + futures = [ex.submit(trim_split, idx, output_dir) for idx in idxs] + for f in futures: + f.result() + return + device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") - num_digits = 8 # num_digits is fixed by lhotse split-lazy for i in range(start, stop): idx = f"{i}".zfill(num_digits) logging.info(f"Processing {idx}/{num_splits}") diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh index 9006fe7b09..515011b3d7 100755 --- a/egs/gigaspeech/ASR/prepare.sh +++ b/egs/gigaspeech/ASR/prepare.sh @@ -15,6 +15,13 @@ stop_stage=8 start=0 stop=-1 # -1 means until the end +# If true, skip storing GigaSpeech fbank features; only produce the (trimmed) +# cut manifests for on-the-fly feature extraction during training. +on_the_fly=true + +# If false (default), skip all musan steps (download, manifest, fbank). +use_musan=false + # Note: This script just prepares the minimal requirements needed by a # transducer training with bpe units. # @@ -138,7 +145,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # # ln -svf /path/to/musan $dl_dir/ # - if [ ! -d $dl_dir/musan ]; then + if [ $use_musan == true ] && [ ! -d $dl_dir/musan ]; then lhotse download musan $dl_dir fi fi @@ -156,12 +163,12 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then $dl_dir/GigaSpeech data/manifests fi -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to $dl_dir/musan - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests +if [ $use_musan == true ] && [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then @@ -174,7 +181,7 @@ fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech." - python3 ./local/compute_fbank_gigaspeech.py + python3 ./local/compute_fbank_gigaspeech.py --on-the-fly $on_the_fly fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then @@ -196,13 +203,14 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then --batch-duration 600 \ --num-splits $num_splits \ --start $start \ - --stop $stop + --stop $stop \ + --on-the-fly $on_the_fly fi -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Compute fbank for musan" - mkdir -p data/fbank - ./local/compute_fbank_musan.py +if [ $use_musan == true ] && [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compute fbank for musan" + mkdir -p data/fbank + ./local/compute_fbank_musan.py fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index 93a41b27a4..a32e3aae0e 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -164,11 +164,19 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--num-workers", type=int, - default=2, + default=8, help="The number of training dataloader workers that " "collect the batches.", ) + group.add_argument( + "--prefetch-factor", + type=int, + default=4, + help="Number of batches each worker prefetches in advance. " + "Ignored when --num-workers is 0.", + ) + group.add_argument( "--enable-spec-aug", type=str2bool, @@ -189,7 +197,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--enable-musan", type=str2bool, - default=True, + default=False, help="When enabled, select noise from MUSAN and mix it" "with training dataset. ", ) @@ -343,8 +351,12 @@ def train_dataloaders( sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, - persistent_workers=False, + persistent_workers=self.args.num_workers > 0, worker_init_fn=worker_init_fn, + pin_memory=True, + prefetch_factor=self.args.prefetch_factor + if self.args.num_workers > 0 + else None, ) return train_dl From ab2f2171852e4f25eb83441aafd27df8feb2fe5f Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Sun, 21 Jun 2026 23:02:42 +0800 Subject: [PATCH 3/4] support bf16 for gigaspeech recipe update training update update update update update Update egs/gigaspeech/ASR/zipformer/train.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Update ctc_decode.py Update ctc_decode.py --- .../ASR/zipformer/asr_datamodule.py | 9 +- egs/gigaspeech/ASR/zipformer/ctc_decode.py | 135 ++++++++++++++---- .../ASR/zipformer/run_train_cluster.sh | 44 ++++++ egs/gigaspeech/ASR/zipformer/train.py | 120 ++++++++++++---- 4 files changed, 250 insertions(+), 58 deletions(-) create mode 100755 egs/gigaspeech/ASR/zipformer/run_train_cluster.sh diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index a32e3aae0e..45fa19d0e6 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -134,7 +134,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--on-the-fly-feats", type=str2bool, - default=False, + default=True, help="When enabled, use on-the-fly cut mixing and feature " "extraction. Will drop existing precomputed feature manifests " "if available.", @@ -172,7 +172,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--prefetch-factor", type=int, - default=4, + default=8, help="Number of batches each worker prefetches in advance. " "Ignored when --num-workers is 0.", ) @@ -401,8 +401,11 @@ def valid_dataloaders( validate, sampler=valid_sampler, batch_size=None, - num_workers=2, + num_workers=self.args.num_workers, persistent_workers=False, + prefetch_factor=self.args.prefetch_factor + if self.args.num_workers > 0 + else None, ) return valid_dl diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py index c28abf0206..e735c9cf30 100755 --- a/egs/gigaspeech/ASR/zipformer/ctc_decode.py +++ b/egs/gigaspeech/ASR/zipformer/ctc_decode.py @@ -21,7 +21,25 @@ """ Usage: -(1) ctc-decoding +(1) ctc-greedy-search +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-greedy-search + +(2) ctc-prefix-beam-search +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-prefix-beam-search + +(3) ctc-decoding ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -30,7 +48,7 @@ --max-duration 600 \ --decoding-method ctc-decoding -(2) 1best +(4) 1best ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -40,7 +58,7 @@ --hlg-scale 0.6 \ --decoding-method 1best -(3) nbest +(5) nbest ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -50,7 +68,7 @@ --hlg-scale 0.6 \ --decoding-method nbest -(4) nbest-rescoring +(6) nbest-rescoring ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -62,7 +80,7 @@ --lm-dir data/lm \ --decoding-method nbest-rescoring -(5) whole-lattice-rescoring +(7) whole-lattice-rescoring ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -88,6 +106,7 @@ import torch import torch.nn as nn from asr_datamodule import GigaSpeechAsrDataModule +from gigaspeech_scoring import asr_text_post_processing from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( @@ -97,6 +116,8 @@ load_checkpoint, ) from icefall.decode import ( + ctc_greedy_search, + ctc_prefix_beam_search, get_lattice, nbest_decoding, nbest_oracle, @@ -195,6 +216,10 @@ def get_parser(): default="ctc-decoding", help="""Decoding method. Supported values are: + - ctc-greedy-search. Use CTC greedy search. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + - ctc-prefix-beam-search. Extract n paths with the given beam; the best + path is the decoding result. - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece model, i.e., lang_dir/bpe.model, to convert word pieces to words. It needs neither a lexicon nor an n-gram LM. @@ -269,6 +294,7 @@ def get_decoding_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "beam": 4, # for ctc-prefix-beam-search } ) return params @@ -327,10 +353,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. Note: If it decodes to nothing, then return None. """ - if HLG is not None: - device = HLG.device - else: - device = H.device + device = params.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -352,6 +375,26 @@ def decode_one_batch( encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) ctc_output = model.ctc_output(encoder_out) # (N, T, C) + if params.decoding_method == "ctc-greedy-search": + token_ids = ctc_greedy_search(ctc_output, encoder_out_lens) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + return {"ctc-greedy-search": hyps} + + if params.decoding_method == "ctc-prefix-beam-search": + token_ids = ctc_prefix_beam_search( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + return {"ctc-prefix-beam-search": hyps} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -559,6 +602,17 @@ def decode_dataset( return results +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + def save_results( params: AttributeDict, test_set_name: str, @@ -566,22 +620,31 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) test_set_wers[key] = wer logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -610,6 +673,8 @@ def main(): params.update(vars(args)) assert params.decoding_method in ( + "ctc-greedy-search", + "ctc-prefix-beam-search", "ctc-decoding", "1best", "nbest", @@ -634,6 +699,9 @@ def main(): params.suffix += f"-chunk-{params.chunk_size}" params.suffix += f"-left-context-{params.left_context_frames}" + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -643,6 +711,7 @@ def main(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) + params.device = device logging.info(f"Device: {device}") logging.info(params) @@ -655,20 +724,28 @@ def main(): # and are defined in local/train_bpe_model.py params.blank_id = 0 - if params.decoding_method == "ctc-decoding": + if params.decoding_method in [ + "ctc-greedy-search", + "ctc-prefix-beam-search", + "ctc-decoding", + ]: HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) + H = None + if params.decoding_method == "ctc-decoding": + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) bpe_model = spm.SentencePieceProcessor() bpe_model.load(str(params.lang_dir / "bpe.model")) else: H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -707,7 +784,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": @@ -813,16 +892,16 @@ def main(): args.return_cuts = True gigaspeech = GigaSpeechAsrDataModule(args) - test_clean_cuts = gigaspeech.test_clean_cuts() - test_other_cuts = gigaspeech.test_other_cuts() + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() - test_clean_dl = gigaspeech.test_dataloaders(test_clean_cuts) - test_other_dl = gigaspeech.test_dataloaders(test_other_cuts) + dev_dl = gigaspeech.test_dataloaders(dev_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] - for test_set, test_dl in zip(test_sets, test_dl): + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl, params=params, diff --git a/egs/gigaspeech/ASR/zipformer/run_train_cluster.sh b/egs/gigaspeech/ASR/zipformer/run_train_cluster.sh new file mode 100755 index 0000000000..b9b6d844c0 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/run_train_cluster.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +# +# Launch zipformer training with torchrun (single- or multi-node). +# train.py reads RANK / LOCAL_RANK / WORLD_SIZE from the env, so +# `--world-size` is ignored here. +# +# Single 8-GPU node: bash zipformer/run_train_cluster.sh +# Multi-node: the scheduler is expected to export WORLD_SIZE (#nodes), +# RANK (node rank), MASTER_ADDR and MASTER_PORT. + +set -euo pipefail + +export PYTHONPATH=`pwd`/../../../ +export NPROC_PER_NODE=${NPROC_PER_NODE:-8} # GPUs per node +NNODES=${WORLD_SIZE:-1} +NODE_RANK=${RANK:-0} +MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} +MASTER_PORT=${MASTER_PORT:-29500} + +export NCCL_IB_TC=136 +export NCCL_IB_SL=5 +export NCCL_IB_GID_INDEX=3 +export NCCL_SOCKET_IFNAME=eth +export NCCL_DEBUG=WARN +export NCCL_IB_HCA=mlx5 +export NCCL_IB_TIMEOUT=22 +export NCCL_IB_QPS_PER_CONNECTION=8 +export NCCL_MIN_NCHANNELS=4 +export NCCL_NET_PLUGIN=none +export OMP_NUM_THREADS=4 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +torchrun \ + --nnodes "${NNODES}" \ + --node_rank "${NODE_RANK}" \ + --nproc_per_node "${NPROC_PER_NODE}" \ + --master_addr "${MASTER_ADDR}" \ + --master_port "${MASTER_PORT}" \ + ./zipformer/train.py \ + --num-epochs 100 \ + --start-epoch 1 \ + --use-bf16 1 \ + --exp-dir zipformer/exp \ + --max-duration 5000 diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index d586fc26a8..cefe7b6148 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Copyright 2021-2026 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, # Mingshuang Luo, # Zengwei Yao, @@ -29,19 +29,28 @@ --world-size 8 \ --num-epochs 30 \ --start-epoch 1 \ - --use-fp16 1 \ + --use-bf16 1 \ --exp-dir zipformer/exp \ - --max-duration 1000 + --max-duration 5000 + +# For multi-node / cluster training, launch with torchrun +# (`--world-size` is ignored; world size comes from the env): +torchrun --nproc_per_node 8 ./zipformer/train.py \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-bf16 1 \ + --exp-dir zipformer/exp \ + --max-duration 5000 # For streaming model training: ./zipformer/train.py \ --world-size 8 \ --num-epochs 30 \ --start-epoch 1 \ - --use-fp16 1 \ + --use-bf16 1 \ --exp-dir zipformer/exp \ --causal 1 \ - --max-duration 1000 + --max-duration 5000 It supports training with: - transducer loss (default), with `--use-transducer True --use-ctc False` @@ -53,11 +62,18 @@ import argparse import copy import logging +import os import warnings from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union +warnings.filterwarnings( + "ignore", + message=r"`torch\.cuda\.amp\.custom_(fwd|bwd)\(args\.\.\.\)` is deprecated.*", + category=FutureWarning, +) + import k2 import optim import sentencepiece as spm @@ -75,7 +91,6 @@ from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -87,13 +102,20 @@ save_checkpoint_with_global_batch_idx, update_averaged_model, ) -from icefall.dist import cleanup_dist, setup_dist +from icefall.dist import ( + cleanup_dist, + get_local_rank, + get_rank, + get_world_size, + setup_dist, +) from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, @@ -480,6 +502,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + add_model_arguments(parser) return parser @@ -537,7 +566,7 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 500, + "log_interval": 100, "reset_interval": 2000, "valid_interval": 20000, # parameters for zipformer @@ -719,7 +748,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -895,7 +924,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -959,7 +988,7 @@ def save_bad_model(suffix: str = ""): batch_size = len(batch["supervisions"]["text"]) try: - with torch_autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): loss, loss_info = compute_loss( params=params, model=model, @@ -1019,7 +1048,7 @@ def save_bad_model(suffix: str = ""): rank=rank, ) - if batch_idx % 100 == 0 and params.use_fp16: + if batch_idx % 100 == 0 and params.use_autocast: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. @@ -1038,14 +1067,14 @@ def save_bad_model(suffix: str = ""): if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") ) if tb_writer is not None: @@ -1057,7 +1086,7 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/current_", params.batch_idx_train ) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: + if params.use_autocast: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) @@ -1093,7 +1122,8 @@ def run(rank, world_size, args): Args: rank: It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. + passed automatically by `mp.spawn()` in :func:`main`, or read from + the env when launched with `torchrun`. The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. @@ -1104,8 +1134,17 @@ def run(rank, world_size, args): params.update(vars(args)) fix_random_seed(params.seed) + + # torchrun sets up the env; otherwise fall back to mp.spawn. + use_ddp_launch = "RANK" in os.environ and "WORLD_SIZE" in os.environ + # Multi-node: pick the CUDA device by local rank, not global rank. + local_rank = get_local_rank() if use_ddp_launch else rank if world_size > 1: - setup_dist(rank, world_size, params.master_port) + if use_ddp_launch: + setup_dist(use_ddp_launch=True) + torch.cuda.set_device(local_rank) + else: + setup_dist(rank, world_size, params.master_port) setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") @@ -1117,7 +1156,7 @@ def run(rank, world_size, args): device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda", rank) + device = torch.device("cuda", local_rank) logging.info(f"Device: {device}") sp = spm.SentencePieceProcessor() @@ -1130,6 +1169,21 @@ def run(rank, world_size, args): if not params.use_transducer: params.ctc_loss_scale = 1.0 + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + logging.info(params) logging.info("About to create model") @@ -1152,7 +1206,7 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) optimizer = ScaledAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), @@ -1186,7 +1240,12 @@ def run(rank, world_size, args): def remove_short_utt(c: Cut): # In ./zipformer.py, the conv module uses the following expression # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 + # When using on-the-fly feature extraction, c.num_frames is None, + # so we estimate it from the duration (Fbank frame_shift = 0.01s). + num_frames = ( + c.num_frames if c.num_frames is not None else round(c.duration / 0.01) + ) + T = ((num_frames - 7) // 2 + 1) // 2 return T > 0 gigaspeech = GigaSpeechAsrDataModule(args) @@ -1225,7 +1284,7 @@ def remove_short_utt(c: Cut): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1325,7 +1384,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch_autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): loss, _ = compute_loss( params=params, model=model, @@ -1357,12 +1416,19 @@ def main(): args = parser.parse_args() args.exp_dir = Path(args.exp_dir) - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + # Launched with torchrun: read rank/world_size from env, no mp.spawn. + # `--world-size` is ignored in this mode. + world_size = get_world_size() + rank = get_rank() + run(rank=rank, world_size=world_size, args=args) else: - run(rank=0, world_size=1, args=args) + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) torch.set_num_threads(1) From e234c260cce73151f1567b971d95fb3407102787 Mon Sep 17 00:00:00 2001 From: yunchongxiao Date: Fri, 26 Jun 2026 21:05:17 +0800 Subject: [PATCH 4/4] fix all flake8 fix display_and_save_batch target-text handling in iwslt22_ta ST - pass sp_tgt= (not sp=) to match the function signature on the non-finite-loss path - index tgt_text["eng"] before SentencePiece encoding so the OOM/exception diagnostic path does not itself crash fix black and isort formatting Apply black (line-length 88) and isort (black profile) across recipes. Formatting-only changes; no logic modified. --- egs/aishell/ASR/conformer_ctc/decode.py | 4 +- egs/aishell/ASR/conformer_ctc/pretrained.py | 4 +- egs/aishell/ASR/conformer_mmi/decode.py | 4 +- egs/aishell/ASR/tdnn_lstm_ctc/decode.py | 4 +- egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py | 4 +- egs/aishell/ASR/whisper/decode.py | 16 +- egs/aishell/ASR/zipformer/ctc_decode.py | 24 +- egs/aishell/ASR/zipformer/train.py | 6 +- egs/ami/SURT/dprnn_zipformer/train.py | 4 +- egs/ami/SURT/dprnn_zipformer/train_adapt.py | 4 +- egs/gigaspeech/ASR/conformer_ctc/decode.py | 8 +- egs/gigaspeech/ASR/zipformer/train.py | 6 + egs/grid/VSR/conformer_ctc2/asr_datamodule.py | 59 ++- egs/grid/VSR/conformer_ctc2/conformer.py | 19 +- egs/grid/VSR/conformer_ctc2/decode.py | 3 +- egs/grid/VSR/conformer_ctc2/train.py | 43 +- egs/grid/VSR/conformer_ctc2/transformer.py | 2 +- egs/grid/VSR/local/compile_hlg_phone.py | 8 +- egs/grid/VSR/local/compute_avhubert_grid.py | 149 ++++--- egs/grid/VSR/local/split_manifests.py | 14 +- egs/iwslt22_ta/ASR/local/compute_fbank_gpu.py | 30 +- .../ASR/local/prepare_transcripts.py | 2 +- .../asr_datamodule.py | 28 +- .../pruned_transducer_stateless5/decode.py | 33 +- .../pruned_transducer_stateless5/finetune.py | 67 +-- .../finetune_loadmodel.py | 67 +-- .../pretrained.py | 18 +- .../ASR/pruned_transducer_stateless5/train.py | 63 +-- .../ASR/zipformer/asr_datamodule.py | 28 +- egs/iwslt22_ta/ASR/zipformer/decode.py | 41 +- egs/iwslt22_ta/ASR/zipformer/profile.py | 26 +- egs/iwslt22_ta/ASR/zipformer/train.py | 5 +- egs/iwslt22_ta/ST/local/cer.py | 42 +- egs/iwslt22_ta/ST/local/compute_fbank_gpu.py | 30 +- .../convert_transcript_words_to_tokens.py | 8 +- egs/iwslt22_ta/ST/local/cuts_validate.py | 29 +- egs/iwslt22_ta/ST/local/prepare_lexicon.py | 4 +- .../ST/local/prepare_transcripts.py | 22 +- .../ST/local/transcript_cleaning.py | 387 +++++++++++------- .../asr_datamodule.py | 28 +- .../ST/pruned_transducer_stateless5/decode.py | 50 ++- .../pruned_transducer_stateless5/finetune.py | 67 +-- .../finetune_loadmodel.py | 67 +-- .../pretrained.py | 18 +- .../ST/pruned_transducer_stateless5/train.py | 81 ++-- egs/iwslt22_ta/ST/zipformer/asr_datamodule.py | 24 +- egs/iwslt22_ta/ST/zipformer/beam_search.py | 39 +- egs/iwslt22_ta/ST/zipformer/decode.py | 58 ++- egs/iwslt22_ta/ST/zipformer/decoder.py | 1 - egs/iwslt22_ta/ST/zipformer/export.py | 7 +- .../ST/zipformer/generate_averaged_model.py | 6 +- egs/iwslt22_ta/ST/zipformer/model.py | 7 +- egs/iwslt22_ta/ST/zipformer/pretrained.py | 7 +- egs/iwslt22_ta/ST/zipformer/profile.py | 2 +- .../ST/zipformer/streaming_decode.py | 57 +-- egs/iwslt22_ta/ST/zipformer/train.py | 33 +- egs/ksponspeech/ASR/zipformer/ctc_decode.py | 8 +- egs/libricss/SURT/dprnn_zipformer/train.py | 4 +- .../SURT/dprnn_zipformer/train_adapt.py | 4 +- egs/librispeech/ASR/conformer_ctc2/train.py | 6 +- egs/librispeech/ASR/conformer_ctc3/train.py | 6 +- .../train.py | 6 +- .../train.py | 6 +- .../ASR/lstm_transducer_stateless/train.py | 6 +- .../ASR/lstm_transducer_stateless2/train.py | 6 +- .../ASR/lstm_transducer_stateless3/train.py | 8 +- .../ASR/pruned2_knowledge/train.py | 8 +- .../pruned_stateless_emformer_rnnt2/train.py | 6 +- .../ASR/pruned_transducer_stateless2/train.py | 6 +- .../ASR/pruned_transducer_stateless3/train.py | 6 +- .../ASR/pruned_transducer_stateless4/train.py | 6 +- .../ASR/pruned_transducer_stateless5/train.py | 6 +- .../ASR/pruned_transducer_stateless6/train.py | 6 +- .../pruned_transducer_stateless7/finetune.py | 6 +- .../ASR/pruned_transducer_stateless7/train.py | 6 +- .../pruned_transducer_stateless7_ctc/train.py | 6 +- .../train.py | 6 +- .../train.py | 6 +- .../train.py | 6 +- .../ASR/pruned_transducer_stateless8/train.py | 6 +- .../ASR/tiny_transducer_ctc/train.py | 6 +- egs/librispeech/ASR/zipformer/ctc_align.py | 11 +- .../export_rknn_transducer_streaming.py | 2 +- egs/librispeech/ASR/zipformer/finetune.py | 8 +- .../ASR/zipformer/scaling_converter.py | 5 +- egs/librispeech/ASR/zipformer/train.py | 6 +- .../ASR/zipformer_adapter/train.py | 6 +- egs/librispeech/ASR/zipformer_ctc/train.py | 6 +- .../ASR/zipformer_lora/finetune.py | 6 +- egs/librispeech/ASR/zipformer_lora/train.py | 6 +- egs/librispeech/ASR/zipformer_mmi/train.py | 6 +- egs/librispeech/SSL/hubert/finetune_ce.py | 2 +- egs/librispeech/SSL/zipformer/finetune.py | 2 +- egs/librispeech/SSL/zipformer/zipformer.py | 3 +- .../WSASR/conformer_ctc2/decode.py | 4 +- .../WSASR/conformer_ctc2/decode_phone.py | 4 +- .../WSASR/conformer_ctc2/train_phone.py | 2 +- egs/libritts/ASR/zipformer/ctc_decode.py | 8 +- egs/mgb2/ASR/conformer_ctc/decode.py | 8 +- egs/mgb2/ASR/conformer_ctc/pretrained.py | 8 +- .../ST/hent_srt/beam_search.py | 251 +++++++----- egs/multi_conv_zh_es_ta/ST/hent_srt/decode.py | 109 ++--- .../ST/hent_srt/decoder.py | 21 +- egs/multi_conv_zh_es_ta/ST/hent_srt/export.py | 11 +- .../ST/hent_srt/load_pretrained_model.py | 14 +- egs/multi_conv_zh_es_ta/ST/hent_srt/model.py | 311 +++++++------- .../ST/hent_srt/streaming_beam_search.py | 25 +- egs/multi_conv_zh_es_ta/ST/hent_srt/train.py | 178 ++++---- .../ST/hent_srt/zipformer.py | 30 +- egs/multi_conv_zh_es_ta/ST/local/cer.py | 52 ++- .../ST/local/compute_fbank_gpu.py | 52 +-- .../ST/local/cuts_validate.py | 28 +- .../ST/local/prepare_st_transcripts.py | 20 +- .../ST/local/prepare_transcripts.py | 20 +- .../asr_datamodule.py | 18 +- .../zipformer_multijoiner_st/beam_search.py | 76 ++-- .../ST/zipformer_multijoiner_st/decode.py | 109 ++--- .../ST/zipformer_multijoiner_st/decoder.py | 1 - .../ST/zipformer_multijoiner_st/joiner.py | 6 +- .../ST/zipformer_multijoiner_st/model.py | 255 ++++++------ .../ST/zipformer_multijoiner_st/profile.py | 14 +- .../ST/zipformer_multijoiner_st/scaling.py | 9 +- .../zipformer_multijoiner_st/subsampling.py | 4 +- .../ST/zipformer_multijoiner_st/train.py | 108 +++-- .../ST/zipformer_multijoiner_st/zipformer.py | 25 +- egs/multi_zh-hans/ASR/whisper/decode.py | 16 +- egs/multi_zh-hans/ASR/whisper/train.py | 4 +- egs/multi_zh-hans/ASR/zipformer/pretrained.py | 2 +- .../ASR_LLM/whisper_llm_zh/decode.py | 11 +- .../ASR_LLM/whisper_llm_zh/train.py | 8 +- egs/speechio/ASR/whisper/decode.py | 16 +- egs/swbd/ASR/conformer_ctc/decode.py | 8 +- egs/tedlium3/ASR/conformer_ctc2/decode.py | 4 +- egs/timit/ASR/tdnn_ligru_ctc/decode.py | 8 +- egs/timit/ASR/tdnn_ligru_ctc/pretrained.py | 8 +- egs/timit/ASR/tdnn_lstm_ctc/decode.py | 8 +- egs/timit/ASR/tdnn_lstm_ctc/pretrained.py | 8 +- egs/wenetspeech/ASR/whisper/decode.py | 16 +- .../test_rknn_on_cpu_simulator_ctc.py | 1 - egs/wenetspeech4tts/TTS/f5-tts/train.py | 4 +- icefall/checkpoint.py | 6 +- icefall/lexicon.py | 4 +- icefall/utils.py | 29 +- 143 files changed, 2135 insertions(+), 1895 deletions(-) delete mode 120000 egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index 90881ee400..2ce18594e8 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -503,7 +503,9 @@ def main(): else: H = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py index 4caff4e160..edad3a8310 100755 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -315,7 +315,9 @@ def main(): hyps = [[token_sym_table[i] for i in ids] for ids in token_ids] elif params.method in ["1best", "attention-decoder"]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False)) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py index c88aea41a0..5e013be7d7 100755 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ b/egs/aishell/ASR/conformer_mmi/decode.py @@ -516,7 +516,9 @@ def main(): else: H = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py index a6dfd8a751..86d6b32ecc 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py @@ -337,7 +337,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False)) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) assert HLG.requires_grad is False diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py index 6cfe2de899..728f034b69 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py @@ -145,7 +145,9 @@ def main(): model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False)) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py index 75d3c5a65b..44ec2845d7 100755 --- a/egs/aishell/ASR/whisper/decode.py +++ b/egs/aishell/ASR/whisper/decode.py @@ -108,9 +108,13 @@ def average_checkpoints( for i in range(1, n): if "model" in torch.load(filenames[i], map_location=device, weights_only=False): - state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"] + state_dict = torch.load( + filenames[i], map_location=device, weights_only=False + )["model"] else: - state_dict = torch.load(filenames[i], map_location=device, weights_only=False) + state_dict = torch.load( + filenames[i], map_location=device, weights_only=False + ) for k in uniqued_names: avg[k] += state_dict[k] @@ -440,7 +444,9 @@ def main(): start = params.epoch - params.avg assert start >= 1, start checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False + f"{params.exp_dir}/epoch-{params.epoch}.pt", + map_location="cpu", + weights_only=False, ) if "model" not in checkpoint: # deepspeed converted checkpoint only contains model state_dict @@ -469,7 +475,9 @@ def main(): torch.save(model.state_dict(), filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False + f"{params.exp_dir}/epoch-{params.epoch}.pt", + map_location="cpu", + weights_only=False, ) if "model" not in checkpoint: model.load_state_dict(checkpoint, strict=True) diff --git a/egs/aishell/ASR/zipformer/ctc_decode.py b/egs/aishell/ASR/zipformer/ctc_decode.py index 940aef1e57..b9e2099d01 100755 --- a/egs/aishell/ASR/zipformer/ctc_decode.py +++ b/egs/aishell/ASR/zipformer/ctc_decode.py @@ -66,10 +66,7 @@ find_checkpoints, load_checkpoint, ) -from icefall.decode import ( - ctc_greedy_search, - ctc_prefix_beam_search, -) +from icefall.decode import ctc_greedy_search, ctc_prefix_beam_search from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -169,6 +166,7 @@ def get_decoding_params() -> AttributeDict: ) return params + def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -242,17 +240,15 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) if params.decoding_method == "ctc-greedy-search": - return {"ctc-greedy-search" : hyps} + return {"ctc-greedy-search": hyps} elif params.decoding_method == "ctc-prefix-beam-search": - return {"ctc-prefix-beam-search" : hyps} + return {"ctc-prefix-beam-search": hyps} else: assert False, f"Unsupported decoding method: {params.decoding_method}" @@ -305,7 +301,7 @@ def decode_dataset( for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): this_batch.append((cut_id, ref_text, hyp_words)) results[name].extend(this_batch) - + num_cuts += len(texts) if batch_idx % log_interval == 0: @@ -326,7 +322,7 @@ def save_results( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results, char_level = True) + store_transcripts(filename=recog_path, texts=results, char_level=True) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -379,7 +375,7 @@ def main(): assert params.decoding_method in ( "ctc-greedy-search", "ctc-prefix-beam-search", - ) # support ctc-greedy-search and ctc-prefix-beam-search + ) # support ctc-greedy-search and ctc-prefix-beam-search params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: @@ -414,7 +410,7 @@ def main(): logging.info(f"Device: {device}") lexicon = Lexicon(params.lang_dir) - + params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index 80b3aff08d..6d06985193 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -647,6 +647,7 @@ def get_model(params: AttributeDict) -> nn.Module: ) return model + def get_spec_augment(params: AttributeDict) -> SpecAugment: num_frame_masks = int(10 * params.time_mask_ratio) max_frames_mask_fraction = 0.15 * params.time_mask_ratio @@ -664,6 +665,7 @@ def get_spec_augment(params: AttributeDict) -> SpecAugment: ) return spec_augment + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -883,7 +885,9 @@ def compute_loss( loss += params.ctc_loss_scale * ctc_loss if use_cr_ctc: # linear warmup - cr_loss_scale = min(batch_idx_train / warm_step, 1.0) * params.cr_loss_scale + cr_loss_scale = ( + min(batch_idx_train / warm_step, 1.0) * params.cr_loss_scale + ) loss += cr_loss_scale * cr_loss assert loss.requires_grad == is_training diff --git a/egs/ami/SURT/dprnn_zipformer/train.py b/egs/ami/SURT/dprnn_zipformer/train.py index d5025b4774..1fb52c7138 100755 --- a/egs/ami/SURT/dprnn_zipformer/train.py +++ b/egs/ami/SURT/dprnn_zipformer/train.py @@ -1263,7 +1263,9 @@ def run(rank, world_size, args): logging.info( f"Initializing model with checkpoint from {params.model_init_ckpt}" ) - init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False) + init_ckpt = torch.load( + params.model_init_ckpt, map_location=device, weights_only=False + ) model.load_state_dict(init_ckpt["model"], strict=False) if world_size > 1: diff --git a/egs/ami/SURT/dprnn_zipformer/train_adapt.py b/egs/ami/SURT/dprnn_zipformer/train_adapt.py index 35b3ced31f..7fd75c05e8 100755 --- a/egs/ami/SURT/dprnn_zipformer/train_adapt.py +++ b/egs/ami/SURT/dprnn_zipformer/train_adapt.py @@ -1254,7 +1254,9 @@ def run(rank, world_size, args): logging.info( f"Initializing model with checkpoint from {params.model_init_ckpt}" ) - init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False) + init_ckpt = torch.load( + params.model_init_ckpt, map_location=device, weights_only=False + ) model.load_state_dict(init_ckpt["model"], strict=False) if world_size > 1: diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index 47f35174f1..ed18b40ce1 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -589,7 +589,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -628,7 +630,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index cefe7b6148..81a375611f 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -74,10 +74,16 @@ category=FutureWarning, ) +from typing import TYPE_CHECKING + import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import GigaSpeechAsrDataModule diff --git a/egs/grid/VSR/conformer_ctc2/asr_datamodule.py b/egs/grid/VSR/conformer_ctc2/asr_datamodule.py index 2b151808db..50ff9fb55a 100644 --- a/egs/grid/VSR/conformer_ctc2/asr_datamodule.py +++ b/egs/grid/VSR/conformer_ctc2/asr_datamodule.py @@ -18,6 +18,8 @@ import argparse import inspect import logging +import random +from dataclasses import replace from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional @@ -35,15 +37,14 @@ ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, - OnTheFlyFeatures, BatchIO, + OnTheFlyFeatures, ) from lhotse.utils import fix_random_seed, supervision_to_frames from torch.utils.data import DataLoader from icefall.utils import str2bool -from dataclasses import replace -import random + class _SeedWorkers: def __init__(self, seed: int): @@ -52,6 +53,7 @@ def __init__(self, seed: int): def __call__(self, worker_id: int): fix_random_seed(self.seed + worker_id) + class GridAsrDataModule: """ DataModule for k2 VSR experiments on the GRID corpus. @@ -187,7 +189,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): default="PrecomputedFeatures", help="AudioSamples or PrecomputedFeatures", ) - + def train_dataloaders( self, cuts_train: CutSet, @@ -223,7 +225,7 @@ def train_dataloaders( # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. - + input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -251,7 +253,7 @@ def train_dataloaders( shuffle=self.args.shuffle, drop_last=self.args.drop_last, ) - + logging.info("About to create train dataloader") if sampler_state_dict is not None: logging.info("Loading sampler state dict") @@ -285,11 +287,11 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: cut_transforms=transforms, return_cuts=self.args.return_cuts, ) - + valid_sampler = SimpleCutSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, ) logging.info("About to create dev dataloader") @@ -310,10 +312,10 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: return_cuts=self.args.return_cuts, ) sampler = SimpleCutSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) logging.debug("About to create test dataloader") test_dl = DataLoader( @@ -323,7 +325,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: num_workers=self.args.num_workers, ) return test_dl - + @lru_cache() def train_all_cuts(self) -> CutSet: cuts = load_manifest_lazy(self.args.manifest_dir / "grid_cuts_train.jsonl.gz") @@ -331,9 +333,9 @@ def train_all_cuts(self) -> CutSet: lambda s: replace(s, text=" ".join(w for w in s.text.split() if w != "sp")) ) return cuts - + def split_train_valid(self, cuts: CutSet, valid_ratio=0.03, seed=42): - cuts = cuts.shuffle(random.Random(seed)) + cuts = cuts.shuffle(random.Random(seed)) n = len(cuts) n_valid = int(n * valid_ratio) @@ -341,39 +343,36 @@ def split_train_valid(self, cuts: CutSet, valid_ratio=0.03, seed=42): train_cuts = cuts.subset(last=n - n_valid) return train_cuts, valid_cuts - - @lru_cache() def test_cuts(self) -> CutSet: logging.info("Grid: About to get test cuts") - cuts = load_manifest_lazy( - self.args.manifest_dir / "grid_cuts_test.jsonl.gz" - ) + cuts = load_manifest_lazy(self.args.manifest_dir / "grid_cuts_test.jsonl.gz") cuts = cuts.map_supervisions( - lambda s: replace( - s, - text=" ".join(w for w in s.text.split() if w != "sp") - ) + lambda s: replace(s, text=" ".join(w for w in s.text.split() if w != "sp")) ) return cuts + class VisualFeatureInputStrategy(BatchIO): def __init__(self, frame_shift: float = 0.04): super().__init__() self.frame_shift = frame_shift - + def __call__(self, cuts): - feats = [torch.from_numpy(cut.load_custom("video_features")).float() for cut in cuts] + feats = [ + torch.from_numpy(cut.load_custom("video_features")).float() for cut in cuts + ] lengths = torch.tensor([f.shape[0] for f in feats], dtype=torch.int32) feats = torch.nn.utils.rnn.pad_sequence(feats, batch_first=True) return feats, lengths - + @property def extractor(self): class DummyExtractor: def __init__(self, frame_shift): self.frame_shift = frame_shift + return DummyExtractor(self.frame_shift) def supervision_intervals(self, cuts: CutSet) -> Dict[str, torch.Tensor]: @@ -385,7 +384,7 @@ def supervision_intervals(self, cuts: CutSet) -> Dict[str, torch.Tensor]: ) start_frames.append(start) nums_frames.append(num) - sequence_idx.append(i) + sequence_idx.append(i) return { "sequence_idx": torch.tensor(sequence_idx, dtype=torch.int32), "start_frame": torch.tensor(start_frames, dtype=torch.int32), diff --git a/egs/grid/VSR/conformer_ctc2/conformer.py b/egs/grid/VSR/conformer_ctc2/conformer.py index 144719dd47..8a7766bfb9 100644 --- a/egs/grid/VSR/conformer_ctc2/conformer.py +++ b/egs/grid/VSR/conformer_ctc2/conformer.py @@ -33,11 +33,12 @@ from torch import Tensor, nn from transformer import Supervisions, Transformer, encoder_padding_mask + # SSL Feature Projection class SSLFeatureProjection(nn.Module): def __init__(self, num_features=768, d_model=256): super().__init__() - self.proj = nn.Linear(num_features , d_model) + self.proj = nn.Linear(num_features, d_model) self.dropout = nn.Dropout(p=0.1) self.act = nn.GELU() @@ -96,15 +97,15 @@ def __init__( self.num_features = num_features self.subsampling_factor = subsampling_factor - - # SSL feature input handling + + # SSL feature input handling if num_features == d_model: self.encoder_embed = nn.Identity() else: self.encoder_embed = SSLFeatureProjection( - num_features=num_features, - d_model=d_model, - ) + num_features=num_features, + d_model=d_model, + ) self.input_layer_norm = nn.LayerNorm(d_model) @@ -147,10 +148,10 @@ def run_encoder( Tensor: Mask tensor of dimension (batch_size, input_length) """ x = self.encoder_embed(x) - - #it doesn't seem to help + + # it doesn't seem to help # x = self.input_layer_norm(x) - + x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) diff --git a/egs/grid/VSR/conformer_ctc2/decode.py b/egs/grid/VSR/conformer_ctc2/decode.py index f9cbb33e06..9cb3dc288b 100755 --- a/egs/grid/VSR/conformer_ctc2/decode.py +++ b/egs/grid/VSR/conformer_ctc2/decode.py @@ -60,6 +60,7 @@ write_error_stats, ) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -961,7 +962,7 @@ def main(): test_cuts = grid.test_cuts() test_dl = grid.test_dataloaders(test_cuts) - + test_sets = ["test"] test_dl = [test_dl] diff --git a/egs/grid/VSR/conformer_ctc2/train.py b/egs/grid/VSR/conformer_ctc2/train.py index fdd6124f78..49aa5f92d5 100755 --- a/egs/grid/VSR/conformer_ctc2/train.py +++ b/egs/grid/VSR/conformer_ctc2/train.py @@ -46,6 +46,7 @@ import argparse import copy import logging +import os import warnings from pathlib import Path from shutil import copyfile @@ -53,9 +54,14 @@ import k2 import optim -import os + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +from typing import TYPE_CHECKING + import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler # Optional: set deterministic flags torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @@ -331,7 +337,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 5, "reset_interval": 10, - "valid_interval": 20, + "valid_interval": 20, # parameters for conformer "feature_dim": 768, "subsampling_factor": 1, @@ -466,7 +472,7 @@ def save_checkpoint( if params.best_valid_epoch == params.cur_epoch: best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) - + def compute_loss( params: AttributeDict, @@ -562,7 +568,7 @@ def compute_loss( sos_id=graph_compiler.sos_id, eos_id=graph_compiler.eos_id, ) - + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss else: loss = ctc_loss @@ -704,13 +710,10 @@ def train_one_epoch( if "CUDA out of memory" in str(e): logging.error(f"failing batch size:{batch_size} ") raise - + scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), - max_norm=1.0 - ) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) @@ -766,9 +769,11 @@ def train_one_epoch( ): logging.error("Your loss contains inf, something goes wrong") if tb_writer is not None: - - tb_writer.add_scalar("train/grad_norm", grad_norm, params.batch_idx_train) - + + tb_writer.add_scalar( + "train/grad_norm", grad_norm, params.batch_idx_train + ) + tb_writer.add_scalar( "train/learning_rate", cur_lr, params.batch_idx_train ) @@ -815,7 +820,6 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) - fix_random_seed(params.seed) if world_size > 1: @@ -882,7 +886,7 @@ def run(rank, world_size, args): num_decoder_layers=params.num_decoder_layers, dropout=0.1, layer_dropout=0.1, - dim_feedforward=1024, + dim_feedforward=1024, ) print(model) @@ -925,20 +929,17 @@ def run(rank, world_size, args): if params.print_diagnostics: diagnostic = diagnostics.attach_diagnostics(model) - grid = GridAsrDataModule(args) - cuts = grid.train_all_cuts() + cuts = grid.train_all_cuts() - train_cuts, valid_cuts = grid.split_train_valid(cuts, 0.03, seed=params.seed) - + train_cuts, valid_cuts = grid.split_train_valid(cuts, 0.03, seed=params.seed) # train_dl = grid.train_dataloaders(train_cuts) valid_dl = grid.valid_dataloaders(valid_cuts) - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint # saved in the middle of an epoch @@ -946,9 +947,7 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = grid.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) + train_dl = grid.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) if params.print_diagnostics: scan_pessimistic_batches_for_oom( diff --git a/egs/grid/VSR/conformer_ctc2/transformer.py b/egs/grid/VSR/conformer_ctc2/transformer.py index f6d7f56f3c..375e31f6ae 100644 --- a/egs/grid/VSR/conformer_ctc2/transformer.py +++ b/egs/grid/VSR/conformer_ctc2/transformer.py @@ -938,7 +938,7 @@ def encoder_padding_mask( ( supervisions["sequence_idx"], supervisions["start_frame"], - supervisions["num_frames"], + supervisions["num_frames"], ), 1, ).to(torch.int32) diff --git a/egs/grid/VSR/local/compile_hlg_phone.py b/egs/grid/VSR/local/compile_hlg_phone.py index 3231797a19..2766c484c1 100755 --- a/egs/grid/VSR/local/compile_hlg_phone.py +++ b/egs/grid/VSR/local/compile_hlg_phone.py @@ -57,8 +57,12 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: # first_token_disambig_id = lexicon.token_table["#0"] # first_word_disambig_id = lexicon.word_table["#0"] - first_token_disambig_id = lexicon.token_table["#0"] if "#0" in lexicon.token_table else 10**9 - first_word_disambig_id = lexicon.word_table["#0"] if "#0" in lexicon.word_table else 10**9 + first_token_disambig_id = ( + lexicon.token_table["#0"] if "#0" in lexicon.token_table else 10**9 + ) + first_word_disambig_id = ( + lexicon.word_table["#0"] if "#0" in lexicon.word_table else 10**9 + ) L = k2.arc_sort(L) G = k2.arc_sort(G) diff --git a/egs/grid/VSR/local/compute_avhubert_grid.py b/egs/grid/VSR/local/compute_avhubert_grid.py index 2acd691145..4a3dd003a4 100755 --- a/egs/grid/VSR/local/compute_avhubert_grid.py +++ b/egs/grid/VSR/local/compute_avhubert_grid.py @@ -10,25 +10,25 @@ The generated avhubert features are saved in data/avhubert. """ from __future__ import annotations -import dlib -import torch.nn.functional as F -import cv2 -from fairseq import checkpoint_utils +import argparse +import contextlib +import itertools import logging import os import sys +from concurrent.futures import ProcessPoolExecutor from pathlib import Path + +import cv2 +import dlib +import numpy as np import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from fairseq import checkpoint_utils from lhotse import CutSet, NumpyHdf5Writer from lhotse.recipes.utils import read_manifests_if_cached -import numpy as np -from concurrent.futures import ProcessPoolExecutor -import logging -import torch.multiprocessing as mp -import itertools -import contextlib -import argparse # CLI arguments @@ -52,15 +52,14 @@ def parse_args() -> argparse.Namespace: "--dlib-predictor", type=Path, default=Path("download/dlib/shape_predictor_68_face_landmarks.dat"), - help="Path to the dlib 68-point landmark model. " - "Default: %(default)s", + help="Path to the dlib 68-point landmark model. " "Default: %(default)s", ) parser.add_argument( "--layer", type=int, default=9, help="Number of encoder layers to keep (0-indexed upper bound). " - "Default: %(default)s", + "Default: %(default)s", ) parser.add_argument( "--n-workers", @@ -85,7 +84,7 @@ def parse_args() -> argparse.Namespace: type=int, default=1, help="Run face detection every N frames; reuse previous landmarks otherwise. " - "Default: %(default)s", + "Default: %(default)s", ) parser.add_argument( "--feats-dir", @@ -116,21 +115,23 @@ def load_globals(args: argparse.Namespace): ------- dict with keys: avhubert_utils, detector, predictor, device, model, transform """ - # AV-HuBERT imports + # AV-HuBERT imports if not args.avhubert_code_dir.exists(): - raise FileNotFoundError(f"AV-HuBERT code directory not found: {args.avhubert_code_dir}") + raise FileNotFoundError( + f"AV-HuBERT code directory not found: {args.avhubert_code_dir}" + ) with _avhubert_on_path(args.avhubert_code_dir): from avhubert.utils import Compose, Normalize - # Dlib + # Dlib if not args.dlib_predictor.exists(): raise FileNotFoundError( f"dlib landmark model not found: {args.dlib_predictor}\n" "Download: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2" ) - detector = dlib.get_frontal_face_detector() + detector = dlib.get_frontal_face_detector() predictor = dlib.shape_predictor(str(args.dlib_predictor)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -139,20 +140,24 @@ def load_globals(args: argparse.Namespace): if not args.avhubert_ckpt.exists(): raise FileNotFoundError(f"AV-HuBERT checkpoint not found: {args.avhubert_ckpt}") - models, cfg, task = checkpoint_utils.load_model_ensemble_and_task([str(args.avhubert_ckpt)]) + models, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [str(args.avhubert_ckpt)] + ) model = models[0] - model.encoder.layers = model.encoder.layers[:args.layer] + model.encoder.layers = model.encoder.layers[: args.layer] model.to(device).eval() logging.info( f"Loaded AV-HuBERT checkpoint: {args.avhubert_ckpt} " f"(layers 0–{args.layer - 1}, device: {device})" ) - # Image transform - transform = Compose([ - Normalize(0.0, 255.0), - Normalize(task.cfg.image_mean, task.cfg.image_std), - ]) + # Image transform + transform = Compose( + [ + Normalize(0.0, 255.0), + Normalize(task.cfg.image_mean, task.cfg.image_std), + ] + ) return dict( detector=detector, @@ -211,12 +216,14 @@ def extract_features_from_visual( frames = list(np.load(mouth_frames_file)["frames"]) logging.info(f"Loaded cached mouth frames from {mouth_frames_file}.") else: - # Landmarks (or cache load) + # Landmarks (or cache load) if landmarks_file.exists(): landmarks = np.load(landmarks_file)["landmarks"] logging.info(f"Loaded cached landmarks from {landmarks_file}.") else: - landmarks = _detect_landmarks(video_path, dlib_detector, dlib_predictor ,detect_every) + landmarks = _detect_landmarks( + video_path, dlib_detector, dlib_predictor, detect_every + ) if landmarks is None: return None np.savez_compressed(landmarks_file, landmarks=landmarks.astype(np.int16)) @@ -227,19 +234,23 @@ def extract_features_from_visual( video_path, landmarks, mouth_w, mouth_h, ROI_SIZE, MOUTH_LEFT, MOUTH_RIGHT ) if len(frames) < MIN_FRAMES: - logging.warning(f"Skipping {video}: only {len(frames)} frames (min {MIN_FRAMES}).") + logging.warning( + f"Skipping {video}: only {len(frames)} frames (min {MIN_FRAMES})." + ) return None np.savez_compressed(mouth_frames_file, frames=np.array(frames, dtype=np.uint8)) logging.info(f"Saved mouth frames to {mouth_frames_file}.") if len(frames) < MIN_FRAMES: - logging.warning(f"Skipping {video}: only {len(frames)} frames (min {MIN_FRAMES}).") + logging.warning( + f"Skipping {video}: only {len(frames)} frames (min {MIN_FRAMES})." + ) return None - + frames_np = transform(np.float32(np.stack(frames))) tensor = torch.from_numpy(frames_np).unsqueeze(0).unsqueeze(0).to(device) - # AV-HuBERT feature extraction + # AV-HuBERT feature extraction with torch.no_grad(): features, _ = model.extract_finetune( source={"video": tensor, "audio": None}, @@ -258,7 +269,9 @@ def _open_video(path: Path) -> cv2.VideoCapture: return cap -def _detect_landmarks(video_path: Path, dlib_detector, dlib_predictor, detect_every: int) -> np.ndarray | None: +def _detect_landmarks( + video_path: Path, dlib_detector, dlib_predictor, detect_every: int +) -> np.ndarray | None: """ Run dlib face + landmark detection on every ``detect_every``-th frame. @@ -279,7 +292,9 @@ def _detect_landmarks(video_path: Path, dlib_detector, dlib_predictor, detect_e if frame_idx % detect_every == 0: faces = dlib_detector(gray) if not faces: - logging.warning(f"No face detected in {video_path} at frame {frame_idx}.") + logging.warning( + f"No face detected in {video_path} at frame {frame_idx}." + ) return None shape = dlib_predictor(gray, faces[0]) last_lm = np.array([[p.x, p.y] for p in shape.parts()]) @@ -334,6 +349,7 @@ def _extract_mouth_frames( _worker_globals: dict = {} + def _worker_init(args: argparse.Namespace) -> None: """Called once per worker process to load model and detector.""" global _worker_globals @@ -367,14 +383,21 @@ def process_worker(args: tuple) -> list: - Duplicates the first frame if only 74 frames are returned (expects 75). - All cuts are assigned a fixed duration of 3.0 s at 25 fps. """ - worker_id, recordings_subset, supervisions_subset, feats_dir, partition, layer = args + ( + worker_id, + recordings_subset, + supervisions_subset, + feats_dir, + partition, + layer, + ) = args # Access resources initialised by _worker_init - model = _worker_globals["model"] + model = _worker_globals["model"] transform = _worker_globals["transform"] - detector = _worker_globals["detector"] + detector = _worker_globals["detector"] predictor = _worker_globals["predictor"] - device = _worker_globals["device"] - + device = _worker_globals["device"] + FIXED_DURATION: float = 3.0 FRAME_SHIFT: float = 0.04 EXPECTED_FRAMES: int = 75 @@ -387,11 +410,20 @@ def process_worker(args: tuple) -> list: for recording in recordings_subset: supervision = sup_by_rec.get(recording.id) if supervision is None: - logging.warning(f"Worker {worker_id} - no supervision for {recording.id}, skipping.") + logging.warning( + f"Worker {worker_id} - no supervision for {recording.id}, skipping." + ) continue try: - feats = extract_features_from_visual(recording.sources[0].source, detector, \ - predictor, device, model, transform, layer) + feats = extract_features_from_visual( + recording.sources[0].source, + detector, + predictor, + device, + model, + transform, + layer, + ) if feats is None: logging.warning( @@ -465,7 +497,7 @@ def compute_avhubert_grid(): list(manifests.keys()), dataset_parts, ) - + for partition, m in manifests.items(): logging.info(f"\nProcessing {partition} with {args.n_workers} workers") @@ -480,19 +512,22 @@ def compute_avhubert_grid(): end = min(start + chunk_size, len(recordings)) if start >= end: break - tasks.append(( - i, - recordings[start:end], - supervisions[start:end], - feats_dir, - partition, - args.layer, - - )) + tasks.append( + ( + i, + recordings[start:end], + supervisions[start:end], + feats_dir, + partition, + args.layer, + ) + ) all_cuts = [] - with ProcessPoolExecutor(max_workers=args.n_workers, initializer=_worker_init, initargs=(args,)) as ex: + with ProcessPoolExecutor( + max_workers=args.n_workers, initializer=_worker_init, initargs=(args,) + ) as ex: results = ex.map(process_worker, tasks) for worker_cuts in results: all_cuts.extend(worker_cuts) @@ -501,12 +536,14 @@ def compute_avhubert_grid(): cut_set = CutSet.from_cuts(all_cuts) cut_set.to_file(feats_dir / f"grid_cuts_{partition}.jsonl.gz") - logging.info(f"Done {partition} → {len(all_cuts)} cuts, stored in {args.n_workers} .h5 files") - + logging.info( + f"Done {partition} → {len(all_cuts)} cuts, stored in {args.n_workers} .h5 files" + ) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = parse_args() - mp.set_start_method('spawn', force=True) + mp.set_start_method("spawn", force=True) compute_avhubert_grid() diff --git a/egs/grid/VSR/local/split_manifests.py b/egs/grid/VSR/local/split_manifests.py index 2afe62c47a..2b0b656536 100755 --- a/egs/grid/VSR/local/split_manifests.py +++ b/egs/grid/VSR/local/split_manifests.py @@ -1,6 +1,7 @@ -from lhotse import load_manifest -from pathlib import Path import os +from pathlib import Path + +from lhotse import load_manifest # --- Configuration --- supervisions_path = Path("data/manifests/grid_supervisions.jsonl.gz") @@ -9,9 +10,11 @@ # --- Unseen speaker setup --- test_speakers = {"s1", "s2", "s20", "s22"} + def video_exists(recording): return all(os.path.exists(source.source) for source in recording.sources) + recordings = load_manifest(recordings_path) recordings = recordings.filter(video_exists) @@ -21,12 +24,11 @@ def video_exists(recording): def get_speaker_id(rec_id): return rec_id.split("_")[0] + train_recordings = recordings.filter( lambda rec: get_speaker_id(rec.id) not in test_speakers ) -test_recordings = recordings.filter( - lambda rec: get_speaker_id(rec.id) in test_speakers -) +test_recordings = recordings.filter(lambda rec: get_speaker_id(rec.id) in test_speakers) train_recordings.to_file(output_dir / "grid_recordings_train.jsonl.gz") test_recordings.to_file(output_dir / "grid_recordings_test.jsonl.gz") @@ -43,4 +45,4 @@ def get_speaker_id(rec_id): print(f"Done splitting RecordingSet by speaker.") print(f"Train recordings: {len(train_recordings)}") -print(f"Test recordings: {len(test_recordings)}") \ No newline at end of file +print(f"Test recordings: {len(test_recordings)}") diff --git a/egs/iwslt22_ta/ASR/local/compute_fbank_gpu.py b/egs/iwslt22_ta/ASR/local/compute_fbank_gpu.py index e96576bac4..5c436a0aab 100755 --- a/egs/iwslt22_ta/ASR/local/compute_fbank_gpu.py +++ b/egs/iwslt22_ta/ASR/local/compute_fbank_gpu.py @@ -23,29 +23,29 @@ The generated fbank features are saved in data/fbank. """ +import argparse import logging import os from pathlib import Path -import argparse import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - from lhotse.features.kaldifeat import ( KaldifeatFbank, KaldifeatFbankConfig, KaldifeatFrameOptions, KaldifeatMelOptions, ) +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect # even when we are not invoking the main (e.g. when spawning subprocesses). + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -114,20 +114,17 @@ def compute_fbank_gpu(args): cut_set = cut_set.resample(sr) cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, - keep_all_channels=False) - cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30) + keep_overlapping=False, keep_all_channels=False + ) + cut_set = cut_set.filter(lambda c: c.duration >= 0.2 and c.duration <= 30) if "train" in partition: - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set = cut_set.to_eager() chunk_size = len(cut_set) // args.num_splits cut_sets = cut_set.split_lazy( output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}", - chunk_size=chunk_size,) + chunk_size=chunk_size, + ) start = args.start stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits num_digits = len(str(args.num_splits)) @@ -157,10 +154,9 @@ def compute_fbank_gpu(args): ) cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz") + if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/iwslt22_ta/ASR/local/prepare_transcripts.py b/egs/iwslt22_ta/ASR/local/prepare_transcripts.py index 4a7e2b1c10..aedd37765d 120000 --- a/egs/iwslt22_ta/ASR/local/prepare_transcripts.py +++ b/egs/iwslt22_ta/ASR/local/prepare_transcripts.py @@ -1 +1 @@ -/exp/ahussein/tmp/icefall/egs/iwslt22_ta/ST/local/prepare_transcripts.py \ No newline at end of file +../../ST/local/prepare_transcripts.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/asr_datamodule.py index dd1ba4a8db..a28daa601d 100644 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -201,14 +201,10 @@ def train_dataloaders( if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -230,9 +226,7 @@ def train_dataloaders( input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -275,9 +269,7 @@ def train_dataloaders( # Drop feats to be on the safe side. train = K2Speech2textTranslationDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -332,8 +324,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2Speech2textTranslationDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -360,8 +351,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2Speech2textTranslationDataset( - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80))) + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats else PrecomputedFeatures(), return_cuts=self.args.return_cuts, @@ -381,9 +371,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_train_shuf.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz") @lru_cache() def dev_cuts(self) -> CutSet: @@ -393,4 +381,4 @@ def dev_cuts(self) -> CutSet: @lru_cache() def test_cuts(self) -> CutSet: logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz") \ No newline at end of file + return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz") diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode.py index 591870029a..6e02992645 100755 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode.py @@ -44,7 +44,7 @@ from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple -from lhotse.qa import validate_cut + import k2 import sentencepiece as spm import torch @@ -61,6 +61,7 @@ modified_beam_search, modified_beam_search_rnnlm_shallow_fusion, ) +from lhotse.qa import validate_cut from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -420,8 +421,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -556,6 +556,8 @@ def decode_one_batch( return {key: hyps} else: return {f"beam_size_{params.beam_size}": hyps} + + def remove_short_and_long_utt(c): # Keep only utterances with duration between 1 second and 20 seconds # @@ -566,9 +568,9 @@ def remove_short_and_long_utt(c): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 0.5 or c.duration > 30.0: - #logging.warning( + # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) + # ) return False if c.supervisions == []: return False @@ -581,12 +583,14 @@ def remove_short_and_long_utt(c): return True + # def remove_seg(c): # if c.supervisions[0].id != 'fla_0102_1_0B_00107': # return True # else: # return False + def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, @@ -664,8 +668,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -677,8 +680,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / - f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -699,8 +701,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / - f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -809,8 +810,7 @@ def main(): ) logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device)) + model.load_state_dict(average_checkpoints(filenames, device=device)) elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: @@ -821,8 +821,7 @@ def main(): filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device)) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ @@ -905,8 +904,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None @@ -933,7 +931,6 @@ def main(): test_sets = ["test", "dev"] test_all_dl = [test_dl, dev_dl] - for test_set, test_dl in zip(test_sets, test_all_dl): results_dict = decode_dataset( diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/finetune.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/finetune.py index ad474dbc64..67ad27b4ea 100755 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/finetune.py +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/finetune.py @@ -89,9 +89,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -248,8 +246,7 @@ def get_parser(): "--initial-lr", type=float, default=0.0003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -272,8 +269,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -296,8 +292,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -648,11 +643,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -700,14 +691,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -715,9 +701,7 @@ def compute_loss( with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa # info["utterances"] = feature.size(0) @@ -842,14 +826,10 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - warmup=( - params.batch_idx_train / params.model_warm_step - ), + warmup=(params.batch_idx_train / params.model_warm_step), ) # summary stats - tot_loss = ( - tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -906,7 +886,7 @@ def train_one_epoch( if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea - #memory_debugging() + # memory_debugging() logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " @@ -936,9 +916,7 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -1056,13 +1034,13 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) MGB2 = MGB2AsrDataModule(args) train_cuts = MGB2.train_cuts() - #pdb.set_trace() + # pdb.set_trace() # def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 30 seconds # @@ -1083,9 +1061,9 @@ def remove_short_and_long_utt(c: Cut): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 0.5 or c.duration > 30.0: - #logging.warning( + # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) + # ) return False if c.supervisions == []: return False @@ -1115,18 +1093,19 @@ def remove_short_and_long_text(c: Cut): # Keep only text with charachters between 20 and 400 return 20 <= len(c.supervisions[0].text) <= 400 + # def remove_seg(c: Cut): # # Keep only text with charachters between 20 and 400 - # return c.supervisions[0].id != "arb_glf-20040221_024028_0A_00102" - #logging.info(f"Total duration before filtering {train_cuts.describe()}") + # return c.supervisions[0].id != "arb_glf-20040221_024028_0A_00102" + # logging.info(f"Total duration before filtering {train_cuts.describe()}") train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_text) # train_cuts = train_cuts.filter(remove_seg) # for c in train_cuts: # if c.supervisions[0].id == "arb_glf-20040221_024028_0A_00102": - # print(c) - #logging.info(f"Total duration after filtering {train_cuts.describe()}") + # print(c) + # logging.info(f"Total duration after filtering {train_cuts.describe()}") if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1135,9 +1114,7 @@ def remove_short_and_long_text(c: Cut): else: sampler_state_dict = None - train_dl = MGB2.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) + train_dl = MGB2.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) valid_cuts = MGB2.dev_cuts() valid_cuts = valid_cuts.filter(remove_short_and_long_utt) valid_cuts = valid_cuts.filter(remove_short_and_long_text) diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/finetune_loadmodel.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/finetune_loadmodel.py index 26d255dd09..bf019fc2d3 100755 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/finetune_loadmodel.py +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/finetune_loadmodel.py @@ -89,9 +89,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -254,8 +252,7 @@ def get_parser(): "--initial-lr", type=float, default=0.00001, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -278,8 +275,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -302,8 +298,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -657,11 +652,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -709,14 +700,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -724,9 +710,7 @@ def compute_loss( with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa # info["utterances"] = feature.size(0) @@ -851,14 +835,10 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - warmup=( - params.batch_idx_train / params.model_warm_step - ), + warmup=(params.batch_idx_train / params.model_warm_step), ) # summary stats - tot_loss = ( - tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -915,7 +895,7 @@ def train_one_epoch( if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea - #memory_debugging() + # memory_debugging() logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " @@ -945,9 +925,7 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -1039,7 +1017,7 @@ def run(rank, world_size, args): if rank == 0: # model_avg is only used with rank 0 model_avg = copy.deepcopy(model) - + checkpoints = load_checkpoint_if_available(params=params, model=model) model.to(device) @@ -1065,13 +1043,13 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) MGB2 = MGB2AsrDataModule(args) train_cuts = MGB2.train_cuts() - #pdb.set_trace() + # pdb.set_trace() # def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 30 seconds # @@ -1092,9 +1070,9 @@ def remove_short_and_long_utt(c: Cut): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 0.5 or c.duration > 30.0: - #logging.warning( + # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) + # ) return False if c.supervisions == []: return False @@ -1124,7 +1102,7 @@ def remove_short_and_long_text(c: Cut): # Keep only text with charachters between 20 and 400 return 3 <= len(c.supervisions[0].text) <= 400 - + # def remove_bad(c: Cut): # tol = 1e-2 # if (c.supervisions[0].end > c.duration + tol) or (c.supervisions[0].start < -tol): @@ -1132,16 +1110,15 @@ def remove_short_and_long_text(c: Cut): # return False # else: # return True - - #logging.info(f"Total duration before filtering {train_cuts.describe()}") + # logging.info(f"Total duration before filtering {train_cuts.describe()}") train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_text) # train_cuts = train_cuts.filter(remove_bad) # train_cuts = train_cuts.filter(remove_seg) # for c in train_cuts: # if c.supervisions[0].id == "arb_glf-20040221_024028_0A_00102": - # print(c) + # print(c) logging.info(f"Total duration after filtering {train_cuts.describe()}") if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: @@ -1151,9 +1128,7 @@ def remove_short_and_long_text(c: Cut): else: sampler_state_dict = None - train_dl = MGB2.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) + train_dl = MGB2.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) valid_cuts = MGB2.dev_cuts() valid_cuts = valid_cuts.filter(remove_short_and_long_utt) valid_cuts = valid_cuts.filter(remove_short_and_long_text) diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/pretrained.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/pretrained.py index 1e100fcbd7..77ba0873b5 100755 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -199,8 +198,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/train.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/train.py index 68f01dfac4..79d6c2e806 100755 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/train.py +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/train.py @@ -76,9 +76,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -240,8 +238,7 @@ def get_parser(): "--initial-lr", type=float, default=0.001, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -264,8 +261,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -288,8 +284,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -641,11 +636,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -653,7 +644,7 @@ def compute_loss( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - #pdb.set_trace() + # pdb.set_trace() texts = batch["supervisions"]["text"] tgt_texts = batch["supervisions"]["tgt_text"] y = sp.encode(texts, out_type=int) @@ -696,14 +687,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -711,9 +697,7 @@ def compute_loss( with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa # info["utterances"] = feature.size(0) @@ -837,14 +821,10 @@ def train_one_epoch( sp_tgt=sp_tgt, batch=batch, is_training=True, - warmup=( - params.batch_idx_train / params.model_warm_step - ), + warmup=(params.batch_idx_train / params.model_warm_step), ) # summary stats - tot_loss = ( - tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -901,7 +881,7 @@ def train_one_epoch( if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea - #memory_debugging() + # memory_debugging() logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " @@ -927,14 +907,12 @@ def train_one_epoch( params=params, model=model, sp=sp, - sp_tgt=sp_tgt, + sp_tgt=sp_tgt, valid_dl=valid_dl, world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -1056,7 +1034,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1082,9 +1060,9 @@ def remove_short_and_long_utt(c: Cut): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 0.3 or c.duration > 30.0: - #logging.warning( + # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) + # ) return False if c.supervisions == []: return False @@ -1114,10 +1092,11 @@ def remove_short_and_long_text(c: Cut): # Keep only text with charachters between 20 and 400 return 3 <= len(c.supervisions[0].text) <= 400 - #logging.info(f"Total duration before filtering {train_cuts.describe()}") + + # logging.info(f"Total duration before filtering {train_cuts.describe()}") train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_text) - #logging.info(f"Total duration after filtering {train_cuts.describe()}") + # logging.info(f"Total duration after filtering {train_cuts.describe()}") if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1226,7 +1205,7 @@ def display_and_save_batch( logging.info(f"features shape: {features.shape}") y = sp.encode(supervisions["text"], out_type=int) - #y_tgt = sp_tgt.encode(supervisions["tgt_text"], out_type=int) + # y_tgt = sp_tgt.encode(supervisions["tgt_text"], out_type=int) num_tokens = sum(len(i) for i in y) logging.info(f"num tokens: {num_tokens}") diff --git a/egs/iwslt22_ta/ASR/zipformer/asr_datamodule.py b/egs/iwslt22_ta/ASR/zipformer/asr_datamodule.py index dd1ba4a8db..a28daa601d 100644 --- a/egs/iwslt22_ta/ASR/zipformer/asr_datamodule.py +++ b/egs/iwslt22_ta/ASR/zipformer/asr_datamodule.py @@ -201,14 +201,10 @@ def train_dataloaders( if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -230,9 +226,7 @@ def train_dataloaders( input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -275,9 +269,7 @@ def train_dataloaders( # Drop feats to be on the safe side. train = K2Speech2textTranslationDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -332,8 +324,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2Speech2textTranslationDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -360,8 +351,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2Speech2textTranslationDataset( - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80))) + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats else PrecomputedFeatures(), return_cuts=self.args.return_cuts, @@ -381,9 +371,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_train_shuf.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz") @lru_cache() def dev_cuts(self) -> CutSet: @@ -393,4 +381,4 @@ def dev_cuts(self) -> CutSet: @lru_cache() def test_cuts(self) -> CutSet: logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz") \ No newline at end of file + return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz") diff --git a/egs/iwslt22_ta/ASR/zipformer/decode.py b/egs/iwslt22_ta/ASR/zipformer/decode.py index 6f5ec49e61..1075461f6c 100755 --- a/egs/iwslt22_ta/ASR/zipformer/decode.py +++ b/egs/iwslt22_ta/ASR/zipformer/decode.py @@ -285,8 +285,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -383,9 +382,7 @@ def decode_one_batch( src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = model.encoder( - x, x_lens, src_key_padding_mask - ) + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) hyps = [] @@ -445,10 +442,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -580,9 +574,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -615,8 +607,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -642,7 +633,7 @@ def main(): params.update(vars(args)) # use predefined parameters that were used during the training - # params.num_encoder_layers = "2,2,2,2,2,2" + # params.num_encoder_layers = "2,2,2,2,2,2" # params.feedforward_dim = "256,512,768,1024,768,512" # params.encoder_dim = "128,256,256,512,256,256" # params.encoder_unmasked_dim = "64,128,128,256,128,128" @@ -683,9 +674,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -717,9 +706,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -746,9 +735,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -807,9 +796,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/iwslt22_ta/ASR/zipformer/profile.py b/egs/iwslt22_ta/ASR/zipformer/profile.py index b460b53389..b1f1c0e4d3 100755 --- a/egs/iwslt22_ta/ASR/zipformer/profile.py +++ b/egs/iwslt22_ta/ASR/zipformer/profile.py @@ -22,24 +22,24 @@ import argparse import logging -import sentencepiece as spm -import torch - from typing import Tuple -from torch import Tensor, nn -from icefall.utils import make_pad_mask -from icefall.profiler import get_model_profile +import sentencepiece as spm +import torch from scaling import BiasNorm +from torch import Tensor, nn from train import ( + add_model_arguments, get_encoder_embed, get_encoder_model, get_joiner_model, - add_model_arguments, get_params, ) from zipformer import BypassModule +from icefall.profiler import get_model_profile +from icefall.utils import make_pad_mask + def get_parser(): parser = argparse.ArgumentParser( @@ -100,17 +100,13 @@ def __init__( self.encoder_embed = encoder_embed self.encoder_proj = encoder_proj - def forward( - self, feature: Tensor, feature_lens: Tensor - ) -> Tuple[Tensor, Tensor]: + def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]: x, x_lens = self.encoder_embed(feature, feature_lens) src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = self.encoder( - x, x_lens, src_key_padding_mask - ) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C) logits = self.encoder_proj(encoder_out) @@ -168,9 +164,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/iwslt22_ta/ASR/zipformer/train.py b/egs/iwslt22_ta/ASR/zipformer/train.py index 5acae85c6e..170165d438 100755 --- a/egs/iwslt22_ta/ASR/zipformer/train.py +++ b/egs/iwslt22_ta/ASR/zipformer/train.py @@ -1208,9 +1208,9 @@ def remove_short_and_long_utt(c: Cut): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 0.3 or c.duration > 30.0: - #logging.warning( + # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) + # ) return False if c.supervisions == []: return False @@ -1240,6 +1240,7 @@ def remove_short_and_long_text(c: Cut): # Keep only text with charachters between 20 and 400 return 3 <= len(c.supervisions[0].text) <= 400 + train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_text) diff --git a/egs/iwslt22_ta/ST/local/cer.py b/egs/iwslt22_ta/ST/local/cer.py index 3635d2e22d..55670670e8 100644 --- a/egs/iwslt22_ta/ST/local/cer.py +++ b/egs/iwslt22_ta/ST/local/cer.py @@ -6,53 +6,51 @@ This script computes CER for the decodings generated by icefall recipe """ -import argparse +import argparse +import os + import jiwer -import os + def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--dec-file", - type=str, - help="file with decoded text" - ) + parser.add_argument("--dec-file", type=str, help="file with decoded text") return parser + def cer_(file): hyp = [] ref = [] cer_results = 0 ref_lens = 0 - with open(file, 'r', encoding='utf-8') as dec: - + with open(file, "r", encoding="utf-8") as dec: + for line in dec: - id, target = line.split('\t') + id, target = line.split("\t") id = id[0:-2] target, txt = target.split("=") - if target == 'ref': - words = txt.strip().strip('[]').split(', ') + if target == "ref": + words = txt.strip().strip("[]").split(", ") word_list = [word.strip("'") for word in words] ref.append(" ".join(word_list)) - elif target == 'hyp': - words = txt.strip().strip('[]').split(', ') + elif target == "hyp": + words = txt.strip().strip("[]").split(", ") word_list = [word.strip("'") for word in words] hyp.append(" ".join(word_list)) for h, r in zip(hyp, ref): - #breakpoint() - cer_results += (jiwer.cer(r, h)*len(r)) + # breakpoint() + cer_results += jiwer.cer(r, h) * len(r) ref_lens += len(r) print(os.path.basename(file)) - print(cer_results/ref_lens) - - + print(cer_results / ref_lens) def main(): parse = get_args() - args = parse.parse_args() + args = parse.parse_args() cer_(args.dec_file) - + + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/iwslt22_ta/ST/local/compute_fbank_gpu.py b/egs/iwslt22_ta/ST/local/compute_fbank_gpu.py index 05ed0a74a9..9960fe5dd0 100755 --- a/egs/iwslt22_ta/ST/local/compute_fbank_gpu.py +++ b/egs/iwslt22_ta/ST/local/compute_fbank_gpu.py @@ -23,29 +23,29 @@ The generated fbank features are saved in data/fbank. """ +import argparse import logging import os from pathlib import Path -import argparse import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - from lhotse.features.kaldifeat import ( KaldifeatFbank, KaldifeatFbankConfig, KaldifeatFrameOptions, KaldifeatMelOptions, ) +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect # even when we are not invoking the main (e.g. when spawning subprocesses). + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -117,20 +117,17 @@ def compute_fbank_gpu(args): cut_set = cut_set.resample(sr) cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, - keep_all_channels=False) - cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30) + keep_overlapping=False, keep_all_channels=False + ) + cut_set = cut_set.filter(lambda c: c.duration >= 0.2 and c.duration <= 30) if "train" in partition: - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set = cut_set.to_eager() chunk_size = len(cut_set) // args.num_splits cut_sets = cut_set.split_lazy( output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}", - chunk_size=chunk_size,) + chunk_size=chunk_size, + ) start = args.start stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits num_digits = len(str(args.num_splits)) @@ -160,10 +157,9 @@ def compute_fbank_gpu(args): ) cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz") + if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/iwslt22_ta/ST/local/convert_transcript_words_to_tokens.py b/egs/iwslt22_ta/ST/local/convert_transcript_words_to_tokens.py index 133499c8bc..a8d5117c97 100755 --- a/egs/iwslt22_ta/ST/local/convert_transcript_words_to_tokens.py +++ b/egs/iwslt22_ta/ST/local/convert_transcript_words_to_tokens.py @@ -51,16 +51,12 @@ def get_args(): "lines. Each line consists of space separated words.", ) parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument( - "--oov", type=str, default="", help="The OOV word." - ) + parser.add_argument("--oov", type=str, default="", help="The OOV word.") return parser.parse_args() -def process_line( - lexicon: Dict[str, List[str]], line: str, oov_token: str -) -> None: +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: """ Args: lexicon: diff --git a/egs/iwslt22_ta/ST/local/cuts_validate.py b/egs/iwslt22_ta/ST/local/cuts_validate.py index 16b3f18059..446690fb16 100755 --- a/egs/iwslt22_ta/ST/local/cuts_validate.py +++ b/egs/iwslt22_ta/ST/local/cuts_validate.py @@ -7,11 +7,11 @@ and CutSets """ -from lhotse import RecordingSet, SupervisionSet, CutSet import argparse import logging -from lhotse.qa import fix_manifests, validate_recordings_and_supervisions +from lhotse import CutSet, RecordingSet, SupervisionSet +from lhotse.qa import fix_manifests, validate_recordings_and_supervisions def get_parser(): @@ -45,13 +45,12 @@ def get_parser(): help="name of the cutset to be saved", ) - - return parser + def valid_asr(cut): tol = 2e-3 - i=0 + i = 0 total_dur = 0 for c in cut: if c.supervisions != []: @@ -59,10 +58,14 @@ def valid_asr(cut): logging.info(f"Supervision beyond the cut. Cut number: {i}") total_dur += c.duration - logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}") + logging.info( + f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}" + ) elif c.supervisions[0].start < -tol: logging.info(f"Supervision starts before the cut. Cut number: {i}") - logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}") + logging.info( + f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}" + ) else: continue else: @@ -70,7 +73,7 @@ def valid_asr(cut): logging.info(f"id: {c.id}") i += 1 logging.info(f"filtered duration: {total_dur}") - + def main(): @@ -91,8 +94,11 @@ def main(): logging.info("Validating manifests") validate_recordings_and_supervisions(recordings, supervisions) - - cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,) + + cuts = CutSet.from_manifests( + recordings=recordings, + supervisions=supervisions, + ) cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) logging.info("Example from cut:") @@ -103,5 +109,6 @@ def main(): if args.savecut != "": cuts.to_file(args.savecut) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/iwslt22_ta/ST/local/prepare_lexicon.py b/egs/iwslt22_ta/ST/local/prepare_lexicon.py index 8075795035..1997f2741b 100755 --- a/egs/iwslt22_ta/ST/local/prepare_lexicon.py +++ b/egs/iwslt22_ta/ST/local/prepare_lexicon.py @@ -25,9 +25,7 @@ def main(): for line in f: line = line.strip() characters = list(line) - characters = " ".join( - ["V" if char == "*" else char for char in characters] - ) + characters = " ".join(["V" if char == "*" else char for char in characters]) lex[line] = characters with open(args.output, "w", encoding="utf-8") as fp: diff --git a/egs/iwslt22_ta/ST/local/prepare_transcripts.py b/egs/iwslt22_ta/ST/local/prepare_transcripts.py index c4e1398299..28768a3c44 100755 --- a/egs/iwslt22_ta/ST/local/prepare_transcripts.py +++ b/egs/iwslt22_ta/ST/local/prepare_transcripts.py @@ -5,12 +5,13 @@ This script prepares transcript_words.txt from cutset """ -from lhotse import CutSet import argparse import logging +import os import pdb from pathlib import Path -import os + +from lhotse import CutSet def get_parser(): @@ -36,7 +37,7 @@ def get_parser(): help="name of the target lang-dir", ) return parser - + def main(): @@ -50,17 +51,20 @@ def main(): langdirs = [Path(args.src_langdir), Path(args.tgt_langdir)] else: langdirs = [Path(args.src_langdir)] - + for langdir in langdirs: if not os.path.exists(langdir): os.makedirs(langdir) - with open(langdirs[0] / "transcript_words.txt", 'w') as src, open(langdirs[1] / "transcript_words.txt", 'w') as tgt: + with open(langdirs[0] / "transcript_words.txt", "w") as src, open( + langdirs[1] / "transcript_words.txt", "w" + ) as tgt: for c in cuts: src_txt = c.supervisions[0].text - tgt_txt = c.supervisions[0].custom['translated_text']['eng'] - src.write(src_txt + '\n') - tgt.write(tgt_txt + '\n') + tgt_txt = c.supervisions[0].custom["translated_text"]["eng"] + src.write(src_txt + "\n") + tgt.write(tgt_txt + "\n") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/iwslt22_ta/ST/local/transcript_cleaning.py b/egs/iwslt22_ta/ST/local/transcript_cleaning.py index ecc8765bb2..48634c9bf6 100755 --- a/egs/iwslt22_ta/ST/local/transcript_cleaning.py +++ b/egs/iwslt22_ta/ST/local/transcript_cleaning.py @@ -1,20 +1,21 @@ #!/usr/bin/env python # Copyright 2020 Kanari AI (Amir Hussein) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +import argparse +import os import pdb -import numpy as np -import pandas as pd import re import string -import argparse import sys -import os + +import numpy as np +import pandas as pd import pyarabic.number as number from pyarabic import araby -_unicode = u"\u0622\u0624\u0626\u0628\u062a\u062c\u06af\u062e\u0630\u0632\u0634\u0636\u0638\u063a\u0640\u0642\u0644\u0646\u0648\u064a\u064c\u064e\u0650\u0652\u0670\u067e\u0686\u0621\u0623\u0625\u06a4\u0627\u0629\u062b\u062d\u062f\u0631\u0633\u0635\u0637\u0639\u0641\u0643\u0645\u0647\u0649\u064b\u064d\u064f\u0651\u0671" -_buckwalter = u"|&}btjGx*z$DZg_qlnwyNaio`PJ'>_()*&^%][ـ،/:"؟.,'{}~¦+|!”…“–ـ''' - english_punctuations = string.punctuation - all_punctuations = set(arabic_punctuations + english_punctuations)-{'@','%','.'} # remove all non verbatim punctuations - - for p in all_punctuations: - if p in text: - text = text.replace(p, '') - text = re.sub('\s+\.','',text) # keep only the "." that is part of a word: marsad@aljazeera.net . => marsad@aljazeera.net - return text + """This function removes all punctuations except the verbatim""" + + arabic_punctuations = """﴿﴾‘`÷×؛<>_()*&^%][ـ،/:"؟.,'{}~¦+|!”…“–ـ""" + english_punctuations = string.punctuation + all_punctuations = set(arabic_punctuations + english_punctuations) - { + "@", + "%", + ".", + } # remove all non verbatim punctuations + + for p in all_punctuations: + if p in text: + text = text.replace(p, "") + text = re.sub( + "\s+\.", "", text + ) # keep only the "." that is part of a word: marsad@aljazeera.net . => marsad@aljazeera.net + return text + def remove_extra_space(text): - text = text.lower() - text = re.sub('\s+', ' ', text) - text = re.sub('\s+\.\s+', '.', text) - return text + text = text.lower() + text = re.sub("\s+", " ", text) + text = re.sub("\s+\.\s+", ".", text) + return text + def remove_dot(text): - words = text.split() - res = [] - for word in words: - if word.replace('.','').isnumeric(): # remove the dot if it is not part of a number - res.append(word) + words = text.split() + res = [] + for word in words: + if word.replace( + ".", "" + ).isnumeric(): # remove the dot if it is not part of a number + res.append(word) + + else: + word = re.sub("\s+\.", "", word) + res.append(word) + + return " ".join(res) + - else: - word = re.sub('\s+\.','',word) - res.append(word) - - return " ".join(res) - def east_to_west_num(text): - eastern_to_western = {"٠":"0","١":"1","٢":"2","٣":"3","٤":"4","٥":"5","٦":"6","٧":"7","٨":"8","٩":"9","٪":"%","_":" ","ڤ":"ف","|":" "} - trans_string = str.maketrans(eastern_to_western) - return text.translate(trans_string) - + eastern_to_western = { + "٠": "0", + "١": "1", + "٢": "2", + "٣": "3", + "٤": "4", + "٥": "5", + "٦": "6", + "٧": "7", + "٨": "8", + "٩": "9", + "٪": "%", + "_": " ", + "ڤ": "ف", + "|": " ", + } + trans_string = str.maketrans(eastern_to_western) + return text.translate(trans_string) + + def remove_repeating_char(text): - return re.sub(r'(.)\1+', r'\1', text) - + return re.sub(r"(.)\1+", r"\1", text) + + def remove_single_char_word(text): - """ - Remove single character word from text - Example: I am in a a home for two years => am in home for two years - Args: - text (str): text - Returns: - (str): text with single char removed - """ - words = text.split() - - filter_words = [word for word in words if len(word) > 1 or word.isnumeric()] - return " ".join(filter_words) + """ + Remove single character word from text + Example: I am in a a home for two years => am in home for two years + Args: + text (str): text + Returns: + (str): text with single char removed + """ + words = text.split() + + filter_words = [word for word in words if len(word) > 1 or word.isnumeric()] + return " ".join(filter_words) + def seperate_english_characters(text): - text = text.lower() - res = re.findall(r'[a-z]+', text) # search for english - for match in res: - if match not in {'.',' '}: - text = re.sub(match, " "+ match+ " ",text) - text = re.sub('\s+', ' ', text) - return text + text = text.lower() + res = re.findall(r"[a-z]+", text) # search for english + for match in res: + if match not in {".", " "}: + text = re.sub(match, " " + match + " ", text) + text = re.sub("\s+", " ", text) + return text + def digit2num(text, dig2num=False): - """ This function is used to clean numbers""" - - # search for numbers with spaces - # 100 . 000 => 100.000 - - res = re.search('[0-9]+\s\.\s[0-9]+', text) - if res != None: - t = re.sub(r'\s', '', res[0]) - text = re.sub(res[0], t, text) - - # seperate numbers glued with words - # 3أشهر => 3 أشهر - # من10الى15 => من 10 الى 15 - # pdb.set_trace() - res = re.findall(r'[^\u0600-\u06FF\%\@a-z]+', text) # search for digits - for match in res: - if match not in {'.',' '}: - text = re.sub(match, " "+ match+ " ",text) - text = re.sub('\s+', ' ', text) - - # transliterate numbers to digits - # 13 => ثلاثة عشر - - if dig2num == True: - words = araby.tokenize(text) - for i in range(len(words)): - digit = re.sub(r'[\u0600-\u06FF]+', '', words[i]) - if digit.isnumeric(): - sub_word = re.sub(r'[^\u0600-\u06FF]+', '', words[i]) - if number.number2text(digit) != 'صفر': - words[i] = sub_word + number.number2text(digit) - else: - pass - - return " ".join(words) - else: - return text - + """This function is used to clean numbers""" + + # search for numbers with spaces + # 100 . 000 => 100.000 + + res = re.search("[0-9]+\s\.\s[0-9]+", text) + if res != None: + t = re.sub(r"\s", "", res[0]) + text = re.sub(res[0], t, text) + + # seperate numbers glued with words + # 3أشهر => 3 أشهر + # من10الى15 => من 10 الى 15 + # pdb.set_trace() + res = re.findall(r"[^\u0600-\u06FF\%\@a-z]+", text) # search for digits + for match in res: + if match not in {".", " "}: + text = re.sub(match, " " + match + " ", text) + text = re.sub("\s+", " ", text) + + # transliterate numbers to digits + # 13 => ثلاثة عشر + + if dig2num == True: + words = araby.tokenize(text) + for i in range(len(words)): + digit = re.sub(r"[\u0600-\u06FF]+", "", words[i]) + if digit.isnumeric(): + sub_word = re.sub(r"[^\u0600-\u06FF]+", "", words[i]) + if number.number2text(digit) != "صفر": + words[i] = sub_word + number.number2text(digit) + else: + pass + + return " ".join(words) + else: + return text + + def data_cleaning(text): - # text = remove_special_words(text) - text = remove_punctuations(text) - text = remove_single_char_word(text) - text = remove_diacritics(text) - text = seperate_english_characters(text) - text = remove_extra_space(text) - text = remove_dot(text) - text = east_to_west_num(text) - text = digit2num(text, True) - text = normalizeArabic(text) - #text = re.sub(r'#\w{1,2}\b', '', text) # text = re.sub(r'#\w{1,3}\b', '', text) - #text = remove_hashes(text) - #text = normalizeArabic(text) - return text + # text = remove_special_words(text) + text = remove_punctuations(text) + text = remove_single_char_word(text) + text = remove_diacritics(text) + text = seperate_english_characters(text) + text = remove_extra_space(text) + text = remove_dot(text) + text = east_to_west_num(text) + text = digit2num(text, True) + text = normalizeArabic(text) + # text = re.sub(r'#\w{1,2}\b', '', text) # text = re.sub(r'#\w{1,3}\b', '', text) + # text = remove_hashes(text) + # text = normalizeArabic(text) + return text + def main(): - input_file = sys.argv[1] # input transcription file with format - to_BW = str(sys.argv[2]) # transform to BW True|False - output_file=sys.argv[3] # output file name - data = read_tsv(input_file) - new_data = [] - for i in range(len(data)): - - tokens = data[i][0].split() - - tokens[1:] = data_cleaning(" ".join(tokens[1:])).split() - #tokens = data_cleaning(" ".join(tokens)).split() - if to_BW == "True": - for i in range(len(tokens[1:])): - tokens[i+1] = fromBuckWalter(tokens[i+1]) - new_data.append(" ".join(tokens)) - else: - new_data.append(" ".join(tokens)) + input_file = sys.argv[1] # input transcription file with format + to_BW = str(sys.argv[2]) # transform to BW True|False + output_file = sys.argv[3] # output file name + data = read_tsv(input_file) + new_data = [] + for i in range(len(data)): + + tokens = data[i][0].split() + + tokens[1:] = data_cleaning(" ".join(tokens[1:])).split() + # tokens = data_cleaning(" ".join(tokens)).split() + if to_BW == "True": + for i in range(len(tokens[1:])): + tokens[i + 1] = fromBuckWalter(tokens[i + 1]) + new_data.append(" ".join(tokens)) + else: + new_data.append(" ".join(tokens)) + + df = pd.DataFrame(data=new_data) + df.to_csv(output_file, sep="\n", header=False, index=False) + - df = pd.DataFrame(data=new_data) - df.to_csv(output_file, sep = '\n', header=False, index=False) - if __name__ == "__main__": main() diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/asr_datamodule.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/asr_datamodule.py index dd1ba4a8db..a28daa601d 100644 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/asr_datamodule.py @@ -201,14 +201,10 @@ def train_dataloaders( if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -230,9 +226,7 @@ def train_dataloaders( input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -275,9 +269,7 @@ def train_dataloaders( # Drop feats to be on the safe side. train = K2Speech2textTranslationDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -332,8 +324,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2Speech2textTranslationDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -360,8 +351,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2Speech2textTranslationDataset( - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80))) + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats else PrecomputedFeatures(), return_cuts=self.args.return_cuts, @@ -381,9 +371,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_train_shuf.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz") @lru_cache() def dev_cuts(self) -> CutSet: @@ -393,4 +381,4 @@ def dev_cuts(self) -> CutSet: @lru_cache() def test_cuts(self) -> CutSet: logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz") \ No newline at end of file + return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz") diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode.py index 545e0073d8..47c75e7e67 100755 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode.py +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode.py @@ -43,7 +43,7 @@ from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple -from lhotse.qa import validate_cut + import k2 import sentencepiece as spm import torch @@ -60,6 +60,7 @@ modified_beam_search, modified_beam_search_rnnlm_shallow_fusion, ) +from lhotse.qa import validate_cut from train_st import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -419,8 +420,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -555,6 +555,8 @@ def decode_one_batch( return {key: hyps} else: return {f"beam_size_{params.beam_size}": hyps} + + def remove_short_and_long_utt(c): # Keep only utterances with duration between 1 second and 20 seconds # @@ -565,9 +567,9 @@ def remove_short_and_long_utt(c): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 0.5 or c.duration > 30.0: - #logging.warning( + # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) + # ) return False if c.supervisions == []: return False @@ -580,12 +582,14 @@ def remove_short_and_long_utt(c): return True + # def remove_seg(c): # if c.supervisions[0].id != 'fla_0102_1_0B_00107': # return True # else: # return False + def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, @@ -648,26 +652,26 @@ def decode_dataset( rnnlm=rnnlm, rnnlm_scale=rnnlm_scale, ) - #breakpoint() + # breakpoint() for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - - for cut_id, hyp_words, ref_text, ref_text_tgt in zip(cut_ids, hyps, texts, texts_tgt): + + for cut_id, hyp_words, ref_text, ref_text_tgt in zip( + cut_ids, hyps, texts, texts_tgt + ): ref_words = ref_text.split() ref_words_tgt = ref_text_tgt.split() this_batch.append((cut_id, ref_words, ref_words_tgt, hyp_words)) results[name].extend(this_batch) - #breakpoint() + # breakpoint() num_cuts += len(texts) if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -677,11 +681,9 @@ def save_results( results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() - + for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt" results = sorted(results) store_translations(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -693,11 +695,9 @@ def save_results( results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() - + for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt" results = sorted(results) store_translations(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -797,8 +797,7 @@ def main(): ) logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device)) + model.load_state_dict(average_checkpoints(filenames, device=device)) elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: @@ -809,8 +808,7 @@ def main(): filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device)) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ @@ -893,8 +891,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None @@ -921,7 +918,6 @@ def main(): test_sets = ["test", "dev"] test_all_dl = [test_dl, dev_dl] - for test_set, test_dl in zip(test_sets, test_all_dl): results_dict = decode_dataset( diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/finetune.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/finetune.py index ad474dbc64..67ad27b4ea 100755 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/finetune.py +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/finetune.py @@ -89,9 +89,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -248,8 +246,7 @@ def get_parser(): "--initial-lr", type=float, default=0.0003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -272,8 +269,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -296,8 +292,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -648,11 +643,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -700,14 +691,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -715,9 +701,7 @@ def compute_loss( with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa # info["utterances"] = feature.size(0) @@ -842,14 +826,10 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - warmup=( - params.batch_idx_train / params.model_warm_step - ), + warmup=(params.batch_idx_train / params.model_warm_step), ) # summary stats - tot_loss = ( - tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -906,7 +886,7 @@ def train_one_epoch( if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea - #memory_debugging() + # memory_debugging() logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " @@ -936,9 +916,7 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -1056,13 +1034,13 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) MGB2 = MGB2AsrDataModule(args) train_cuts = MGB2.train_cuts() - #pdb.set_trace() + # pdb.set_trace() # def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 30 seconds # @@ -1083,9 +1061,9 @@ def remove_short_and_long_utt(c: Cut): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 0.5 or c.duration > 30.0: - #logging.warning( + # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) + # ) return False if c.supervisions == []: return False @@ -1115,18 +1093,19 @@ def remove_short_and_long_text(c: Cut): # Keep only text with charachters between 20 and 400 return 20 <= len(c.supervisions[0].text) <= 400 + # def remove_seg(c: Cut): # # Keep only text with charachters between 20 and 400 - # return c.supervisions[0].id != "arb_glf-20040221_024028_0A_00102" - #logging.info(f"Total duration before filtering {train_cuts.describe()}") + # return c.supervisions[0].id != "arb_glf-20040221_024028_0A_00102" + # logging.info(f"Total duration before filtering {train_cuts.describe()}") train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_text) # train_cuts = train_cuts.filter(remove_seg) # for c in train_cuts: # if c.supervisions[0].id == "arb_glf-20040221_024028_0A_00102": - # print(c) - #logging.info(f"Total duration after filtering {train_cuts.describe()}") + # print(c) + # logging.info(f"Total duration after filtering {train_cuts.describe()}") if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1135,9 +1114,7 @@ def remove_short_and_long_text(c: Cut): else: sampler_state_dict = None - train_dl = MGB2.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) + train_dl = MGB2.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) valid_cuts = MGB2.dev_cuts() valid_cuts = valid_cuts.filter(remove_short_and_long_utt) valid_cuts = valid_cuts.filter(remove_short_and_long_text) diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/finetune_loadmodel.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/finetune_loadmodel.py index 26d255dd09..bf019fc2d3 100755 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/finetune_loadmodel.py +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/finetune_loadmodel.py @@ -89,9 +89,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -254,8 +252,7 @@ def get_parser(): "--initial-lr", type=float, default=0.00001, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -278,8 +275,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -302,8 +298,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -657,11 +652,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -709,14 +700,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -724,9 +710,7 @@ def compute_loss( with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa # info["utterances"] = feature.size(0) @@ -851,14 +835,10 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - warmup=( - params.batch_idx_train / params.model_warm_step - ), + warmup=(params.batch_idx_train / params.model_warm_step), ) # summary stats - tot_loss = ( - tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -915,7 +895,7 @@ def train_one_epoch( if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea - #memory_debugging() + # memory_debugging() logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " @@ -945,9 +925,7 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -1039,7 +1017,7 @@ def run(rank, world_size, args): if rank == 0: # model_avg is only used with rank 0 model_avg = copy.deepcopy(model) - + checkpoints = load_checkpoint_if_available(params=params, model=model) model.to(device) @@ -1065,13 +1043,13 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) MGB2 = MGB2AsrDataModule(args) train_cuts = MGB2.train_cuts() - #pdb.set_trace() + # pdb.set_trace() # def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 30 seconds # @@ -1092,9 +1070,9 @@ def remove_short_and_long_utt(c: Cut): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 0.5 or c.duration > 30.0: - #logging.warning( + # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) + # ) return False if c.supervisions == []: return False @@ -1124,7 +1102,7 @@ def remove_short_and_long_text(c: Cut): # Keep only text with charachters between 20 and 400 return 3 <= len(c.supervisions[0].text) <= 400 - + # def remove_bad(c: Cut): # tol = 1e-2 # if (c.supervisions[0].end > c.duration + tol) or (c.supervisions[0].start < -tol): @@ -1132,16 +1110,15 @@ def remove_short_and_long_text(c: Cut): # return False # else: # return True - - #logging.info(f"Total duration before filtering {train_cuts.describe()}") + # logging.info(f"Total duration before filtering {train_cuts.describe()}") train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_text) # train_cuts = train_cuts.filter(remove_bad) # train_cuts = train_cuts.filter(remove_seg) # for c in train_cuts: # if c.supervisions[0].id == "arb_glf-20040221_024028_0A_00102": - # print(c) + # print(c) logging.info(f"Total duration after filtering {train_cuts.describe()}") if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: @@ -1151,9 +1128,7 @@ def remove_short_and_long_text(c: Cut): else: sampler_state_dict = None - train_dl = MGB2.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) + train_dl = MGB2.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) valid_cuts = MGB2.dev_cuts() valid_cuts = valid_cuts.filter(remove_short_and_long_utt) valid_cuts = valid_cuts.filter(remove_short_and_long_text) diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/pretrained.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/pretrained.py index 1e100fcbd7..77ba0873b5 100755 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/pretrained.py +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -199,8 +198,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train.py index 5801fde9b5..8673eccc34 100755 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train.py +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train.py @@ -77,9 +77,7 @@ str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -235,8 +233,7 @@ def get_parser(): "--initial-lr", type=float, default=0.001, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -259,8 +256,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -283,8 +279,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -636,11 +631,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -648,9 +639,9 @@ def compute_loss( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - #pdb.set_trace() + # pdb.set_trace() texts = batch["supervisions"]["text"] - tgt_texts = batch["supervisions"]["tgt_text"]['eng'] + tgt_texts = batch["supervisions"]["tgt_text"]["eng"] y = sp.encode(texts, out_type=int) y_tgt = sp_tgt.encode(tgt_texts, out_type=int) y = k2.RaggedTensor(y).to(device) @@ -679,7 +670,7 @@ def compute_loss( f"simple_loss: {simple_loss}\n" f"pruned_loss: {pruned_loss}" ) - display_and_save_batch(batch, params=params, sp=sp) + display_and_save_batch(batch, params=params, sp_tgt=sp_tgt) simple_loss = simple_loss[simple_loss_is_finite] pruned_loss = pruned_loss[pruned_loss_is_finite] @@ -691,14 +682,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -706,9 +692,7 @@ def compute_loss( with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa # info["utterances"] = feature.size(0) @@ -830,14 +814,10 @@ def train_one_epoch( sp_tgt=sp_tgt, batch=batch, is_training=True, - warmup=( - params.batch_idx_train / params.model_warm_step - ), + warmup=(params.batch_idx_train / params.model_warm_step), ) # summary stats - tot_loss = ( - tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -850,7 +830,7 @@ def train_one_epoch( else: continue except: # noqa - display_and_save_batch(batch, params=params, sp=sp) + display_and_save_batch(batch, params=params, sp_tgt=sp_tgt) raise if params.print_diagnostics and batch_idx == 5: @@ -894,7 +874,7 @@ def train_one_epoch( if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea - #memory_debugging() + # memory_debugging() logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " @@ -919,14 +899,12 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, - sp_tgt=sp_tgt, + sp_tgt=sp_tgt, valid_dl=valid_dl, world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -1046,7 +1024,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1072,9 +1050,9 @@ def remove_short_and_long_utt(c: Cut): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 0.1 or c.duration > 30.0: - #logging.warning( + # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) + # ) return False if c.supervisions == []: return False @@ -1085,7 +1063,7 @@ def remove_short_and_long_utt(c: Cut): # In ./conformer.py, the conv module uses the following expression # for subsampling T = ((c.num_frames - 1) // 2 - 1) // 2 - tokens = sp_tgt.encode(c.supervisions[0].custom['tgt_text'], out_type=str) + tokens = sp_tgt.encode(c.supervisions[0].custom["tgt_text"], out_type=str) if T < len(tokens): # logging.warning( @@ -1103,11 +1081,12 @@ def remove_short_and_long_utt(c: Cut): def remove_short_and_long_text(c: Cut): # Keep only text with charachters between 20 and 400 - return 3 <= len(c.supervisions[0].custom['tgt_text']) <= 400 - #logging.info(f"Total duration before filtering {train_cuts.describe()}") + return 3 <= len(c.supervisions[0].custom["tgt_text"]) <= 400 + + # logging.info(f"Total duration before filtering {train_cuts.describe()}") train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_text) - #logging.info(f"Total duration after filtering {train_cuts.describe()}") + # logging.info(f"Total duration after filtering {train_cuts.describe()}") if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1189,7 +1168,6 @@ def remove_short_and_long_text(c: Cut): def display_and_save_batch( batch: dict, params: AttributeDict, - sp: spm.SentencePieceProcessor, sp_tgt: spm.SentencePieceProcessor, ) -> None: """Display the batch statistics and save the batch into disk. @@ -1200,7 +1178,7 @@ def display_and_save_batch( for the content in it. params: Parameters for training. See :func:`get_params`. - sp: + sp_tgt: The BPE model. """ from lhotse.utils import uuid4 @@ -1214,9 +1192,8 @@ def display_and_save_batch( logging.info(f"features shape: {features.shape}") - y = sp.encode(supervisions["text"], out_type=int) - y_tgt = sp_tgt.encode(supervisions["tgt_text"], out_type=int) - num_tokens = sum(len(i) for i in y) + y_tgt = sp_tgt.encode(supervisions["tgt_text"]["eng"], out_type=int) + num_tokens = sum(len(i) for i in y_tgt) logging.info(f"num tokens: {num_tokens}") @@ -1265,7 +1242,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, sp=sp, sp_tgt=sp_tgt) + display_and_save_batch(batch, params=params, sp_tgt=sp_tgt) raise diff --git a/egs/iwslt22_ta/ST/zipformer/asr_datamodule.py b/egs/iwslt22_ta/ST/zipformer/asr_datamodule.py index b822b7c685..eb280a1efe 100644 --- a/egs/iwslt22_ta/ST/zipformer/asr_datamodule.py +++ b/egs/iwslt22_ta/ST/zipformer/asr_datamodule.py @@ -201,9 +201,7 @@ def train_dataloaders( if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) @@ -228,9 +226,7 @@ def train_dataloaders( input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -273,9 +269,7 @@ def train_dataloaders( # Drop feats to be on the safe side. train = K2Speech2TextTranslationDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -330,8 +324,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: if self.args.on_the_fly_feats: validate = K2Speech2TextTranslationDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -358,8 +351,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2Speech2TextTranslationDataset( - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80))) + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats else PrecomputedFeatures(), return_cuts=self.args.return_cuts, @@ -379,9 +371,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_train_shuf.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz") @lru_cache() def dev_cuts(self) -> CutSet: @@ -391,4 +381,4 @@ def dev_cuts(self) -> CutSet: @lru_cache() def test_cuts(self) -> CutSet: logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz") \ No newline at end of file + return load_manifest_lazy(self.args.manifest_dir / "cuts_test1.jsonl.gz") diff --git a/egs/iwslt22_ta/ST/zipformer/beam_search.py b/egs/iwslt22_ta/ST/zipformer/beam_search.py index 1eaa380497..4c1195536e 100644 --- a/egs/iwslt22_ta/ST/zipformer/beam_search.py +++ b/egs/iwslt22_ta/ST/zipformer/beam_search.py @@ -1024,20 +1024,24 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) if use_hat == True: - # For blank symbol, log-prob is log-sigmoid of the score - logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) - # Additionally, to ensure the the probs of blank and non-blank sum to 1, we - # need to add the following term to the log-probs of non-blank symbols. This - # is equivalent to log(1 - sigmoid(logits[..., 0])). - #breakpoint() - nb_shift = logp_b - logits[..., 0] - nb_shift = nb_shift.unsqueeze(-1) - log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift # (num_hyps, vocab_size-1) - log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) - log_probs.add_(ys_log_probs) + # For blank symbol, log-prob is log-sigmoid of the score + logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) + # Additionally, to ensure the the probs of blank and non-blank sum to 1, we + # need to add the following term to the log-probs of non-blank symbols. This + # is equivalent to log(1 - sigmoid(logits[..., 0])). + # breakpoint() + nb_shift = logp_b - logits[..., 0] + nb_shift = nb_shift.unsqueeze(-1) + log_probs1 = (logits[..., 1:] / temperature).log_softmax( + dim=-1 + ) + nb_shift # (num_hyps, vocab_size-1) + log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) + log_probs.add_(ys_log_probs) else: - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - log_probs.add_(ys_log_probs) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1094,6 +1098,7 @@ def modified_beam_search( timestamps=ans_timestamps, ) + def modified_beam_search_hat( model: Transducer, encoder_out: torch.Tensor, @@ -1203,7 +1208,6 @@ def modified_beam_search_hat( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - # For blank symbol, log-prob is log-sigmoid of the score logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) # Additionally, to ensure the the probs of blank and non-blank sum to 1, we @@ -1212,8 +1216,10 @@ def modified_beam_search_hat( breakpoint() nb_shift = logp_b - logits[..., 0] - log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift # (num_hyps, vocab_size-1) - log_probs = torch.cat((logp_b, log_probs), dim=-1) + log_probs1 = (logits[..., 1:] / temperature).log_softmax( + dim=-1 + ) + nb_shift # (num_hyps, vocab_size-1) + log_probs = torch.cat((logp_b, log_probs1), dim=-1) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1271,6 +1277,7 @@ def modified_beam_search_hat( timestamps=ans_timestamps, ) + def modified_beam_search_lm_rescore( model: Transducer, encoder_out: torch.Tensor, diff --git a/egs/iwslt22_ta/ST/zipformer/decode.py b/egs/iwslt22_ta/ST/zipformer/decode.py index b13a72f1b6..b2c646781e 100755 --- a/egs/iwslt22_ta/ST/zipformer/decode.py +++ b/egs/iwslt22_ta/ST/zipformer/decode.py @@ -248,8 +248,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -346,9 +345,7 @@ def decode_one_batch( src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = model.encoder( - x, x_lens, src_key_padding_mask - ) + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) hyps = [] @@ -408,10 +405,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -425,7 +419,7 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - use_hat=params.use_hat_decode + use_hat=params.use_hat_decode, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -519,7 +513,7 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - texts_tgt = batch["supervisions"]["tgt_text"]['eng'] + texts_tgt = batch["supervisions"]["tgt_text"]["eng"] hyps_dict = decode_one_batch( params=params, @@ -529,26 +523,26 @@ def decode_dataset( word_table=word_table, batch=batch, ) - + for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - - for cut_id, hyp_words, ref_text, ref_text_tgt in zip(cut_ids, hyps, texts, texts_tgt): + + for cut_id, hyp_words, ref_text, ref_text_tgt in zip( + cut_ids, hyps, texts, texts_tgt + ): ref_words = ref_text.split() ref_words_tgt = ref_text_tgt.split() this_batch.append((cut_id, ref_words, ref_words_tgt, hyp_words)) results[name].extend(this_batch) - + num_cuts += len(texts) if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -558,11 +552,9 @@ def save_results( results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() - + for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt" results = sorted(results) store_translations(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -579,7 +571,7 @@ def main(): params.update(vars(args)) # use predefined parameters that were used during the training - # params.num_encoder_layers = "2,2,2,2,2,2" + # params.num_encoder_layers = "2,2,2,2,2,2" # params.feedforward_dim = "256,512,768,1024,768,512" # params.encoder_dim = "128,256,256,512,256,256" # params.encoder_unmasked_dim = "64,128,128,256,128,128" @@ -620,9 +612,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -654,9 +644,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -683,9 +673,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -744,9 +734,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/iwslt22_ta/ST/zipformer/decoder.py b/egs/iwslt22_ta/ST/zipformer/decoder.py index 0ca06233ad..93e0f9f7ef 100644 --- a/egs/iwslt22_ta/ST/zipformer/decoder.py +++ b/egs/iwslt22_ta/ST/zipformer/decoder.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from scaling import Balancer diff --git a/egs/iwslt22_ta/ST/zipformer/export.py b/egs/iwslt22_ta/ST/zipformer/export.py index e424aa0193..7169c153d6 100755 --- a/egs/iwslt22_ta/ST/zipformer/export.py +++ b/egs/iwslt22_ta/ST/zipformer/export.py @@ -208,6 +208,7 @@ import sentencepiece as spm import torch +from scaling_converter import convert_scaled_to_non_scaled from torch import Tensor, nn from train import add_model_arguments, get_params, get_transducer_model @@ -218,7 +219,6 @@ load_checkpoint, ) from icefall.utils import make_pad_mask, str2bool -from scaling_converter import convert_scaled_to_non_scaled def get_parser(): @@ -305,6 +305,7 @@ def get_parser(): class EncoderModel(nn.Module): """A wrapper for encoder and encoder_embed""" + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: super().__init__() self.encoder = encoder @@ -323,9 +324,7 @@ def forward( src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = self.encoder( - x, x_lens, src_key_padding_mask - ) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return encoder_out, encoder_out_lens diff --git a/egs/iwslt22_ta/ST/zipformer/generate_averaged_model.py b/egs/iwslt22_ta/ST/zipformer/generate_averaged_model.py index fe29355f26..5df4bb436f 100755 --- a/egs/iwslt22_ta/ST/zipformer/generate_averaged_model.py +++ b/egs/iwslt22_ta/ST/zipformer/generate_averaged_model.py @@ -43,13 +43,9 @@ import sentencepiece as spm import torch from asr_datamodule import LibriSpeechAsrDataModule - from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints_with_averaged_model, - find_checkpoints, -) +from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints def get_parser(): diff --git a/egs/iwslt22_ta/ST/zipformer/model.py b/egs/iwslt22_ta/ST/zipformer/model.py index 05bed93645..8f626f309d 100644 --- a/egs/iwslt22_ta/ST/zipformer/model.py +++ b/egs/iwslt22_ta/ST/zipformer/model.py @@ -19,9 +19,9 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface +from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask -from scaling import ScaledLinear class Transducer(nn.Module): @@ -82,7 +82,6 @@ def __init__( initial_scale=0.25, ) - def forward( self, x: torch.Tensor, @@ -219,7 +218,7 @@ def forward( ) return (simple_loss, pruned_loss) - + class Transducer_asr_st(Transducer): """ @@ -419,7 +418,7 @@ def forward( reduction="sum", return_grad=True, ) - + with torch.cuda.amp.autocast(enabled=False): simple_loss_tgt, (px_grad_tgt, py_grad_tgt) = k2.rnnt_loss_smoothed( lm=lm.float(), diff --git a/egs/iwslt22_ta/ST/zipformer/pretrained.py b/egs/iwslt22_ta/ST/zipformer/pretrained.py index a4b7c2c361..19d72659cc 100755 --- a/egs/iwslt22_ta/ST/zipformer/pretrained.py +++ b/egs/iwslt22_ta/ST/zipformer/pretrained.py @@ -120,10 +120,11 @@ greedy_search_batch, modified_beam_search, ) -from icefall.utils import make_pad_mask from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import make_pad_mask + def get_parser(): parser = argparse.ArgumentParser( @@ -323,9 +324,7 @@ def main(): src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = model.encoder( - x, x_lens, src_key_padding_mask - ) + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) hyps = [] diff --git a/egs/iwslt22_ta/ST/zipformer/profile.py b/egs/iwslt22_ta/ST/zipformer/profile.py index c93adbd143..70af99d1f1 120000 --- a/egs/iwslt22_ta/ST/zipformer/profile.py +++ b/egs/iwslt22_ta/ST/zipformer/profile.py @@ -1 +1 @@ -../../../librispeech/ASR/zipformer/profile.py \ No newline at end of file +../../ASR/zipformer/profile.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/zipformer/streaming_decode.py b/egs/iwslt22_ta/ST/zipformer/streaming_decode.py index c2d58cb1e5..5d50de5887 100755 --- a/egs/iwslt22_ta/ST/zipformer/streaming_decode.py +++ b/egs/iwslt22_ta/ST/zipformer/streaming_decode.py @@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: ) batch_states.append(cached_embed_left_pad) - processed_lens = torch.cat( - [state_list[i][-1] for i in range(batch_size)], dim=0 - ) + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) batch_states.append(processed_lens) return batch_states @@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: for layer in range(tot_num_layers): layer_offset = layer * 6 # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk( - chunks=batch_size, dim=1 - ) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( chunks=batch_size, dim=1 @@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: cached_conv2_list[i], ] - cached_embed_left_pad_list = batch_states[-2].chunk( - chunks=batch_size, dim=0 - ) + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) for i in range(batch_size): state_list[i].append(cached_embed_left_pad_list[i]) @@ -380,11 +374,7 @@ def streaming_forward( Returns encoder outputs, output lengths, and updated states. """ cached_embed_left_pad = states[-2] - ( - x, - x_lens, - new_cached_embed_left_pad, - ) = model.encoder_embed.streaming_forward( + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( x=features, x_lens=feature_lens, cached_left_pad=cached_embed_left_pad, @@ -404,9 +394,7 @@ def streaming_forward( new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat( - [processed_mask, src_key_padding_mask], dim=1 - ) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_states = states[:-2] @@ -494,9 +482,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = torch.tensor(processed_lens, device=device) processed_lens = processed_lens + encoder_out_lens @@ -517,9 +503,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = unstack_states(new_states) @@ -577,9 +561,7 @@ def decode_dataset( decode_streams = [] for num, cut in enumerate(cuts): # each utterance has a DecodeStream. - initial_states = get_init_states( - model=model, batch_size=1, device=device - ) + initial_states = get_init_states(model=model, batch_size=1, device=device) decode_stream = DecodeStream( params=params, cut_id=cut.id, @@ -649,9 +631,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -684,8 +664,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -718,9 +697,7 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" assert params.causal, params.causal - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." @@ -760,9 +737,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -789,9 +766,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/iwslt22_ta/ST/zipformer/train.py b/egs/iwslt22_ta/ST/zipformer/train.py index 4b661d93ef..3290396a6f 100755 --- a/egs/iwslt22_ta/ST/zipformer/train.py +++ b/egs/iwslt22_ta/ST/zipformer/train.py @@ -335,7 +335,7 @@ def get_parser(): files, e.g., checkpoints, log, etc, are saved """, ) - + parser.add_argument( "--bpe-tgt-model", type=str, @@ -790,11 +790,8 @@ def compute_loss( batch_idx_train = params.batch_idx_train warm_step = params.warm_step - texts = batch["supervisions"]["text"] - tgt_texts = batch["supervisions"]["tgt_text"]['eng'] - y = sp.encode(texts, out_type=int) + tgt_texts = batch["supervisions"]["tgt_text"]["eng"] y_tgt = sp_tgt.encode(tgt_texts, out_type=int) - y = k2.RaggedTensor(y).to(device) y_tgt = k2.RaggedTensor(y_tgt).to(device) with torch.set_grad_enabled(is_training): @@ -817,7 +814,7 @@ def compute_loss( f"simple_loss: {simple_loss}\n" f"pruned_loss: {pruned_loss}" ) - display_and_save_batch(batch, params=params, sp=sp, sp_tgt=sp_tgt) + display_and_save_batch(batch, params=params, sp_tgt=sp_tgt) simple_loss = simple_loss[simple_loss_is_finite] pruned_loss = pruned_loss[pruned_loss_is_finite] @@ -985,7 +982,7 @@ def save_bad_model(suffix: str = ""): continue except: # noqa save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) + display_and_save_batch(batch, params=params, sp_tgt=sp_tgt) raise if params.print_diagnostics and batch_idx == 5: @@ -1074,7 +1071,6 @@ def save_bad_model(suffix: str = ""): valid_info = compute_validation_loss( params=params, model=model, - sp=sp, sp_tgt=sp_tgt, valid_dl=valid_dl, world_size=world_size, @@ -1132,8 +1128,8 @@ def run(rank, world_size, args): sp_tgt.load(params.bpe_tgt_model) # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = sp_tgt.piece_to_id("") + params.vocab_size = sp_tgt.get_piece_size() logging.info(params) @@ -1201,9 +1197,9 @@ def remove_short_and_long_utt(c: Cut): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 0.1 or c.duration > 30.0: - #logging.warning( + # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) + # ) return False if c.supervisions == []: return False @@ -1214,7 +1210,9 @@ def remove_short_and_long_utt(c: Cut): # In ./conformer.py, the conv module uses the following expression # for subsampling T = ((c.num_frames - 1) // 2 - 1) // 2 - tokens = sp_tgt.encode(c.supervisions[0].custom['translated_text']['eng'], out_type=str) + tokens = sp_tgt.encode( + c.supervisions[0].custom["translated_text"]["eng"], out_type=str + ) if T < len(tokens): # logging.warning( @@ -1232,7 +1230,8 @@ def remove_short_and_long_utt(c: Cut): def remove_short_and_long_text(c: Cut): # Keep only text with charachters between 20 and 400 - return 3 <= len(c.supervisions[0].custom['translated_text']['eng']) <= 400 + return 3 <= len(c.supervisions[0].custom["translated_text"]["eng"]) <= 400 + train_cuts = train_cuts.filter(remove_short_and_long_utt) # train_cuts = train_cuts.filter(remove_short_and_long_text) @@ -1314,7 +1313,6 @@ def remove_short_and_long_text(c: Cut): def display_and_save_batch( batch: dict, params: AttributeDict, - sp: spm.SentencePieceProcessor, sp_tgt: spm.SentencePieceProcessor, ) -> None: """Display the batch statistics and save the batch into disk. @@ -1325,7 +1323,7 @@ def display_and_save_batch( for the content in it. params: Parameters for training. See :func:`get_params`. - sp: + sp_tgt: The BPE model. """ from lhotse.utils import uuid4 @@ -1339,7 +1337,6 @@ def display_and_save_batch( logging.info(f"features shape: {features.shape}") - y = sp.encode(supervisions["text"], out_type=int) y_tgt = sp_tgt.encode(supervisions["tgt_text"], out_type=int) num_tokens = sum(len(i) for i in y_tgt) logging.info(f"num tokens: {num_tokens}") @@ -1380,7 +1377,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, sp=sp, sp_tgt=sp_tgt) + display_and_save_batch(batch, params=params, sp_tgt=sp_tgt) raise logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" diff --git a/egs/ksponspeech/ASR/zipformer/ctc_decode.py b/egs/ksponspeech/ASR/zipformer/ctc_decode.py index 10239db5ef..aecea87cf9 100755 --- a/egs/ksponspeech/ASR/zipformer/ctc_decode.py +++ b/egs/ksponspeech/ASR/zipformer/ctc_decode.py @@ -666,7 +666,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -705,7 +707,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/libricss/SURT/dprnn_zipformer/train.py b/egs/libricss/SURT/dprnn_zipformer/train.py index 186d4f6fb5..d5a9607138 100755 --- a/egs/libricss/SURT/dprnn_zipformer/train.py +++ b/egs/libricss/SURT/dprnn_zipformer/train.py @@ -1286,7 +1286,9 @@ def run(rank, world_size, args): logging.info( f"Initializing model with checkpoint from {params.model_init_ckpt}" ) - init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False) + init_ckpt = torch.load( + params.model_init_ckpt, map_location=device, weights_only=False + ) model.load_state_dict(init_ckpt["model"], strict=False) if world_size > 1: diff --git a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py index 4d1f3cf022..94b4a737ed 100755 --- a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py +++ b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py @@ -1175,7 +1175,9 @@ def run(rank, world_size, args): logging.info( f"Initializing model with checkpoint from {params.model_init_ckpt}" ) - init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False) + init_ckpt = torch.load( + params.model_init_ckpt, map_location=device, weights_only=False + ) model.load_state_dict(init_ckpt["model"], strict=True) if world_size > 1: diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 14c132ada7..42b14de959 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -51,11 +51,15 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index 64e77f4218..91d560814b 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -61,11 +61,15 @@ import logging from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index fc33f95124..64754e40fd 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -62,12 +62,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index b00cc6cc63..2312cf6bd0 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -62,12 +62,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index e23da3b56d..366aeb4eb6 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -48,12 +48,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 271b21f5ea..1c1ec3fe15 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -53,12 +53,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Iterator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import AsrDataModule diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index e169b499f3..687839b7c0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -48,12 +48,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -1137,7 +1141,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch_.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 0611fd8cb2..9e86432a04 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -48,12 +48,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -76,9 +80,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( - create_grad_scaler, AttributeDict, MetricsTracker, + create_grad_scaler, setup_logger, str2bool, torch_autocast, diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index 2af8f3f4cd..fee51dd909 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -38,11 +38,15 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index ce6c89614b..c00de70ae5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -60,12 +60,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index f0396eb2fc..8669d90580 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -53,12 +53,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Iterator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import AsrDataModule diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index c35f52309e..2d353829c0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -62,12 +62,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 6f9f926235..0ec81a79a2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -50,12 +50,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 35ee74f15c..64f7f927c8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -61,12 +61,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index d3d996b4ac..0625d06125 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -48,12 +48,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index f94da97888..85f439c1ea 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -50,12 +50,16 @@ import logging import warnings from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index a26f11c826..73625525d8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -50,12 +50,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 5585d74de8..c90206f48f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -44,12 +44,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 4d8a2644d4..cf3711064d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -49,12 +49,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index 1442aa1215..b18e50f25d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -50,12 +50,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Iterator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import AsrDataModule diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 9b372109f4..900a2eb355 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -55,12 +55,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Iterator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import AsrDataModule diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 368bd20faa..2c1ec38bb0 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -35,11 +35,15 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/zipformer/ctc_align.py b/egs/librispeech/ASR/zipformer/ctc_align.py index fff05146fe..3b6b716034 100755 --- a/egs/librispeech/ASR/zipformer/ctc_align.py +++ b/egs/librispeech/ASR/zipformer/ctc_align.py @@ -52,10 +52,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule as AsrDataModule from lhotse import set_caching_enabled -from torchaudio.functional import ( - forced_align, - merge_tokens, -) +from torchaudio.functional import forced_align, merge_tokens from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( @@ -64,11 +61,7 @@ find_checkpoints, load_checkpoint, ) -from icefall.utils import ( - AttributeDict, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, setup_logger, str2bool LOG_EPS = math.log(1e-10) diff --git a/egs/librispeech/ASR/zipformer/export_rknn_transducer_streaming.py b/egs/librispeech/ASR/zipformer/export_rknn_transducer_streaming.py index 27ff81b91e..61381d3b52 100755 --- a/egs/librispeech/ASR/zipformer/export_rknn_transducer_streaming.py +++ b/egs/librispeech/ASR/zipformer/export_rknn_transducer_streaming.py @@ -9,9 +9,9 @@ from rknn.api import RKNN from test_rknn_on_cpu_simulator_ctc_streaming import ( MetaData, + export_rknn, get_meta_data, init_model, - export_rknn, ) logging.basicConfig(level=logging.WARNING) diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 94e8b273a4..8b4f13310d 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -59,12 +59,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -94,9 +98,9 @@ from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( - create_grad_scaler, AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py index 1d24a159e4..b0d7eecc7c 100644 --- a/egs/librispeech/ASR/zipformer/scaling_converter.py +++ b/egs/librispeech/ASR/zipformer/scaling_converter.py @@ -38,10 +38,7 @@ SwooshROnnx, Whiten, ) -from zipformer import ( - CompactRelPositionalEncoding, - SimpleDownsample, -) +from zipformer import CompactRelPositionalEncoding, SimpleDownsample class NonStreamingChunkCausalDepthwiseConv1d(torch.nn.Module): diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 6a6ce447e0..9db62afe87 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -59,12 +59,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index fcd7272e94..d33b11a0f1 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -48,12 +48,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py index bd3bfa3323..ea142d5857 100755 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -33,9 +33,13 @@ import logging from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index c26a2f5cc9..f338d68bd6 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -59,12 +59,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py index 2b83d58ef6..6ca4a9c86f 100755 --- a/egs/librispeech/ASR/zipformer_lora/train.py +++ b/egs/librispeech/ASR/zipformer_lora/train.py @@ -57,12 +57,16 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import sentencepiece as spm import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index e0ca0a6a51..45d943b38c 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -50,11 +50,15 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import k2 import optim import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py index 1c1dc25a5f..e9d60678af 100644 --- a/egs/librispeech/SSL/hubert/finetune_ce.py +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -81,12 +81,12 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py index 8b044fbb55..dbfc73bc92 100644 --- a/egs/librispeech/SSL/zipformer/finetune.py +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -81,12 +81,12 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] diff --git a/egs/librispeech/SSL/zipformer/zipformer.py b/egs/librispeech/SSL/zipformer/zipformer.py index 5071a91a82..290f0ea7f0 100644 --- a/egs/librispeech/SSL/zipformer/zipformer.py +++ b/egs/librispeech/SSL/zipformer/zipformer.py @@ -22,7 +22,6 @@ import random import warnings from typing import List, Optional, Tuple, Union -from icefall.utils import torch_autocast import torch from encoder_interface import EncoderInterface @@ -48,6 +47,8 @@ ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ diff --git a/egs/librispeech/WSASR/conformer_ctc2/decode.py b/egs/librispeech/WSASR/conformer_ctc2/decode.py index 822df6722c..e997eaf9b6 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/decode.py +++ b/egs/librispeech/WSASR/conformer_ctc2/decode.py @@ -578,7 +578,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False diff --git a/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py b/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py index 95b57b8e8f..ff378f5673 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py +++ b/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py @@ -457,7 +457,9 @@ def main(): params.num_classes = num_classes - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False)) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) assert HLG.requires_grad is False diff --git a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py index a32183bf79..1182c0d51d 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py @@ -79,13 +79,13 @@ from icefall.lexicon import Lexicon from icefall.otc_phone_graph_compiler import OtcPhoneTrainingGraphCompiler from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, encode_supervisions_otc, get_texts, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] diff --git a/egs/libritts/ASR/zipformer/ctc_decode.py b/egs/libritts/ASR/zipformer/ctc_decode.py index bd360b74f0..42d63daf56 100755 --- a/egs/libritts/ASR/zipformer/ctc_decode.py +++ b/egs/libritts/ASR/zipformer/ctc_decode.py @@ -802,7 +802,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -842,7 +844,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method in [ diff --git a/egs/mgb2/ASR/conformer_ctc/decode.py b/egs/mgb2/ASR/conformer_ctc/decode.py index 26e470bd70..0937b1e479 100755 --- a/egs/mgb2/ASR/conformer_ctc/decode.py +++ b/egs/mgb2/ASR/conformer_ctc/decode.py @@ -575,7 +575,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -614,7 +616,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: diff --git a/egs/mgb2/ASR/conformer_ctc/pretrained.py b/egs/mgb2/ASR/conformer_ctc/pretrained.py index 8a3655bf65..5fe8de0b39 100755 --- a/egs/mgb2/ASR/conformer_ctc/pretrained.py +++ b/egs/mgb2/ASR/conformer_ctc/pretrained.py @@ -347,7 +347,9 @@ def main(): "attention-decoder", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False)) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +360,9 @@ def main(): "attention-decoder", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu", weights_only=False)) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/multi_conv_zh_es_ta/ST/hent_srt/beam_search.py b/egs/multi_conv_zh_es_ta/ST/hent_srt/beam_search.py index bb032e89c5..893faa8439 100644 --- a/egs/multi_conv_zh_es_ta/ST/hent_srt/beam_search.py +++ b/egs/multi_conv_zh_es_ta/ST/hent_srt/beam_search.py @@ -170,12 +170,13 @@ def greedy_search_batch( scores=ans_scores, ) + def greedy_search_st( model: nn.Module, encoder_out: torch.Tensor, st_encoder_out: torch.Tensor, max_sym_per_frame: int, - max_sym_per_frame_asr:int=1, + max_sym_per_frame_asr: int = 1, st_blank_penalty: float = 0.0, return_timestamps: bool = False, ) -> Union[List[int], DecodingResults]: @@ -203,7 +204,7 @@ def greedy_search_st( blank_id_st = model.st_decoder.blank_id context_size_st = model.st_decoder.context_size unk_id_st = getattr(model, "unk_id", blank_id_st) - + blank_id = model.decoder.blank_id context_size = model.decoder.context_size unk_id = getattr(model, "unk_id", blank_id_st) @@ -213,14 +214,14 @@ def greedy_search_st( decoder_input_st = torch.tensor( [-1] * (context_size_st - 1) + [blank_id_st], device=device, dtype=torch.int64 ).reshape(1, context_size_st) - + decoder_input = torch.tensor( [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 ).reshape(1, context_size) - + decoder_out_st = model.st_decoder(decoder_input_st, need_pad=False) decoder_out_st = model.st_joiner.decoder_proj(decoder_out_st) - encoder_out_st = model.st_joiner.encoder_proj(st_encoder_out) + encoder_out_st = model.st_joiner.encoder_proj(st_encoder_out) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -264,9 +265,9 @@ def greedy_search_st( if y_st not in (blank_id_st, unk_id_st): hyp_st.append(y_st) timestamp_st.append(t) - decoder_input_st = torch.tensor([hyp_st[-context_size_st:]], device=device).reshape( - 1, context_size_st - ) + decoder_input_st = torch.tensor( + [hyp_st[-context_size_st:]], device=device + ).reshape(1, context_size_st) decoder_out_st = model.st_decoder(decoder_input_st, need_pad=False) decoder_out_st = model.st_joiner.decoder_proj(decoder_out_st) @@ -276,7 +277,7 @@ def greedy_search_st( else: sym_per_frame = 0 t += 1 - + T = encoder_out.size(1) t = 0 hyp = [blank_id] * context_size @@ -324,7 +325,7 @@ def greedy_search_st( else: sym_per_frame = 0 t += 1 - + hyp_st = hyp_st[context_size_st:] hyp = hyp[context_size:] # remove blanks @@ -336,6 +337,7 @@ def greedy_search_st( timestamps=[timestamp, timestamp_st], ) + def greedy_search_batch_st( model: nn.Module, encoder_out: torch.Tensor, @@ -375,18 +377,16 @@ def greedy_search_batch_st( blank_id_st = model.st_decoder.blank_id unk_id_st = getattr(model, "unk_id", blank_id_st) context_size_st = model.st_decoder.context_size - - + blank_id = model.decoder.blank_id unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - batch_size_list = packed_st_encoder_out.batch_sizes.tolist() N = st_encoder_out.size(0) assert torch.all(st_encoder_out_lens > 0), st_encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) - + hyps_st = [[-1] * (context_size_st - 1) + [blank_id_st] for _ in range(N)] timestamps_st = [[] for _ in range(N)] @@ -446,7 +446,7 @@ def greedy_search_batch_st( ) decoder_out_st = model.st_decoder(decoder_input_st, need_pad=False) decoder_out_st = model.st_joiner.decoder_proj(decoder_out_st) - + # ASR packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, @@ -521,7 +521,7 @@ def greedy_search_batch_st( ans_st.append(sorted_ans_st[unsorted_indices_st[i]]) ans_timestamps_st.append(timestamps_st[unsorted_indices_st[i]]) ans_scores_st.append(scores_st[unsorted_indices_st[i]]) - + # ASR sorted_ans = [h[context_size:] for h in hyps] ans = [] @@ -533,7 +533,6 @@ def greedy_search_batch_st( ans_timestamps.append(timestamps[unsorted_indices[i]]) ans_scores.append(scores[unsorted_indices[i]]) - if not return_timestamps: return [ans, ans_st] else: @@ -542,7 +541,8 @@ def greedy_search_batch_st( timestamps=[ans_timestamps, ans_timestamps_st] # scores=[ans_scores, ans_scores_st], ) - + + @dataclass class Hypothesis: # The predicted tokens so far. @@ -568,9 +568,8 @@ class Hypothesis: # Context graph state context_state: Optional[ContextState] = None - + decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - @property def key(self) -> str: @@ -819,7 +818,7 @@ def modified_beam_search( finalized_B = B[batch_size:] + finalized_B B = B[:batch_size] hyps_shape = get_hyps_shape(B).to(device) - + finalized_B_st = B_st[batch_size:] + finalized_B_st B_st = B_st[:batch_size] @@ -874,20 +873,24 @@ def modified_beam_search( # For blank symbol, log-prob is log-sigmoid of the score if use_hat == True: - # For blank symbol, log-prob is log-sigmoid of the score - logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) - # Additionally, to ensure the the probs of blank and non-blank sum to 1, we - # need to add the following term to the log-probs of non-blank symbols. This - # is equivalent to log(1 - sigmoid(logits[..., 0])). - #breakpoint() - nb_shift = logp_b - logits[..., 0] - nb_shift = nb_shift.unsqueeze(-1) - log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift # (num_hyps, vocab_size-1) - log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) - log_probs.add_(ys_log_probs) + # For blank symbol, log-prob is log-sigmoid of the score + logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) + # Additionally, to ensure the the probs of blank and non-blank sum to 1, we + # need to add the following term to the log-probs of non-blank symbols. This + # is equivalent to log(1 - sigmoid(logits[..., 0])). + # breakpoint() + nb_shift = logp_b - logits[..., 0] + nb_shift = nb_shift.unsqueeze(-1) + log_probs1 = (logits[..., 1:] / temperature).log_softmax( + dim=-1 + ) + nb_shift # (num_hyps, vocab_size-1) + log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) + log_probs.add_(ys_log_probs) else: - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - log_probs.add_(ys_log_probs) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -899,7 +902,6 @@ def modified_beam_search( ) ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - # st current_encoder_out_st = torch.index_select( current_encoder_out_st, @@ -928,9 +930,11 @@ def modified_beam_search( log_probs_st = torch.cat((logp_b_st.unsqueeze(-1), log_probs1_st), dim=-1) log_probs_st.add_(ys_log_probs_st) else: - log_probs_st = (logits_st / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs_st = (logits_st / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs_st.add_(ys_log_probs_st) - + vocab_size_st = log_probs_st.size(-1) log_probs_st = log_probs_st.reshape(-1) @@ -940,7 +944,9 @@ def modified_beam_search( log_probs_shape_st = k2.ragged.create_ragged_shape2( row_splits=row_splits_st, cached_tot_size=log_probs_st.numel() ) - ragged_log_probs_st = k2.RaggedTensor(shape=log_probs_shape_st, value=log_probs_st) + ragged_log_probs_st = k2.RaggedTensor( + shape=log_probs_shape_st, value=log_probs_st + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1092,7 +1098,7 @@ def modified_beam_search_st( current_encoder_out_st = current_encoder_out_st.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) offset = end - + finalized_B_st = B_st[batch_size:] + finalized_B_st B_st = B_st[:batch_size] @@ -1141,9 +1147,11 @@ def modified_beam_search_st( log_probs_st = torch.cat((logp_b_st.unsqueeze(-1), log_probs1_st), dim=-1) log_probs_st.add_(ys_log_probs_st) else: - log_probs_st = (logits_st / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs_st = (logits_st / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs_st.add_(ys_log_probs_st) - + vocab_size_st = log_probs_st.size(-1) log_probs_st = log_probs_st.reshape(-1) @@ -1153,7 +1161,9 @@ def modified_beam_search_st( log_probs_shape_st = k2.ragged.create_ragged_shape2( row_splits=row_splits_st, cached_tot_size=log_probs_st.numel() ) - ragged_log_probs_st = k2.RaggedTensor(shape=log_probs_shape_st, value=log_probs_st) + ragged_log_probs_st = k2.RaggedTensor( + shape=log_probs_shape_st, value=log_probs_st + ) for i in range(batch_size): topk_log_probs_st, topk_indexes_st = ragged_log_probs_st[i].topk(beam) @@ -1180,7 +1190,6 @@ def modified_beam_search_st( ) B_st[i].add(new_hyp_st) - B_st = B_st + finalized_B_st best_hyps_st = [b.get_most_probable(length_norm=True) for b in B_st] @@ -1200,7 +1209,7 @@ def modified_beam_search_st( hyps=[None, ans_st], timestamps=[None, ans_timestamps_st], ) - + def modified_beam_search_st2( model: nn.Module, @@ -1275,7 +1284,7 @@ def modified_beam_search_st2( current_encoder_out_st = current_encoder_out_st.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) offset = end - + finalized_B_st = B_st[batch_size:] + finalized_B_st B_st = B_st[:batch_size] @@ -1292,7 +1301,7 @@ def modified_beam_search_st2( device=device, dtype=torch.int64, ) # (num_hyps, context_size) - + decoder_out_st = model.st_decoder(decoder_input_st, need_pad=False).unsqueeze(1) decoder_out_st = model.st_joiner.decoder_proj(decoder_out_st) # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) @@ -1326,9 +1335,11 @@ def modified_beam_search_st2( log_probs_st = torch.cat((logp_b_st.unsqueeze(-1), log_probs1_st), dim=-1) log_probs_st.add_(ys_log_probs_st) else: - log_probs_st = (logits_st / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs_st = (logits_st / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs_st.add_(ys_log_probs_st) - + vocab_size_st = log_probs_st.size(-1) log_probs_st = log_probs_st.reshape(-1) @@ -1338,7 +1349,9 @@ def modified_beam_search_st2( log_probs_shape_st = k2.ragged.create_ragged_shape2( row_splits=row_splits_st, cached_tot_size=log_probs_st.numel() ) - ragged_log_probs_st = k2.RaggedTensor(shape=log_probs_shape_st, value=log_probs_st) + ragged_log_probs_st = k2.RaggedTensor( + shape=log_probs_shape_st, value=log_probs_st + ) for i in range(batch_size): topk_log_probs_st, topk_indexes_st = ragged_log_probs_st[i].topk(beam) @@ -1434,20 +1447,24 @@ def modified_beam_search_st2( # For blank symbol, log-prob is log-sigmoid of the score if use_hat == True: - # For blank symbol, log-prob is log-sigmoid of the score - logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) - # Additionally, to ensure the the probs of blank and non-blank sum to 1, we - # need to add the following term to the log-probs of non-blank symbols. This - # is equivalent to log(1 - sigmoid(logits[..., 0])). - #breakpoint() - nb_shift = logp_b - logits[..., 0] - nb_shift = nb_shift.unsqueeze(-1) - log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift # (num_hyps, vocab_size-1) - log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) - log_probs.add_(ys_log_probs) + # For blank symbol, log-prob is log-sigmoid of the score + logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) + # Additionally, to ensure the the probs of blank and non-blank sum to 1, we + # need to add the following term to the log-probs of non-blank symbols. This + # is equivalent to log(1 - sigmoid(logits[..., 0])). + # breakpoint() + nb_shift = logp_b - logits[..., 0] + nb_shift = nb_shift.unsqueeze(-1) + log_probs1 = (logits[..., 1:] / temperature).log_softmax( + dim=-1 + ) + nb_shift # (num_hyps, vocab_size-1) + log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) + log_probs.add_(ys_log_probs) else: - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - log_probs.add_(ys_log_probs) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1458,7 +1475,7 @@ def modified_beam_search_st2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - + for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) with warnings.catch_warnings(): @@ -1495,7 +1512,6 @@ def modified_beam_search_st2( ans.append(sorted_ans[unsorted_indices[i]]) ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - B_st = B_st + finalized_B_st best_hyps_st = [b.get_most_probable(length_norm=True) for b in B_st] @@ -1593,7 +1609,7 @@ def modified_beam_search_st2_lstm( offset = end # breakpoint() finalized_B_st = B_st[batch_size:] + finalized_B_st - B_st = B_st[:batch_size] + B_st = B_st[:batch_size] hyps_shape_st = get_hyps_shape(B_st).to(device) A_st = [list(b) for b in B_st] @@ -1603,17 +1619,25 @@ def modified_beam_search_st2_lstm( ) # (num_hyps, 1) decoder_input_st = torch.tensor( - [hyp.ys[-1] for hyps in A_st for hyp in hyps], - device=device, - dtype=torch.int64, - ).unsqueeze(1) # (num_hyps, 1) - + [hyp.ys[-1] for hyps in A_st for hyp in hyps], + device=device, + dtype=torch.int64, + ).unsqueeze( + 1 + ) # (num_hyps, 1) + st_decode_state = [hyp.decoder_state[-1] for hyps in A_st for hyp in hyps] - cur_h_states_st = torch.cat([state[0] for state in st_decode_state], dim=1) # (num_layers, batch, hidden_dim) - cur_c_states_st = torch.cat([state[1] for state in st_decode_state], dim=1) # (num_layers, batch, hidden_dim) + cur_h_states_st = torch.cat( + [state[0] for state in st_decode_state], dim=1 + ) # (num_layers, batch, hidden_dim) + cur_c_states_st = torch.cat( + [state[1] for state in st_decode_state], dim=1 + ) # (num_layers, batch, hidden_dim) st_decode_state = (cur_h_states_st, cur_c_states_st) - - decoder_out_st, (new_h_states, new_c_states) = model.st_decoder(decoder_input_st, st_decode_state, need_pad=False) + + decoder_out_st, (new_h_states, new_c_states) = model.st_decoder( + decoder_input_st, st_decode_state, need_pad=False + ) decoder_out_st = decoder_out_st.unsqueeze(1) decoder_out_st = model.st_joiner.decoder_proj(decoder_out_st) # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) @@ -1645,9 +1669,11 @@ def modified_beam_search_st2_lstm( log_probs_st = torch.cat((logp_b_st.unsqueeze(-1), log_probs1_st), dim=-1) log_probs_st.add_(ys_log_probs_st) else: - log_probs_st = (logits_st / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs_st = (logits_st / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs_st.add_(ys_log_probs_st) - + vocab_size_st = log_probs_st.size(-1) log_probs_st = log_probs_st.reshape(-1) @@ -1657,7 +1683,9 @@ def modified_beam_search_st2_lstm( log_probs_shape_st = k2.ragged.create_ragged_shape2( row_splits=row_splits_st, cached_tot_size=log_probs_st.numel() ) - ragged_log_probs_st = k2.RaggedTensor(shape=log_probs_shape_st, value=log_probs_st) + ragged_log_probs_st = k2.RaggedTensor( + shape=log_probs_shape_st, value=log_probs_st + ) for i in range(batch_size): topk_log_probs_st, topk_indexes_st = ragged_log_probs_st[i].topk(beam) @@ -1679,14 +1707,14 @@ def modified_beam_search_st2_lstm( new_timestamp_st.append(t) new_log_prob_st = topk_log_probs_st[k] - - # Hypothesis( - # ys=[blank_id_st], - # log_prob=torch.zeros(1, dtype=torch.float32, device=device), - # timestamp=[], - # decoder_state=[(h, c)], - # ) - # LSTM decoder_state + + # Hypothesis( + # ys=[blank_id_st], + # log_prob=torch.zeros(1, dtype=torch.float32, device=device), + # timestamp=[], + # decoder_state=[(h, c)], + # ) + # LSTM decoder_state if new_token == blank_id_st: new_hyp_st = Hypothesis( ys=new_ys_st, @@ -1700,10 +1728,12 @@ def modified_beam_search_st2_lstm( ys=new_ys_st, log_prob=new_log_prob_st, timestamp=new_timestamp_st, - decoder_state=[( - new_h_states[:, hyp_idx: hyp_idx + 1, :], - new_c_states[:, hyp_idx: hyp_idx + 1, :], - )], + decoder_state=[ + ( + new_h_states[:, hyp_idx : hyp_idx + 1, :], + new_c_states[:, hyp_idx : hyp_idx + 1, :], + ) + ], ) # new_hyp_st = Hypothesis( # ys=new_ys_st, log_prob=new_log_prob_st, timestamp=new_timestamp_st @@ -1780,20 +1810,24 @@ def modified_beam_search_st2_lstm( # For blank symbol, log-prob is log-sigmoid of the score if use_hat == True: - # For blank symbol, log-prob is log-sigmoid of the score - logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) - # Additionally, to ensure the the probs of blank and non-blank sum to 1, we - # need to add the following term to the log-probs of non-blank symbols. This - # is equivalent to log(1 - sigmoid(logits[..., 0])). - #breakpoint() - nb_shift = logp_b - logits[..., 0] - nb_shift = nb_shift.unsqueeze(-1) - log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift # (num_hyps, vocab_size-1) - log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) - log_probs.add_(ys_log_probs) + # For blank symbol, log-prob is log-sigmoid of the score + logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) + # Additionally, to ensure the the probs of blank and non-blank sum to 1, we + # need to add the following term to the log-probs of non-blank symbols. This + # is equivalent to log(1 - sigmoid(logits[..., 0])). + # breakpoint() + nb_shift = logp_b - logits[..., 0] + nb_shift = nb_shift.unsqueeze(-1) + log_probs1 = (logits[..., 1:] / temperature).log_softmax( + dim=-1 + ) + nb_shift # (num_hyps, vocab_size-1) + log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) + log_probs.add_(ys_log_probs) else: - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - log_probs.add_(ys_log_probs) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1804,7 +1838,7 @@ def modified_beam_search_st2_lstm( row_splits=row_splits, cached_tot_size=log_probs.numel() ) ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - + for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) with warnings.catch_warnings(): @@ -1841,7 +1875,6 @@ def modified_beam_search_st2_lstm( ans.append(sorted_ans[unsorted_indices[i]]) ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - B_st = B_st + finalized_B_st best_hyps_st = [b.get_most_probable(length_norm=True) for b in B_st] if st_lstm_pred: @@ -2144,7 +2177,7 @@ def modified_beam_search_lm_shallow_fusion( hyps=ans, timestamps=ans_timestamps, ) - + def modified_beam_search_lm_rescore_LODR( model: nn.Module, @@ -2241,7 +2274,7 @@ def modified_beam_search_lm_rescore_LODR( device=device, dtype=torch.int64, ) # (num_hyps, context_size) - + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) decoder_out = model.joiner.decoder_proj(decoder_out) # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) @@ -2268,9 +2301,9 @@ def modified_beam_search_lm_rescore_LODR( # is equivalent to log(1 - sigmoid(logits[..., 0])). nb_shift = logp_b - logits[..., 0] nb_shift = nb_shift.unsqueeze(-1) - log_probs1 = (logits[..., 1:]/ temperature).log_softmax(dim=-1) + nb_shift + log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift - #log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + # log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) log_probs.add_(ys_log_probs) @@ -2677,4 +2710,4 @@ def modified_beam_search_LODR( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - return ans \ No newline at end of file + return ans diff --git a/egs/multi_conv_zh_es_ta/ST/hent_srt/decode.py b/egs/multi_conv_zh_es_ta/ST/hent_srt/decode.py index 0cb1afd42c..70b8716624 100755 --- a/egs/multi_conv_zh_es_ta/ST/hent_srt/decode.py +++ b/egs/multi_conv_zh_es_ta/ST/hent_srt/decode.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # Copyright 2024-2025 Johns Hopkins University (Author: Amir Hussein) -# +# # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -90,6 +90,8 @@ import logging import math import os +import re +import string from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -100,16 +102,16 @@ import torch.nn as nn from asr_datamodule import MultiLingAsrDataModule from beam_search import ( - greedy_search_st, - greedy_search_batch_st, greedy_search_batch, - modified_beam_search_st, - modified_beam_search_st2, - modified_beam_search_st2_lstm, + greedy_search_batch_st, + greedy_search_st, modified_beam_search, - modified_beam_search_lm_shallow_fusion, modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, modified_beam_search_LODR, + modified_beam_search_st, + modified_beam_search_st2, + modified_beam_search_st2_lstm, ) from train import add_model_arguments, get_model, get_params @@ -130,33 +132,36 @@ str2bool, write_error_stats, ) -import string -import re LOG_EPS = math.log(1e-10) -def remove_punc(text, replacement_char='_'): + +def remove_punc(text, replacement_char="_"): """This function removes all English punctuations except the single quote (verbatim).""" english_punctuations = string.punctuation + "¿¡" - english_punctuations = english_punctuations.replace("'",'') - #english_punctuations = ''.join(c for c in string.punctuation if c != "'") + english_punctuations = english_punctuations.replace("'", "") + # english_punctuations = ''.join(c for c in string.punctuation if c != "'") # Create a translation table that maps each punctuation to the replacement character. - translator = str.maketrans(english_punctuations, replacement_char * len(english_punctuations)) - + translator = str.maketrans( + english_punctuations, replacement_char * len(english_punctuations) + ) + # Translate the text using the translation table text = text.translate(translator) - text = text.replace('_','') - + text = text.replace("_", "") + return text + def clean(text): text = remove_punc(text) text = text.lower() - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) text = text.rstrip() return text + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -221,11 +226,7 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( - "--dev-lang", - type=str, - default=None, - help="""dev language for evaluation""" - + "--dev-lang", type=str, default=None, help="""dev language for evaluation""" ) parser.add_argument( @@ -379,7 +380,7 @@ def get_parser(): Used only when `--use-shallow-fusion` is set to True. """, ) - + parser.add_argument( "--use-hat-decode", type=str2bool, @@ -514,7 +515,12 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens = model.forward_encoder(feature, feature_lens) + ( + encoder_out, + encoder_out_lens, + st_encoder_out, + st_encoder_out_lens, + ) = model.forward_encoder(feature, feature_lens) hyps = [] @@ -525,7 +531,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, st_encoder_out=st_encoder_out, st_encoder_out_lens=st_encoder_out_lens, - st_blank_penalty=params.st_blank_penalty + st_blank_penalty=params.st_blank_penalty, ) for hyp, hyp_st in zip(sp.decode(hyp_tokens[0]), sp_st.decode(hyp_tokens[1])): hyps.append([hyp.split(), hyp_st.split()]) @@ -538,10 +544,10 @@ def decode_one_batch( st_encoder_out_lens=st_encoder_out_lens, beam=params.beam_size, use_hat=params.use_hat_decode, - st_blank_penalty =params.st_blank_penalty + st_blank_penalty=params.st_blank_penalty, ) for hyp, hyp_st in zip(sp.decode(hyp_tokens[0]), sp_st.decode(hyp_tokens[1])): - + hyps.append([hyp.split(), hyp_st.split()]) elif params.decoding_method == "modified_beam_search_lstm": hyp_tokens = modified_beam_search_st2_lstm( @@ -554,9 +560,9 @@ def decode_one_batch( use_hat=params.use_hat_decode, ) for hyp, hyp_st in zip(sp.decode(hyp_tokens[0]), sp_st.decode(hyp_tokens[1])): - + hyps.append([hyp.split(), hyp_st.split()]) - + elif params.decoding_method == "modified_beam_search_LODR": hyp_tokens = modified_beam_search_LODR( model=model, @@ -584,8 +590,8 @@ def decode_one_batch( hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.05 * i for i in range(4, 10)] - hyp_tokens = modified_beam_search_lm_rescore_LODR( + lm_scale_list = [0.05 * i for i in range(4, 10)] + hyp_tokens = modified_beam_search_lm_rescore_LODR( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, @@ -595,7 +601,7 @@ def decode_one_batch( sp=sp, lm_scale_list=lm_scale_list, ) - for hyp in sp.decode(hyp_tokens): + for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) else: @@ -612,7 +618,7 @@ def decode_one_batch( encoder_out=encoder_out_i, st_encoder_out=encoder_out_st_i, max_sym_per_frame=params.max_sym_per_frame, - st_blank_penalty =params.st_blank_penalty + st_blank_penalty=params.st_blank_penalty, ) # elif params.decoding_method == "beam_search": # hyp = beam_search( @@ -690,7 +696,7 @@ def decode_dataset( results_st = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - texts_st = batch["supervisions"]['tgt_text']['en'] + texts_st = batch["supervisions"]["tgt_text"]["en"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -709,8 +715,10 @@ def decode_dataset( this_batch = [] this_batch_st = [] assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text, ref_text_st in zip(cut_ids, hyps, texts, texts_st): - + for cut_id, hyp_words, ref_text, ref_text_st in zip( + cut_ids, hyps, texts, texts_st + ): + if params.clean: tmp_hyp = " ".join(hyp_words[0]) tmp_hyp = clean(tmp_hyp) @@ -726,7 +734,9 @@ def decode_dataset( ref_words_st = ref_text_st.split() if ref_words_st: this_batch.append((cut_id, ref_words, hyp_words_asr)) - this_batch_st.append((cut_id, ref_words, ref_words_st, hyp_words_st)) + this_batch_st.append( + (cut_id, ref_words, ref_words_st, hyp_words_st) + ) results_asr[name].extend(this_batch) results_st[name].extend(this_batch_st) @@ -790,9 +800,7 @@ def save_st_results( results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt" results = sorted(results) store_translations(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -816,7 +824,6 @@ def main(): "modified_beam_search_LODR", "modified_beam_search", "modified_beam_search_lstm", - ) params.res_dir = params.exp_dir / params.decoding_method @@ -974,12 +981,10 @@ def main(): model.eval() # only load the neural network LM if required - if ( - params.use_shallow_fusion - or params.decoding_method in ( - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - "modified_beam_search_lm_rescore_LODR",) + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + "modified_beam_search_lm_rescore_LODR", ): LM = LmScorer( lm_type=params.lm_type, @@ -1065,8 +1070,8 @@ def main(): dev_all = multiling.dev_all_cuts() if params.dev_lang: logging.info(f"Evaluating on dev") - dev_sp = dev_all.filter(lambda c: c.supervisions[0].language[0]=='es') - dev_zh = dev_all.filter(lambda c: c.supervisions[0].language[0]=='zh') + dev_sp = dev_all.filter(lambda c: c.supervisions[0].language[0] == "es") + dev_zh = dev_all.filter(lambda c: c.supervisions[0].language[0] == "zh") # # test_fisher_cs = multling.test_fisher_cs() # test_ta1 = multiling.ta_test_cuts() # dev_ta = multiling.ta_dev_cuts() @@ -1078,8 +1083,8 @@ def main(): # # test_callhome_dl = multiling.test_dataloaders(test_callhome) # # test_fisher_cs_dl = multling.test_dataloaders(test_fisher_cs) - # dev_sp_dl = multiling.test_dataloaders(dev_sp) - # dev_zh_dl = multiling.test_dataloaders(dev_zh) + # dev_sp_dl = multiling.test_dataloaders(dev_sp) + # dev_zh_dl = multiling.test_dataloaders(dev_zh) # test_sets = ["test-fisher", "test-iwslt22","test-bnn", "test-callhome", "test-fisher-cs"] test_sets = ["test-fisher", "iwslt-ta", "test-bnn"] # test_sets = ["dev_ta","dev_sp", "dev_zh"] @@ -1120,4 +1125,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/multi_conv_zh_es_ta/ST/hent_srt/decoder.py b/egs/multi_conv_zh_es_ta/ST/hent_srt/decoder.py index ffbb7c93e6..3e1b0b4b65 100644 --- a/egs/multi_conv_zh_es_ta/ST/hent_srt/decoder.py +++ b/egs/multi_conv_zh_es_ta/ST/hent_srt/decoder.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Tuple + import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence - from scaling import Balancer - +from torch.nn.utils.rnn import pad_sequence # Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, # Zengrui Jin, @@ -40,13 +40,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -from scaling import Balancer - class LSTMDecoder(nn.Module): """LSTM decoder.""" @@ -109,7 +102,7 @@ def __init__( batch_first=True, dropout=rnn_dropout, ) - + self.balancer2 = Balancer( decoder_dim, channel_dim=-1, @@ -124,7 +117,7 @@ def forward( self, y: torch.Tensor, states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - need_pad: bool = False + need_pad: bool = False, ) -> torch.Tensor: """ Args: @@ -142,7 +135,9 @@ def forward( embedding_out = self.balancer(embedding_out) if need_pad is True: - embedding_out = pad_sequence(embedding_out, batch_first=True, padding_value=0) + embedding_out = pad_sequence( + embedding_out, batch_first=True, padding_value=0 + ) rnn_out, (h, c) = self.rnn(embedding_out, states) diff --git a/egs/multi_conv_zh_es_ta/ST/hent_srt/export.py b/egs/multi_conv_zh_es_ta/ST/hent_srt/export.py index 075065f49a..c0d49b938e 100755 --- a/egs/multi_conv_zh_es_ta/ST/hent_srt/export.py +++ b/egs/multi_conv_zh_es_ta/ST/hent_srt/export.py @@ -389,7 +389,12 @@ def forward( features: (N, T, C) feature_lengths: (N,) """ - encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens = model.forward_encoder(feature, feature_lengths) + ( + encoder_out, + encoder_out_lens, + st_encoder_out, + st_encoder_out_lens, + ) = self.model.forward_encoder(features, feature_lengths) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) st_encoder_out = st_encoder_out.permute(1, 0, 2) return encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens @@ -596,8 +601,8 @@ def main(): filename_start=filename_start, filename_end=filename_end, device=device, - - ), strict=False + ), + strict=False, ) model.eval() diff --git a/egs/multi_conv_zh_es_ta/ST/hent_srt/load_pretrained_model.py b/egs/multi_conv_zh_es_ta/ST/hent_srt/load_pretrained_model.py index 83e31e1737..5871ac9a4b 100644 --- a/egs/multi_conv_zh_es_ta/ST/hent_srt/load_pretrained_model.py +++ b/egs/multi_conv_zh_es_ta/ST/hent_srt/load_pretrained_model.py @@ -90,7 +90,7 @@ def get_attr(obj: Any, key: str): >>> assert A.linear.weight is get_attr(A, 'linear.weight') """ - + if key.strip() == "": return obj for k in key.split("."): @@ -98,20 +98,20 @@ def get_attr(obj: Any, key: str): return obj obj = get_attr(model, dst_key) - + src_state = torch.load(path, map_location=map_location) - + if excludes is not None: for e in excludes.split(","): src_state = {k: v for k, v in src_state.items() if not k.startswith(e)} if src_key is not None: - src_state['model'] = { + src_state["model"] = { k[len(src_key) + 1 :]: v - for k, v in src_state['model'].items() + for k, v in src_state["model"].items() if k.startswith(src_key) } dst_state = obj.state_dict() if ignore_init_mismatch: - src_state = filter_state_dict(dst_state, src_state['model']) + src_state = filter_state_dict(dst_state, src_state["model"]) dst_state.update(src_state) - obj.load_state_dict(dst_state) \ No newline at end of file + obj.load_state_dict(dst_state) diff --git a/egs/multi_conv_zh_es_ta/ST/hent_srt/model.py b/egs/multi_conv_zh_es_ta/ST/hent_srt/model.py index 2d57f8f10c..8440f6910b 100644 --- a/egs/multi_conv_zh_es_ta/ST/hent_srt/model.py +++ b/egs/multi_conv_zh_es_ta/ST/hent_srt/model.py @@ -21,11 +21,12 @@ import k2 import torch -from torch import Tensor -from lhotse.dataset import SpecAugment import torch.nn as nn from encoder_interface import EncoderInterface +from lhotse.dataset import SpecAugment from scaling import ScaledLinear +from torch import Tensor + from icefall.utils import add_sos, make_pad_mask, time_warp @@ -48,7 +49,7 @@ def __init__( use_ctc: bool = False, use_st_ctc: bool = False, use_hat: bool = False, - use_lstm_pred:bool=False, + use_lstm_pred: bool = False, ): """A multitask Transducer ASR-ST model with seperate joiners and predictors but shared acoustic encoder. @@ -102,7 +103,7 @@ def __init__( self.decoder = decoder self.joiner = joiner - + self.st_joiner = st_joiner self.st_decoder = st_decoder self.st_encoder = st_encoder @@ -165,17 +166,19 @@ def forward_encoder( src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens, st_input = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out, encoder_out_lens, st_input = self.encoder( + x, x_lens, src_key_padding_mask + ) if self.st_encoder is not None: - st_src_key_padding_mask = make_pad_mask(encoder_out_lens) - - st_encoder_out, st_encoder_out_lens = self.st_encoder( + st_src_key_padding_mask = make_pad_mask(encoder_out_lens) + + st_encoder_out, st_encoder_out_lens = self.st_encoder( st_input, x_lens, src_key_padding_mask ) - st_encoder_out = st_encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + st_encoder_out = st_encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) else: - st_encoder_out_lens = None - st_encoder_out = None + st_encoder_out_lens = None + st_encoder_out = None encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) return encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens @@ -283,7 +286,7 @@ def forward_cr_ctc( cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() return ctc_loss, cr_loss - + def forward_st_cr_ctc( self, st_encoder_out: torch.Tensor, @@ -325,7 +328,7 @@ def forward_st_cr_ctc( # target_lengths=target_lengths.cpu(), # reduction="sum", # ) - + # if not torch.isfinite(st_ctc_loss): # breakpoint() @@ -380,7 +383,6 @@ def forward_st_transducer( part """ # Now for the decoder, i.e., the prediction network - blank_id = self.decoder.blank_id st_blank_id = self.st_decoder.blank_id @@ -394,9 +396,9 @@ def forward_st_transducer( # decoder_out: [B, S + 1, decoder_dim] decoder_out = self.decoder(sos_y_padded) if self.use_lstm_pred: - st_decoder_out, _ = self.st_decoder(st_sos_y_padded) + st_decoder_out, _ = self.st_decoder(st_sos_y_padded) else: - st_decoder_out = self.st_decoder(st_sos_y_padded) + st_decoder_out = self.st_decoder(st_sos_y_padded) # Note: y does not start with SOS # y_padded : [B, S] @@ -424,7 +426,7 @@ def forward_st_transducer( lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - + st_lm = self.simple_st_lm_proj(st_decoder_out) st_am = self.simple_st_am_proj(st_encoder_out) @@ -434,33 +436,32 @@ def forward_st_transducer( # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) - st_simple_loss, (st_px_grad, st_py_grad) = k2.rnnt_loss_smoothed( - lm=st_lm.float(), - am=st_am.float(), - symbols=st_y_padded, - termination_symbol=st_blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=st_boundary, - reduction="sum", - return_grad=True, - ) - + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + st_simple_loss, (st_px_grad, st_py_grad) = k2.rnnt_loss_smoothed( + lm=st_lm.float(), + am=st_am.float(), + symbols=st_y_padded, + termination_symbol=st_blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=st_boundary, + reduction="sum", + return_grad=True, + ) # am_pruned : [B, T, prune_range, encoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim] - + # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( px_grad=px_grad, @@ -473,26 +474,26 @@ def forward_st_transducer( lm=self.joiner.decoder_proj(decoder_out), ranges=ranges, ) - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - use_hat_loss=self.use_hat, - ) - # logits : [B, T, prune_range, vocab_size] - + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + use_hat_loss=self.use_hat, + ) + # logits : [B, T, prune_range, vocab_size] + st_ranges = k2.get_rnnt_prune_ranges( - px_grad=st_px_grad, - py_grad=st_py_grad, - boundary=st_boundary, - s_range=st_prune_range, + px_grad=st_px_grad, + py_grad=st_py_grad, + boundary=st_boundary, + s_range=st_prune_range, ) st_am_pruned, st_lm_pruned = k2.do_rnnt_pruning( am=self.st_joiner.encoder_proj(st_encoder_out), @@ -503,18 +504,19 @@ def forward_st_transducer( st_logits = self.st_joiner(st_am_pruned, st_lm_pruned, project_input=False) # Compute HAT loss for st with torch.cuda.amp.autocast(enabled=False): - pruned_st_loss = k2.rnnt_loss_pruned( - logits=st_logits.float(), - symbols=st_y.pad(mode="constant", padding_value=blank_id).to(torch.int64), - ranges=st_ranges, - termination_symbol=st_blank_id, - boundary=st_boundary, - reduction="sum", - use_hat_loss=self.use_hat, - ) - - return simple_loss, st_simple_loss, pruned_loss, pruned_st_loss + pruned_st_loss = k2.rnnt_loss_pruned( + logits=st_logits.float(), + symbols=st_y.pad(mode="constant", padding_value=blank_id).to( + torch.int64 + ), + ranges=st_ranges, + termination_symbol=st_blank_id, + boundary=st_boundary, + reduction="sum", + use_hat_loss=self.use_hat, + ) + return simple_loss, st_simple_loss, pruned_loss, pruned_st_loss def forward_transducer( self, @@ -546,7 +548,6 @@ def forward_transducer( part """ # Now for the decoder, i.e., the prediction network - blank_id = self.decoder.blank_id sos_y = add_sos(y, sos_id=blank_id) @@ -578,22 +579,21 @@ def forward_transducer( # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) - + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) # am_pruned : [B, T, prune_range, encoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim] - + # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( px_grad=px_grad, @@ -606,20 +606,20 @@ def forward_transducer( lm=self.joiner.decoder_proj(decoder_out), ranges=ranges, ) - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - use_hat_loss=self.use_hat, - ) - # logits : [B, T, prune_range, vocab_size] + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + use_hat_loss=self.use_hat, + ) + # logits : [B, T, prune_range, vocab_size] return simple_loss, pruned_loss def forward( @@ -629,7 +629,7 @@ def forward( y: k2.RaggedTensor, st_y: k2.RaggedTensor, prune_range: int = 5, - st_prune_range: int =10, + st_prune_range: int = 10, am_scale: float = 0.0, lm_scale: float = 0.0, use_st_cr_ctc: bool = False, @@ -673,7 +673,7 @@ def forward( Parameter for the time warping; larger values mean more warping. Set to ``None``, or less than ``1``, to disable. Used only if use_cr_ctc is True. - + Returns: Return the transducer losses and CTC loss, in form of (simple_loss, pruned_loss, ctc_loss) @@ -691,26 +691,31 @@ def forward( assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) device = x.device if use_st_cr_ctc or use_asr_cr_ctc: - assert self.use_ctc or self.use_st_ctc - if use_spec_aug: - assert spec_augment is not None and spec_augment.time_warp_factor < 1 - # Apply time warping before input duplicating - assert supervision_segments is not None - x = time_warp( - x, - time_warp_factor=time_warp_factor, - supervision_segments=supervision_segments, - ) - # Independently apply frequency masking and time masking to the two copies - x = spec_augment(x.repeat(2, 1, 1)) - else: - x = x.repeat(2, 1, 1) - x_lens = x_lens.repeat(2) - y = k2.ragged.cat([y, y], axis=0) - if self.st_joiner != None and self.use_st_ctc: - st_y = k2.ragged.cat([st_y, st_y], axis=0) + assert self.use_ctc or self.use_st_ctc + if use_spec_aug: + assert spec_augment is not None and spec_augment.time_warp_factor < 1 + # Apply time warping before input duplicating + assert supervision_segments is not None + x = time_warp( + x, + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments, + ) + # Independently apply frequency masking and time masking to the two copies + x = spec_augment(x.repeat(2, 1, 1)) + else: + x = x.repeat(2, 1, 1) + x_lens = x_lens.repeat(2) + y = k2.ragged.cat([y, y], axis=0) + if self.st_joiner != None and self.use_st_ctc: + st_y = k2.ragged.cat([st_y, st_y], axis=0) # Compute encoder outputs - encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens = self.forward_encoder(x, x_lens) + ( + encoder_out, + encoder_out_lens, + st_encoder_out, + st_encoder_out_lens, + ) = self.forward_encoder(x, x_lens) row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] @@ -720,38 +725,43 @@ def forward( if self.use_transducer: # Compute transducer loss if self.st_joiner != None: - simple_loss, st_simple_loss, pruned_loss, st_pruned_loss = self.forward_st_transducer( - st_encoder_out=st_encoder_out, - st_encoder_out_lens=st_encoder_out_lens, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - y=y.to(x.device), - y_lens=y_lens, - st_y=st_y.to(x.device), - st_y_lens=st_y_lens, - prune_range=st_prune_range, - am_scale=am_scale, - lm_scale=lm_scale, - ) - if use_asr_cr_ctc: - simple_loss = simple_loss * 0.5 - pruned_loss = pruned_loss * 0.5 - if use_st_cr_ctc: - st_simple_loss = st_simple_loss * 0.5 - st_pruned_loss = st_pruned_loss * 0.5 + ( + simple_loss, + st_simple_loss, + pruned_loss, + st_pruned_loss, + ) = self.forward_st_transducer( + st_encoder_out=st_encoder_out, + st_encoder_out_lens=st_encoder_out_lens, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(x.device), + y_lens=y_lens, + st_y=st_y.to(x.device), + st_y_lens=st_y_lens, + prune_range=st_prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + if use_asr_cr_ctc: + simple_loss = simple_loss * 0.5 + pruned_loss = pruned_loss * 0.5 + if use_st_cr_ctc: + st_simple_loss = st_simple_loss * 0.5 + st_pruned_loss = st_pruned_loss * 0.5 else: simple_loss, pruned_loss = self.forward_transducer( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - y=y.to(x.device), - y_lens=y_lens, - prune_range=prune_range, - am_scale=am_scale, - lm_scale=lm_scale, - ) + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(x.device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) if use_asr_cr_ctc: - simple_loss = simple_loss * 0.5 - pruned_loss = pruned_loss * 0.5 + simple_loss = simple_loss * 0.5 + pruned_loss = pruned_loss * 0.5 st_simple_loss, st_pruned_loss = torch.empty(0), torch.empty(0) else: simple_loss = torch.empty(0) @@ -808,4 +818,13 @@ def forward( st_ctc_loss = torch.empty(0) st_cr_loss = torch.empty(0) - return simple_loss, st_simple_loss, pruned_loss, st_pruned_loss, ctc_loss, st_ctc_loss, cr_loss, st_cr_loss + return ( + simple_loss, + st_simple_loss, + pruned_loss, + st_pruned_loss, + ctc_loss, + st_ctc_loss, + cr_loss, + st_cr_loss, + ) diff --git a/egs/multi_conv_zh_es_ta/ST/hent_srt/streaming_beam_search.py b/egs/multi_conv_zh_es_ta/ST/hent_srt/streaming_beam_search.py index 56aa926641..ebaa362d66 100644 --- a/egs/multi_conv_zh_es_ta/ST/hent_srt/streaming_beam_search.py +++ b/egs/multi_conv_zh_es_ta/ST/hent_srt/streaming_beam_search.py @@ -59,8 +59,8 @@ def greedy_search_st( context_size = model.decoder.context_size unk_id = getattr(model, "unk_id", blank_id_st) - #ST - + # ST + decoder_input_st = torch.tensor( [stream.hyp_st[-context_size_st:] for stream in streams], device=device, @@ -69,7 +69,7 @@ def greedy_search_st( # decoder_out is of shape (N, 1, decoder_out_dim) decoder_out_st = model.st_decoder(decoder_input_st, need_pad=False) decoder_out_st = model.st_joiner.decoder_proj(decoder_out_st) - + # ASR decoder_input = torch.tensor( [stream.hyp_asr[-context_size:] for stream in streams], @@ -111,7 +111,7 @@ def greedy_search_st( for i, v in enumerate(y_st): if v not in (blank_id_st, unk_id_st): streams[i].hyp_st.append(v) - + # update decoder output # decoder_input_st = torch.tensor( # [stream.hyp_st[-context_size_st:].reshape( @@ -119,10 +119,17 @@ def greedy_search_st( # device=device, # dtype=torch.int64, # ) - decoder_input_st = torch.stack([ - torch.tensor(stream.hyp_st[-context_size_st:], device=device, dtype=torch.int64) - for stream in streams]).reshape(len(streams), context_size_st) - + decoder_input_st = torch.stack( + [ + torch.tensor( + stream.hyp_st[-context_size_st:], + device=device, + dtype=torch.int64, + ) + for stream in streams + ] + ).reshape(len(streams), context_size_st) + decoder_out_st = model.st_decoder( decoder_input_st, need_pad=False, @@ -241,4 +248,4 @@ def modified_beam_search( B[i].add(new_hyp) for i in range(batch_size): - streams[i].hyps = B[i] \ No newline at end of file + streams[i].hyps = B[i] diff --git a/egs/multi_conv_zh_es_ta/ST/hent_srt/train.py b/egs/multi_conv_zh_es_ta/ST/hent_srt/train.py index 0410a54ef3..59171dd46d 100755 --- a/egs/multi_conv_zh_es_ta/ST/hent_srt/train.py +++ b/egs/multi_conv_zh_es_ta/ST/hent_srt/train.py @@ -59,11 +59,9 @@ import logging import warnings from pathlib import Path -from load_pretrained_model import load_pretrained_model from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union -from torch.optim import Optimizer -from lhotse.dataset import SpecAugment + import k2 import optim import sentencepiece as spm @@ -74,8 +72,10 @@ from decoder import Decoder, LSTMDecoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset import SpecAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed +from load_pretrained_model import load_pretrained_model from model import HENT_SRT from optim import Eden, ScaledAdam from scaling import ScheduledFloat @@ -83,6 +83,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -177,7 +178,6 @@ def add_model_arguments(parser: argparse.ArgumentParser): default="12", help="Value dimension per head in encoder stacks: a single int or comma-separated list.", ) - parser.add_argument( "--pos-head-dim", @@ -197,7 +197,6 @@ def add_model_arguments(parser: argparse.ArgumentParser): default=2, help="ST encoder output downsampling factor", ) - parser.add_argument( "--pos-dim", @@ -335,7 +334,8 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--st-value-head-dim", type=str, default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.",) + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) parser.add_argument( "--st-pos-head-dim", type=str, @@ -448,6 +448,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use CTC head with ST encoder.", ) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -822,6 +823,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: ) return encoder_embed + def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Zipformer2( output_downsampling_factor=params.output_downsampling_factor, @@ -841,9 +843,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: causal=params.causal, chunk_size=_to_int_tuple(params.chunk_size), left_context_frames=_to_int_tuple(params.left_context_frames), - st_output_layer=-1,) + st_output_layer=-1, + ) return encoder + def get_st_encoder_model(params: AttributeDict) -> nn.Module: st_encoder = Zipformer2( output_downsampling_factor=params.st_output_downsampling_factor, @@ -866,6 +870,7 @@ def get_st_encoder_model(params: AttributeDict) -> nn.Module: ) return st_encoder + def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, @@ -875,6 +880,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: ) return decoder + def get_lstm_decoder_model(params: AttributeDict) -> nn.Module: decoder = LSTMDecoder( vocab_size=params.vocab_size, @@ -890,7 +896,9 @@ def get_lstm_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), #int(params.encoder_dim.split(",")[-1]), + encoder_dim=max( + _to_int_tuple(params.encoder_dim) + ), # int(params.encoder_dim.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -907,15 +915,19 @@ def get_st_decoder_model(params: AttributeDict) -> nn.Module: ) return decoder + def get_st_joiner_model(params: AttributeDict) -> nn.Module: st_joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.st_encoder_dim)), #int(params.st_encoder_dim.split(",")[-1]), + encoder_dim=max( + _to_int_tuple(params.st_encoder_dim) + ), # int(params.st_encoder_dim.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.st_joiner_dim, vocab_size=params.vocab_st_size, ) return st_joiner + def get_model(params: AttributeDict) -> nn.Module: assert params.use_transducer or params.use_ctc, ( f"At least one of them should be True, " @@ -944,7 +956,6 @@ def get_model(params: AttributeDict) -> nn.Module: decoder = None joiner = None - model = HENT_SRT( encoder_embed=encoder_embed, encoder=encoder, @@ -952,8 +963,12 @@ def get_model(params: AttributeDict) -> nn.Module: joiner=joiner, st_joiner=st_joiner, st_decoder=st_decoder, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), #int(params.encoder_dim.split(",")[-1]), - st_encoder_dim=max(_to_int_tuple(params.st_encoder_dim)), #int(params.st_encoder_dim.split(",")[-1]), + encoder_dim=max( + _to_int_tuple(params.encoder_dim) + ), # int(params.encoder_dim.split(",")[-1]), + st_encoder_dim=max( + _to_int_tuple(params.st_encoder_dim) + ), # int(params.st_encoder_dim.split(",")[-1]), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, st_vocab_size=params.vocab_st_size, @@ -962,10 +977,11 @@ def get_model(params: AttributeDict) -> nn.Module: use_st_ctc=params.use_st_ctc, use_hat=params.use_hat, st_encoder=st_encoder, - use_lstm_pred=params.use_lstm_predictor + use_lstm_pred=params.use_lstm_predictor, ) return model + def get_spec_augment(params: AttributeDict) -> SpecAugment: num_frame_masks = int(10 * params.time_mask_ratio) max_frames_mask_fraction = 0.15 * params.time_mask_ratio @@ -983,6 +999,7 @@ def get_spec_augment(params: AttributeDict) -> SpecAugment: ) return spec_augment + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -1023,7 +1040,7 @@ def load_checkpoint_if_available( return None assert filename.is_file(), f"{filename} does not exist!" - + saved_params = load_checkpoint( filename, model=model, @@ -1093,16 +1110,17 @@ def save_checkpoint( if params.best_train_epoch == params.cur_epoch: best_train_filename = params.exp_dir / "best-train-loss.pt" copyfile(src=filename, dst=best_train_filename) - + if params.best_valid_epoch == params.cur_epoch: best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) + def find_bad_length(feats, texts): - - feats_dur = ((feats-7)//2+1)//2 + + feats_dur = ((feats - 7) // 2 + 1) // 2 text_lengths = [len(t) for t in texts] - text_lengths_tensor = torch.tensor(text_lengths, device='cuda:0') + text_lengths_tensor = torch.tensor(text_lengths, device="cuda:0") # Check where text length is greater than duration errors = text_lengths_tensor > feats_dur @@ -1113,7 +1131,8 @@ def find_bad_length(feats, texts): if error_indices: print(f"Feature after subsampling: {feats_dur[error_indices[0]]}") print(f"Tokens: {text_lengths[error_indices[0]]}") - + + def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], @@ -1143,7 +1162,7 @@ def compute_loss( spec_augment: The SpecAugment instance used only when use_cr_ctc is True. """ - + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) @@ -1157,7 +1176,7 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - st_texts = batch["supervisions"]['tgt_text']['en'] + st_texts = batch["supervisions"]["tgt_text"]["en"] texts = [utt.replace("| ", "") for utt in texts] st_y = st_sp.encode(st_texts, out_type=int) st_y = k2.RaggedTensor(st_y) @@ -1165,13 +1184,13 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) if params.st_scale != 1: - alpha_st = params.st_scale - alpha_asr = 1-params.st_scale - + alpha_st = params.st_scale + alpha_asr = 1 - params.st_scale + else: alpha_st, alpha_asr = 1, 1 - - use_asr_cr_ctc, use_st_cr_ctc = params.use_asr_cr_ctc, params.use_st_cr_ctc + + use_asr_cr_ctc, use_st_cr_ctc = params.use_asr_cr_ctc, params.use_st_cr_ctc use_spec_aug = (use_asr_cr_ctc or use_st_cr_ctc) and is_training if use_spec_aug: supervision_intervals = batch["supervisions"] @@ -1187,7 +1206,16 @@ def compute_loss( supervision_segments = None with torch.set_grad_enabled(is_training): # find_bad_length(feature_lens, st_sp.encode(st_texts, out_type=int)) - simple_loss, st_simple_loss, pruned_loss, st_pruned_loss, ctc_loss, st_ctc_loss, cr_loss, st_cr_loss = model( + ( + simple_loss, + st_simple_loss, + pruned_loss, + st_pruned_loss, + ctc_loss, + st_ctc_loss, + cr_loss, + st_cr_loss, + ) = model( x=feature, x_lens=feature_lens, y=y, @@ -1209,13 +1237,12 @@ def compute_loss( if params.use_st_joiner: st_simple_loss_is_finite = torch.isfinite(st_simple_loss) st_pruned_loss_is_finite = torch.isfinite(st_pruned_loss) - + # if use_st_cr_ctc: - - # if not st_ctc_loss_is_finite: - # breakpoint() - + # if not st_ctc_loss_is_finite: + # breakpoint() + is_finite = ( simple_loss_is_finite & pruned_loss_is_finite @@ -1232,22 +1259,18 @@ def compute_loss( ) st_simple_loss = st_simple_loss[st_simple_loss_is_finite] st_pruned_loss = st_pruned_loss[st_pruned_loss_is_finite] - + else: - is_finite = ( - simple_loss_is_finite - & pruned_loss_is_finite - ) + is_finite = simple_loss_is_finite & pruned_loss_is_finite if not torch.all(is_finite): logging.info( "Not all losses are finite!\n" f"simple_losses: {simple_loss}\n" f"pruned_losses: {pruned_loss}\n" ) - + simple_loss = simple_loss[simple_loss_is_finite] pruned_loss = pruned_loss[pruned_loss_is_finite] - if params.use_transducer: s = params.simple_loss_scale @@ -1259,17 +1282,22 @@ def compute_loss( else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss += alpha_asr*(simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss) - #loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss += alpha_asr * ( + simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + ) + # loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_st_joiner: - - loss += alpha_st* (simple_loss_scale * st_simple_loss + pruned_loss_scale * st_pruned_loss) - #loss += pruned_loss_scale * lid_pruned_loss + + loss += alpha_st * ( + simple_loss_scale * st_simple_loss + + pruned_loss_scale * st_pruned_loss + ) + # loss += pruned_loss_scale * lid_pruned_loss if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss @@ -1294,7 +1322,7 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_st_joiner: info["st_simple_loss"] = st_simple_loss.detach().cpu().item() - info["st_pruned_loss"] = st_pruned_loss.detach().cpu().item() + info["st_pruned_loss"] = st_pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.use_asr_cr_ctc: @@ -1624,21 +1652,21 @@ def run(rank, world_size, args): # model.load_state_dict(init_ckpt["model"], strict=True) # missing_keys, unexpected_keys = model.load_state_dict(init_ckpt["model"], strict=True) pretrained_paths = [ - params.model_init_ckpt+":encoder_embed:encoder_embed", - params.model_init_ckpt+":encoder:encoder", # Load ASR encoder - params.model_init_ckpt+":decoder:decoder", - params.model_init_ckpt+":joiner:joiner", - params.model_init_ckpt+":simple_lm_proj:simple_lm_proj", - params.model_init_ckpt+":simple_am_proj:simple_am_proj", - params.model_init_ckpt+":ctc_output:ctc_output", - ] + params.model_init_ckpt + ":encoder_embed:encoder_embed", + params.model_init_ckpt + ":encoder:encoder", # Load ASR encoder + params.model_init_ckpt + ":decoder:decoder", + params.model_init_ckpt + ":joiner:joiner", + params.model_init_ckpt + ":simple_lm_proj:simple_lm_proj", + params.model_init_ckpt + ":simple_am_proj:simple_am_proj", + params.model_init_ckpt + ":ctc_output:ctc_output", + ] for pretrained_path in pretrained_paths: logging.info(f"Loading pretrained params from {pretrained_path}") load_pretrained_model( model=model, init_param=pretrained_path, ignore_init_mismatch=True, # Set to False if you want an error for mismatched layers - map_location=device # Use "cuda" for GPU + map_location=device, # Use "cuda" for GPU ) if params.freeze_main_model: @@ -1672,7 +1700,7 @@ def run(rank, world_size, args): ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - + # if checkpoints and "optimizer" in checkpoints: # logging.info("Loading optimizer state dict") # optimizer.load_state_dict(checkpoints["optimizer"]) @@ -1718,19 +1746,21 @@ def remove_short_and_long_utt(c: Cut): # and S is the number of tokens in the utterance # In ./zipformer.py, the conv module uses the following expression - # for subsampling + # for subsampling if params.use_st_cr_ctc: T = ((c.num_frames - 7) // 2 + 1) // (params.st_output_downsampling_factor) - st_tokens = st_sp.encode(c.supervisions[0].custom['translated_text']['en'], out_type=str) + st_tokens = st_sp.encode( + c.supervisions[0].custom["translated_text"]["en"], out_type=str + ) if T <= len(st_tokens): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. " - # f"Number of frames (before subsampling): {c.num_frames}. " - # f"Number of frames (after subsampling): {T}. " - # f"ST Text: {c.supervisions[0].custom['translated_text']['en']}. " - # f"ST Tokens: {st_tokens}. " - # f"Number of tokens: {len(st_tokens)}" - # ) + # logging.warning( + # f"Exclude cut with ID {c.id} from training. " + # f"Number of frames (before subsampling): {c.num_frames}. " + # f"Number of frames (after subsampling): {T}. " + # f"ST Text: {c.supervisions[0].custom['translated_text']['en']}. " + # f"ST Tokens: {st_tokens}. " + # f"Number of tokens: {len(st_tokens)}" + # ) return False if params.use_asr_cr_ctc: T = ((c.num_frames - 7) // 2 + 1) // (params.output_downsampling_factor) @@ -1748,6 +1778,7 @@ def remove_short_and_long_utt(c: Cut): return False return True + train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: @@ -1760,7 +1791,7 @@ def remove_short_and_long_utt(c: Cut): train_dl = multling.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - + valid_cuts = multling.dev_all_cuts() valid_cuts = valid_cuts.filter(remove_short_and_long_utt) valid_dl = multling.valid_dataloaders(valid_cuts) @@ -1796,14 +1827,14 @@ def remove_short_and_long_utt(c: Cut): optimizer=optimizer, scheduler=scheduler, sp=sp, - st_sp = st_sp, + st_sp=st_sp, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, tb_writer=tb_writer, world_size=world_size, rank=rank, - spec_augment=spec_augment + spec_augment=spec_augment, ) if params.print_diagnostics: @@ -1917,5 +1948,6 @@ def main(): else: run(rank=0, world_size=1, args=args) + if __name__ == "__main__": main() diff --git a/egs/multi_conv_zh_es_ta/ST/hent_srt/zipformer.py b/egs/multi_conv_zh_es_ta/ST/hent_srt/zipformer.py index 8d7d20571f..32f55ac5e3 100644 --- a/egs/multi_conv_zh_es_ta/ST/hent_srt/zipformer.py +++ b/egs/multi_conv_zh_es_ta/ST/hent_srt/zipformer.py @@ -18,28 +18,33 @@ # limitations under the License. import copy +import logging import math +import random import warnings from typing import List, Optional, Tuple, Union -import logging + import torch -import random from encoder_interface import EncoderInterface from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, Balancer, BiasNorm, - Dropout2, ChunkCausalDepthwiseConv1d, - ActivationDropoutAndLinear, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Dropout2, + FloatLike, + ScheduledFloat, Whiten, - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + convert_num_channels, + limit_param_value, penalize_abs_values_gt, softmax, - ScheduledFloat, - FloatLike, - limit_param_value, - convert_num_channels, ) from torch import Tensor, nn @@ -144,7 +149,7 @@ def _to_tuple(x): self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) pos_head_dim = _to_tuple(pos_head_dim) - + self.num_heads = num_heads = _to_tuple(num_heads) feedforward_dim = _to_tuple(feedforward_dim) self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) @@ -367,7 +372,6 @@ def forward( return x, lengths, outputs[self.st_output_layer] else: return x, lengths - def _get_attn_mask( self, x: Tensor, chunk_size: int, left_context_chunks: int @@ -462,7 +466,7 @@ def streaming_forward( num_layers = module.num_layers ds = self.downsampling_factor[i] x = convert_num_channels(x, self.encoder_dim[i]) - + x, new_layer_states = module.streaming_forward( x, states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], diff --git a/egs/multi_conv_zh_es_ta/ST/local/cer.py b/egs/multi_conv_zh_es_ta/ST/local/cer.py index fc2a2fb9a8..39c01bc3a5 100644 --- a/egs/multi_conv_zh_es_ta/ST/local/cer.py +++ b/egs/multi_conv_zh_es_ta/ST/local/cer.py @@ -1,15 +1,13 @@ -import argparse -import jiwer -import os +import argparse +import os import re +import jiwer + + def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--dec-file", - type=str, - help="file with decoded text" - ) + parser.add_argument("--dec-file", type=str, help="file with decoded text") return parser @@ -24,50 +22,50 @@ def contains_chinese(text): Returns: bool: True if the string contains at least one Chinese character, False otherwise. """ - chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]') + chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]") return bool(chinese_char_pattern.search(text)) + def cer_(file): hyp = [] ref = [] cer_results = 0 ref_lens = 0 - with open(file, 'r', encoding='utf-8') as dec: - + with open(file, "r", encoding="utf-8") as dec: + for line in dec: - id, target = line.split('\t') + id, target = line.split("\t") id = id[0:-1] target, txt = target.split("=") - - if target == 'ref': - words = txt.strip().strip('[]').split(', ') + + if target == "ref": + words = txt.strip().strip("[]").split(", ") word_list = [word.strip("'") for word in words] # if contains_chinese(" ".join(word_list)): # word_list = [" ".join(re.findall(r".",word.strip("'"))) for word in words] ref.append("".join(word_list)) - elif target == 'hyp': - words = txt.strip().strip('[]').split(', ') - + elif target == "hyp": + words = txt.strip().strip("[]").split(", ") + word_list = [word.strip("'") for word in words] # if contains_chinese(" ".join(word_list)): - + # word_list = ["".join(re.findall(r".",word.strip("'"))) for word in words] hyp.append("".join(word_list)) for h, r in zip(hyp, ref): if r: - cer_results += (jiwer.cer(r, h)*len(r)) - + cer_results += jiwer.cer(r, h) * len(r) + ref_lens += len(r) - #print(os.path.basename(file)) + # print(os.path.basename(file)) print(cer_results / ref_lens) - - def main(): parse = get_args() - args = parse.parse_args() + args = parse.parse_args() cer_(args.dec_file) - + + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/multi_conv_zh_es_ta/ST/local/compute_fbank_gpu.py b/egs/multi_conv_zh_es_ta/ST/local/compute_fbank_gpu.py index 05828b4915..4ce04f3c9c 100755 --- a/egs/multi_conv_zh_es_ta/ST/local/compute_fbank_gpu.py +++ b/egs/multi_conv_zh_es_ta/ST/local/compute_fbank_gpu.py @@ -23,29 +23,29 @@ The generated fbank features are saved in data/fbank. """ +import argparse import logging import os from pathlib import Path -import argparse import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - from lhotse.features.kaldifeat import ( KaldifeatFbank, KaldifeatFbankConfig, KaldifeatFrameOptions, KaldifeatMelOptions, ) +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect # even when we are not invoking the main (e.g. when spawning subprocesses) + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -81,29 +81,26 @@ def get_args(): def compute_fbank_gpu(args): - src_dir = Path(args.datadir+"/manifests") - output_dir = Path(args.datadir+"/fbank") - num_jobs = min(os.cpu_count(),10) + src_dir = Path(args.datadir + "/manifests") + output_dir = Path(args.datadir + "/fbank") + num_jobs = min(os.cpu_count(), 10) num_mel_bins = 80 sampling_rate = 16000 sr = 16000 logging.info(f"Cpus {num_jobs}") - if args.test: - dataset_parts = ( - "hkust_test", - "iwslt_ta_test", - "fisher-sp_test", - "dev") + dataset_parts = ("hkust_test", "iwslt_ta_test", "fisher-sp_test", "dev") else: - dataset_parts = ( - "train") + dataset_parts = "train" prefix = "cts" suffix = "jsonl.gz" manifests = read_manifests_if_cached( - prefix=prefix, dataset_parts=dataset_parts, output_dir=src_dir,suffix=suffix, + prefix=prefix, + dataset_parts=dataset_parts, + output_dir=src_dir, + suffix=suffix, ) assert manifests is not None @@ -132,15 +129,11 @@ def compute_fbank_gpu(args): cut_set = cut_set.resample(sr) cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, - keep_all_channels=False) - cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30) + keep_overlapping=False, keep_all_channels=False + ) + cut_set = cut_set.filter(lambda c: c.duration >= 0.2 and c.duration <= 30) if "train" in partition: - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{partition}", @@ -149,7 +142,7 @@ def compute_fbank_gpu(args): num_workers=num_jobs, storage_type=LilcomChunkyWriter, overwrite=True, - ) + ) cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz") else: logging.info(f"Processing {partition}") @@ -161,13 +154,12 @@ def compute_fbank_gpu(args): num_workers=num_jobs, storage_type=LilcomChunkyWriter, overwrite=True, - ) + ) cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz") + if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/multi_conv_zh_es_ta/ST/local/cuts_validate.py b/egs/multi_conv_zh_es_ta/ST/local/cuts_validate.py index f5cfb47280..e09ccc5bad 100644 --- a/egs/multi_conv_zh_es_ta/ST/local/cuts_validate.py +++ b/egs/multi_conv_zh_es_ta/ST/local/cuts_validate.py @@ -1,11 +1,11 @@ #!/usr/bin/python -from lhotse import RecordingSet, SupervisionSet, CutSet import argparse import logging -from lhotse.qa import fix_manifests, validate_recordings_and_supervisions import pdb +from lhotse import CutSet, RecordingSet, SupervisionSet +from lhotse.qa import fix_manifests, validate_recordings_and_supervisions def get_parser(): @@ -44,7 +44,7 @@ def get_parser(): def valid_asr(cut): tol = 2e-3 - i=0 + i = 0 total_dur = 0 for c in cut: if c.supervisions != []: @@ -52,10 +52,14 @@ def valid_asr(cut): logging.info(f"Supervision beyond the cut. Cut number: {i}") total_dur += c.duration - logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}") + logging.info( + f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}" + ) elif c.supervisions[0].start < -tol: logging.info(f"Supervision starts before the cut. Cut number: {i}") - logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}") + logging.info( + f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}" + ) else: continue else: @@ -63,7 +67,7 @@ def valid_asr(cut): logging.info(f"id: {c.id}") i += 1 logging.info(f"filtered duration: {total_dur}") - + def main(): @@ -74,7 +78,7 @@ def main(): else: recordings = RecordingSet.from_file(args.rec) supervisions = SupervisionSet.from_file(args.sup) - # breakpoint() + # breakpoint() logging.info("Example from supervisions:") logging.info(supervisions[0]) logging.info("Example from recordings") @@ -82,8 +86,11 @@ def main(): recordings, supervisions = fix_manifests(recordings, supervisions) logging.info("Validating manifests") validate_recordings_and_supervisions(recordings, supervisions) - - cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,) + + cuts = CutSet.from_manifests( + recordings=recordings, + supervisions=supervisions, + ) cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) cuts.describe() logging.info("Example from cut:") @@ -93,5 +100,6 @@ def main(): if args.savecut != "": cuts.to_file(args.savecut) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/multi_conv_zh_es_ta/ST/local/prepare_st_transcripts.py b/egs/multi_conv_zh_es_ta/ST/local/prepare_st_transcripts.py index 96abc26f65..67be3860f1 100755 --- a/egs/multi_conv_zh_es_ta/ST/local/prepare_st_transcripts.py +++ b/egs/multi_conv_zh_es_ta/ST/local/prepare_st_transcripts.py @@ -5,12 +5,13 @@ This script prepares transcript_words.txt from cutset """ -from lhotse import CutSet import argparse import logging +import os import pdb from pathlib import Path -import os + +from lhotse import CutSet def get_parser(): @@ -30,7 +31,7 @@ def get_parser(): help="name of the lang-dir", ) return parser - + def main(): @@ -40,14 +41,15 @@ def main(): logging.info("Reading the cuts") cuts = CutSet.from_file(args.cut) langdir = Path(args.langdir) - + if not os.path.exists(langdir): os.makedirs(langdir) - - with open(langdir / "st_words.txt", 'w') as txt: + + with open(langdir / "st_words.txt", "w") as txt: for c in cuts: - text = c.supervisions[0].custom['translated_text']['en'] - txt.write(text + '\n') + text = c.supervisions[0].custom["translated_text"]["en"] + txt.write(text + "\n") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/multi_conv_zh_es_ta/ST/local/prepare_transcripts.py b/egs/multi_conv_zh_es_ta/ST/local/prepare_transcripts.py index a9da2d6954..0aba9f75d7 100755 --- a/egs/multi_conv_zh_es_ta/ST/local/prepare_transcripts.py +++ b/egs/multi_conv_zh_es_ta/ST/local/prepare_transcripts.py @@ -5,12 +5,13 @@ This script prepares transcript_words.txt from cutset """ -from lhotse import CutSet import argparse import logging +import os import pdb from pathlib import Path -import os + +from lhotse import CutSet def get_parser(): @@ -30,7 +31,7 @@ def get_parser(): help="name of the lang-dir", ) return parser - + def main(): @@ -40,15 +41,16 @@ def main(): logging.info("Reading the cuts") cuts = CutSet.from_file(args.cut) langdir = Path(args.langdir) - + if not os.path.exists(langdir): os.makedirs(langdir) - - with open(langdir / "transcript_words.txt", 'w') as txt: + + with open(langdir / "transcript_words.txt", "w") as txt: for c in cuts: - #breakpoint() + # breakpoint() text = c.supervisions[0].text - txt.write(text + '\n') + txt.write(text + "\n") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/asr_datamodule.py b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/asr_datamodule.py index 55ee1872c6..220b39f480 100644 --- a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/asr_datamodule.py +++ b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/asr_datamodule.py @@ -29,8 +29,8 @@ CutConcatenate, CutMix, DynamicBucketingSampler, - K2SpeechRecognitionDataset, K2Speech2TextTranslationDataset, + K2SpeechRecognitionDataset, PrecomputedFeatures, SimpleCutSampler, SpecAugment, @@ -392,30 +392,26 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: @lru_cache() def train_cuts(self) -> CutSet: logging.info("Train data: About to get training cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_train.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz") @lru_cache() def dev_all_cuts(self) -> CutSet: logging.info("Dev data: About to get develop cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_dev.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") @lru_cache() def test_hkust(self) -> CutSet: logging.info("About to get test-hkust cuts") - return load_manifest_lazy( - self.args.manifest_dir / "cuts_hkust_test.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "cuts_hkust_test.jsonl.gz") + def test_iwslt22(self) -> CutSet: logging.info("About to get test-iwslt22 cuts") return load_manifest_lazy( self.args.manifest_dir / "cuts_iwslt_ta_test.jsonl.gz" ) + def test_fisher(self) -> CutSet: logging.info("About to get test-fisher cuts") return load_manifest_lazy( self.args.manifest_dir / "cuts_fisher-sp_test.jsonl.gz" - ) \ No newline at end of file + ) diff --git a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/beam_search.py b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/beam_search.py index d83bde2631..dc595b2d25 100644 --- a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/beam_search.py +++ b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/beam_search.py @@ -42,7 +42,7 @@ def greedy_search_st( model: nn.Module, encoder_out: torch.Tensor, max_sym_per_frame: int, - max_sym_per_frame_asr:int=1, + max_sym_per_frame_asr: int = 1, st_blank_penalty: float = 0.0, return_timestamps: bool = False, ) -> Union[List[int], DecodingResults]: @@ -70,7 +70,7 @@ def greedy_search_st( blank_id_st = model.st_decoder.blank_id context_size_st = model.st_decoder.context_size unk_id_st = getattr(model, "unk_id", blank_id_st) - + blank_id = model.decoder.blank_id context_size = model.decoder.context_size unk_id = getattr(model, "unk_id", blank_id_st) @@ -80,14 +80,14 @@ def greedy_search_st( decoder_input_st = torch.tensor( [-1] * (context_size_st - 1) + [blank_id_st], device=device, dtype=torch.int64 ).reshape(1, context_size_st) - + decoder_input = torch.tensor( [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 ).reshape(1, context_size) - + decoder_out_st = model.st_decoder(decoder_input_st, need_pad=False) decoder_out_st = model.st_joiner.decoder_proj(decoder_out_st) - encoder_out_st = model.st_joiner.encoder_proj(encoder_out) + encoder_out_st = model.st_joiner.encoder_proj(encoder_out) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -131,9 +131,9 @@ def greedy_search_st( if y_st not in (blank_id_st, unk_id_st): hyp_st.append(y_st) timestamp_st.append(t) - decoder_input_st = torch.tensor([hyp_st[-context_size_st:]], device=device).reshape( - 1, context_size_st - ) + decoder_input_st = torch.tensor( + [hyp_st[-context_size_st:]], device=device + ).reshape(1, context_size_st) decoder_out_st = model.st_decoder(decoder_input_st, need_pad=False) decoder_out_st = model.st_joiner.decoder_proj(decoder_out_st) @@ -143,7 +143,7 @@ def greedy_search_st( else: sym_per_frame = 0 t += 1 - + T = encoder_out.size(1) t = 0 hyp = [blank_id] * context_size @@ -191,7 +191,7 @@ def greedy_search_st( else: sym_per_frame = 0 t += 1 - + hyp_st = hyp_st[context_size_st:] hyp = hyp[context_size:] # remove blanks @@ -202,7 +202,8 @@ def greedy_search_st( hyps=[hyp, hyp_st], timestamps=[timestamp, timestamp_st], ) - + + def greedy_search_batch( model: nn.Module, encoder_out: torch.Tensor, @@ -600,7 +601,7 @@ def modified_beam_search( finalized_B = B[batch_size:] + finalized_B B = B[:batch_size] hyps_shape = get_hyps_shape(B).to(device) - + finalized_B_st = B_st[batch_size:] + finalized_B_st B_st = B_st[:batch_size] @@ -655,20 +656,24 @@ def modified_beam_search( # For blank symbol, log-prob is log-sigmoid of the score if use_hat == True: - # For blank symbol, log-prob is log-sigmoid of the score - logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) - # Additionally, to ensure the the probs of blank and non-blank sum to 1, we - # need to add the following term to the log-probs of non-blank symbols. This - # is equivalent to log(1 - sigmoid(logits[..., 0])). - #breakpoint() - nb_shift = logp_b - logits[..., 0] - nb_shift = nb_shift.unsqueeze(-1) - log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift # (num_hyps, vocab_size-1) - log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) - log_probs.add_(ys_log_probs) + # For blank symbol, log-prob is log-sigmoid of the score + logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) + # Additionally, to ensure the the probs of blank and non-blank sum to 1, we + # need to add the following term to the log-probs of non-blank symbols. This + # is equivalent to log(1 - sigmoid(logits[..., 0])). + # breakpoint() + nb_shift = logp_b - logits[..., 0] + nb_shift = nb_shift.unsqueeze(-1) + log_probs1 = (logits[..., 1:] / temperature).log_softmax( + dim=-1 + ) + nb_shift # (num_hyps, vocab_size-1) + log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) + log_probs.add_(ys_log_probs) else: - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - log_probs.add_(ys_log_probs) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -680,7 +685,6 @@ def modified_beam_search( ) ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - # st current_encoder_out_st = torch.index_select( current_encoder_out_st, @@ -709,9 +713,11 @@ def modified_beam_search( log_probs_st = torch.cat((logp_b_st.unsqueeze(-1), log_probs1_st), dim=-1) log_probs_st.add_(ys_log_probs_st) else: - log_probs_st = (logits_st / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs_st = (logits_st / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs_st.add_(ys_log_probs_st) - + vocab_size_st = log_probs_st.size(-1) log_probs_st = log_probs_st.reshape(-1) @@ -720,7 +726,9 @@ def modified_beam_search( log_probs_shape_st = k2.ragged.create_ragged_shape2( row_splits=row_splits_st, cached_tot_size=log_probs_st.numel() ) - ragged_log_probs_st = k2.RaggedTensor(shape=log_probs_shape_st, value=log_probs_st) + ragged_log_probs_st = k2.RaggedTensor( + shape=log_probs_shape_st, value=log_probs_st + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1079,7 +1087,7 @@ def modified_beam_search_lm_shallow_fusion( hyps=ans, timestamps=ans_timestamps, ) - + def modified_beam_search_lm_rescore_LODR( model: nn.Module, @@ -1176,7 +1184,7 @@ def modified_beam_search_lm_rescore_LODR( device=device, dtype=torch.int64, ) # (num_hyps, context_size) - + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) decoder_out = model.joiner.decoder_proj(decoder_out) # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) @@ -1203,9 +1211,9 @@ def modified_beam_search_lm_rescore_LODR( # is equivalent to log(1 - sigmoid(logits[..., 0])). nb_shift = logp_b - logits[..., 0] nb_shift = nb_shift.unsqueeze(-1) - log_probs1 = (logits[..., 1:]/ temperature).log_softmax(dim=-1) + nb_shift + log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift - #log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + # log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) log_probs.add_(ys_log_probs) @@ -1612,4 +1620,4 @@ def modified_beam_search_LODR( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - return ans \ No newline at end of file + return ans diff --git a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/decode.py b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/decode.py index 4a8ce554b9..6428641180 100755 --- a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/decode.py +++ b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/decode.py @@ -99,6 +99,8 @@ import logging import math import os +import re +import string from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -109,11 +111,11 @@ import torch.nn as nn from asr_datamodule import MultiLingAsrDataModule from beam_search import ( - greedy_search_st, greedy_search_batch, + greedy_search_st, modified_beam_search, - modified_beam_search_lm_shallow_fusion, modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, modified_beam_search_LODR, ) from train import add_model_arguments, get_model, get_params @@ -135,33 +137,36 @@ str2bool, write_error_stats, ) -import string -import re LOG_EPS = math.log(1e-10) -def remove_punc(text, replacement_char='_'): + +def remove_punc(text, replacement_char="_"): """This function removes all English punctuations except the single quote (verbatim).""" english_punctuations = string.punctuation + "¿¡" - english_punctuations = english_punctuations.replace("'",'') - #english_punctuations = ''.join(c for c in string.punctuation if c != "'") + english_punctuations = english_punctuations.replace("'", "") + # english_punctuations = ''.join(c for c in string.punctuation if c != "'") # Create a translation table that maps each punctuation to the replacement character. - translator = str.maketrans(english_punctuations, replacement_char * len(english_punctuations)) - + translator = str.maketrans( + english_punctuations, replacement_char * len(english_punctuations) + ) + # Translate the text using the translation table text = text.translate(translator) - text = text.replace('_','') - + text = text.replace("_", "") + return text + def clean(text): text = remove_punc(text) text = text.lower() - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) text = text.rstrip() return text + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -371,15 +376,11 @@ def get_parser(): Used only when `--use-shallow-fusion` is set to True. """, ) - + parser.add_argument( - "--dev-lang", - type=str, - default=None, - help="""dev language for evaluation""" - + "--dev-lang", type=str, default=None, help="""dev language for evaluation""" ) - + parser.add_argument( "--use-hat-decode", type=str2bool, @@ -535,9 +536,9 @@ def decode_one_batch( use_hat=params.use_hat_decode, ) for hyp, hyp_st in zip(sp.decode(hyp_tokens[0]), sp_st.decode(hyp_tokens[1])): - + hyps.append([hyp.split(), hyp_st.split()]) - + elif params.decoding_method == "modified_beam_search_LODR": hyp_tokens = modified_beam_search_LODR( model=model, @@ -565,8 +566,8 @@ def decode_one_batch( hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.05 * i for i in range(4, 10)] - hyp_tokens = modified_beam_search_lm_rescore_LODR( + lm_scale_list = [0.05 * i for i in range(4, 10)] + hyp_tokens = modified_beam_search_lm_rescore_LODR( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, @@ -576,7 +577,7 @@ def decode_one_batch( sp=sp, lm_scale_list=lm_scale_list, ) - for hyp in sp.decode(hyp_tokens): + for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) else: @@ -591,7 +592,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, - st_blank_penalty =params.st_blank_penalty + st_blank_penalty=params.st_blank_penalty, ) # elif params.decoding_method == "beam_search": # hyp = beam_search( @@ -669,7 +670,7 @@ def decode_dataset( results_st = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - texts_st = batch["supervisions"]['tgt_text']['en'] + texts_st = batch["supervisions"]["tgt_text"]["en"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, @@ -689,7 +690,9 @@ def decode_dataset( this_batch = [] this_batch_st = [] assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text, ref_text_st in zip(cut_ids, hyps, texts, texts_st): + for cut_id, hyp_words, ref_text, ref_text_st in zip( + cut_ids, hyps, texts, texts_st + ): if params.clean: tmp_hyp = " ".join(hyp_words[0]) tmp_hyp = clean(tmp_hyp) @@ -769,14 +772,15 @@ def save_st_results( results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"{test_set_name}-{key}-{params.suffix}.txt" results = sorted(results) store_translations(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") + import time + + def measure_real_time_factor(decode_fn, *args, **kwargs): """ Measures the real-time factor (RTF) of the decoding function. @@ -784,7 +788,7 @@ def measure_real_time_factor(decode_fn, *args, **kwargs): Args: decode_fn: The decoding function to wrap (e.g., decode_dataset). *args, **kwargs: Arguments passed to the decode function. - + Returns: results: Output of the decode function. rtf: Real-time factor value. @@ -806,9 +810,10 @@ def measure_real_time_factor(decode_fn, *args, **kwargs): print(f"Total audio duration: {total_duration:.2f}s") print(f"Total decoding time: {decoding_time:.2f}s") print(f"Real-time factor (RTF): {rtf:.4f}") - + return results, rtf + @torch.no_grad() def main(): parser = get_parser() @@ -981,12 +986,10 @@ def main(): model.eval() # only load the neural network LM if required - if ( - params.use_shallow_fusion - or params.decoding_method in ( - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - "modified_beam_search_lm_rescore_LODR",) + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + "modified_beam_search_lm_rescore_LODR", ): LM = LmScorer( lm_type=params.lm_type, @@ -1071,25 +1074,25 @@ def main(): test_iwslt22_dl = multiling.test_dataloaders(test_iwslt22) test_fisher_dl = multiling.test_dataloaders(test_fisher) - test_sets = [ "test-fisher", "iwslt-ta", "test-hkust"] + test_sets = ["test-fisher", "iwslt-ta", "test-hkust"] test_dl = [test_fisher_dl, test_iwslt22_dl, test_hkust_dl] for test_set, test_dl in zip(test_sets, test_dl): - (results_dict_asr, results_dict_st), rtf = measure_real_time_factor( - decode_dataset, - test_dl, # dataloader - params=params, - model=model, - sp=sp, - sp_st=sp_st, - word_table=word_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) + (results_dict_asr, results_dict_st), rtf = measure_real_time_factor( + decode_dataset, + test_dl, # dataloader + params=params, + model=model, + sp=sp, + sp_st=sp_st, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) save_asr_results( params=params, diff --git a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/decoder.py b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/decoder.py index 283d902c60..cadf43690c 100644 --- a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/decoder.py +++ b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/decoder.py @@ -18,7 +18,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from scaling import Balancer diff --git a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/joiner.py b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/joiner.py index 4d3e59ee9b..7dd18807de 100644 --- a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/joiner.py +++ b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/joiner.py @@ -14,10 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch import torch.nn as nn from scaling import ScaledLinear -from typing import Optional + class Joiner(nn.Module): def __init__( @@ -33,7 +35,7 @@ def __init__( self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) if encoder_lid: self.lid_proj = ScaledLinear(encoder_lid, joiner_dim, initial_scale=0.25) - self.output_linear = nn.Linear(joiner_dim, vocab_size) + self.output_linear = nn.Linear(joiner_dim, vocab_size) def forward( self, diff --git a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/model.py b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/model.py index 33f15cc694..27cb571636 100644 --- a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/model.py +++ b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/model.py @@ -21,13 +21,13 @@ import k2 import torch -from torch import Tensor -from lhotse.dataset import SpecAugment import torch.nn as nn from encoder_interface import EncoderInterface - -from icefall.utils import add_sos, make_pad_mask, time_warp +from lhotse.dataset import SpecAugment from scaling import ScaledLinear +from torch import Tensor + +from icefall.utils import add_sos, make_pad_mask, time_warp class StModel(nn.Module): @@ -99,7 +99,7 @@ def __init__( self.decoder = decoder self.joiner = joiner - + self.st_joiner = st_joiner self.st_decoder = st_decoder @@ -166,6 +166,7 @@ def forward_encoder( encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) return encoder_out, encoder_out_lens + def forward_st_ctc( self, st_encoder_out: torch.Tensor, @@ -223,7 +224,7 @@ def forward_ctc( reduction="sum", ) return ctc_loss - + def forward_cr_ctc( self, encoder_out: torch.Tensor, @@ -311,7 +312,7 @@ def forward_st_cr_ctc( # target_lengths=target_lengths.cpu(), # reduction="sum", # ) - + # if not torch.isfinite(st_ctc_loss): # breakpoint() @@ -363,7 +364,6 @@ def forward_transducer( part """ # Now for the decoder, i.e., the prediction network - blank_id = self.decoder.blank_id st_blank_id = self.st_decoder.blank_id @@ -413,33 +413,32 @@ def forward_transducer( # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) - st_simple_loss, (st_px_grad, st_py_grad) = k2.rnnt_loss_smoothed( - lm=st_lm.float(), - am=st_am.float(), - symbols=st_y_padded, - termination_symbol=st_blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=st_boundary, - reduction="sum", - return_grad=True, - ) - + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + st_simple_loss, (st_px_grad, st_py_grad) = k2.rnnt_loss_smoothed( + lm=st_lm.float(), + am=st_am.float(), + symbols=st_y_padded, + termination_symbol=st_blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=st_boundary, + reduction="sum", + return_grad=True, + ) # am_pruned : [B, T, prune_range, encoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim] - + # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( px_grad=px_grad, @@ -452,26 +451,26 @@ def forward_transducer( lm=self.joiner.decoder_proj(decoder_out), ranges=ranges, ) - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - use_hat_loss=self.use_hat, - ) - # logits : [B, T, prune_range, vocab_size] - + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + use_hat_loss=self.use_hat, + ) + # logits : [B, T, prune_range, vocab_size] + st_ranges = k2.get_rnnt_prune_ranges( - px_grad=st_px_grad, - py_grad=st_py_grad, - boundary=st_boundary, - s_range=prune_range, + px_grad=st_px_grad, + py_grad=st_py_grad, + boundary=st_boundary, + s_range=prune_range, ) st_am_pruned, st_lm_pruned = k2.do_rnnt_pruning( am=self.st_joiner.encoder_proj(encoder_out), @@ -482,16 +481,18 @@ def forward_transducer( st_logits = self.st_joiner(st_am_pruned, st_lm_pruned, project_input=False) # Compute HAT loss for st with torch.cuda.amp.autocast(enabled=False): - pruned_st_loss = k2.rnnt_loss_pruned( - logits=st_logits.float(), - symbols=st_y.pad(mode="constant", padding_value=blank_id).to(torch.int64), - ranges=st_ranges, - termination_symbol=st_blank_id, - boundary=st_boundary, - reduction="sum", - use_hat_loss=self.use_hat, - ) - + pruned_st_loss = k2.rnnt_loss_pruned( + logits=st_logits.float(), + symbols=st_y.pad(mode="constant", padding_value=blank_id).to( + torch.int64 + ), + ranges=st_ranges, + termination_symbol=st_blank_id, + boundary=st_boundary, + reduction="sum", + use_hat_loss=self.use_hat, + ) + return simple_loss, st_simple_loss, pruned_loss, pruned_st_loss def forward( @@ -562,24 +563,24 @@ def forward( assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) if use_st_cr_ctc or use_asr_cr_ctc: - assert self.use_ctc or self.use_st_ctc - if use_spec_aug: - assert spec_augment is not None and spec_augment.time_warp_factor < 1 - # Apply time warping before input duplicating - assert supervision_segments is not None - x = time_warp( - x, - time_warp_factor=time_warp_factor, - supervision_segments=supervision_segments, - ) - # Independently apply frequency masking and time masking to the two copies - x = spec_augment(x.repeat(2, 1, 1)) - else: - x = x.repeat(2, 1, 1) - x_lens = x_lens.repeat(2) - y = k2.ragged.cat([y, y], axis=0) - if self.st_joiner != None and self.use_st_ctc: - st_y = k2.ragged.cat([st_y, st_y], axis=0) + assert self.use_ctc or self.use_st_ctc + if use_spec_aug: + assert spec_augment is not None and spec_augment.time_warp_factor < 1 + # Apply time warping before input duplicating + assert supervision_segments is not None + x = time_warp( + x, + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments, + ) + # Independently apply frequency masking and time masking to the two copies + x = spec_augment(x.repeat(2, 1, 1)) + else: + x = x.repeat(2, 1, 1) + x_lens = x_lens.repeat(2) + y = k2.ragged.cat([y, y], axis=0) + if self.st_joiner != None and self.use_st_ctc: + st_y = k2.ragged.cat([st_y, st_y], axis=0) # Compute encoder outputs @@ -590,39 +591,44 @@ def forward( st_y_lens = st_row_splits[1:] - st_row_splits[:-1] if self.use_transducer: - + # Compute transducer loss if self.st_joiner != None: - simple_loss, st_simple_loss, pruned_loss, st_pruned_loss = self.forward_transducer( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - y=y.to(x.device), - y_lens=y_lens, - st_y=st_y.to(x.device), - st_y_lens=st_y_lens, - prune_range=prune_range, - am_scale=am_scale, - lm_scale=lm_scale, - ) - if use_asr_cr_ctc: - simple_loss = simple_loss * 0.5 - pruned_loss = pruned_loss * 0.5 - if use_st_cr_ctc: - st_simple_loss = st_simple_loss * 0.5 - st_pruned_loss = st_pruned_loss * 0.5 + ( + simple_loss, + st_simple_loss, + pruned_loss, + st_pruned_loss, + ) = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(x.device), + y_lens=y_lens, + st_y=st_y.to(x.device), + st_y_lens=st_y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + if use_asr_cr_ctc: + simple_loss = simple_loss * 0.5 + pruned_loss = pruned_loss * 0.5 + if use_st_cr_ctc: + st_simple_loss = st_simple_loss * 0.5 + st_pruned_loss = st_pruned_loss * 0.5 else: simple_loss, pruned_loss = self.forward_transducer( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - y=y.to(x.device), - y_lens=y_lens, - prune_range=prune_range, - am_scale=am_scale, - lm_scale=lm_scale, - ) + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(x.device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) if use_asr_cr_ctc: - simple_loss = simple_loss * 0.5 - pruned_loss = pruned_loss * 0.5 + simple_loss = simple_loss * 0.5 + pruned_loss = pruned_loss * 0.5 st_simple_loss, st_pruned_loss = torch.empty(0), torch.empty(0) else: simple_loss = torch.empty(0) @@ -654,28 +660,37 @@ def forward( if self.use_st_ctc: st_targets = st_y.values if not use_st_cr_ctc: - st_ctc_loss = self.forward_st_ctc( + st_ctc_loss = self.forward_st_ctc( st_encoder_out=encoder_out, st_encoder_out_lens=encoder_out_lens, targets=st_targets, target_lengths=st_y_lens, ) - st_cr_loss = torch.empty(0) + st_cr_loss = torch.empty(0) else: - st_ctc_loss, st_cr_loss = self.forward_st_cr_ctc( - st_encoder_out=encoder_out, - st_encoder_out_lens=encoder_out_lens, - st_targets=st_targets, - st_target_lengths=st_y_lens, - # encoder_out=encoder_out, - # encoder_out_lens=encoder_out_lens, - # targets=targets, - # target_lengths=y_lens, - ) - st_ctc_loss = st_ctc_loss * 0.5 - st_cr_loss = st_cr_loss * 0.5 + st_ctc_loss, st_cr_loss = self.forward_st_cr_ctc( + st_encoder_out=encoder_out, + st_encoder_out_lens=encoder_out_lens, + st_targets=st_targets, + st_target_lengths=st_y_lens, + # encoder_out=encoder_out, + # encoder_out_lens=encoder_out_lens, + # targets=targets, + # target_lengths=y_lens, + ) + st_ctc_loss = st_ctc_loss * 0.5 + st_cr_loss = st_cr_loss * 0.5 else: st_ctc_loss = torch.empty(0) st_cr_loss = torch.empty(0) - return simple_loss, st_simple_loss, pruned_loss, st_pruned_loss, ctc_loss, st_ctc_loss, cr_loss, st_cr_loss + return ( + simple_loss, + st_simple_loss, + pruned_loss, + st_pruned_loss, + ctc_loss, + st_ctc_loss, + cr_loss, + st_cr_loss, + ) diff --git a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/profile.py b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/profile.py index 57f44a90a8..b1f1c0e4d3 100755 --- a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/profile.py +++ b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/profile.py @@ -22,24 +22,24 @@ import argparse import logging -import sentencepiece as spm -import torch - from typing import Tuple -from torch import Tensor, nn -from icefall.utils import make_pad_mask -from icefall.profiler import get_model_profile +import sentencepiece as spm +import torch from scaling import BiasNorm +from torch import Tensor, nn from train import ( + add_model_arguments, get_encoder_embed, get_encoder_model, get_joiner_model, - add_model_arguments, get_params, ) from zipformer import BypassModule +from icefall.profiler import get_model_profile +from icefall.utils import make_pad_mask + def get_parser(): parser = argparse.ArgumentParser( diff --git a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/scaling.py b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/scaling.py index c0f1e30873..29ac33c02b 100644 --- a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/scaling.py +++ b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/scaling.py @@ -15,15 +15,16 @@ # limitations under the License. -from typing import Optional, Tuple, Union import logging -import k2 -from torch.cuda.amp import custom_fwd, custom_bwd +import math import random +from typing import Optional, Tuple, Union + +import k2 import torch -import math import torch.nn as nn from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: diff --git a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/subsampling.py b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/subsampling.py index d16d87bac9..b2f769d3f6 100644 --- a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/subsampling.py +++ b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/subsampling.py @@ -16,11 +16,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import warnings +from typing import Tuple import torch -from torch import Tensor, nn from scaling import ( Balancer, BiasNorm, @@ -34,6 +33,7 @@ SwooshR, Whiten, ) +from torch import Tensor, nn class ConvNeXt(nn.Module): diff --git a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/train.py b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/train.py index b0dc751582..6c10c31893 100755 --- a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/train.py +++ b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/train.py @@ -97,8 +97,7 @@ from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union -from torch.optim import Optimizer -from lhotse.dataset import SpecAugment + import k2 import optim import sentencepiece as spm @@ -109,6 +108,7 @@ from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset import SpecAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import StModel @@ -118,6 +118,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -363,7 +364,8 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--st-value-head-dim", type=str, default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.",) + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) parser.add_argument( "--st-pos-head-dim", type=str, @@ -436,6 +438,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use CTC head with ST encoder.", ) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -808,6 +811,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: ) return encoder_embed + def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Zipformer2( output_downsampling_factor=2, @@ -826,9 +830,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: warmup_batches=4000.0, causal=params.causal, chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames),) + left_context_frames=_to_int_tuple(params.left_context_frames), + ) return encoder + def get_st_encoder_model(params: AttributeDict) -> nn.Module: st_encoder = Zipformer2( output_downsampling_factor=2, @@ -851,6 +857,7 @@ def get_st_encoder_model(params: AttributeDict) -> nn.Module: ) return st_encoder + def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, @@ -863,7 +870,9 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)),#int(params.encoder_dim.split(",")[-1]), + encoder_dim=max( + _to_int_tuple(params.encoder_dim) + ), # int(params.encoder_dim.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -880,15 +889,19 @@ def get_st_decoder_model(params: AttributeDict) -> nn.Module: ) return st_decoder + def get_st_joiner_model(params: AttributeDict) -> nn.Module: st_joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)),#int(params.encoder_dim.split(",")[-1]), + encoder_dim=max( + _to_int_tuple(params.encoder_dim) + ), # int(params.encoder_dim.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.st_joiner_dim, vocab_size=params.st_vocab_size, ) return st_joiner + def get_model(params: AttributeDict) -> nn.Module: assert params.use_transducer or params.use_ctc, ( f"At least one of them should be True, " @@ -912,7 +925,6 @@ def get_model(params: AttributeDict) -> nn.Module: decoder = None joiner = None - model = StModel( encoder_embed=encoder_embed, encoder=encoder, @@ -920,7 +932,9 @@ def get_model(params: AttributeDict) -> nn.Module: joiner=joiner, st_joiner=st_joiner, st_decoder=st_decoder, - encoder_dim=max(_to_int_tuple(params.encoder_dim)),#int(params.encoder_dim.split(",")[-1]), + encoder_dim=max( + _to_int_tuple(params.encoder_dim) + ), # int(params.encoder_dim.split(",")[-1]), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, st_vocab_size=params.st_vocab_size, @@ -949,6 +963,7 @@ def get_spec_augment(params: AttributeDict) -> SpecAugment: ) return spec_augment + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -989,7 +1004,7 @@ def load_checkpoint_if_available( return None assert filename.is_file(), f"{filename} does not exist!" - + saved_params = load_checkpoint( filename, model=model, @@ -1094,7 +1109,7 @@ def compute_loss( spec_augment: The SpecAugment instance used only when use_cr_ctc is True. """ - + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) @@ -1108,7 +1123,7 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - st_texts = batch["supervisions"]['tgt_text']['en'] + st_texts = batch["supervisions"]["tgt_text"]["en"] texts = [utt.replace("| ", "") for utt in texts] st_y = st_sp.encode(st_texts, out_type=int) st_y = k2.RaggedTensor(st_y) @@ -1116,13 +1131,13 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) if params.st_scale != 1: - alpha_st = params.st_scale - alpha_asr = 1-params.st_scale + alpha_st = params.st_scale + alpha_asr = 1 - params.st_scale else: alpha_st, alpha_asr = 1, 1 - use_asr_cr_ctc, use_st_cr_ctc = params.use_asr_cr_ctc, params.use_st_cr_ctc + use_asr_cr_ctc, use_st_cr_ctc = params.use_asr_cr_ctc, params.use_st_cr_ctc use_spec_aug = (use_asr_cr_ctc or use_st_cr_ctc) and is_training - + if use_spec_aug: supervision_intervals = batch["supervisions"] supervision_segments = torch.stack( @@ -1136,7 +1151,16 @@ def compute_loss( else: supervision_segments = None with torch.set_grad_enabled(is_training): - simple_loss, st_simple_loss, pruned_loss, st_pruned_loss, ctc_loss, st_ctc_loss, cr_loss, st_cr_loss = model( + ( + simple_loss, + st_simple_loss, + pruned_loss, + st_pruned_loss, + ctc_loss, + st_ctc_loss, + cr_loss, + st_cr_loss, + ) = model( x=feature, x_lens=feature_lens, y=y, @@ -1186,14 +1210,17 @@ def compute_loss( else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_st_joiner: - loss += simple_loss_scale * st_simple_loss + pruned_loss_scale * st_pruned_loss + loss += ( + simple_loss_scale * st_simple_loss + + pruned_loss_scale * st_pruned_loss + ) if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss @@ -1218,7 +1245,7 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_st_joiner: info["st_simple_loss"] = st_simple_loss.detach().cpu().item() - info["st_pruned_loss"] = st_pruned_loss.detach().cpu().item() + info["st_pruned_loss"] = st_pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.use_asr_cr_ctc: @@ -1573,7 +1600,7 @@ def run(rank, world_size, args): ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - + # if checkpoints and "optimizer" in checkpoints: # logging.info("Loading optimizer state dict") # optimizer.load_state_dict(checkpoints["optimizer"]) @@ -1608,7 +1635,7 @@ def remove_short_and_long_utt(c: Cut): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - if c.duration =< 0.1 or c.duration >= 30.0: + if c.duration <= 0.1 or c.duration >= 30.0: # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # ) @@ -1635,18 +1662,20 @@ def remove_short_and_long_utt(c: Cut): # return False if params.use_st_cr_ctc: T = ((c.num_frames - 7) // 2 + 1) // 2 - st_tokens = st_sp.encode(c.supervisions[0].custom['translated_text']['en'], out_type=str) + st_tokens = st_sp.encode( + c.supervisions[0].custom["translated_text"]["en"], out_type=str + ) if T <= len(st_tokens): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. " - # f"Number of frames (before subsampling): {c.num_frames}. " - # f"Number of frames (after subsampling): {T}. " - # f"ST Text: {c.supervisions[0].custom['translated_text']['en']}. " - # f"ST Tokens: {st_tokens}. " - # f"Number of tokens: {len(st_tokens)}" - # ) + # logging.warning( + # f"Exclude cut with ID {c.id} from training. " + # f"Number of frames (before subsampling): {c.num_frames}. " + # f"Number of frames (after subsampling): {T}. " + # f"ST Text: {c.supervisions[0].custom['translated_text']['en']}. " + # f"ST Tokens: {st_tokens}. " + # f"Number of tokens: {len(st_tokens)}" + # ) return False - + if params.use_asr_cr_ctc: T = ((c.num_frames - 7) // 2 + 1) // 2 tokens = sp.encode(c.supervisions[0].text, out_type=str) @@ -1666,7 +1695,9 @@ def remove_short_and_long_utt(c: Cut): train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.lang: - train_cuts = train_cuts.filter(lambda c: c.supervisions[0].language[0]==params.lang) + train_cuts = train_cuts.filter( + lambda c: c.supervisions[0].language[0] == params.lang + ) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1681,7 +1712,9 @@ def remove_short_and_long_utt(c: Cut): valid_cuts = multling.dev_all_cuts() if params.lang: - valid_cuts = valid_cuts.filter(lambda c: c.supervisions[0].language[0]==params.lang) + valid_cuts = valid_cuts.filter( + lambda c: c.supervisions[0].language[0] == params.lang + ) valid_dl = multling.valid_dataloaders(valid_cuts) if not params.print_diagnostics: @@ -1715,7 +1748,7 @@ def remove_short_and_long_utt(c: Cut): optimizer=optimizer, scheduler=scheduler, sp=sp, - st_sp = st_sp, + st_sp=st_sp, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -1836,5 +1869,6 @@ def main(): else: run(rank=0, world_size=1, args=args) + if __name__ == "__main__": main() diff --git a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/zipformer.py b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/zipformer.py index e7b088ac17..5ebc7e7ad6 100644 --- a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/zipformer.py +++ b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/zipformer.py @@ -18,28 +18,33 @@ # limitations under the License. import copy +import logging import math +import random import warnings from typing import List, Optional, Tuple, Union -import logging + import torch -import random from encoder_interface import EncoderInterface from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, Balancer, BiasNorm, - Dropout2, ChunkCausalDepthwiseConv1d, - ActivationDropoutAndLinear, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Dropout2, + FloatLike, + ScheduledFloat, Whiten, - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + convert_num_channels, + limit_param_value, penalize_abs_values_gt, softmax, - ScheduledFloat, - FloatLike, - limit_param_value, - convert_num_channels, ) from torch import Tensor, nn diff --git a/egs/multi_zh-hans/ASR/whisper/decode.py b/egs/multi_zh-hans/ASR/whisper/decode.py index 5b9665c5a5..b9014bc509 100755 --- a/egs/multi_zh-hans/ASR/whisper/decode.py +++ b/egs/multi_zh-hans/ASR/whisper/decode.py @@ -109,9 +109,13 @@ def average_checkpoints( for i in range(1, n): if "model" in torch.load(filenames[i], map_location=device, weights_only=False): - state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"] + state_dict = torch.load( + filenames[i], map_location=device, weights_only=False + )["model"] else: - state_dict = torch.load(filenames[i], map_location=device, weights_only=False) + state_dict = torch.load( + filenames[i], map_location=device, weights_only=False + ) for k in uniqued_names: avg[k] += state_dict[k] @@ -484,7 +488,9 @@ def main(): start = params.epoch - params.avg assert start >= 1, start checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False + f"{params.exp_dir}/epoch-{params.epoch}.pt", + map_location="cpu", + weights_only=False, ) if "model" not in checkpoint: # deepspeed converted checkpoint only contains model state_dict @@ -513,7 +519,9 @@ def main(): torch.save(model.state_dict(), filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False + f"{params.exp_dir}/epoch-{params.epoch}.pt", + map_location="cpu", + weights_only=False, ) if "model" not in checkpoint: model.load_state_dict(checkpoint, strict=True) diff --git a/egs/multi_zh-hans/ASR/whisper/train.py b/egs/multi_zh-hans/ASR/whisper/train.py index 3ffaef2122..1ad89979a0 100755 --- a/egs/multi_zh-hans/ASR/whisper/train.py +++ b/egs/multi_zh-hans/ASR/whisper/train.py @@ -809,7 +809,9 @@ def run(rank, world_size, args): del model.alignment_heads if params.pretrained_model_path: - checkpoint = torch.load(params.pretrained_model_path, map_location="cpu", weights_only=False) + checkpoint = torch.load( + params.pretrained_model_path, map_location="cpu", weights_only=False + ) if "model" not in checkpoint: model.load_state_dict(checkpoint, strict=True) else: diff --git a/egs/multi_zh-hans/ASR/zipformer/pretrained.py b/egs/multi_zh-hans/ASR/zipformer/pretrained.py index 53be57fae2..4da9bb676f 100755 --- a/egs/multi_zh-hans/ASR/zipformer/pretrained.py +++ b/egs/multi_zh-hans/ASR/zipformer/pretrained.py @@ -335,7 +335,7 @@ def token_ids_to_words(token_ids: List[int]) -> str: byte_list.append(int(token[3:-1], 16)) else: byte_list += list(token.encode("utf-8")) - text = bytes(byte_list).decode("utf-8", errors='ignore') + text = bytes(byte_list).decode("utf-8", errors="ignore") return text.replace("▁", " ").strip() if params.method == "fast_beam_search": diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 7c3901c201..f43101a39a 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -114,9 +114,13 @@ def average_checkpoints( for i in range(1, n): if "model" in torch.load(filenames[i], map_location=device, weights_only=False): - state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"] + state_dict = torch.load( + filenames[i], map_location=device, weights_only=False + )["model"] else: - state_dict = torch.load(filenames[i], map_location=device, weights_only=False) + state_dict = torch.load( + filenames[i], map_location=device, weights_only=False + ) for k in uniqued_names: avg[k] += state_dict[k] @@ -548,7 +552,8 @@ def main(): # torch.save(avg_checkpoint, filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin", weights_only=False, + f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin", + weights_only=False, map_location="cpu", ) model.load_state_dict(checkpoint, strict=False) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 7162af9589..c374c520f7 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -652,7 +652,9 @@ def run(rank, world_size, args): ) if params.pretrained_model_path: - checkpoint = torch.load(params.pretrained_model_path, map_location="cpu", weights_only=False) + checkpoint = torch.load( + params.pretrained_model_path, map_location="cpu", weights_only=False + ) missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) num_param = sum([p.numel() for p in model.parameters()]) @@ -704,7 +706,9 @@ def remove_short_and_long_utt(c: Cut): sampler_state_dict = None if params.sampler_state_dict_path: - sampler_state_dict = torch.load(params.sampler_state_dict_path, weights_only=False) + sampler_state_dict = torch.load( + params.sampler_state_dict_path, weights_only=False + ) sampler_state_dict["max_duration"] = params.max_duration train_dl = data_module.train_dataloaders( diff --git a/egs/speechio/ASR/whisper/decode.py b/egs/speechio/ASR/whisper/decode.py index 9ee3ecd048..a98c504654 100644 --- a/egs/speechio/ASR/whisper/decode.py +++ b/egs/speechio/ASR/whisper/decode.py @@ -110,9 +110,13 @@ def average_checkpoints( for i in range(1, n): if "model" in torch.load(filenames[i], map_location=device, weights_only=False): - state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"] + state_dict = torch.load( + filenames[i], map_location=device, weights_only=False + )["model"] else: - state_dict = torch.load(filenames[i], map_location=device, weights_only=False) + state_dict = torch.load( + filenames[i], map_location=device, weights_only=False + ) for k in uniqued_names: avg[k] += state_dict[k] @@ -447,7 +451,9 @@ def main(): start = params.epoch - params.avg assert start >= 1, start checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False + f"{params.exp_dir}/epoch-{params.epoch}.pt", + map_location="cpu", + weights_only=False, ) if "model" not in checkpoint: # deepspeed converted checkpoint only contains model state_dict @@ -476,7 +482,9 @@ def main(): torch.save(model.state_dict(), filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False + f"{params.exp_dir}/epoch-{params.epoch}.pt", + map_location="cpu", + weights_only=False, ) if "model" not in checkpoint: model.load_state_dict(checkpoint, strict=True) diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py index 9e28043ab9..52c60ab67d 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -698,7 +698,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -738,7 +740,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/tedlium3/ASR/conformer_ctc2/decode.py b/egs/tedlium3/ASR/conformer_ctc2/decode.py index 220c7a6c1d..3dc6e544b9 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/decode.py +++ b/egs/tedlium3/ASR/conformer_ctc2/decode.py @@ -675,7 +675,9 @@ def main() -> None: H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py index 541ff09a01..50e0e824ca 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py @@ -398,7 +398,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False)) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -424,7 +426,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False + ) G = k2.Fsa.from_dict(d).to(device) if params.method == "whole-lattice-rescoring": diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py index 78b17558c5..7d1abffb2b 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py @@ -173,7 +173,9 @@ def main(): model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False)) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -181,7 +183,9 @@ def main(): if params.method == "whole-lattice-rescoring": logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu", weights_only=False)) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py index f3eebcc61f..90b6c6b712 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py @@ -397,7 +397,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False)) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -423,7 +425,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False + ) G = k2.Fsa.from_dict(d).to(device) if params.method == "whole-lattice-rescoring": diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py index a1e93b3296..159f6bc97a 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py @@ -173,7 +173,9 @@ def main(): model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False)) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -181,7 +183,9 @@ def main(): if params.method == "whole-lattice-rescoring": logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu", weights_only=False)) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/wenetspeech/ASR/whisper/decode.py b/egs/wenetspeech/ASR/whisper/decode.py index 2363a6992b..7126369966 100755 --- a/egs/wenetspeech/ASR/whisper/decode.py +++ b/egs/wenetspeech/ASR/whisper/decode.py @@ -107,9 +107,13 @@ def average_checkpoints( for i in range(1, n): if "model" in torch.load(filenames[i], map_location=device, weights_only=False): - state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"] + state_dict = torch.load( + filenames[i], map_location=device, weights_only=False + )["model"] else: - state_dict = torch.load(filenames[i], map_location=device, weights_only=False) + state_dict = torch.load( + filenames[i], map_location=device, weights_only=False + ) for k in uniqued_names: avg[k] += state_dict[k] @@ -435,7 +439,9 @@ def main(): start = params.epoch - params.avg assert start >= 1, start checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False + f"{params.exp_dir}/epoch-{params.epoch}.pt", + map_location="cpu", + weights_only=False, ) if "model" not in checkpoint: # deepspeed converted checkpoint only contains model state_dict @@ -464,7 +470,9 @@ def main(): torch.save(model.state_dict(), filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False + f"{params.exp_dir}/epoch-{params.epoch}.pt", + map_location="cpu", + weights_only=False, ) if "model" not in checkpoint: model.load_state_dict(checkpoint, strict=True) diff --git a/egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py b/egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py deleted file mode 120000 index 8c203406b8..0000000000 --- a/egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 0cc0bf240f..e271305092 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -987,7 +987,9 @@ def run(rank, world_size, args): model = get_model(params) if params.pretrained_model_path: - checkpoint = torch.load(params.pretrained_model_path, map_location="cpu", weights_only=False) + checkpoint = torch.load( + params.pretrained_model_path, map_location="cpu", weights_only=False + ) if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: model = load_F5_TTS_pretrained_checkpoint( model, params.pretrained_model_path diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 4ab685684b..8d030cc33f 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -21,9 +21,13 @@ import os import re from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch + +if TYPE_CHECKING: + from torch.amp import GradScaler + import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler from torch import Tensor diff --git a/icefall/lexicon.py b/icefall/lexicon.py index 6a157ffeaf..45d333af4c 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -166,7 +166,9 @@ def __init__( if (lang_dir / "Linv.pt").exists(): logging.info(f"Loading pre-compiled {lang_dir}/Linv.pt") - L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt", weights_only=False)) + L_inv = k2.Fsa.from_dict( + torch.load(lang_dir / "Linv.pt", weights_only=False) + ) else: logging.info("Converting L.pt to Linv.pt") L = k2.Fsa.from_dict(torch.load(lang_dir / "L.pt", weights_only=False)) diff --git a/icefall/utils.py b/icefall/utils.py index 0d4e24db53..8af39c4a8e 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -630,8 +630,8 @@ def store_transcripts_and_timestamps( def store_translations( - filename: Pathlike, texts: Iterable[Tuple[str, str, str]], - lowercase: bool = True) -> None: + filename: Pathlike, texts: Iterable[Tuple[str, str, str]], lowercase: bool = True +) -> None: """Save predicted results and reference transcripts to a file. Args: @@ -644,15 +644,19 @@ def store_translations( Returns: Return None. """ + from sacrebleu.metrics import BLEU + bleu = BLEU(lowercase=lowercase) hyp_list = [] ref_list = [] dir_ = os.path.dirname(filename) - reftgt = os.path.join(dir_, "reftgt-" + str(os.path.basename(filename))) - refsrc = os.path.join(dir_, "refsrc-"+str(os.path.basename(filename))) - hyp = os.path.join(dir_, "hyp-"+str( os.path.basename(filename))) - bleu_file = os.path.join(dir_, "bleu-"+str( os.path.basename(filename))) - with open(filename, "w") as f, open(reftgt, "w") as f_tgt, open(hyp, "w") as f_hyp, open(refsrc, "w") as f_src: + reftgt = os.path.join(dir_, "reftgt-" + str(os.path.basename(filename))) + refsrc = os.path.join(dir_, "refsrc-" + str(os.path.basename(filename))) + hyp = os.path.join(dir_, "hyp-" + str(os.path.basename(filename))) + bleu_file = os.path.join(dir_, "bleu-" + str(os.path.basename(filename))) + with open(filename, "w") as f, open(reftgt, "w") as f_tgt, open( + hyp, "w" + ) as f_hyp, open(refsrc, "w") as f_src: for cut_id, ref, ref_tgt, hyp in texts: ref = " ".join(ref) ref_tgt = " ".join(ref_tgt) @@ -661,7 +665,6 @@ def store_translations( print(f"{cut_id}: ref_tgt {ref_tgt}", file=f) print(f"{cut_id}: hyp {hyp}", file=f) print("\n", file=f) - print(f"{ref}", file=f_src) print(f"{ref_tgt}", file=f_tgt) @@ -670,14 +673,14 @@ def store_translations( hyp_list.append(hyp) ref_list.append(ref_tgt) - with open(bleu_file, 'w') as b: + with open(bleu_file, "w") as b: print(str(bleu.corpus_score(hyp_list, [ref_list])), file=b) print(f"BLEU signiture: {str(bleu.get_signature())}", file=b) - + logging.info( - f"[{bleu.corpus_score(hyp_list, [ref_list])}] " - f"BLEU signiture: {str(bleu.get_signature())}" - ) + f"[{bleu.corpus_score(hyp_list, [ref_list])}] " + f"BLEU signiture: {str(bleu.get_signature())}" + ) def write_error_stats(