Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions examples/librispeech/s0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,67 @@ test other
| ctc_greedy_search | 8.73 | 9.82 | 9.83 |
| ctc prefix beam search | 8.70 | 9.81 | 9.79 |
| attention rescoring | 8.05 | 9.08 | 9.10 |


## ChunkFormer U2++ Result

* Model info:
* Encoder Params: 32,356,096
* Downsample rate: dw_striding 8x
* encoder_dim 256, head 4, linear_units 2048
* num_blocks 12, cnn_module_kernel 15
* Feature info: using fbank feature, cmvn, dither, online speed perturb
* Training info:
* train_u2++_chunkformer_small.yaml, kernel size 15
* dynamic batch size 120.000, 2 gpu, acc_grad 4, 200 epochs, dither 1.0
* adamw, lr 1e-3, warmuplr, warmup_steps: 25000
* specaug and speed perturb
* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 100, beam size 10

#### Full context training -> Chunk context inferencing:
⚠️ Attention Decoder does **not** support chunk-context inference due to cross-attention mismatch with full context training. Chunk-context training is required to resolve this mismatch.

| Decoding Mode | Dev Clean | Dev Other | Test Clean | Test Other |
|------------------------|-----------|-----------|------------|------------|
| CTC Greedy Search | 3.05 | 8.84 | 3.27 | 8.54 |
| CTC Prefix Beam Search | 3.04 | 8.83 | 3.26 | 8.54 |
| Attention Decoder | 4.58 | 9.62 | 5.07 | 9.22 |
| Attention Rescoring | 2.83 | 8.39 | 2.97 | 8.02 |

#### Full context training -> Full context inferencing:
| Decoding Mode | Dev Clean | Dev Other | Test Clean | Test Other |
|------------------------|-----------|-----------|------------|------------|
| CTC Greedy Search | 3.08 | 8.82 | 3.24 | 8.55 |
| CTC Prefix Beam Search | 3.06 | 8.80 | 3.23 | 8.53 |
| Attention Decoder | 2.92 | 8.28 | 3.03 | 8.05 |
| Attention Rescoring | 2.80 | 8.37 | 2.94 | 8.03 |


## ChunkFormer U2++ Result: Joint Full and Chunk Context Training
* Model info:
* Encoder Params: 32,356,096
* Downsample rate: dw_striding 8x
* encoder_dim 256, head 4, linear_units 2048
* num_blocks 12, cnn_module_kernel 15
* dynamic_conv: true
* dynamic_chunk_sizes: [-1, -1, 64, 128, 256]
* dynamic_left_context_sizes: [64, 128, 256]
* dynamic_right_context_sizes: [64, 128, 256]
* Feature info: using fbank feature, cmvn, dither, online speed perturb
* Training info:
* train_u2++_chunkformer_small.yaml, kernel size 15
* dynamic batch size 120.000, 2 gpu, acc_grad 4, 200 epochs, dither 1.0
* adamw, lr 1e-3, warmuplr, warmup_steps: 25000
* specaug and speed perturb
* Decoding info:
* ctc_weight 0.3, reverse weight 0.5, average_num 100, beam size 10
* Chunk size, left context size, and right context size are represented as (c, l, r)
* Results on test-clean / test other

| Decoding Mode | (-1, -1, -1) | (64, 128, 128) | (128, 128, 128) | (128, 256, 256) | (256, 64, 64) | (256, 128, 128) |
|-----------------------|----------------|----------------|------------------|------------------|----------------|------------------|
| ctc_greedy_search | 3.19 / 8.51* | 3.22 / 8.54 | 3.20 / 8.53 | 3.20 / 8.53 | 3.18* / 8.52 | 3.18* / 8.51* |
| ctc_prefix_beam_search| 3.17 / 8.50 | 3.20 / 8.53 | 3.18 / 8.51 | 3.19 / 8.51 | 3.16* / 8.50 | 3.16* / 8.49* |
| attention | 3.24* / 8.03* | 3.38 / 8.16 | 3.24* / 8.07 | 3.28 / 8.05 | 3.29 / 8.13 | 3.26 / 8.08 |
| attention_rescoring | 2.96 / 7.88* | 2.82* / 7.89 | 2.96 / 7.90 | 2.97 / 7.90 | 2.95 / 7.89 | 2.95 / 7.88* |
| Average | 3.14* / 8.23* | 3.16 / 8.28 | 3.15 / 8.25 | 3.16 / 8.25 | 3.15 / 8.26 | 3.14* / 8.24 |
121 changes: 121 additions & 0 deletions examples/librispeech/s0/conf/train_u2++_chunkformer_small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# network architecture
# encoder related
encoder: chunkformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: dw_striding # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'chunk_rel_pos'
selfattention_layer_type: 'chunk_rel_seflattn'
dynamic_conv: false
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
# Enable the settings below for joint training on full and chunk context
# dynamic_conv: true
# dynamic_chunk_sizes: [-1, -1, 64, 128, 256]
# dynamic_left_context_sizes: [64, 128, 256]
# dynamic_right_context_sizes: [64, 128, 256]

# decoder related
decoder: bitransformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 3
r_num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: bpe
tokenizer_conf:
symbol_table_path: 'data/lang_char/train_960_bpe5000_units.txt'
split_with_space: false
bpe_path: 'data/lang_char/train_960_bpe5000.model'
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train_960/global_cmvn'
is_json_cmvn: true

# hybrid CTC/attention
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
reverse_weight: 0.3

# dataset related
dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
min_length: 0
token_max_length: 400
token_min_length: 1
# min_output_input_ratio: 0.0005
# max_output_input_ratio: 0.1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
spec_sub: false
spec_sub_conf:
num_t_sub: 3
max_t: 30
shuffle: true
shuffle_conf:
shuffle_size: 1000
sort: false
sort_conf:
sort_size: 2000 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'dynamic' # static or dynamic
max_frames_in_batch: 120000
# At inference, pad_feat should be False to activate
# masked batch and chunk context decoding
pad_feat: True

grad_clip: 5
accum_grad: 4
max_epoch: 200
log_interval: 100

optim: adamw
optim_conf:
lr: 0.001
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
5 changes: 4 additions & 1 deletion wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,10 @@ def main():
with torch.no_grad():
for batch_idx, batch in enumerate(test_data_loader):
keys = batch["keys"]
feats = batch["feats"].to(device)
if type(batch["feats"]) is torch.Tensor:
feats = batch["feats"].to(device)
else:
feats = batch["feats"]
target = batch["target"].to(device)
feats_lengths = batch["feats_lengths"].to(device)
target_lengths = batch["target_lengths"].to(device)
Expand Down
Loading
Loading