Skip to content

Commit

Permalink
Refactor decode.py to make it more readable and more modular. (k2-fsa#44
Browse files Browse the repository at this point in the history
)

* Refactor decode.py to make it more readable and more modular.

* Fix an error.

Nbest.fsa should always have token IDs as labels and
word IDs as aux_labels.

* Add nbest decoding.

* Compute edit distance with k2.

* Refactor nbest-oracle.

* Add rescore with nbest lists.

* Add whole-lattice rescoring.

* Add rescoring with attention decoder.

* Refactoring.

* Fixes after refactoring.

* Fix a typo.

* Minor fixes.

* Replace [] with () for shapes.

* Use k2 v1.9

* Use Levenshtein graphs/alignment from k2 v1.9

* [doc] Require k2 >= v1.9

* Minor fixes.
  • Loading branch information
csukuangfj authored Sep 20, 2021
1 parent cc77cb3 commit a80e58e
Show file tree
Hide file tree
Showing 20 changed files with 686 additions and 618 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-yesno-recipe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
os: [ubuntu-18.04]
python-version: [3.8]
torch: ["1.8.1"]
k2-version: ["1.8.dev20210917"]
k2-version: ["1.9.dev20210919"]
fail-fast: false

steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
os: [ubuntu-18.04, macos-10.15]
python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"]
k2-version: ["1.8.dev20210917"]
k2-version: ["1.9.dev20210919"]

fail-fast: false

Expand Down
1 change: 0 additions & 1 deletion docs/source/installation/images/k2-v-1.7.svg

This file was deleted.

1 change: 1 addition & 0 deletions docs/source/installation/images/k2-v1.9-blueviolet.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions docs/source/installation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Installation
.. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
:alt: Supported PyTorch versions

.. |k2_versions| image:: ./images/k2-v-1.7.svg
.. |k2_versions| image:: ./images/k2-v1.9-blueviolet.svg
:alt: Supported k2 versions

``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
Expand All @@ -40,7 +40,7 @@ to install ``k2``.

.. CAUTION::

You need to install ``k2`` with a version at least **v1.7**.
You need to install ``k2`` with a version at least **v1.9**.

