Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 69 additions & 32 deletions egs/librispeech/ASR/zipformer/ctc_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@
import math
from collections import defaultdict
from pathlib import Path, PurePath
from typing import Dict, List, Tuple
from typing import Dict, List, Set, Tuple

import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule as AsrDataModule
from lhotse import set_caching_enabled
from lhotse.cut import Cut
from torchaudio.functional import (
forced_align,
merge_tokens,
Expand Down Expand Up @@ -166,11 +167,16 @@ def get_parser():
)

parser.add_argument(
"dataset_manifests",
"--max-utt-duration",
type=float,
default=60.0,
help="Maximal duration of an utterance in seconds, used in cut-set filtering.",
)

parser.add_argument(
"dataset_manifest",
type=str,
nargs="+",
help="CutSet manifests to be aligned (CutSet with features and transcripts). "
"Each CutSet as a separate arg : `manifest1 mainfest2 ...`",
help="CutSet manifests to be aligned (CutSet with features and transcripts).",
)

add_model_arguments(parser)
Expand Down Expand Up @@ -393,16 +399,17 @@ def align_dataset(

def save_alignment_output(
params: AttributeDict,
test_set_name: str,
dataset_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
removed_cut_ids: Set[str],
):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"""
Save the token alignments and per-utterance confidences.
"""

for key, results in results_dict.items():

alignments_filename = params.res_dir / f"alignments-{test_set_name}.txt"
alignments_filename = params.res_dir / f"alignments-{dataset_name}.txt"

time_step = 0.04

Expand All @@ -425,7 +432,7 @@ def save_alignment_output(

# ---------------------------

confidences_filename = params.res_dir / f"confidences-{test_set_name}.txt"
confidences_filename = params.res_dir / f"confidences-{dataset_name}.txt"

with open(confidences_filename, "w", encoding="utf8") as fd:
print(
Expand Down Expand Up @@ -458,6 +465,15 @@ def save_alignment_output(
file=fd,
)

# previously removed by `cuts.filter(remove_long_transcripts)`
for utterance_key in removed_cut_ids:
print(f"{utterance_key} -2.0 -2.0 "
"-2.0 "
"(0,0,0,0,0) "
"(0,0)",
file=fd,
)

logging.info(f"The confidences are stored in `{confidences_filename}`")


Expand Down Expand Up @@ -605,37 +621,58 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

# we need cut ids to display recognition results.
# we need cut_ids to display recognition results.
args.return_cuts = True
asr_datamodule = AsrDataModule(args)

# create array of dataloaders (one per test-set)
testset_labels = []
testset_dataloaders = []
for testset_manifest in args.dataset_manifests:
label = PurePath(testset_manifest).name # basename
label = label.replace(".jsonl.gz", "")
dataset_label = PurePath(args.dataset_manifest).name # basename
dataset_label = dataset_label.replace(".jsonl.gz", "")

test_cuts = asr_datamodule.load_manifest(testset_manifest)
test_dataloader = asr_datamodule.test_dataloaders(test_cuts)
dataset_cuts = asr_datamodule.load_manifest(args.dataset_manifest)

testset_labels.append(label)
testset_dataloaders.append(test_dataloader)
def remove_long_transcripts(c: Cut):

# align
for test_set, test_dl in zip(testset_labels, testset_dataloaders):
results_dict = align_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
)
if c.duration > params.max_utt_duration:
logging.warning(
f"Exclude cut with ID {c.id} from alignment. Duration: {c.duration}"
)
return False

T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = np.array(sp.encode(c.supervisions[0].text, out_type=str))
num_repeats = np.sum(tokens[1:] == tokens[:-1])

# For CTC `num_tokens + num_repeats` is needed. otherwise inf. in loss appears.
if T < (len(tokens) + num_repeats):
logging.warning(
f"Exclude cut with ID {c.id} from alignment (too many supervision tokens). "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Number of tokens: {len(tokens)}"
)
return False
Comment thread
coderabbitai[bot] marked this conversation as resolved.

save_alignment_output(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
return True

cut_ids_orig = set(list(dataset_cuts.ids))
dataset_cuts = dataset_cuts.filter(remove_long_transcripts)
cut_ids_removed = cut_ids_orig - set(list(dataset_cuts.ids))

dataset_dl = asr_datamodule.test_dataloaders(dataset_cuts)

results_dict = align_dataset(
dl=dataset_dl,
params=params,
model=model,
sp=sp,
)

save_alignment_output(
params=params,
dataset_name=dataset_label,
results_dict=results_dict,
removed_cut_ids=cut_ids_removed,
)

logging.info("Done!")

Expand Down
22 changes: 16 additions & 6 deletions egs/librispeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@

import k2
import optim
import numpy as np
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -384,7 +385,10 @@ def get_parser():
)

parser.add_argument(
"--base-lr", type=float, default=0.045, help="The base learning rate."
"--base-lr",
type=float,
default=0.045,
help="The base learning rate.",
)

parser.add_argument(
Expand Down Expand Up @@ -1407,18 +1411,24 @@ def remove_short_and_long_utt(c: Cut):
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
tokens = np.array(sp.encode(c.supervisions[0].text, out_type=str))

if args.use_ctc:
# For CTC `T < num_tokens + num_repeats` is needed, blanks are added.
num_repeats = np.sum(tokens[1:] == tokens[:-1])
min_T = len(tokens) + num_repeats
else:
# For Transducer `T < num_tokens` is okay.
min_T = len(tokens)

# For CTC `(T - 2) < len(tokens)` is needed. otherwise inf. in loss appears.
# For Transducer `T < len(tokens)` was okay.
if (T - 2) < len(tokens):
if T < min_T:
logging.warning(
f"Exclude cut with ID {c.id} from training (too many supervision tokens). "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
f"Number of tokens: {len(tokens)}, min_T: {min_T}"
)
return False

Expand Down
Loading