forked from k2-fsa/icefall
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Ready to merge] Pruned transducer stateless5 recipe for tal_csasr (m…
…ix Chinese chars and English BPE) (k2-fsa#428) * add pruned transducer stateless5 recipe for tal_csasr * do some changes for merging * change for conformer.py * add wer and cer for Chinese and English respectively * fix a error for conformer.py
- Loading branch information
1 parent
6e609c6
commit 2cb1618
Showing
34 changed files
with
4,975 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
|
||
# Introduction | ||
|
||
This recipe includes some different ASR models trained with TAL_CSASR. | ||
|
||
[./RESULTS.md](./RESULTS.md) contains the latest results. | ||
|
||
# Transducers | ||
|
||
There are various folders containing the name `transducer` in this folder. | ||
The following table lists the differences among them. | ||
|
||
| | Encoder | Decoder | Comment | | ||
|---------------------------------------|---------------------|--------------------|-----------------------------| | ||
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner| | ||
|
||
The decoder in `transducer_stateless` is modified from the paper | ||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). | ||
We place an additional Conv1d layer right after the input embedding layer. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
## Results | ||
|
||
### TAL_CSASR Mix Chars and BPEs training results (Pruned Transducer Stateless5) | ||
|
||
#### 2022-06-22 | ||
|
||
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/428. | ||
|
||
The WERs are | ||
|
||
|decoding-method | epoch(iter) | avg | dev | test | | ||
|--|--|--|--|--| | ||
|greedy_search | 30 | 24 | 7.49 | 7.58| | ||
|modified_beam_search | 30 | 24 | 7.33 | 7.38| | ||
|fast_beam_search | 30 | 24 | 7.32 | 7.42| | ||
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39| | ||
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22| | ||
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.27| | ||
|greedy_search | 348000 | 30 | 7.46 | 7.54| | ||
|modified_beam_search | 348000 | 30 | 7.24 | 7.36| | ||
|fast_beam_search | 348000 | 30 | 7.25 | 7.39 | | ||
|
||
The results (CER(%) and WER(%)) for Chinese CER and English WER respectivly (zh: Chinese, en: English): | ||
|decoding-method | epoch(iter) | avg | dev | dev_zh | dev_en | test | test_zh | test_en | | ||
|--|--|--|--|--|--|--|--|--| | ||
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13| | ||
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 | | ||
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77| | ||
|
||
The training command for reproducing is given below: | ||
|
||
``` | ||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5" | ||
./pruned_transducer_stateless5/train.py \ | ||
--world-size 6 \ | ||
--num-epochs 30 \ | ||
--start-epoch 1 \ | ||
--exp-dir pruned_transducer_stateless5/exp \ | ||
--lang-dir data/lang_char \ | ||
--max-duration 90 | ||
``` | ||
|
||
The tensorboard training log can be found at | ||
https://tensorboard.dev/experiment/KaACzXOVR0OM6cy0qbN5hw/#scalars | ||
|
||
The decoding command is: | ||
``` | ||
epoch=30 | ||
avg=24 | ||
use_average_model=True | ||
## greedy search | ||
./pruned_transducer_stateless5/decode.py \ | ||
--epoch $epoch \ | ||
--avg $avg \ | ||
--exp-dir pruned_transducer_stateless5/exp \ | ||
--lang-dir ./data/lang_char \ | ||
--max-duration 800 \ | ||
--use-averaged-model $use_average_model | ||
## modified beam search | ||
./pruned_transducer_stateless5/decode.py \ | ||
--epoch $epoch \ | ||
--avg $avg \ | ||
--exp-dir pruned_transducer_stateless5/exp \ | ||
--lang-dir ./data/lang_char \ | ||
--max-duration 800 \ | ||
--decoding-method modified_beam_search \ | ||
--beam-size 4 \ | ||
--use-averaged-model $use_average_model | ||
## fast beam search | ||
./pruned_transducer_stateless5/decode.py \ | ||
--epoch $epoch \ | ||
--avg $avg \ | ||
--exp-dir ./pruned_transducer_stateless5/exp \ | ||
--lang-dir ./data/lang_char \ | ||
--max-duration 1500 \ | ||
--decoding-method fast_beam_search \ | ||
--beam 4 \ | ||
--max-contexts 4 \ | ||
--max-states 8 \ | ||
--use-averaged-model $use_average_model | ||
``` | ||
|
||
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_tal-csasr_pruned_transducer_stateless5> |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../librispeech/ASR/local/compute_fbank_musan.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) | ||
# | ||
# See ../../../../LICENSE for clarification regarding multiple authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
""" | ||
This file computes fbank features of the tal_csasr dataset. | ||
It looks for manifests in the directory data/manifests. | ||
The generated fbank features are saved in data/fbank. | ||
""" | ||
|
||
import argparse | ||
import logging | ||
import os | ||
from pathlib import Path | ||
|
||
import torch | ||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter | ||
from lhotse.recipes.utils import read_manifests_if_cached | ||
|
||
from icefall.utils import get_executor | ||
|
||
# Torch's multithreaded behavior needs to be disabled or | ||
# it wastes a lot of CPU and slow things down. | ||
# Do this outside of main() in case it needs to take effect | ||
# even when we are not invoking the main (e.g. when spawning subprocesses). | ||
torch.set_num_threads(1) | ||
torch.set_num_interop_threads(1) | ||
|
||
|
||
def compute_fbank_tal_csasr(num_mel_bins: int = 80): | ||
src_dir = Path("data/manifests/tal_csasr") | ||
output_dir = Path("data/fbank") | ||
num_jobs = min(15, os.cpu_count()) | ||
|
||
dataset_parts = ( | ||
"train_set", | ||
"dev_set", | ||
"test_set", | ||
) | ||
prefix = "tal_csasr" | ||
suffix = "jsonl.gz" | ||
manifests = read_manifests_if_cached( | ||
dataset_parts=dataset_parts, | ||
output_dir=src_dir, | ||
prefix=prefix, | ||
suffix=suffix, | ||
) | ||
assert manifests is not None | ||
|
||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) | ||
|
||
with get_executor() as ex: # Initialize the executor only once. | ||
for partition, m in manifests.items(): | ||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" | ||
if (output_dir / cuts_filename).is_file(): | ||
logging.info(f"{partition} already exists - skipping.") | ||
continue | ||
logging.info(f"Processing {partition}") | ||
cut_set = CutSet.from_manifests( | ||
recordings=m["recordings"], | ||
supervisions=m["supervisions"], | ||
) | ||
if "train" in partition: | ||
cut_set = ( | ||
cut_set | ||
+ cut_set.perturb_speed(0.9) | ||
+ cut_set.perturb_speed(1.1) | ||
) | ||
cut_set = cut_set.compute_and_store_features( | ||
extractor=extractor, | ||
storage_path=f"{output_dir}/{prefix}_feats_{partition}", | ||
# when an executor is specified, make more partitions | ||
num_jobs=num_jobs if ex is None else 80, | ||
executor=ex, | ||
storage_type=LilcomChunkyWriter, | ||
) | ||
cut_set.to_file(output_dir / cuts_filename) | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--num-mel-bins", | ||
type=int, | ||
default=80, | ||
help="""The number of mel bins for Fbank""", | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
if __name__ == "__main__": | ||
formatter = ( | ||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
) | ||
|
||
logging.basicConfig(format=formatter, level=logging.INFO) | ||
|
||
args = get_args() | ||
compute_fbank_tal_csasr(num_mel_bins=args.num_mel_bins) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang | ||
# Mingshuang Luo) | ||
# | ||
# See ../../../../LICENSE for clarification regarding multiple authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
This file displays duration statistics of utterances in a manifest. | ||
You can use the displayed value to choose minimum/maximum duration | ||
to remove short and long utterances during the training. | ||
See the function `remove_short_and_long_utt()` | ||
in ../../../librispeech/ASR/transducer/train.py | ||
for usage. | ||
""" | ||
|
||
|
||
from lhotse import load_manifest | ||
|
||
|
||
def main(): | ||
paths = [ | ||
"./data/fbank/tal_csasr_cuts_train_set.jsonl.gz", | ||
"./data/fbank/tal_csasr_cuts_dev_set.jsonl.gz", | ||
"./data/fbank/tal_csasr_cuts_test_set.jsonl.gz", | ||
] | ||
|
||
for path in paths: | ||
print(f"Displaying the statistics for {path}") | ||
cuts = load_manifest(path) | ||
cuts.describe() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
||
""" | ||
Displaying the statistics for ./data/fbank/tal_csasr_cuts_train_set.jsonl.gz | ||
Cuts count: 1050000 | ||
Total duration (hours): 1679.0 | ||
Speech duration (hours): 1679.0 (100.0%) | ||
*** | ||
Duration statistics (seconds): | ||
mean 5.8 | ||
std 4.1 | ||
min 0.3 | ||
25% 2.8 | ||
50% 4.4 | ||
75% 7.3 | ||
99% 18.0 | ||
99.5% 18.8 | ||
99.9% 20.8 | ||
max 36.5 | ||
Displaying the statistics for ./data/fbank/tal_csasr_cuts_dev_set.jsonl.gz | ||
Cuts count: 5000 | ||
Total duration (hours): 8.0 | ||
Speech duration (hours): 8.0 (100.0%) | ||
*** | ||
Duration statistics (seconds): | ||
mean 5.8 | ||
std 4.0 | ||
min 0.5 | ||
25% 2.8 | ||
50% 4.5 | ||
75% 7.4 | ||
99% 17.0 | ||
99.5% 17.7 | ||
99.9% 19.5 | ||
max 21.5 | ||
Displaying the statistics for ./data/fbank/tal_csasr_cuts_test_set.jsonl.gz | ||
Cuts count: 15000 | ||
Total duration (hours): 23.6 | ||
Speech duration (hours): 23.6 (100.0%) | ||
*** | ||
Duration statistics (seconds): | ||
mean 5.7 | ||
std 4.0 | ||
min 0.5 | ||
25% 2.8 | ||
50% 4.4 | ||
75% 7.2 | ||
99% 17.2 | ||
99.5% 17.9 | ||
99.9% 19.6 | ||
max 32.3 | ||
""" |
Oops, something went wrong.