Skip to content

Commit

Permalink
[Ready to merge] Pruned transducer stateless5 recipe for tal_csasr (m…
Browse files Browse the repository at this point in the history
…ix Chinese chars and English BPE) (k2-fsa#428)

* add pruned transducer stateless5 recipe for tal_csasr

* do some changes for merging

* change for conformer.py

* add wer and cer for Chinese and English respectively

* fix a error for conformer.py
  • Loading branch information
luomingshuang authored Jun 28, 2022
1 parent 6e609c6 commit 2cb1618
Show file tree
Hide file tree
Showing 34 changed files with 4,975 additions and 0 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ We provide the following recipes:
- [WenetSpeech][wenetspeech]
- [Alimeeting][alimeeting]
- [Aishell4][aishell4]
- [TAL_CSASR][tal_csasr]

### yesno

Expand Down Expand Up @@ -286,6 +287,21 @@ The best CER(%) results:

We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)

### TAL_CSASR

We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TAL_CSASR_pruned_transducer_stateless5].

#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss

The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English):
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|--|--|--|--|--|--|--|
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
|fast_beam_search| 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77|

We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)

## Deployment with C++

Once you have trained a model in icefall, you may want to deploy it with C++,
Expand Down Expand Up @@ -315,6 +331,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2
[Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2
[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5
[TAL_CSASR_pruned_transducer_stateless5]: egs/tal_csasr/ASR/pruned_transducer_stateless5
[yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR
[aishell]: egs/aishell/ASR
Expand All @@ -325,4 +342,5 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[wenetspeech]: egs/wenetspeech/ASR
[alimeeting]: egs/alimeeting/ASR
[aishell4]: egs/aishell4/ASR
[tal_csasr]: egs/tal_csasr/ASR
[k2]: https://github.com/k2-fsa/k2
19 changes: 19 additions & 0 deletions egs/tal_csasr/ASR/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

# Introduction

This recipe includes some different ASR models trained with TAL_CSASR.

[./RESULTS.md](./RESULTS.md) contains the latest results.

# Transducers

There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.

| | Encoder | Decoder | Comment |
|---------------------------------------|---------------------|--------------------|-----------------------------|
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner|

The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.
87 changes: 87 additions & 0 deletions egs/tal_csasr/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
## Results

### TAL_CSASR Mix Chars and BPEs training results (Pruned Transducer Stateless5)

#### 2022-06-22

Using the codes from this PR https://github.com/k2-fsa/icefall/pull/428.

The WERs are

|decoding-method | epoch(iter) | avg | dev | test |
|--|--|--|--|--|
|greedy_search | 30 | 24 | 7.49 | 7.58|
|modified_beam_search | 30 | 24 | 7.33 | 7.38|
|fast_beam_search | 30 | 24 | 7.32 | 7.42|
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39|
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22|
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.27|
|greedy_search | 348000 | 30 | 7.46 | 7.54|
|modified_beam_search | 348000 | 30 | 7.24 | 7.36|
|fast_beam_search | 348000 | 30 | 7.25 | 7.39 |

The results (CER(%) and WER(%)) for Chinese CER and English WER respectivly (zh: Chinese, en: English):
|decoding-method | epoch(iter) | avg | dev | dev_zh | dev_en | test | test_zh | test_en |
|--|--|--|--|--|--|--|--|--|
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77|

The training command for reproducing is given below:

```
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5"
./pruned_transducer_stateless5/train.py \
--world-size 6 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 90
```

The tensorboard training log can be found at
https://tensorboard.dev/experiment/KaACzXOVR0OM6cy0qbN5hw/#scalars

The decoding command is:
```
epoch=30
avg=24
use_average_model=True
## greedy search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--use-averaged-model $use_average_model
## modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--decoding-method modified_beam_search \
--beam-size 4 \
--use-averaged-model $use_average_model
## fast beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--use-averaged-model $use_average_model
```

A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_tal-csasr_pruned_transducer_stateless5>
Empty file.
1 change: 1 addition & 0 deletions egs/tal_csasr/ASR/local/compute_fbank_musan.py
115 changes: 115 additions & 0 deletions egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
This file computes fbank features of the tal_csasr dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""

import argparse
import logging
import os
from pathlib import Path

import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import get_executor

# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)


def compute_fbank_tal_csasr(num_mel_bins: int = 80):
src_dir = Path("data/manifests/tal_csasr")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())

dataset_parts = (
"train_set",
"dev_set",
"test_set",
)
prefix = "tal_csasr"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None

extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))

with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)

return parser.parse_args()


if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)

logging.basicConfig(format=formatter, level=logging.INFO)

args = get_args()
compute_fbank_tal_csasr(num_mel_bins=args.num_mel_bins)
96 changes: 96 additions & 0 deletions egs/tal_csasr/ASR/local/display_manifest_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()`
in ../../../librispeech/ASR/transducer/train.py
for usage.
"""


from lhotse import load_manifest


def main():
paths = [
"./data/fbank/tal_csasr_cuts_train_set.jsonl.gz",
"./data/fbank/tal_csasr_cuts_dev_set.jsonl.gz",
"./data/fbank/tal_csasr_cuts_test_set.jsonl.gz",
]

for path in paths:
print(f"Displaying the statistics for {path}")
cuts = load_manifest(path)
cuts.describe()


if __name__ == "__main__":
main()

"""
Displaying the statistics for ./data/fbank/tal_csasr_cuts_train_set.jsonl.gz
Cuts count: 1050000
Total duration (hours): 1679.0
Speech duration (hours): 1679.0 (100.0%)
***
Duration statistics (seconds):
mean 5.8
std 4.1
min 0.3
25% 2.8
50% 4.4
75% 7.3
99% 18.0
99.5% 18.8
99.9% 20.8
max 36.5
Displaying the statistics for ./data/fbank/tal_csasr_cuts_dev_set.jsonl.gz
Cuts count: 5000
Total duration (hours): 8.0
Speech duration (hours): 8.0 (100.0%)
***
Duration statistics (seconds):
mean 5.8
std 4.0
min 0.5
25% 2.8
50% 4.5
75% 7.4
99% 17.0
99.5% 17.7
99.9% 19.5
max 21.5
Displaying the statistics for ./data/fbank/tal_csasr_cuts_test_set.jsonl.gz
Cuts count: 15000
Total duration (hours): 23.6
Speech duration (hours): 23.6 (100.0%)
***
Duration statistics (seconds):
mean 5.7
std 4.0
min 0.5
25% 2.8
50% 4.4
75% 7.2
99% 17.2
99.5% 17.9
99.9% 19.6
max 32.3
"""
Loading

0 comments on commit 2cb1618

Please sign in to comment.