.. HINT::

Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/conformer_ctc/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run_encoder(
"""
Args:
x:
The model input. Its shape is [N, T, C].
The model input. Its shape is (N, T, C).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
Expand Down
43 changes: 29 additions & 14 deletions egs/librispeech/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,12 @@ def decode_one_batch(
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is [N, T, C]
# at entry, feature is (N, T, C)

supervisions = batch["supervisions"]

nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
# nnet_output is (N, T, C)

supervision_segments = torch.stack(
(
Expand All @@ -244,14 +244,19 @@ def decode_one_batch(
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is slightly worse than that of rescored lattices.
return nbest_oracle(
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
return {key: hyps}

if params.method in ["1best", "nbest"]:
if params.method == "1best":
Expand All @@ -264,7 +269,7 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
)
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa

Expand All @@ -288,17 +293,23 @@ def decode_one_batch(
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`

best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
Expand All @@ -308,16 +319,20 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"

ans = dict()
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
for lm_scale in lm_scale_list:
ans[lm_scale_str] = [[] * lattice.shape[0]]
return ans


Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/conformer_ctc/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def main():
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
Expand Down
32 changes: 16 additions & 16 deletions egs/librispeech/ASR/conformer_ctc/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
Expand All @@ -34,10 +34,10 @@ def __init__(self, idim: int, odim: int) -> None:
"""
Args:
idim:
Input dim. The input shape is [N, T, idim].
Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
assert idim >= 7
super().__init__()
Expand All @@ -58,18 +58,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Args:
x:
Its shape is [N, T, idim].
Its shape is (N, T, idim).
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
# On entry, x is [N, T, idim]
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
return x


Expand All @@ -80,8 +80,8 @@ class VggSubsampling(nn.Module):
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""

Expand All @@ -93,10 +93,10 @@ def __init__(self, idim: int, odim: int) -> None:
Args:
idim:
Input dim. The input shape is [N, T, idim].
Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
super().__init__()

Expand Down Expand Up @@ -149,10 +149,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Args:
x:
Its shape is [N, T, idim].
Its shape is (N, T, idim).
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
x = x.unsqueeze(1)
x = self.layers(x)
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/ASR/conformer_ctc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,14 @@ def compute_loss(
"""
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is [N, T, C]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)

supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
# nnet_output is (N, T, C)

# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
Expand Down
44 changes: 22 additions & 22 deletions egs/librispeech/ASR/conformer_ctc/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def __init__(
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")

# self.encoder_embed converts the input of shape [N, T, num_classes]
# to the shape [N, T//subsampling_factor, d_model].
# self.encoder_embed converts the input of shape (N, T, num_classes)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_classes -> d_model
Expand Down Expand Up @@ -162,7 +162,7 @@ def forward(
"""
Args:
x:
The input tensor. Its shape is [N, T, C].
The input tensor. Its shape is (N, T, C).
supervision:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
Expand All @@ -171,17 +171,17 @@ def forward(
Returns:
Return a tuple containing 3 tensors:
- CTC output for ctc decoding. Its shape is [N, T, C]
- Encoder output with shape [T, N, C]. It can be used as key and
- CTC output for ctc decoding. Its shape is (N, T, C)
- Encoder output with shape (T, N, C). It can be used as key and
value for the decoder.
- Encoder output padding mask. It can be used as
memory_key_padding_mask for the decoder. Its shape is [N, T].
memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
Expand All @@ -195,7 +195,7 @@ def run_encoder(
Args:
x:
The model input. Its shape is [N, T, C].
The model input. Its shape is (N, T, C).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
Expand All @@ -206,8 +206,8 @@ def run_encoder(
padding mask for the decoder.
Returns:
Return a tuple with two tensors:
- The encoder output, with shape [T, N, C]
- encoder padding mask, with shape [N, T].
- The encoder output, with shape (T, N, C)
- encoder padding mask, with shape (N, T).
The mask is None if `supervisions` is None.
It is used as memory key padding mask in the decoder.
"""
Expand All @@ -225,11 +225,11 @@ def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
Args:
x:
The output tensor from the transformer encoder.
Its shape is [T, N, C]
Its shape is (T, N, C)
Returns:
Return a tensor that can be used for CTC decoding.
Its shape is [N, T, C]
Its shape is (N, T, C)
"""
x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
Expand All @@ -247,7 +247,7 @@ def decoder_forward(
"""
Args:
memory:
It's the output of the encoder with shape [T, N, C]
It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
Expand Down Expand Up @@ -312,7 +312,7 @@ def decoder_nll(
"""
Args:
memory:
It's the output of the encoder with shape [T, N, C]
It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
Expand Down Expand Up @@ -654,13 +654,13 @@ def __init__(self, d_model: int, dropout: float = 0.1) -> None:
def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required.
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x
is [N, T, d_model]. If T > T1, then we change the shape of self.pe
to [N, T, d_model]. Otherwise, nothing is done.
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
is (N, T, d_model). If T > T1, then we change the shape of self.pe
to (N, T, d_model). Otherwise, nothing is done.
Args:
x:
It is a tensor of shape [N, T, C].
It is a tensor of shape (N, T, C).
Returns:
Return None.
"""
Expand All @@ -678,7 +678,7 @@ def extend_pe(self, x: torch.Tensor) -> None:
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# Now pe is of shape [1, T, d_model], where T is x.size(1)
# Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -687,10 +687,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Args:
x:
Its shape is [N, T, C]
Its shape is (N, T, C)
Returns:
Return a tensor of shape [N, T, C]
Return a tensor of shape (N, T, C)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :]
Expand Down
Loading

0 comments on commit a80e58e

Please sign in to comment.