Skip to content

Commit

Permalink
[Ready to be merged] Add RNN-LM to Conformer-CTC decoding (k2-fsa#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
ezerhouni authored Jun 23, 2022
1 parent dc89b61 commit 0475d75
Show file tree
Hide file tree
Showing 25 changed files with 2,659 additions and 42 deletions.
45 changes: 36 additions & 9 deletions egs/librispeech/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -1299,17 +1299,18 @@ You can find the tensorboard log at: <https://tensorboard.dev/experiment/D7NQc3x

#### 2021-11-09

The best WER, as of 2021-11-09, for the librispeech test dataset is below
(using HLG decoding + n-gram LM rescoring + attention decoder rescoring):
The best WER, as of 2022-06-20, for the librispeech test dataset is below
(using HLG decoding + n-gram LM rescoring + attention decoder rescoring + rnn lm rescoring):

| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.42 | 5.73 |
| WER | 2.32 | 5.39 |

Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
| ngram_lm_scale | attention_scale |
|----------------|-----------------|
| 2.0 | 2.0 |

| ngram_lm_scale | attention_scale | rnn_lm_scale |
|----------------|-----------------|--------------|
| 0.3 | 2.1 | 2.2 |


To reproduce the above result, use the following commands for training:
Expand All @@ -1330,11 +1331,27 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--start-epoch 0 \
--num-epochs 90
# Note: It trains for 90 epochs, but the best WER is at epoch-77.pt
# Train the RNN-LM
cd icefall
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./rnn_lm/train.py \
--exp-dir rnn_lm/exp_2048_3_tied \
--start-epoch 0 \
--world-size 4 \
--num-epochs 30 \
--use-fp16 1 \
--embedding-dim 2048 \
--hidden-dim 2048 \
--num-layers 3 \
--batch-size 500 \
--tie-weights true
```

and the following command for decoding

```
rnn_dir=$(git rev-parse --show-toplevel)/icefall/rnn_lm
./conformer_ctc/decode.py \
--exp-dir conformer_ctc/exp_500_att0.8 \
--lang-dir data/lang_bpe_500 \
Expand All @@ -1344,13 +1361,23 @@ and the following command for decoding
--num-paths 1000 \
--epoch 77 \
--avg 55 \
--method attention-decoder \
--nbest-scale 0.5
--nbest-scale 0.5 \
--rnn-lm-exp-dir ${rnn_dir}/exp_2048_3_tied \
--rnn-lm-epoch 29 \
--rnn-lm-avg 3 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights true \
--method rnn-lm
```

You can find the pre-trained model by visiting
You can find the Conformer-CTC pre-trained model by visiting
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09>

and the RNN-LM pre-trained model:
<https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>

The tensorboard log for training is available at
<https://tensorboard.dev/experiment/hZDWrZfaSqOMqtW0NEfXKg/#scalars>

Expand Down
152 changes: 135 additions & 17 deletions egs/librispeech/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,27 @@
from conformer import Conformer

from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.checkpoint import load_checkpoint
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_rnn_lm,
rescore_with_whole_lattice,
)
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
get_texts,
load_averaged_model,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)

Expand Down Expand Up @@ -93,7 +97,9 @@ def get_parser():
is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (6) nbest-oracle. Its WER is the lower bound of any n-best
- (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
you have trained an RNN LM using ./rnn_lm/train.py
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
Expand All @@ -105,7 +111,7 @@ def get_parser():
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
""",
)

Expand All @@ -116,7 +122,7 @@ def get_parser():
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
A smaller value results in more unique paths.
""",
)
Expand All @@ -139,11 +145,67 @@ def get_parser():
"--lm-dir",
type=str,
default="data/lm",
help="""The LM dir.
help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)

parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)

parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)

parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)

parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)

parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)

parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)

return parser


Expand Down Expand Up @@ -173,6 +235,7 @@ def get_params() -> AttributeDict:
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
Expand Down Expand Up @@ -205,6 +268,8 @@ def decode_one_batch(
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
Expand Down Expand Up @@ -330,6 +395,7 @@ def decode_one_batch(
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]

lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
Expand Down Expand Up @@ -357,8 +423,6 @@ def decode_one_batch(
G_with_epsilon_loops=G,
lm_scale_list=None,
)
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`

best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
Expand All @@ -370,6 +434,26 @@ def decode_one_batch(
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
elif params.method == "rnn-lm":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)

best_path_dict = rescore_with_rnn_lm(
lattice=rescored_lattice,
num_paths=params.num_paths,
rnn_lm_model=rnn_lm_model,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
blank_id=0,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"

Expand All @@ -388,6 +472,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
Expand All @@ -405,6 +490,8 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
Expand Down Expand Up @@ -442,6 +529,7 @@ def decode_dataset(
hyps_dict = decode_one_batch(
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
Expand Down Expand Up @@ -490,7 +578,7 @@ def save_results(
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
if params.method == "attention-decoder":
if params.method in ("attention-decoder", "rnn-lm"):
# Set it to False since there are too many logs.
enable_log = False
else:
Expand Down Expand Up @@ -566,6 +654,10 @@ def main():
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id

params.num_classes = num_classes
params.sos_id = sos_id
params.eos_id = eos_id

if params.method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
Expand All @@ -590,6 +682,7 @@ def main():
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
Expand Down Expand Up @@ -621,7 +714,11 @@ def main():
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)

if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
Expand All @@ -648,20 +745,40 @@ def main():
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model = load_averaged_model(
params.exp_dir, model, params.epoch, params.avg, device
)

model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

rnn_lm_model = None
if params.method == "rnn-lm":
rnn_lm_model = RnnLmModel(
vocab_size=params.num_classes,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
else:
rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
)
rnn_lm_model.eval()

librispeech = LibriSpeechAsrDataModule(args)

test_clean_cuts = librispeech.test_clean_cuts()
Expand All @@ -678,6 +795,7 @@ def main():
dl=test_dl,
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
Expand Down
Loading

0 comments on commit 0475d75

Please sign in to comment.