diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 05dfb17dd1..630a72aea4 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -313,3 +313,67 @@ test other | ctc_greedy_search | 8.73 | 9.82 | 9.83 | | ctc prefix beam search | 8.70 | 9.81 | 9.79 | | attention rescoring | 8.05 | 9.08 | 9.10 | + + +## ChunkFormer U2++ Result + +* Model info: + * Encoder Params: 32,356,096 + * Downsample rate: dw_striding 8x + * encoder_dim 256, head 4, linear_units 2048 + * num_blocks 12, cnn_module_kernel 15 +* Feature info: using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_u2++_chunkformer_small.yaml, kernel size 15 + * dynamic batch size 120.000, 2 gpu, acc_grad 4, 200 epochs, dither 1.0 + * adamw, lr 1e-3, warmuplr, warmup_steps: 25000 + * specaug and speed perturb +* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 100, beam size 10 + +#### Full context training -> Chunk context inferencing: +⚠️ Attention Decoder does **not** support chunk-context inference due to cross-attention mismatch with full context training. Chunk-context training is required to resolve this mismatch. + +| Decoding Mode | Dev Clean | Dev Other | Test Clean | Test Other | +|------------------------|-----------|-----------|------------|------------| +| CTC Greedy Search | 3.05 | 8.84 | 3.27 | 8.54 | +| CTC Prefix Beam Search | 3.04 | 8.83 | 3.26 | 8.54 | +| Attention Decoder | 4.58 | 9.62 | 5.07 | 9.22 | +| Attention Rescoring | 2.83 | 8.39 | 2.97 | 8.02 | + +#### Full context training -> Full context inferencing: +| Decoding Mode | Dev Clean | Dev Other | Test Clean | Test Other | +|------------------------|-----------|-----------|------------|------------| +| CTC Greedy Search | 3.08 | 8.82 | 3.24 | 8.55 | +| CTC Prefix Beam Search | 3.06 | 8.80 | 3.23 | 8.53 | +| Attention Decoder | 2.92 | 8.28 | 3.03 | 8.05 | +| Attention Rescoring | 2.80 | 8.37 | 2.94 | 8.03 | + + +## ChunkFormer U2++ Result: Joint Full and Chunk Context Training +* Model info: + * Encoder Params: 32,356,096 + * Downsample rate: dw_striding 8x + * encoder_dim 256, head 4, linear_units 2048 + * num_blocks 12, cnn_module_kernel 15 + * dynamic_conv: true + * dynamic_chunk_sizes: [-1, -1, 64, 128, 256] + * dynamic_left_context_sizes: [64, 128, 256] + * dynamic_right_context_sizes: [64, 128, 256] +* Feature info: using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_u2++_chunkformer_small.yaml, kernel size 15 + * dynamic batch size 120.000, 2 gpu, acc_grad 4, 200 epochs, dither 1.0 + * adamw, lr 1e-3, warmuplr, warmup_steps: 25000 + * specaug and speed perturb +* Decoding info: + * ctc_weight 0.3, reverse weight 0.5, average_num 100, beam size 10 + * Chunk size, left context size, and right context size are represented as (c, l, r) + * Results on test-clean / test other + +| Decoding Mode | (-1, -1, -1) | (64, 128, 128) | (128, 128, 128) | (128, 256, 256) | (256, 64, 64) | (256, 128, 128) | +|-----------------------|----------------|----------------|------------------|------------------|----------------|------------------| +| ctc_greedy_search | 3.19 / 8.51* | 3.22 / 8.54 | 3.20 / 8.53 | 3.20 / 8.53 | 3.18* / 8.52 | 3.18* / 8.51* | +| ctc_prefix_beam_search| 3.17 / 8.50 | 3.20 / 8.53 | 3.18 / 8.51 | 3.19 / 8.51 | 3.16* / 8.50 | 3.16* / 8.49* | +| attention | 3.24* / 8.03* | 3.38 / 8.16 | 3.24* / 8.07 | 3.28 / 8.05 | 3.29 / 8.13 | 3.26 / 8.08 | +| attention_rescoring | 2.96 / 7.88* | 2.82* / 7.89 | 2.96 / 7.90 | 2.97 / 7.90 | 2.95 / 7.89 | 2.95 / 7.88* | +| Average | 3.14* / 8.23* | 3.16 / 8.28 | 3.15 / 8.25 | 3.16 / 8.25 | 3.15 / 8.26 | 3.14* / 8.24 | \ No newline at end of file diff --git a/examples/librispeech/s0/conf/train_u2++_chunkformer_small.yaml b/examples/librispeech/s0/conf/train_u2++_chunkformer_small.yaml new file mode 100644 index 0000000000..2d121ab74b --- /dev/null +++ b/examples/librispeech/s0/conf/train_u2++_chunkformer_small.yaml @@ -0,0 +1,121 @@ +# network architecture +# encoder related +encoder: chunkformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: dw_striding # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'chunk_rel_pos' + selfattention_layer_type: 'chunk_rel_seflattn' + dynamic_conv: false + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + # Enable the settings below for joint training on full and chunk context + # dynamic_conv: true + # dynamic_chunk_sizes: [-1, -1, 64, 128, 256] + # dynamic_left_context_sizes: [64, 128, 256] + # dynamic_right_context_sizes: [64, 128, 256] + +# decoder related +decoder: bitransformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 3 + r_num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + +tokenizer: bpe +tokenizer_conf: + symbol_table_path: 'data/lang_char/train_960_bpe5000_units.txt' + split_with_space: false + bpe_path: 'data/lang_char/train_960_bpe5000.model' + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train_960/global_cmvn' + is_json_cmvn: true + +# hybrid CTC/attention +model: asr_model +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + reverse_weight: 0.3 + +# dataset related +dataset: asr +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 400 + token_min_length: 1 + # min_output_input_ratio: 0.0005 + # max_output_input_ratio: 0.1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + spec_sub: false + spec_sub_conf: + num_t_sub: 3 + max_t: 30 + shuffle: true + shuffle_conf: + shuffle_size: 1000 + sort: false + sort_conf: + sort_size: 2000 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'dynamic' # static or dynamic + max_frames_in_batch: 120000 + # At inference, pad_feat should be False to activate + # masked batch and chunk context decoding + pad_feat: True + +grad_clip: 5 +accum_grad: 4 +max_epoch: 200 +log_interval: 100 + +optim: adamw +optim_conf: + lr: 0.001 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 25000 diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index e3654b0ace..f9275b6a0a 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -281,7 +281,10 @@ def main(): with torch.no_grad(): for batch_idx, batch in enumerate(test_data_loader): keys = batch["keys"] - feats = batch["feats"].to(device) + if type(batch["feats"]) is torch.Tensor: + feats = batch["feats"].to(device) + else: + feats = batch["feats"] target = batch["target"].to(device) feats_lengths = batch["feats_lengths"].to(device) target_lengths = batch["target_lengths"].to(device) diff --git a/wenet/chunkformer/attention.py b/wenet/chunkformer/attention.py new file mode 100644 index 0000000000..68122ee558 --- /dev/null +++ b/wenet/chunkformer/attention.py @@ -0,0 +1,305 @@ +"""Multi-Head Attention layer definition.""" + +import math +from typing import Tuple + +import torch +from torch import nn +from wenet.transformer.attention import MultiHeadedAttention + +class ChunkAttentionWithRelativeRightContext(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x, left_context_size: int = 0, right_context_size: int = 0): + """Compute relative positional encoding. The position should capture both + left and right context. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1+left_context_size). + time1 means the length of query vector. + left_context_size (int): Left context size for limited chunk context + right_context_size (int): Right context size for limited chunk context + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.size() + time2 = time1 + left_context_size + right_context_size + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def forward(self, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + chunk_size: int = 0, + left_context_size: int = 0, + right_context_size: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (torch.Tensor): Cache tensor (B, 1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + chunk_size (int): Chunk size for limited chunk context + left_context_size (int): Left context size for limited chunk context + right_context_size (int): Right context size for limited chunk context + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + bz = query.shape[0] + n_feat = query.shape[2] + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + limited_context_attn = (chunk_size > 0 + and left_context_size > 0 + and right_context_size > 0) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(2) > 0: + key_cache, value_cache = torch.split( + cache, cache.size(-1) // 2, dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + elif limited_context_attn: + # chunking query + # [B, time1, head, d_k] + q_size = q.size(1) + n_frames_pad = (chunk_size - ((q_size - chunk_size) % chunk_size)) + n_frames_pad = n_frames_pad % chunk_size + q = torch.nn.functional.pad(q, (0, 0, 0, 0, 0, n_frames_pad)) + # [B, n_chunks, head, d_k, q_size] + q = q.unfold(1, size=chunk_size, step=chunk_size) + # [B * n_chunks, head, d_k, q_size] + q = q.reshape(-1, q.size(2), q.size(3), q.size(4)) + # [B * n_chunks,q_size, head, d_k] + q = q.permute(0, 3, 1, 2) + + # Chunking key and value + # (batch, head, time1, d_k * 2) + kv = torch.cat([k, v], dim=-1) + kv = torch.nn.functional.pad( + kv, + (0, 0, left_context_size, n_frames_pad + right_context_size)) + # [B, head, n_chunks, d_k * 2, l + c + r] + kv = kv.unfold( + 2, + size=left_context_size + chunk_size + right_context_size, + step=chunk_size) + # [B, n_chunks, head, l + c + r, d_k * 2] + kv = kv.permute(0, 2, 1, 4, 3) + # [B * n_chunks, head, l + c + r, d_k * 2] + kv = kv.reshape(-1, kv.size(2), kv.size(3), kv.size(4)) + k, v = torch.split(kv, kv.size(-1) // 2, dim=-1) + + # Chunking mask for query + # [B, 1, T + n_frames_pad] + mask_q = torch.nn.functional.pad(mask, (0, n_frames_pad)) + # [B, 1, n_chunks, chunk_size] + mask_q = mask_q.unfold(-1, size=chunk_size, step=chunk_size) + # [B *n_chunks, chunk_size] + mask_q = mask_q.reshape(-1, mask_q.size(-1)) + + # Chunking mask for key and value + mask_kv = torch.nn.functional.pad( + mask, + (left_context_size, n_frames_pad + right_context_size)) + # [B, 1, n_chunks, chunk_size] + mask_kv = mask_kv.unfold( + -1, + size=left_context_size + chunk_size + right_context_size, + step=chunk_size) + # [B, * n_chunks, chunk_size] + mask_kv = mask_kv.reshape(-1, mask_kv.size(3)) + + # finalize mask + mask = mask_q.unsqueeze(-1) & mask_kv.unsqueeze(1) + + # return dummy new cache + new_cache = cache + else: + new_cache = cache + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # Add relative shift with left and right context inclusion, it can stream + matrix_bd = self.rel_shift(matrix_bd, left_context_size, right_context_size) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + attn_output = self.forward_attention(v, scores, mask) + if limited_context_attn: + attn_output = attn_output.reshape(bz, -1, n_feat) + attn_output = attn_output[:, :q_size, :] + + return attn_output, new_cache + + def forward_parallel_chunk( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0)), + right_context_size: int = 0, + left_context_size: int = 0, + truncated_context_size: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (torch.Tensor): Cache tensor (cache_t, head, d_k * 2), + where `cache_t == left_context_size` + and `head * d_k == size` + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (cache_t, head, d_k * 2) + where `cache_t == left_context_size` + and `head * d_k == size` + """ + q, k, v = self.forward_qkv(query, key, value) + + q = q.transpose(1, 2) # (batch, time1, head, d_k) + cache_t = cache.size(0) + if cache_t == 0: + cache = torch.zeros( + (left_context_size, self.h, self.d_k * 2), + device=q.device, dtype=q.dtype + ) + # (B, head, time1, d_k * 2), + kv = torch.cat([k, v], dim=-1) + # [n_chunk * chunk_size, head, F] + kv = kv.transpose(1, 2).reshape(-1, self.h, self.d_k * 2) + + + # ----------Overlapping Chunk Transformation----------------------------------- + kv = torch.cat([cache, kv], dim=0) + + if cache_t > 0: + new_cache = kv[:truncated_context_size + cache.size(0)][-cache.size(0):] + else: + # Streaming long-form transcription is disabled if input cache is empty, + new_cache = torch.zeros((0, 0, 0), device=q.device, dtype=q.dtype) + kv = torch.nn.functional.pad(kv, (0, 0, 0, 0, 0, right_context_size)) + kv = kv.unfold( + 0, + left_context_size + q.shape[1] + right_context_size, + q.shape[1] + ) + # ----------------------------------------------------------------------------- + + # [n_chunk + 1, head, F, left_context_size] + kv = kv.transpose(2, 3) + k, v = torch.split( + kv, kv.size(-1) // 2, dim=-1) + + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + + # Add relative shift with left and right context inclusion, it can stream + matrix_bd = self.rel_shift(matrix_bd, left_context_size, right_context_size) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache diff --git a/wenet/chunkformer/convolution.py b/wenet/chunkformer/convolution.py new file mode 100644 index 0000000000..9bd47891e3 --- /dev/null +++ b/wenet/chunkformer/convolution.py @@ -0,0 +1,254 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) + +"""ConvolutionModule definition.""" + +from typing import Tuple + +import torch +from torch import nn + +class ChunkConvolutionModule(nn.Module): + """ConvolutionModule in ChunkFormer model.""" + def __init__(self, + channels: int, + kernel_size: int = 15, + activation: nn.Module = nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True, + dynamic_conv: bool = False): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + super().__init__() + self.dynamic_conv = dynamic_conv + self.channels = channels + self.kernel_size = kernel_size + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + elif dynamic_conv: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = 0 + self.lorder = (kernel_size - 1) // 2 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + + assert norm in ['batch_norm', 'layer_norm'] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = nn.BatchNorm1d(channels) + else: + self.use_layer_norm = True + self.norm = nn.LayerNorm(channels) + + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + + def forward( + self, + x: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + cache: torch.Tensor = torch.zeros((0, 0, 0)), + chunk_size: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + chunk_size (int): Chunk size for dynamic chunk convolution. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) # (#batch, channels, time) + + if self.dynamic_conv and chunk_size <= 0: + chunk_size = x.size(2) + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad.to(torch.bool), 0.0) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + if self.lorder > 0: + if cache.size(2) == 0: # cache_t == 0 + x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) + else: + assert cache.size(0) == x.size(0) # equal batch + assert cache.size(1) == x.size(1) # equal channel + x = torch.cat((cache, x), dim=2) + assert (x.size(2) > self.lorder) + new_cache = x + else: + # It's better we just return None if no cache is required, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + + n_frames_pad = -1 + n_chunks = -1 + if self.dynamic_conv: + size = self.lorder + chunk_size + step = chunk_size + + n_frames_pad = (step - ((x.size(2) - size) % step)) % step + # (batch, 2*channel, dim + n_frames_pad) + x = torch.nn.functional.pad(x, (0, n_frames_pad)) + + n_chunks = ((x.size(2) - size) // step) + 1 + # [B, C, n_chunks, size] + x = x.unfold(-1, size=size, step=step) + # [B, n_chunks, C, size] + x = x.transpose(1, 2) + # [B * n_chunks, C, size] + x = x.reshape(-1, x.size(2), x.size(3)) + + # pad right for dynamic conv + x = nn.functional.pad(x, (0, self.lorder), 'constant', 0.0) + + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + + if self.dynamic_conv: + # [B, n_chunk, C, chunk_size] + x = x.reshape(-1, n_chunks, x.size(1), x.size(2)) + # [B, C, n_chunks, chunk_size] + x = x.transpose(1, 2) + # [B, C, n_chunks * chunk_size] + x = x.reshape(x.size(0), x.size(1), -1) + # remove padding + x = x[..., :x.size(2) - n_frames_pad] + + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) + + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad.to(torch.bool), 0.0) + return x.transpose(1, 2), new_cache + + + + def forward_parallel_chunk( + self, + x: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + cache: torch.Tensor = torch.zeros((0, 0)), + truncated_context_size: int = 0 + + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (channels, cache_t), + (0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (time, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) # (#batch, channels, time) + lorder = self.kernel_size // 2 + chunk_size = x.shape[-1] + cache_t = cache.size(-1) + if cache_t == 0: + cache = torch.zeros(self.channels, lorder).to(x.device) + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # ----------Overlapping Chunk Transformation----------------------------------- + x = x.transpose(0, 1).reshape(self.channels, -1) # [C, n_chunk * T] + x = torch.cat([cache, x], dim=-1) + + # Streaming long-form transcription is disabled if input cache is empty + if cache_t > 0: + new_cache = x[:, :truncated_context_size + cache.size(-1)] + new_cache = new_cache[:, -cache.size(-1):] + else: + new_cache = torch.zeros((0, 0)) + + x = nn.functional.pad(x, (0, lorder), 'constant', 0.0) + x = x.unfold(-1, chunk_size + 2 * lorder, chunk_size).transpose(0, 1) + # [n_chunk +1, C, chunk_size + 2 * lorder] + # ----------------------------------------------------------------------------- + + if mask_pad.size(2) > 0: # time > 0 + x = torch.where(mask_pad, x, 0) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad[:, :, lorder:-lorder], 0.0) + + return x.transpose(1, 2), new_cache diff --git a/wenet/chunkformer/embedding.py b/wenet/chunkformer/embedding.py new file mode 100644 index 0000000000..f747ef568e --- /dev/null +++ b/wenet/chunkformer/embedding.py @@ -0,0 +1,115 @@ +"""Positonal Encoding Module.""" + +import math +from typing import Tuple, Union + +import torch + +class RelPositionalEncodingWithRightContext(torch.nn.Module): + """Relative positional encoding module. + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncodingWithRightContext, self).__init__() + + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.xscale = math.sqrt(self.d_model) + self.max_len = max_len + self.extend_pe(max_len) + + def extend_pe(self, size: int) -> None: + """Reset the positional encodings.""" + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i torch.Tensor: + + if isinstance(left_context_size, int): + assert left_context_size + chunk_size < self.max_len + x_size_1 = chunk_size + left_context_size + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_1 + + 1 : self.pe.size(1) // 2 # noqa E203 + + chunk_size + right_context_size, + ] + else: + assert left_context_size + chunk_size < self.max_len + x_size_1 = chunk_size + left_context_size + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_1 + + 1 : self.pe.size(1) // 2 # noqa E203 + + chunk_size + right_context_size, + ] + + return pos_emb + + def forward( + self, + x: torch.Tensor, + chunk_size: Union[int, torch.Tensor] = 0, + left_context_size: Union[int, torch.Tensor] = 0, + right_context_size: Union[int, torch.Tensor] = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + offset (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + chunk_size (int): Chunk size for limited chunk context + left_context_size (int): Left context size for limited chunk context + right_context_size (int): Right context size for limited chunk context + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + x = x * self.xscale + chunk_size = x.size(1) if chunk_size <= 0 else chunk_size + pos_emb = self.position_encoding( + chunk_size=chunk_size, + left_context_size=left_context_size, + right_context_size=right_context_size, + apply_dropout=False + ).to(device=x.device, dtype=x.dtype) + return self.dropout(x), self.dropout(pos_emb) diff --git a/wenet/chunkformer/encoder.py b/wenet/chunkformer/encoder.py new file mode 100644 index 0000000000..c4223b0c0f --- /dev/null +++ b/wenet/chunkformer/encoder.py @@ -0,0 +1,499 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) + +"""Encoder definition.""" +import random +from typing import Tuple + +import torch + + +from wenet.chunkformer.attention import ChunkAttentionWithRelativeRightContext +from wenet.chunkformer.convolution import ChunkConvolutionModule +from wenet.chunkformer.embedding import RelPositionalEncodingWithRightContext +from wenet.chunkformer.encoder_layer import ChunkFormerEncoderLayer +from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward +from wenet.chunkformer.subsampling import DepthwiseConvSubsampling +from wenet.utils.mask import make_pad_mask +from wenet.transformer.encoder import BaseEncoder +from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES, + WENET_NORM_CLASSES) +class ChunkFormerEncoder(BaseEncoder): + """ChunkFormer encoder module.""" + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "dw_striding", + pos_enc_layer_type: str = "chunk_rel_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + macaron_style: bool = True, + selfattention_layer_type: str = "chunk_rel_seflattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + dynamic_conv: bool = False, + layer_norm_type: str = 'layer_norm', + gradient_checkpointing: bool = False, + final_norm: bool = True, + norm_eps: float = 1e-5, + use_sdpa: bool = False, + dynamic_chunk_sizes: list = None, + dynamic_left_context_sizes: list = None, + dynamic_right_context_sizes: list = None, + ): + """Construct ChunkFormerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + selfattention_layer_type (str): Encoder attention layer type, + the parameter has no effect now, it's just for configure + compatibility. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + dynamic_chunk_sizes (list): List of chunk sizes for dynamic chunking. + dynamic_left_context_sizes (list): List of left context sizes for + dynamic chunking. + dynamic_right_context_sizes (list): List of right context sizes for + dynamic chunking. + """ + torch.nn.Module.__init__(self) + assert selfattention_layer_type == "chunk_rel_seflattn" + assert pos_enc_layer_type == "chunk_rel_pos" + assert input_layer == "dw_striding" + + self._output_size = output_size + self.global_cmvn = global_cmvn + + assert layer_norm_type in ['layer_norm', 'rms_norm'] + self.normalize_before = normalize_before + self.final_norm = final_norm + self.after_norm = WENET_NORM_CLASSES[layer_norm_type](output_size, + eps=norm_eps) + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + self.gradient_checkpointing = gradient_checkpointing + self.use_sdpa = use_sdpa + + self._output_size = output_size + self.global_cmvn = global_cmvn + # NOTE(Mddct): head_dim == output_size // attention_heads for most of + # speech tasks, but for other task (LLM), + # head_dim == hidden_size * attention_heads. refactor later + + assert layer_norm_type in ['layer_norm', 'rms_norm'] + self.normalize_before = normalize_before + self.final_norm = final_norm + self.after_norm = WENET_NORM_CLASSES[layer_norm_type](output_size, + eps=norm_eps) + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + self.gradient_checkpointing = gradient_checkpointing + + self.dynamic_chunk_sizes = dynamic_chunk_sizes + self.dynamic_left_context_sizes = dynamic_left_context_sizes + self.dynamic_right_context_sizes = dynamic_right_context_sizes + + self.cnn_module_kernel = cnn_module_kernel + activation = WENET_ACTIVATION_CLASSES[activation_type]() + self.num_blocks = num_blocks + self.dynamic_conv = dynamic_conv + self.input_size = input_size + self.attention_heads = attention_heads + + self.embed = DepthwiseConvSubsampling( + subsampling=input_layer, + subsampling_rate=8, + feat_in=input_size, + feat_out=output_size, + conv_channels=output_size, + pos_enc_class=RelPositionalEncodingWithRightContext( + output_size, positional_dropout_rate), + subsampling_conv_chunking_factor=1, + activation=torch.nn.ReLU(), + ) + + + encoder_selfattn_layer = ChunkAttentionWithRelativeRightContext + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + + # feed-forward module definition + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + # convolution module definition + convolution_layer = ChunkConvolutionModule + convolution_layer_args = (output_size, cnn_module_kernel, activation, + cnn_module_norm, causal, True, dynamic_conv) + + self.encoders = torch.nn.ModuleList([ + ChunkFormerEncoderLayer( + size=output_size, + self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args), + feed_forward=positionwise_layer(*positionwise_layer_args), + feed_forward_macaron=positionwise_layer( + *positionwise_layer_args) if macaron_style else None, + conv_module=convolution_layer( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate=dropout_rate, + normalize_before=normalize_before + ) for _ in range(num_blocks) + ]) + + def limited_context_selection(self): + full_context_training = True + if (self.dynamic_chunk_sizes is not None + and self.dynamic_left_context_sizes is not None + and self.dynamic_right_context_sizes is not None): + chunk_size = random.choice(self.dynamic_chunk_sizes) + left_context_size = random.choice(self.dynamic_left_context_sizes) + right_context_size = random.choice(self.dynamic_right_context_sizes) + full_context_training = not (chunk_size > 0 + and left_context_size > 0 + and right_context_size > 0) + + if full_context_training: + chunk_size, left_context_size, right_context_size = 0, 0, 0 + return chunk_size, left_context_size, right_context_size + + def forward_encoder( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + chunk_size: int = 0, + left_context_size: int = 0, + right_context_size: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, T, D) + xs_lens: input length (B) + chunk_size (int): Chunk size for limited chunk context + left_context_size (int): Left context size for limited chunk context + right_context_size (int): Right context size for limited chunk context + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + """ + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + + xs, pos_emb, masks = self.embed( + xs, masks, + chunk_size=chunk_size, + left_context_size=left_context_size, + right_context_size=right_context_size + ) + mask_pad = masks # (B, 1, T/subsample_rate) + + xs = self.forward_layers( + xs, masks, pos_emb, mask_pad, + chunk_size=chunk_size, + left_context_size=left_context_size, + right_context_size=right_context_size, + ) + if self.normalize_before and self.final_norm: + xs = self.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + return xs, masks + + def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor, + chunk_size: int = 0, + left_context_size: int = 0, + right_context_size: int = 0) -> torch.Tensor: + for idx, layer in enumerate(self.encoders): + xs, chunk_masks, _, _ = layer( + xs, chunk_masks, pos_emb, mask_pad, + chunk_size=chunk_size, + left_context_size=left_context_size, + right_context_size=right_context_size, + ) + return xs + + def forward(self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + **kwargs): + """ + Main forward function that dispatches to either the standard + forward pass or the parallel chunk version based on the + model's training mode. + """ + # for masked batch chunk context inference + # should add a better flag to trigger + if decoding_chunk_size > 0 and num_decoding_left_chunks > 0: + # If both decoding_chunk_size and num_decoding_left_chunks + # are set, use the parallel chunk decoding. + return self.forward_parallel_chunk( + xs=xs, + xs_origin_lens=xs_lens, + chunk_size=decoding_chunk_size, + left_context_size=num_decoding_left_chunks, + # we assume left and right context are the same + right_context_size=num_decoding_left_chunks, + **kwargs + ) + else: + (chunk_size, + left_context_size, + right_context_size) = self.limited_context_selection() + return self.forward_encoder( + xs=xs, + xs_lens=xs_lens, + chunk_size=chunk_size, + left_context_size=left_context_size, + right_context_size=right_context_size, + **kwargs + ) + + def forward_parallel_chunk( + self, + xs, + xs_origin_lens, + chunk_size: int = -1, + left_context_size: int = -1, + right_context_size: int = -1, + att_cache: torch.Tensor = torch.zeros((0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0)), + truncated_context_size: int = 0, + offset: torch.Tensor = torch.zeros(0), + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed positions in tensor. + + Args: + xs: list of B input tensors (T, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + """ + if offset.shape[0] == 0: + offset = torch.zeros( + len(xs), dtype=torch.long, + device=xs_origin_lens.device + ) + + # --------------------------Masked Batching---------------------------------- + subsampling = self.embed.subsampling_rate + context = self.embed.right_context + 1 # Add current frame + size = (chunk_size - 1) * subsampling + context + step = subsampling * chunk_size + device = xs_origin_lens.device + + conv_lorder = self.cnn_module_kernel // 2 + + upper_bounds_att = [] + lower_bounds_att = [] + upper_bounds_conv = [] + lower_bounds_conv = [] + x_pad = [] + xs_lens = [] + n_chunks = [] + for xs_origin_len, x, offs in zip(xs_origin_lens, xs, offset): + x = x.to(device) + + # padding for unfold + if x.size(0) >= size: + n_frames_pad = (step - ((x.size(0) - size) % step)) % step + else: + n_frames_pad = size - x.size(0) + x = torch.nn.functional.pad(x, (0, 0, 0, n_frames_pad)) # (T, 80) + n_chunk = ((x.size(0) - size) // step) + 1 + x = x.unfold(0, size=size, step=step) # [n_chunk, 80, size] + x = x.transpose(2, 1) + + # attention boundaries + max_len = 1 + (xs_origin_len - context) // subsampling + upper_bound_att = chunk_size + right_context_size + torch.arange( + 0, + 1 + (xs_origin_len + n_frames_pad - context) // subsampling, + 1 + (size - context) // subsampling, device=device + ) + lower_bound_att = upper_bound_att - max_len + upper_bound_att += offs + + # convolution boundaries + upper_bound_conv = chunk_size + conv_lorder + torch.arange( + 0, + 1 + (xs_origin_len + n_frames_pad - context) // subsampling, + 1 + (size - context) // subsampling, device=device + ) + lower_bound_conv = torch.maximum( + upper_bound_conv - max_len, + torch.full_like(upper_bound_conv, conv_lorder - right_context_size) + ) + upper_bound_conv += offs + + + xs_lens += [size] * (n_chunk - 1) + [size - n_frames_pad] + upper_bounds_att.append(upper_bound_att) + lower_bounds_att.append(lower_bound_att) + upper_bounds_conv.append(upper_bound_conv) + lower_bounds_conv.append(lower_bound_conv) + x_pad.append(x) + n_chunks.append(n_chunk) + + + xs = torch.cat(x_pad, dim=0).to(device) + xs_lens = torch.tensor(xs_lens).to(device) + masks = ~make_pad_mask(xs_lens, xs.size(1)).unsqueeze(1) # (B, 1, T) + upper_bounds_att = torch.cat(upper_bounds_att).unsqueeze(1).to(device) + lower_bounds_att = torch.cat(lower_bounds_att).unsqueeze(1).to(device) + upper_bounds_conv = torch.cat(upper_bounds_conv).unsqueeze(1).to(device) + lower_bounds_conv = torch.cat(lower_bounds_conv).unsqueeze(1).to(device) + + + # forward model + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + + + xs, pos_emb, masks = self.embed( + xs, masks, + offset=left_context_size, + right_context_size=right_context_size + ) + + # convolution mask + # [B, left_context_size + chunksize] + mask_pad = torch.arange( + 0, + conv_lorder + chunk_size + conv_lorder, + device=masks.device + ).unsqueeze(0).repeat(xs.size(0), 1) + mask_pad = (lower_bounds_conv <= mask_pad) & (mask_pad < upper_bounds_conv) + mask_pad = mask_pad.flip(-1).unsqueeze(1) + + # attention mask + # [B, left_context_size + chunksize] + att_mask = torch.arange( + 0, + left_context_size + chunk_size + right_context_size, + device=masks.device + ).unsqueeze(0).repeat(xs.size(0), 1) + att_mask = (lower_bounds_att <= att_mask) & (att_mask < upper_bounds_att) + att_mask = att_mask.flip(-1).unsqueeze(1) + + r_att_cache = [] + r_cnn_cache = [] + att_cache = att_cache.to(device) + cnn_cache = cnn_cache.to(device) + + for i, layer in enumerate(self.encoders): + xs, _, new_att_cache, new_cnn_cache = layer.forward_parallel_chunk( + xs, + att_mask, + pos_emb, + mask_pad=mask_pad, + right_context_size=right_context_size, + left_context_size=left_context_size, + att_cache=att_cache[i] if att_cache.size(0) > 0 else att_cache, + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache, + truncated_context_size=truncated_context_size, + ) + + r_att_cache.append(new_att_cache) + r_cnn_cache.append(new_cnn_cache) + + if self.normalize_before: + xs = self.after_norm(xs) + + + # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), + # ? may be larger than cache_t1, it depends on required_cache_size + r_att_cache = torch.stack(r_att_cache, dim=0) + # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) + r_cnn_cache = torch.stack(r_cnn_cache, dim=0) + + # It would be no need to reconstruct (padding) in greedy search + # but for compatibility with Wenet, we reconstruct it here + xs_lens = self.embed.calc_length(xs_origin_lens) + xs, masks = self.reconstruct(xs, xs_lens, n_chunks) + offset += xs_lens + + return xs, masks + + def reconstruct( + self, + xs, + xs_lens, + n_chunks + ): + xs = xs.split(n_chunks, dim=0) + xs = [x.reshape(-1, self._output_size)[:x_len] for x, x_len in zip(xs, xs_lens)] + + xs = torch.nn.utils.rnn.pad_sequence( + xs, + batch_first=True, + padding_value=0 + ) + masks = ~make_pad_mask(xs_lens, xs.size(1)) + # (B, 1, T) + masks = masks.unsqueeze(1).to(xs.device) + return xs, masks diff --git a/wenet/chunkformer/encoder_layer.py b/wenet/chunkformer/encoder_layer.py new file mode 100644 index 0000000000..fb06d624bf --- /dev/null +++ b/wenet/chunkformer/encoder_layer.py @@ -0,0 +1,242 @@ +"""Encoder self-attention layer definition.""" + +from typing import Optional, Tuple + +import torch +from torch import nn + + +class ChunkFormerEncoderLayer(nn.Module): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: Optional[nn.Module] = None, + feed_forward_macaron: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + dropout_rate: float = 0.1, + normalize_before: bool = True, + ): + """Construct an EncoderLayer object.""" + super().__init__() + + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module + self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = nn.LayerNorm(size, + eps=1e-5) # for the CNN module + self.norm_final = nn.LayerNorm( + size, eps=1e-5) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + chunk_size: int = 0, + left_context_size: int = 0, + right_context_size: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ChunkFormerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (batch, 1, head, cache_t1, d_k * 3), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in ChunkFormer layer + (batch, 1, size, cache_t2) + chunk_size (int): Chunk size for limited chunk context + left_context_size (int): Left context size for limited chunk context + right_context_size (int): Right context size for limited chunk context + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 3). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + x_att, new_att_cache = self.self_attn( + x, x, x, mask, pos_emb, att_cache, + chunk_size=chunk_size, + left_context_size=left_context_size, + right_context_size=right_context_size + ) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + + x, new_cnn_cache = self.conv_module( + x, mask_pad, cnn_cache, + chunk_size=chunk_size) + x = residual + self.dropout(x) + + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + return x, mask, new_att_cache, new_cnn_cache + + def forward_parallel_chunk( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor, + att_cache: torch.Tensor = torch.zeros((0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0)), + right_context_size: int = 0, + left_context_size: int = 0, + truncated_context_size: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ChunkFormerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (batch, 1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in ChunkFormer layer + (batch, 1, size, cache_t2) + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + + x_att, new_att_cache = self.self_attn.forward_parallel_chunk( + x, x, x, mask, pos_emb, att_cache, + right_context_size=right_context_size, + left_context_size=left_context_size, + truncated_context_size=truncated_context_size) + + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + + x, new_cnn_cache = self.conv_module.forward_parallel_chunk( + x, mask_pad, + cnn_cache, + truncated_context_size=truncated_context_size) + + x = residual + self.dropout(x) + + if not self.normalize_before: + x = self.norm_conv(x) + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + + return x, mask, new_att_cache, new_cnn_cache diff --git a/wenet/chunkformer/subsampling.py b/wenet/chunkformer/subsampling.py new file mode 100644 index 0000000000..3557bd4e90 --- /dev/null +++ b/wenet/chunkformer/subsampling.py @@ -0,0 +1,292 @@ +"""Subsampling layer definition.""" + + +import torch +import math +from wenet.utils.mask import make_pad_mask + +class DepthwiseConvSubsampling(torch.nn.Module): + """ + Args: + subsampling (str): The subsampling technique + subsampling_rate (int): The subsampling factor which should be a power of 2 + subsampling_conv_chunking_factor (int): Input chunking factor + 1 (auto) or a power of 2. Default is 1 + feat_in (int): size of the input features + feat_out (int): size of the output features + conv_channels (int): Number of channels for the convolution layers. + activation (Module): activation function, default is nn.ReLU() + """ + + def __init__( + self, + subsampling, + subsampling_rate, + feat_in, + feat_out, + conv_channels, + pos_enc_class: torch.nn.Module, + subsampling_conv_chunking_factor=1, + activation=torch.nn.ReLU(), + ): + super(DepthwiseConvSubsampling, self).__init__() + self._subsampling = subsampling + self._conv_channels = conv_channels + self._feat_in = feat_in + self._feat_out = feat_out + self.pos_enc = pos_enc_class + + if subsampling_rate % 2 != 0: + raise ValueError("Sampling factor should be a multiply of 2!") + self._sampling_num = int(math.log(subsampling_rate, 2)) + self.subsampling_rate = subsampling_rate + self.right_context = 14 + + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): + raise ValueError("""subsampling_conv_chunking_factor + "should be -1, 1, or a power of 2""") + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + in_channels = 1 + layers = [] + + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + + self._left_padding = 0 + self._right_padding = 0 + self._max_cache_len = 0 + + # Layer 1 + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=0, + ) + ) + in_channels = conv_channels + layers.append(activation) + + for _ in range(self._sampling_num - 1): + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=0, + groups=in_channels, + ) + ) + + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + ) + ) + layers.append(activation) + in_channels = conv_channels + + in_length = torch.tensor(feat_in, dtype=torch.float) + out_length = self.calc_length( + lengths=in_length + ) + self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) + self.conv2d_subsampling = True + + + self.conv = torch.nn.Sequential(*layers) + + def get_sampling_frames(self): + return [1, self.subsampling_rate] + + def get_streaming_cache_size(self): + return [0, self.subsampling_rate + 1] + + def forward(self, + x, + mask, + chunk_size: int = -1, + left_context_size: int = 0, + right_context_size: int = 0): + lengths = mask.sum(dim=-1).squeeze(-1) + lengths = self.calc_length( + lengths, + ) + + # Unsqueeze Channel Axis + if self.conv2d_subsampling: + x = x.unsqueeze(1) + # Transpose to Channel First mode + else: + x = x.transpose(1, 2) + + # split inputs if chunking_factor is set + if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling: + if self.subsampling_conv_chunking_factor == 1: + # if subsampling_conv_chunking_factor is 1, we split only if needed + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2 ** 31 / self._conv_channels * self._stride * self._stride + if torch.numel(x) > x_ceil: + need_to_split = True + else: + need_to_split = False + else: + # if subsampling_conv_chunking_factor > 1 we always split + need_to_split = True + + # need_to_split = False + if need_to_split: + x, success = self.conv_split_by_batch(x) + # success = False + if not success: # if unable to split by batch, try by channel + x = self.conv_split_by_channel(x) + else: + x = self.conv(x) + else: + x = self.conv(x) + + # Flatten Channel and Frequency Axes + if self.conv2d_subsampling: + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, -1)) + # Transpose to Channel Last mode + else: + x = x.transpose(1, 2) + x, pos_emb = self.pos_enc( + x, + chunk_size=chunk_size, + left_context_size=left_context_size, + right_context_size=right_context_size) + mask = ~make_pad_mask(lengths, x.size(1)).unsqueeze(1) + return x, pos_emb, mask + + def reset_parameters(self): + # initialize weights + with torch.no_grad(): + # init conv + scale = 1.0 / self._kernel_size + dw_max = (self._kernel_size ** 2) ** -0.5 + pw_max = self._conv_channels ** -0.5 + + torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) + torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) + + for idx in range(2, len(self.conv), 3): + torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max) + torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max) + + fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5 + torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) + torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) + + def conv_split_by_batch(self, x): + """ Tries to split input by batch, run conv and concat results """ + b, _, _, _ = x.size() + if b == 1: # can't split if batch size is 1 + return x, False + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2 ** 31 / self._conv_channels * self._stride * self._stride + p = math.ceil(math.log(torch.numel(x) / x_ceil, 2)) + cf = 2 ** p + + new_batch_size = b // cf + if new_batch_size == 0: # input is too big + return x, False + + return torch.cat([self.conv(chunk) + for chunk in torch.split(x, new_batch_size, 0)]), True + + def conv_split_by_channel(self, x): + """ For dw convs, tries to split input by time, run conv and concat results """ + x = self.conv[0](x) # full conv2D + x = self.conv[1](x) # activation + + for i in range(self._sampling_num - 1): + _, c, t, _ = x.size() + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + p = math.ceil(math.log(torch.numel(x) / 2 ** 31, 2)) + cf = 2 ** p + + new_c = int(c // cf) + if new_c == 0: + new_c = 1 + + new_t = int(t // cf) + if new_t == 0: + new_t = 1 + + x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, x) + + # splitting pointwise convs by time + x = torch.cat([self.conv[i * 3 + 3](chunk) + for chunk in torch.split(x, new_t, 2)], 2) + x = self.conv[i * 3 + 4](x) # activation + return x + + def channel_chunked_conv(self, conv, chunk_size, x): + """ Performs channel chunked convolution""" + + ind = 0 + out_chunks = [] + for chunk in torch.split(x, chunk_size, 1): + step = chunk.size()[1] + ch_out = torch.nn.functional.conv2d( + chunk, + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], + stride=self._stride, + padding=self._left_padding, + groups=step, + ) + out_chunks.append(ch_out) + ind += step + + return torch.cat(out_chunks, 1) + + def calc_length(self, lengths): + """ + Calculates the output length of a Tensor + passed through a convolution or max pooling layer + """ + all_paddings = self._left_padding + self._right_padding + kernel_size = self._kernel_size + stride = self._stride + ceil_mode = self._ceil_mode + repeat_num = self._sampling_num + add_pad = all_paddings - kernel_size + one = 1.0 + for i in range(repeat_num): + lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one + if ceil_mode: + lengths = torch.ceil(lengths) + else: + lengths = torch.floor(lengths) + return lengths.to(dtype=torch.int) diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index 95a3eafa97..d6773de5a7 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -132,11 +132,15 @@ def Dataset(data_type, batch_conf = conf.get('batch_conf', {}) batch_type = batch_conf.get('batch_type', 'static') + pad_feat = batch_conf.get('pad_feat', 'True') + assert batch_type in ['static', 'bucket', 'dynamic'] if batch_type == 'static': assert 'batch_size' in batch_conf batch_size = batch_conf.get('batch_size', 16) - dataset = dataset.batch(batch_size, wrapper_class=processor.padding) + dataset = dataset.batch( + batch_size, + wrapper_class=lambda batch: processor.padding(batch, pad_feat)) elif batch_type == 'bucket': assert 'bucket_boundaries' in batch_conf assert 'bucket_batch_sizes' in batch_conf @@ -144,12 +148,12 @@ def Dataset(data_type, processor.feats_length_fn, batch_conf['bucket_boundaries'], batch_conf['bucket_batch_sizes'], - wrapper_class=processor.padding) + wrapper_class=lambda batch: processor.padding(batch, pad_feat)) else: max_frames_in_batch = batch_conf.get('max_frames_in_batch', 12000) dataset = dataset.dynamic_batch( processor.DynamicBatchWindow(max_frames_in_batch), - wrapper_class=processor.padding, + wrapper_class=lambda batch: processor.padding(batch, pad_feat) ) return dataset diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 4838aa5721..56649d1e42 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -527,7 +527,7 @@ def spec_trim(sample, max_t=20): return sample -def padding(data): +def padding(data, pad_feat=True): """ Padding the data into training data Args: @@ -565,7 +565,7 @@ def padding(data): batch = { "keys": sorted_keys, - "feats": padded_feats, + "feats": padded_feats if pad_feat else sorted_feats, "target": padding_labels, "feats_lengths": feats_lengths, "target_lengths": label_lengths, diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index ee05ab5d2b..8eda66af03 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -301,7 +301,7 @@ def decode( Returns: dict results of all decoding methods """ - assert speech.shape[0] == speech_lengths.shape[0] + assert len(speech) == len(speech_lengths) assert decoding_chunk_size != 0 encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, diff --git a/wenet/utils/class_utils.py b/wenet/utils/class_utils.py index 59c67f816f..0f22fdf524 100644 --- a/wenet/utils/class_utils.py +++ b/wenet/utils/class_utils.py @@ -11,6 +11,8 @@ from wenet.firered.subsampling import FireRedConv2dSubsampling4 from wenet.paraformer.embedding import ParaformerPositinoalEncoding from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4 +from wenet.chunkformer.subsampling import DepthwiseConvSubsampling +from wenet.chunkformer.embedding import RelPositionalEncodingWithRightContext from wenet.transformer.attention import (MultiHeadedAttention, MultiHeadedCrossAttention, RelPositionMultiHeadedAttention, @@ -54,7 +56,8 @@ "conv2d8": Conv2dSubsampling8, 'paraformer_dummy': torch.nn.Identity, 'stack_n_frames': StackNFramesSubsampling, - 'firered_conv2d4': FireRedConv2dSubsampling4 + 'firered_conv2d4': FireRedConv2dSubsampling4, + 'dw_striding': DepthwiseConvSubsampling, } WENET_EMB_CLASSES = { @@ -66,7 +69,8 @@ "embed_learnable_pe": LearnablePositionalEncoding, "abs_pos_paraformer": ParaformerPositinoalEncoding, 'rope_pos': RopePositionalEncoding, - 'rel_pos_firered': FireRedRelPositionalEncoding + 'rel_pos_firered': FireRedRelPositionalEncoding, + 'chunk_rel_pos': RelPositionalEncodingWithRightContext } WENET_ATTENTION_CLASSES = { diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 1217c7dfb9..6d99988add 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -21,6 +21,7 @@ DualTransformerEncoder) from wenet.e_branchformer.encoder import EBranchformerEncoder from wenet.efficient_conformer.encoder import EfficientConformerEncoder +from wenet.chunkformer.encoder import ChunkFormerEncoder from wenet.finetune.lora.utils import (inject_lora_to_model, mark_only_lora_as_trainable) from wenet.firered.encoder import FireRedConformerEncoder @@ -45,6 +46,7 @@ from wenet.whisper.whisper import Whisper WENET_ENCODER_CLASSES = { + "chunkformer": ChunkFormerEncoder, "transformer": TransformerEncoder, "conformer": ConformerEncoder, "squeezeformer": SqueezeformerEncoder,