From 6d2280d88544f906cdcc40c7cc65e64774b187d3 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Fri, 19 Dec 2025 15:17:04 +0100 Subject: [PATCH] minimum length of encoder output for CTC training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There might be consecutive repetition of symbol in the reference, and for this the CTC alignment must put a blank in between, so the reverse mapping of aligned symbols produces the original reference. I realised it recently while playing with CTC aligner from torachaudio with the noisy yodas2 dataset. To illustrate: "a a b c c d e f" - len(tokens) is 8 - but, because of duplications 'a a', 'c c' - the minimum length of encoder output is 10 - the shortest valid CTC alignment is: "a ∅ a b c ∅ c d e f" --- egs/librispeech/ASR/zipformer/ctc_align.py | 101 ++++++++++++++------- egs/librispeech/ASR/zipformer/train.py | 22 +++-- 2 files changed, 85 insertions(+), 38 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/ctc_align.py b/egs/librispeech/ASR/zipformer/ctc_align.py index fff05146fe..bcdc2e546c 100755 --- a/egs/librispeech/ASR/zipformer/ctc_align.py +++ b/egs/librispeech/ASR/zipformer/ctc_align.py @@ -44,7 +44,7 @@ import math from collections import defaultdict from pathlib import Path, PurePath -from typing import Dict, List, Tuple +from typing import Dict, List, Set, Tuple import numpy as np import sentencepiece as spm @@ -52,6 +52,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule as AsrDataModule from lhotse import set_caching_enabled +from lhotse.cut import Cut from torchaudio.functional import ( forced_align, merge_tokens, @@ -166,11 +167,16 @@ def get_parser(): ) parser.add_argument( - "dataset_manifests", + "--max-utt-duration", + type=float, + default=60.0, + help="Maximal duration of an utterance in seconds, used in cut-set filtering.", + ) + + parser.add_argument( + "dataset_manifest", type=str, - nargs="+", - help="CutSet manifests to be aligned (CutSet with features and transcripts). " - "Each CutSet as a separate arg : `manifest1 mainfest2 ...`", + help="CutSet manifests to be aligned (CutSet with features and transcripts).", ) add_model_arguments(parser) @@ -393,8 +399,9 @@ def align_dataset( def save_alignment_output( params: AttributeDict, - test_set_name: str, + dataset_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + removed_cut_ids: Set[str], ): """ Save the token alignments and per-utterance confidences. @@ -402,7 +409,7 @@ def save_alignment_output( for key, results in results_dict.items(): - alignments_filename = params.res_dir / f"alignments-{test_set_name}.txt" + alignments_filename = params.res_dir / f"alignments-{dataset_name}.txt" time_step = 0.04 @@ -425,7 +432,7 @@ def save_alignment_output( # --------------------------- - confidences_filename = params.res_dir / f"confidences-{test_set_name}.txt" + confidences_filename = params.res_dir / f"confidences-{dataset_name}.txt" with open(confidences_filename, "w", encoding="utf8") as fd: print( @@ -458,6 +465,15 @@ def save_alignment_output( file=fd, ) + # previously removed by `cuts.filter(remove_long_transcripts)` + for utterance_key in removed_cut_ids: + print(f"{utterance_key} -2.0 -2.0 " + "-2.0 " + "(0,0,0,0,0) " + "(0,0)", + file=fd, + ) + logging.info(f"The confidences are stored in `{confidences_filename}`") @@ -605,37 +621,58 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - # we need cut ids to display recognition results. + # we need cut_ids to display recognition results. args.return_cuts = True asr_datamodule = AsrDataModule(args) - # create array of dataloaders (one per test-set) - testset_labels = [] - testset_dataloaders = [] - for testset_manifest in args.dataset_manifests: - label = PurePath(testset_manifest).name # basename - label = label.replace(".jsonl.gz", "") + dataset_label = PurePath(args.dataset_manifest).name # basename + dataset_label = dataset_label.replace(".jsonl.gz", "") - test_cuts = asr_datamodule.load_manifest(testset_manifest) - test_dataloader = asr_datamodule.test_dataloaders(test_cuts) + dataset_cuts = asr_datamodule.load_manifest(args.dataset_manifest) - testset_labels.append(label) - testset_dataloaders.append(test_dataloader) + def remove_long_transcripts(c: Cut): - # align - for test_set, test_dl in zip(testset_labels, testset_dataloaders): - results_dict = align_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - ) + if c.duration > params.max_utt_duration: + logging.warning( + f"Exclude cut with ID {c.id} from alignment. Duration: {c.duration}" + ) + return False + + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = np.array(sp.encode(c.supervisions[0].text, out_type=str)) + num_repeats = np.sum(tokens[1:] == tokens[:-1]) + + # For CTC `num_tokens + num_repeats` is needed. otherwise inf. in loss appears. + if T < (len(tokens) + num_repeats): + logging.warning( + f"Exclude cut with ID {c.id} from alignment (too many supervision tokens). " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Number of tokens: {len(tokens)}" + ) + return False - save_alignment_output( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) + return True + + cut_ids_orig = set(list(dataset_cuts.ids)) + dataset_cuts = dataset_cuts.filter(remove_long_transcripts) + cut_ids_removed = cut_ids_orig - set(list(dataset_cuts.ids)) + + dataset_dl = asr_datamodule.test_dataloaders(dataset_cuts) + + results_dict = align_dataset( + dl=dataset_dl, + params=params, + model=model, + sp=sp, + ) + + save_alignment_output( + params=params, + dataset_name=dataset_label, + results_dict=results_dict, + removed_cut_ids=cut_ids_removed, + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 98e3250020..8ee3a699fa 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -63,6 +63,7 @@ import k2 import optim +import numpy as np import sentencepiece as spm import torch import torch.multiprocessing as mp @@ -384,7 +385,10 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." + "--base-lr", + type=float, + default=0.045, + help="The base learning rate.", ) parser.add_argument( @@ -1407,18 +1411,24 @@ def remove_short_and_long_utt(c: Cut): # In ./zipformer.py, the conv module uses the following expression # for subsampling T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) + tokens = np.array(sp.encode(c.supervisions[0].text, out_type=str)) + + if args.use_ctc: + # For CTC `T < num_tokens + num_repeats` is needed, blanks are added. + num_repeats = np.sum(tokens[1:] == tokens[:-1]) + min_T = len(tokens) + num_repeats + else: + # For Transducer `T < num_tokens` is okay. + min_T = len(tokens) - # For CTC `(T - 2) < len(tokens)` is needed. otherwise inf. in loss appears. - # For Transducer `T < len(tokens)` was okay. - if (T - 2) < len(tokens): + if T < min_T: logging.warning( f"Exclude cut with ID {c.id} from training (too many supervision tokens). " f"Number of frames (before subsampling): {c.num_frames}. " f"Number of frames (after subsampling): {T}. " f"Text: {c.supervisions[0].text}. " f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" + f"Number of tokens: {len(tokens)}, min_T: {min_T}" ) return False