diff --git a/chunkformer/modules/attention.py b/chunkformer/modules/attention.py index 69801ce..ce6cb10 100644 --- a/chunkformer/modules/attention.py +++ b/chunkformer/modules/attention.py @@ -6,6 +6,8 @@ import torch from torch import nn +from chunkformer.utils.common import unfold_with_loop + class MultiHeadedAttention(nn.Module): """Multi-Head Attention layer. @@ -265,6 +267,20 @@ def rel_shift(self, x, left_context_size: int = 0, right_context_size: int = 0): storage_offset=n_stride * (time1 - 1), ) + def rel_shift_export(self, x, left_context_size: int = 0, right_context_size: int = 0): + (batch_size, num_heads, time1, n) = x.size() + time2 = time1 + left_context_size + right_context_size + + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time2) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes.to(x.device)) + x = x.reshape(batch_size, num_heads, time1, time2) + return x + def forward( self, query: torch.Tensor, @@ -276,6 +292,7 @@ def forward( chunk_size: int = 0, left_context_size: int = 0, right_context_size: int = 0, + export: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: @@ -337,7 +354,10 @@ def forward( n_frames_pad = n_frames_pad % chunk_size q = torch.nn.functional.pad(q, (0, 0, 0, 0, 0, n_frames_pad)) # [B, n_chunks, head, d_k, q_size] - q = q.unfold(1, size=chunk_size, step=chunk_size) + if export: + q = unfold_with_loop(q, 1, chunk_size, chunk_size) + else: + q = q.unfold(1, size=chunk_size, step=chunk_size) # [B * n_chunks, head, d_k, q_size] q = q.reshape(-1, q.size(2), q.size(3), q.size(4)) # [B * n_chunks,q_size, head, d_k] @@ -350,9 +370,14 @@ def forward( kv, (0, 0, left_context_size, n_frames_pad + right_context_size) ) # [B, head, n_chunks, d_k * 2, l + c + r] - kv = kv.unfold( - 2, size=left_context_size + chunk_size + right_context_size, step=chunk_size - ) + if export: + kv = unfold_with_loop( + kv, 2, left_context_size + chunk_size + right_context_size, chunk_size + ) + else: + kv = kv.unfold( + 2, size=left_context_size + chunk_size + right_context_size, step=chunk_size + ) # [B, n_chunks, head, l + c + r, d_k * 2] kv = kv.permute(0, 2, 1, 4, 3) # [B * n_chunks, head, l + c + r, d_k * 2] @@ -363,7 +388,10 @@ def forward( # [B, 1, T + n_frames_pad] mask_q = torch.nn.functional.pad(mask, (0, n_frames_pad)) # [B, 1, n_chunks, chunk_size] - mask_q = mask_q.unfold(-1, size=chunk_size, step=chunk_size) + if export: + mask_q = unfold_with_loop(mask_q, -1, chunk_size, chunk_size) + else: + mask_q = mask_q.unfold(-1, size=chunk_size, step=chunk_size) # [B *n_chunks, chunk_size] mask_q = mask_q.reshape(-1, mask_q.size(-1)) @@ -372,9 +400,14 @@ def forward( mask, (left_context_size, n_frames_pad + right_context_size) ) # [B, 1, n_chunks, chunk_size] - mask_kv = mask_kv.unfold( - -1, size=left_context_size + chunk_size + right_context_size, step=chunk_size - ) + if export: + mask_kv = unfold_with_loop( + mask_kv, -1, left_context_size + chunk_size + right_context_size, chunk_size + ) + else: + mask_kv = mask_kv.unfold( + -1, size=left_context_size + chunk_size + right_context_size, step=chunk_size + ) # [B, * n_chunks, chunk_size] mask_kv = mask_kv.reshape(-1, mask_kv.size(3)) @@ -405,7 +438,10 @@ def forward( # (batch, head, time1, time2) matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) # Add relative shift with left and right context inclusion, it can stream - matrix_bd = self.rel_shift(matrix_bd, left_context_size, right_context_size) + if export: + matrix_bd = self.rel_shift_export(matrix_bd, left_context_size, right_context_size) + else: + matrix_bd = self.rel_shift(matrix_bd, left_context_size, right_context_size) scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) diff --git a/chunkformer/modules/convolution.py b/chunkformer/modules/convolution.py index dd9560e..c8d97c0 100644 --- a/chunkformer/modules/convolution.py +++ b/chunkformer/modules/convolution.py @@ -20,6 +20,8 @@ import torch from torch import nn +from chunkformer.utils.common import unfold_with_loop + class ChunkConvolutionModule(nn.Module): """ConvolutionModule in ChunkFormer model.""" @@ -104,6 +106,7 @@ def forward( mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), cache: torch.Tensor = torch.zeros((0, 0, 0)), chunk_size: int = 0, + export: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute convolution module. Args: @@ -157,7 +160,10 @@ def forward( n_chunks = ((x.size(2) - size) // step) + 1 # [B, C, n_chunks, size] - x = x.unfold(-1, size=size, step=step) + if export: + x = unfold_with_loop(x, -1, size=size, step=step) + else: + x = x.unfold(-1, size=size, step=step) # [B, n_chunks, C, size] x = x.transpose(1, 2) # [B * n_chunks, C, size] diff --git a/chunkformer/modules/encoder.py b/chunkformer/modules/encoder.py index a344422..5287f03 100644 --- a/chunkformer/modules/encoder.py +++ b/chunkformer/modules/encoder.py @@ -217,6 +217,7 @@ def forward_encoder( chunk_size: int = 0, left_context_size: int = 0, right_context_size: int = 0, + export: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Embed positions in tensor. @@ -258,6 +259,7 @@ def forward_encoder( chunk_size=chunk_size, left_context_size=left_context_size, right_context_size=right_context_size, + export=export, ) if self.normalize_before and self.final_norm: xs = self.after_norm(xs) @@ -275,6 +277,7 @@ def forward_layers( chunk_size: int = 0, left_context_size: int = 0, right_context_size: int = 0, + export: bool = False, ) -> torch.Tensor: for idx, layer in enumerate(self.encoders): xs, chunk_masks, _, _ = layer( @@ -285,6 +288,7 @@ def forward_layers( chunk_size=chunk_size, left_context_size=left_context_size, right_context_size=right_context_size, + export=export, ) return xs diff --git a/chunkformer/modules/encoder_layer.py b/chunkformer/modules/encoder_layer.py index e6835ee..8ec7d66 100644 --- a/chunkformer/modules/encoder_layer.py +++ b/chunkformer/modules/encoder_layer.py @@ -70,6 +70,7 @@ def forward( chunk_size: int = 0, left_context_size: int = 0, right_context_size: int = 0, + export: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute encoded features. @@ -118,6 +119,7 @@ def forward( chunk_size=chunk_size, left_context_size=left_context_size, right_context_size=right_context_size, + export=export, ) x = residual + self.dropout(x_att) if not self.normalize_before: @@ -131,7 +133,9 @@ def forward( if self.normalize_before: x = self.norm_conv(x) - x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache, chunk_size=chunk_size) + x, new_cnn_cache = self.conv_module( + x, mask_pad, cnn_cache, chunk_size=chunk_size, export=export + ) x = residual + self.dropout(x) if not self.normalize_before: diff --git a/chunkformer/utils/common.py b/chunkformer/utils/common.py index 193a498..4d92853 100644 --- a/chunkformer/utils/common.py +++ b/chunkformer/utils/common.py @@ -16,7 +16,7 @@ import math import time -from typing import List, Tuple +from typing import List, Tuple, Union import torch from torch.nn.utils.rnn import pad_sequence @@ -24,6 +24,40 @@ IGNORE_ID = -1 +def unfold_with_loop(x: torch.Tensor, dimension: int, size: int, step: int) -> torch.Tensor: + """ + Replicates the functionality of torch.Tensor.unfold using slicing and a for-loop. + + Args: + x: The input tensor to unfold (e.g., shape (B, H, L, D)). + dimension: The dimension along which to unfold (e.g., 2 for L). + size: The size of the sliding window (W). + step: The step between adjacent windows (S). + + Returns: + The unfolded tensor. If original shape was (..., L, ...), the new shape + will be (..., num_windows, size, ...). + """ + dimension = dimension if dimension >= 0 else dimension + x.ndim + permute_lst = list(range(len(x.shape))) + [dimension] + permute_lst.pop(dimension) + x = x.permute(permute_lst) + + if step <= 0 or size <= 0: + raise ValueError("Size and step must be positive integers.") + + unfolded_chunks: List[torch.Tensor] = [] + for start_index in range(0, x.size(-1) - size + 1, step): + end_index = start_index + size + slices: List[Union[slice, int]] = [slice(None)] * x.ndim + slices[-1] = slice(start_index, end_index) + chunk = x[slices] + unfolded_chunks.append(chunk) + + unfolded_tensor = torch.stack(unfolded_chunks, dim=dimension) + return unfolded_tensor + + def pad_list(xs: List[torch.Tensor], pad_value: int): """Perform padding for the list of tensors. diff --git a/chunkformer/utils/mask.py b/chunkformer/utils/mask.py index cf657b8..052632e 100644 --- a/chunkformer/utils/mask.py +++ b/chunkformer/utils/mask.py @@ -218,7 +218,7 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: [0, 0, 1, 1, 1]] """ batch_size = lengths.size(0) - max_len = max_len if max_len > 0 else lengths.max().item() + max_len = max_len if max_len > 0 else lengths.max() seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = lengths.unsqueeze(-1) diff --git a/examples/asr/ctc/run.sh b/examples/asr/ctc/run.sh index d804112..5ad42dd 100644 --- a/examples/asr/ctc/run.sh +++ b/examples/asr/ctc/run.sh @@ -42,6 +42,10 @@ decode_checkpoint=$dir/final.pt # maybe you can try to adjust it if you can not get close results as README.md average_num=75 decode_modes="ctc_greedy_search" +# Specify decoding_chunk_size if it's a unified dynamic chunk trained model +# -1 for full chunk +decoding_chunk_size=64 +num_decoding_left_chunks=128 # left_context_size = right_context_size # bpemode (unigram or bpe) nbpe=1024 @@ -51,7 +55,7 @@ bpemode=bpe # To enable upload and set these variables: hf_token="hf_xxxxxxxxxxxxxxxxxxxxxxxxx" # Your Hugging Face token hf_repo_id="username/chunkformer-model" # Your repository ID - +onnx=false set -e set -u set -o pipefail @@ -152,10 +156,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --src_path $dir \ --num ${average_num} fi - # Specify decoding_chunk_size if it's a unified dynamic chunk trained model - # -1 for full chunk - decoding_chunk_size=64 - num_decoding_left_chunks=128 ctc_weight=0.3 for test in $recog_set; do @@ -266,6 +266,15 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then echo "Warning: $vocab not found for tokenizer folder." fi + # Step 6: Export ONNX model if requested + if [ ${onnx} == true ]; then + echo "Exporting ONNX model using export_onnx.py..." + python chunkformer/tools/export_onnx.py "$inference_model_dir" \ + --chunk_size ${decoding_chunk_size} \ + --left_context_size ${num_decoding_left_chunks} \ + --right_context_size ${num_decoding_left_chunks} + fi + echo "Model setup completed. Directory structure:" ls -la $inference_model_dir echo "" diff --git a/pyproject.toml b/pyproject.toml index 7d85a42..5f12acb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,10 @@ dependencies = [ "deepspeed>=0.14.0", "librosa", "langid", + "onnx==1.19.0", + "onnx-ir==0.1.9", + "onnxruntime==1.23.0", + "onnxscript==0.5.2", ] keywords = [ "speech-recognition", diff --git a/tools/export_onnx.py b/tools/export_onnx.py new file mode 100644 index 0000000..1700191 --- /dev/null +++ b/tools/export_onnx.py @@ -0,0 +1,136 @@ +import argparse +import os + +import numpy as np +import onnxruntime as ort +import torch + +import onnx +from chunkformer.chunkformer_model import ChunkFormerModel + + +class EncoderONNXWrapper(torch.nn.Module): + def __init__(self, encoder): + super().__init__() + self.encoder = encoder + + def forward(self, xs, xs_lens, chunk_size, left_context_size, right_context_size): + return self.encoder.forward_encoder( + xs, + xs_lens, + chunk_size=chunk_size, + left_context_size=left_context_size, + right_context_size=right_context_size, + export=True, + ) + + +def get_encoder_from_inference_dir(inference_model_dir): + model = ChunkFormerModel.from_pretrained(inference_model_dir) + encoder = model.get_encoder() + encoder.eval() + return encoder + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Export ChunkFormerEncoder to ONNX.") + parser.add_argument( + "inference_model_dir", type=str, help="Path to the inference model directory." + ) + parser.add_argument("--chunk_size", type=int, default=64, help="Chunk size for encoder export.") + parser.add_argument( + "--left_context_size", type=int, default=128, help="Left context size for encoder export." + ) + parser.add_argument( + "--right_context_size", type=int, default=128, help="Right context size for encoder export." + ) + args = parser.parse_args() + + inference_model_dir = args.inference_model_dir + encoder = get_encoder_from_inference_dir(inference_model_dir) + encoder_onnx = EncoderONNXWrapper(encoder) + encoder_onnx.eval() + + batch_size = 1 + seq_len = 500 * 8 + 7 + input_dim = 80 + xs = torch.randn(batch_size, seq_len, input_dim) + xs_lens = torch.full((batch_size,), seq_len, dtype=torch.long) + chunk_size = args.chunk_size + left_context_size = args.left_context_size + right_context_size = args.right_context_size + + onnx_dir = os.path.join(inference_model_dir, "onnx") + os.makedirs(onnx_dir, exist_ok=True) + onnx_path = os.path.join(onnx_dir, "chunkformer_encoder.onnx") + + # call the encoder + with torch.no_grad(): + encoder_out, encoder_mask = encoder_onnx( + xs, + xs_lens, + chunk_size=chunk_size, + left_context_size=left_context_size, + right_context_size=right_context_size, + ) + print( + f"Encoder output shape: {encoder_out.shape}, Encoder mask shape: {encoder_mask.shape}" + ) + + onnx_model = torch.onnx.export( + encoder_onnx, + (xs, xs_lens, chunk_size, left_context_size, right_context_size), + onnx_path, + input_names=["xs", "xs_lens", "chunk_size", "left_context_size", "right_context_size"], + output_names=["encoder_out", "encoder_mask"], + opset_version=17, + dynamic_axes={ + "xs": {0: "batch_size", 1: "seq_len"}, + "xs_lens": {0: "batch_size"}, + "encoder_out": {0: "batch_size", 1: "out_seq_len"}, + "encoder_mask": {0: "batch_size", 2: "out_seq_len"}, + }, + # verbose=True, + dynamo=False, + report=False, + profile=False, + verbose=1, + ) + print(f"Exported ChunkFormerEncoder to {onnx_path} using forward_encoder.") + + # Verify the ONNX model + onnx_model = onnx.load(onnx_path) # type: ignore[attr-defined] + onnx.checker.check_model(onnx_model) # type: ignore[attr-defined] + print("ONNX model is valid.") + + # Test inference with ONNX Runtime + ort_session = ort.InferenceSession(onnx_path) + ort_inputs = { + "xs": xs.numpy(), + "xs_lens": xs_lens.numpy(), + "chunk_size": np.array(chunk_size, dtype=np.int64), + "left_context_size": np.array(left_context_size, dtype=np.int64), + "right_context_size": np.array(right_context_size, dtype=np.int64), + } + ort_outs = ort_session.run(None, ort_inputs) + print("ONNX Runtime inference successful.") + + # Compare with PyTorch output + with torch.no_grad(): + torch_out, torch_mask = encoder_onnx( + xs, + xs_lens, + chunk_size=chunk_size, + left_context_size=left_context_size, + right_context_size=right_context_size, + ) + print("PyTorch inference successful.") + + # Compare the results + np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-05, atol=5e-06) + assert (torch_mask.numpy() == ort_outs[1]).all(), "Encoder masks do not match!" + print("The outputs from PyTorch and ONNX Runtime match!") + + # Compute the difference + diff_out = np.abs(torch_out.numpy() - ort_outs[0]) + print(f"Max difference in encoder_out: {diff_out.max()}")