From 9da5b6cddff39e76470e7585895bf7ec6df7f2df Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 2 Jun 2026 12:22:49 +0800 Subject: [PATCH 1/6] Export non-streaming zipformer to qnn --- egs/librispeech/ASR/zipformer/export-onnx.py | 40 ++++++++++++------- .../ASR/zipformer/scaling_converter.py | 15 ++++++- egs/librispeech/ASR/zipformer/zipformer.py | 31 ++++++++------ 3 files changed, 56 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index 03c7d6f820..d4b1a95e3e 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -211,8 +211,7 @@ def __init__( def forward( self, x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: """Please see the help information of Zipformer.forward Args: @@ -225,6 +224,7 @@ def forward( - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - encoder_out_lens, A 1-D tensor of shape (N,) """ + x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) x, x_lens = self.encoder_embed(x, x_lens) src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]) x = x.permute(1, 0, 2) @@ -233,7 +233,7 @@ def forward( encoder_out = self.encoder_proj(encoder_out) # Now encoder_out is of shape (N, T, joiner_dim) - return encoder_out, encoder_out_lens + return encoder_out class OnnxDecoder(nn.Module): @@ -310,25 +310,30 @@ def export_encoder_model_onnx( opset_version: The opset version to use. """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) + x = torch.zeros(1, 500, 80, dtype=torch.float32) + x = torch.zeros(1, 1000, 80, dtype=torch.float32) + x = torch.zeros(1, 800, 80, dtype=torch.float32) + x = torch.zeros(1, 300, 80, dtype=torch.float32) + # x_lens = torch.tensor([488], dtype=torch.int64) - encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) + encoder_model = torch.jit.trace(encoder_model, x) torch.onnx.export( encoder_model, - (x, x_lens), + x, encoder_filename, verbose=False, opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], + input_names=["x"], + output_names=["encoder_out"], dynamic_axes={ "x": {0: "N", 1: "T"}, "x_lens": {0: "N"}, "encoder_out": {0: "N", 1: "T"}, "encoder_out_lens": {0: "N"}, - }, + } + if False + else {}, ) meta_data = { @@ -368,7 +373,7 @@ def export_decoder_model_onnx( context_size = decoder_model.decoder.context_size vocab_size = decoder_model.decoder.vocab_size - y = torch.zeros(10, context_size, dtype=torch.int64) + y = torch.zeros(1, context_size, dtype=torch.int32) decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, @@ -381,7 +386,9 @@ def export_decoder_model_onnx( dynamic_axes={ "y": {0: "N"}, "decoder_out": {0: "N"}, - }, + } + if False + else {}, ) meta_data = { @@ -409,8 +416,8 @@ def export_joiner_model_onnx( joiner_dim = joiner_model.output_linear.weight.shape[1] logging.info(f"joiner dim: {joiner_dim}") - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) torch.onnx.export( joiner_model, @@ -427,7 +434,9 @@ def export_joiner_model_onnx( "encoder_out": {0: "N"}, "decoder_out": {0: "N"}, "logit": {0: "N"}, - }, + } + if False + else {}, ) meta_data = { "joiner_dim": str(joiner_dim), @@ -616,6 +625,7 @@ def main(): # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection logging.info("Generate int8 quantization models") + return encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" quantize_dynamic( diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py index 1f95648a07..485efc6b6b 100644 --- a/egs/librispeech/ASR/zipformer/scaling_converter.py +++ b/egs/librispeech/ASR/zipformer/scaling_converter.py @@ -37,7 +37,11 @@ SwooshROnnx, Whiten, ) -from zipformer import CompactRelPositionalEncoding +from zipformer import ( + CompactRelPositionalEncoding, + ChunkCausalDepthwiseConv1d, + SimpleDownsample, +) # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa @@ -89,7 +93,14 @@ def convert_scaled_to_non_scaled( d[name] = SwooshROnnx() elif is_onnx and isinstance(m, SwooshL): d[name] = SwooshLOnnx() - elif is_onnx and isinstance(m, CompactRelPositionalEncoding): + elif is_onnx and isinstance( + m, + ( + CompactRelPositionalEncoding, + ChunkCausalDepthwiseConv1d, + SimpleDownsample, + ), + ): # We want to recreate the positional encoding vector when # the input changes, so we have to use torch.jit.script() # to replace torch.jit.trace() diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index e83a894008..d269407db3 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1348,17 +1348,19 @@ def forward(self, src: Tensor) -> Tensor: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - if self.causal and torch.jit.is_tracing(): - assert ( - pad == 0 - ), f"pad should be zero for exporting streaming models. Given {pad}" - # If we are exporting a streaming model, then we skip the if statement if not self.causal or not torch.jit.is_tracing(): - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - - assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds) + if pad > 0: + src_extra = src[src.shape[0] - 1 :].expand( + pad, src.shape[1], src.shape[2] + ) + src = torch.cat((src, src_extra), dim=0) + elif self.causal and torch.jit.is_scripting(): + if pad > 0: + src_extra = src[src.shape[0] - 1 :].expand( + pad, src.shape[1], src.shape[2] + ) + src = torch.cat((src, src_extra), dim=0) src = src.reshape(d_seq_len, ds, batch_size, in_channels) @@ -1498,14 +1500,17 @@ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: Returns: positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). """ - self.extend_pe(x, left_context_len) + if not torch.jit.is_scripting(): + self.extend_pe(x, left_context_len) + assert self.pe is not None + pe = self.pe x_size_left = x.size(0) + left_context_len # length of positive side: x.size(0) + left_context_len # length of negative side: x.size(0) - pos_emb = self.pe[ - self.pe.size(0) // 2 + pos_emb = pe[ + pe.size(0) // 2 - x_size_left - + 1 : self.pe.size(0) // 2 # noqa E203 + + 1 : pe.size(0) // 2 # noqa E203 + x.size(0), :, ] From 90ed141ca26493c9d3070854508ecdee052796b4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 2 Jun 2026 18:18:20 +0800 Subject: [PATCH 2/6] First working version --- egs/librispeech/ASR/zipformer/export-onnx.py | 13 +++++++++---- egs/librispeech/ASR/zipformer/zipformer.py | 10 +++------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index d4b1a95e3e..92a8bc2eb7 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -117,6 +117,12 @@ def get_parser(): "'--epoch' and '--iter'", ) + parser.add_argument( + "--max-len", + type=int, + required=True, + ) + parser.add_argument( "--use-averaged-model", type=str2bool, @@ -289,6 +295,7 @@ def forward( def export_encoder_model_onnx( encoder_model: OnnxEncoder, encoder_filename: str, + max_len: int, opset_version: int = 11, ) -> None: """Export the given encoder model to ONNX format. @@ -310,10 +317,7 @@ def export_encoder_model_onnx( opset_version: The opset version to use. """ - x = torch.zeros(1, 500, 80, dtype=torch.float32) - x = torch.zeros(1, 1000, 80, dtype=torch.float32) - x = torch.zeros(1, 800, 80, dtype=torch.float32) - x = torch.zeros(1, 300, 80, dtype=torch.float32) + x = torch.zeros(1, max_len, 80, dtype=torch.float32) # x_lens = torch.tensor([488], dtype=torch.int64) encoder_model = torch.jit.trace(encoder_model, x) @@ -587,6 +591,7 @@ def main(): export_encoder_model_onnx( encoder, encoder_filename, + max_len=params.max_len, opset_version=opset_version, ) logging.info(f"Exported encoder to {encoder_filename}") diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d269407db3..d92836e05a 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1507,13 +1507,9 @@ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: x_size_left = x.size(0) + left_context_len # length of positive side: x.size(0) + left_context_len # length of negative side: x.size(0) - pos_emb = pe[ - pe.size(0) // 2 - - x_size_left - + 1 : pe.size(0) // 2 # noqa E203 - + x.size(0), - :, - ] + start_pos = pe.size(0) // 2 - x_size_left + 1 + end_pos = pe.size(0) // 2 + x.size(0) + pos_emb = pe[start_pos:end_pos] pos_emb = pos_emb.unsqueeze(0) return self.dropout(pos_emb) From be3e455025e6d7250a9355faa53c218676e96fff Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 3 Jun 2026 11:08:44 +0800 Subject: [PATCH 3/6] restore export-onnx.py --- egs/librispeech/ASR/zipformer/export-onnx.py | 45 +++++++------------- 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index 92a8bc2eb7..03c7d6f820 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -117,12 +117,6 @@ def get_parser(): "'--epoch' and '--iter'", ) - parser.add_argument( - "--max-len", - type=int, - required=True, - ) - parser.add_argument( "--use-averaged-model", type=str2bool, @@ -217,7 +211,8 @@ def __init__( def forward( self, x: torch.Tensor, - ) -> torch.Tensor: + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: """Please see the help information of Zipformer.forward Args: @@ -230,7 +225,6 @@ def forward( - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - encoder_out_lens, A 1-D tensor of shape (N,) """ - x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) x, x_lens = self.encoder_embed(x, x_lens) src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]) x = x.permute(1, 0, 2) @@ -239,7 +233,7 @@ def forward( encoder_out = self.encoder_proj(encoder_out) # Now encoder_out is of shape (N, T, joiner_dim) - return encoder_out + return encoder_out, encoder_out_lens class OnnxDecoder(nn.Module): @@ -295,7 +289,6 @@ def forward( def export_encoder_model_onnx( encoder_model: OnnxEncoder, encoder_filename: str, - max_len: int, opset_version: int = 11, ) -> None: """Export the given encoder model to ONNX format. @@ -317,27 +310,25 @@ def export_encoder_model_onnx( opset_version: The opset version to use. """ - x = torch.zeros(1, max_len, 80, dtype=torch.float32) - # x_lens = torch.tensor([488], dtype=torch.int64) + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) - encoder_model = torch.jit.trace(encoder_model, x) + encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) torch.onnx.export( encoder_model, - x, + (x, x_lens), encoder_filename, verbose=False, opset_version=opset_version, - input_names=["x"], - output_names=["encoder_out"], + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], dynamic_axes={ "x": {0: "N", 1: "T"}, "x_lens": {0: "N"}, "encoder_out": {0: "N", 1: "T"}, "encoder_out_lens": {0: "N"}, - } - if False - else {}, + }, ) meta_data = { @@ -377,7 +368,7 @@ def export_decoder_model_onnx( context_size = decoder_model.decoder.context_size vocab_size = decoder_model.decoder.vocab_size - y = torch.zeros(1, context_size, dtype=torch.int32) + y = torch.zeros(10, context_size, dtype=torch.int64) decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, @@ -390,9 +381,7 @@ def export_decoder_model_onnx( dynamic_axes={ "y": {0: "N"}, "decoder_out": {0: "N"}, - } - if False - else {}, + }, ) meta_data = { @@ -420,8 +409,8 @@ def export_joiner_model_onnx( joiner_dim = joiner_model.output_linear.weight.shape[1] logging.info(f"joiner dim: {joiner_dim}") - projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) torch.onnx.export( joiner_model, @@ -438,9 +427,7 @@ def export_joiner_model_onnx( "encoder_out": {0: "N"}, "decoder_out": {0: "N"}, "logit": {0: "N"}, - } - if False - else {}, + }, ) meta_data = { "joiner_dim": str(joiner_dim), @@ -591,7 +578,6 @@ def main(): export_encoder_model_onnx( encoder, encoder_filename, - max_len=params.max_len, opset_version=opset_version, ) logging.info(f"Exported encoder to {encoder_filename}") @@ -630,7 +616,6 @@ def main(): # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection logging.info("Generate int8 quantization models") - return encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" quantize_dynamic( From 514faa2817206f8de17d8fb12f5c7fccda4ef307 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 3 Jun 2026 19:09:31 +0800 Subject: [PATCH 4/6] first streaming version for qnn --- .../ASR/zipformer/export-onnx-ctc.py | 2 +- .../zipformer/export-onnx-streaming-ctc.py | 14 +- .../ASR/zipformer/export-onnx-streaming.py | 34 +- egs/librispeech/ASR/zipformer/export-onnx.py | 2 +- .../export-streaming-as-non-streaming-onnx.py | 679 ++++++++++++++++++ egs/librispeech/ASR/zipformer/export.py | 6 +- .../ASR/zipformer/scaling_converter.py | 48 +- egs/librispeech/ASR/zipformer/zipformer.py | 10 +- 8 files changed, 773 insertions(+), 22 deletions(-) create mode 100755 egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py diff --git a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py index 413b5bb1ed..7bd271c9f6 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py @@ -231,7 +231,7 @@ def forward( - log_probs_len, a 1-D int64 tensor of shape (N,) """ x, x_lens = self.encoder_embed(x, x_lens) - src_key_padding_mask = make_pad_mask(x_lens) + src_key_padding_mask = make_pad_mask(x_lens).to(torch.int32) x = x.permute(1, 0, 2) encoder_out, log_probs_len = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py index 9a715eefd5..6fc505b4d1 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py @@ -273,7 +273,7 @@ def forward( ) assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) - src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool) + src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.int32) # processed_mask is used to mask out initial states processed_mask = torch.arange(left_context_len, device=x.device).expand( @@ -281,7 +281,7 @@ def forward( ) processed_lens = states[-1] # (batch,) # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).to(torch.int32).flip(1) # Update processed lengths new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) @@ -641,11 +641,13 @@ def main(): ) logging.info(f"Exported model to {model_filename}") - if params.enable_int8_quantization: - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + if not params.enable_int8_quantization: + return - logging.info("Generate int8 quantization models") + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") if params.use_external_data: model_filename_int8 = f"ctc-{suffix}.int8.onnx" diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index daeb86f6af..1412c6e846 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -107,6 +107,14 @@ def get_parser(): help="1 to also export int8 onnx models.", ) + parser.add_argument( + "--use-int32-inputs", + type=int, + default=0, + help="""1 to use int32_t as input types if applicable. 0 to use + int64_t otherwise.""", + ) + parser.add_argument( "--epoch", type=int, @@ -285,7 +293,7 @@ def forward( ) assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) - src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool) + src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.int32) # processed_mask is used to mask out initial states processed_mask = torch.arange(left_context_len, device=x.device).expand( @@ -293,7 +301,9 @@ def forward( ) processed_lens = states[-1] # (batch,) # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).to( + torch.int32 + ).flip(1) # Update processed lengths new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) @@ -318,7 +328,7 @@ def forward( new_states = new_encoder_states + [ new_cached_embed_left_pad, - new_processed_lens, + new_processed_lens.to(states[-1].dtype), ] return encoder_out, new_states @@ -406,6 +416,7 @@ def export_encoder_model_onnx( dynamic_batch: bool = True, use_whisper_features: bool = False, use_external_data: bool = False, + use_int32_inputs: int = 0, ) -> None: encoder_model.encoder.__class__.forward = ( encoder_model.encoder.__class__.streaming_forward @@ -418,6 +429,12 @@ def export_encoder_model_onnx( x = torch.rand(1, T, feature_dim, dtype=torch.float32) init_state = encoder_model.get_init_states() + + if use_int32_inputs: + init_state = [ + s if s.dtype != torch.int64 else s.to(torch.int32) for s in init_state + ] + num_encoders = len(encoder_model.encoder.encoder_dim) logging.info(f"num_encoders: {num_encoders}") logging.info(f"len(init_state): {len(init_state)}") @@ -566,6 +583,7 @@ def export_decoder_model_onnx( decoder_filename: str, opset_version: int = 11, dynamic_batch: bool = True, + use_int32_inputs: int = 0, ) -> None: """Export the decoder model to ONNX format. @@ -588,7 +606,11 @@ def export_decoder_model_onnx( context_size = decoder_model.decoder.context_size vocab_size = decoder_model.decoder.vocab_size - y = torch.zeros(1, context_size, dtype=torch.int64) + if use_int32_inputs: + y = torch.zeros(1, context_size, dtype=torch.int32) + else: + y = torch.zeros(1, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, @@ -817,10 +839,12 @@ def main(): opset_version=opset_version, feature_dim=params.feature_dim, dynamic_batch=params.dynamic_batch == 1, + use_int32_inputs=params.use_int32_inputs, use_whisper_features=params.use_whisper_features, use_external_data=params.use_external_data, ) logging.info(f"Exported encoder to {encoder_filename}") + return logging.info("Exporting decoder") decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" @@ -829,8 +853,10 @@ def main(): decoder_filename, opset_version=opset_version, dynamic_batch=params.dynamic_batch == 1, + use_int32_inputs=params.use_int32_inputs, ) logging.info(f"Exported decoder to {decoder_filename}") + return logging.info("Exporting joiner") joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index 03c7d6f820..d569f7878d 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -226,7 +226,7 @@ def forward( - encoder_out_lens, A 1-D tensor of shape (N,) """ x, x_lens = self.encoder_embed(x, x_lens) - src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]) + src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]).to(torch.int32) x = x.permute(1, 0, 2) encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) diff --git a/egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py b/egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py new file mode 100755 index 0000000000..9b31652819 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py @@ -0,0 +1,679 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a transducer model from PyTorch to ONNX. + +If you train a streaming model and want to export a non-streaming version, +please use this script. + +Example: + + ./zipformer/export-streaming-as-non-streaming-onnx.py \ + --max-len -1 \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir ./exp \ + --tokens ./tokens.txt \ + \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --causal 1 \ + --use-int32-inputs 1 \ + --chunk-size "-1" \ + --left-context-frames "-1" +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--max-len", + type=int, + default=-1, + ) + + parser.add_argument( + "--dynamic-axes", + type=int, + default=1, + help="1 to support dynamic axes. 0 to diable dynamic axes", + ) + + parser.add_argument( + "--use-int32-inputs", + type=int, + default=0, + help="""1 to use int32_t as input types if applicable. 0 to use + int64_t otherwise.""", + ) + + parser.add_argument( + "--enable-int8-quantization", + type=int, + default=1, + help="1 to also export int8 onnx models.", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]).to(torch.int32) + x = x.permute(1, 0, 2) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + max_len: int, + dynamic_axes: int, + use_int32_inputs: int, + opset_version: int = 13, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + if max_len > 0: + x = torch.zeros(1, max_len, 80, dtype=torch.float32) + else: + x = torch.zeros(1, 3000, 80, dtype=torch.float32) + + if use_int32_inputs: + x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) + else: + x_lens = torch.tensor([x.shape[1]], dtype=torch.int64) + print("x_lens", x_lens) + + encoder_model = torch.jit.trace(encoder_model, x) + print("x_lens", x_lens) + + torch.onnx.export( + encoder_model, + x, + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["encoder_out"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + } + if dynamic_axes + else {}, + ) + + meta_data = { + "model_type": "zipformer2", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming zipformer2", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + use_int32_inputs, + dynamic_axes: int, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + if use_int32_inputs: + y = torch.zeros(1, context_size, dtype=torch.int32) + else: + y = torch.zeros(1, context_size, dtype=torch.int64) + + decoder_model = torch.jit.script(decoder_model) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + } + if dynamic_axes + else {}, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + dynamic_axes: int, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + } + if dynamic_axes + else {}, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + max_len=params.max_len, + dynamic_axes=params.dynamic_axes, + use_int32_inputs=params.use_int32_inputs, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + dynamic_axes=params.dynamic_axes, + use_int32_inputs=params.use_int32_inputs, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + dynamic_axes=params.dynamic_axes, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + if params.fp16: + logging.info("Generate fp16 models") + + encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" + export_onnx_fp16(encoder_filename, encoder_filename_fp16) + + decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" + export_onnx_fp16(decoder_filename, decoder_filename_fp16) + + joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" + export_onnx_fp16(joiner_filename, joiner_filename_fp16) + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + if not params.enable_int8_quantization: + return + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/export.py b/egs/librispeech/ASR/zipformer/export.py index 1f3373cd83..6e46f7548c 100755 --- a/egs/librispeech/ASR/zipformer/export.py +++ b/egs/librispeech/ASR/zipformer/export.py @@ -278,7 +278,7 @@ def forward( """ x, x_lens = self.encoder_embed(features, feature_lengths) - src_key_padding_mask = make_pad_mask(x_lens) + src_key_padding_mask = make_pad_mask(x_lens).to(torch.int32) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) @@ -327,7 +327,7 @@ def forward( ) assert x.size(1) == chunk_size, (x.size(1), chunk_size) - src_key_padding_mask = make_pad_mask(x_lens) + src_key_padding_mask = make_pad_mask(x_lens).to(torch.int32) # processed_mask is used to mask out initial states processed_mask = torch.arange(left_context_len, device=x.device).expand( @@ -335,7 +335,7 @@ def forward( ) processed_lens = states[-1] # (batch,) # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).to(torch.int32).flip(1) # Update processed lengths new_processed_lens = processed_lens + x_lens diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py index 485efc6b6b..1d24a159e4 100644 --- a/egs/librispeech/ASR/zipformer/scaling_converter.py +++ b/egs/librispeech/ASR/zipformer/scaling_converter.py @@ -29,6 +29,7 @@ import torch.nn as nn from scaling import ( Balancer, + ChunkCausalDepthwiseConv1d, Dropout3, ScaleGrad, SwooshL, @@ -39,11 +40,53 @@ ) from zipformer import ( CompactRelPositionalEncoding, - ChunkCausalDepthwiseConv1d, SimpleDownsample, ) +class NonStreamingChunkCausalDepthwiseConv1d(torch.nn.Module): + """A non-streaming replacement for ChunkCausalDepthwiseConv1d that avoids + dynamic-shape torch.zeros and conditionals, making it ONNX-export friendly. + + In non-streaming mode (chunk_size=-1), the entire sequence is one chunk, + so we simplify the forward pass accordingly. + """ + + def __init__(self, original: ChunkCausalDepthwiseConv1d): + super().__init__() + self.causal_conv = original.causal_conv + self.chunkwise_conv = original.chunkwise_conv + self.chunkwise_conv_scale = original.chunkwise_conv_scale + self.kernel_size = original.kernel_size + + def forward(self, x: torch.Tensor, chunk_size: int = -1) -> torch.Tensor: + (batch_size, num_channels, seq_len) = x.shape + left_pad = self.kernel_size // 2 + + x = torch.nn.functional.pad(x, (left_pad, 0)) + + x_causal = self.causal_conv(x[..., : left_pad + seq_len]) + + x_chunk = x[..., left_pad:] + x_chunk = self.chunkwise_conv(x_chunk) + + left_edge = self.chunkwise_conv_scale[0] + right_edge = self.chunkwise_conv_scale[1] + # seq_len >= kernel_size in non-streaming mode, so we pad with zeros + t = seq_len - self.kernel_size + channels = left_edge.shape[0] + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype + ) + left_edge = torch.cat((left_edge, pad), dim=-1) + right_edge = torch.cat((pad, right_edge), dim=-1) + chunk_scale = 1.0 + (left_edge + right_edge) + + x_chunk = x_chunk * chunk_scale + + return x_chunk + x_causal + + # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa # get_submodule was added to nn.Module at v1.9.0 def get_submodule(model, target): @@ -93,11 +136,12 @@ def convert_scaled_to_non_scaled( d[name] = SwooshROnnx() elif is_onnx and isinstance(m, SwooshL): d[name] = SwooshLOnnx() + elif is_onnx and isinstance(m, ChunkCausalDepthwiseConv1d): + d[name] = torch.jit.script(NonStreamingChunkCausalDepthwiseConv1d(m)) elif is_onnx and isinstance( m, ( CompactRelPositionalEncoding, - ChunkCausalDepthwiseConv1d, SimpleDownsample, ), ): diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d92836e05a..bd953c900b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1427,7 +1427,7 @@ def __init__( self, embed_dim: int, dropout_rate: FloatLike, - max_len: int = 1000, + max_len: int = 2000, length_factor: float = 1.0, ) -> None: """Construct a CompactRelPositionalEncoding object.""" @@ -1733,7 +1733,7 @@ def forward( seq_len, ), key_padding_mask.shape attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), + key_padding_mask.to(torch.bool).unsqueeze(1), -1000, ) @@ -1863,7 +1863,7 @@ def streaming_forward( if key_padding_mask is not None: assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), + key_padding_mask.to(torch.bool).unsqueeze(1), -1000, ) @@ -2354,7 +2354,7 @@ def forward( x = x.permute(1, 2, 0) # (#batch, channels, time). if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = x.masked_fill(src_key_padding_mask.to(torch.bool).unsqueeze(1).expand_as(x), 0.0) if ( not torch.jit.is_scripting() @@ -2408,7 +2408,7 @@ def streaming_forward( x = x.permute(1, 2, 0) # (#batch, channels, time). if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = x.masked_fill(src_key_padding_mask.to(torch.bool).unsqueeze(1).expand_as(x), 0.0) x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) From e80ccb60f96092aebbb75907afe4a36a8653ef83 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 5 Jun 2026 14:38:25 +0800 Subject: [PATCH 5/6] keep x_lens if needed --- .../ASR/zipformer/export-onnx-streaming.py | 7 +- .../export-streaming-as-non-streaming-onnx.py | 90 +++++++++++++++---- 2 files changed, 75 insertions(+), 22 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 1412c6e846..693b103c2b 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -301,9 +301,9 @@ def forward( ) processed_lens = states[-1] # (batch,) # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).to( - torch.int32 - ).flip(1) + processed_mask = ( + (processed_lens.unsqueeze(1) <= processed_mask).to(torch.int32).flip(1) + ) # Update processed lengths new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) @@ -856,7 +856,6 @@ def main(): use_int32_inputs=params.use_int32_inputs, ) logging.info(f"Exported decoder to {decoder_filename}") - return logging.info("Exporting joiner") joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" diff --git a/egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py b/egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py index 9b31652819..caaef4a82f 100755 --- a/egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py @@ -9,7 +9,7 @@ If you train a streaming model and want to export a non-streaming version, please use this script. -Example: +Example 1: Export a streaming model as a non-streaming model. ./zipformer/export-streaming-as-non-streaming-onnx.py \ --max-len -1 \ @@ -27,6 +27,29 @@ --use-int32-inputs 1 \ --chunk-size "-1" \ --left-context-frames "-1" + +Example 2: Export a streaming model as a non-streaming model suitable +for NPU (e.g., Qualcomm NPU) + + ./zipformer/export-streaming-as-non-streaming-onnx.py \ + --keep-x-lens 0 \ + --max-len 1000 \ + --dynamic-axes 0 \ + --use-int32-inputs 1 \ + --enable-int8-quantization 0 \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir ./exp \ + --tokens ./tokens.txt \ + \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --causal 1 \ + --chunk-size "-1" \ + --left-context-frames "-1" """ import argparse @@ -92,6 +115,13 @@ def get_parser(): default=-1, ) + parser.add_argument( + "--keep-x-lens", + type=int, + default=-1, + help="1 to keep the encoder input x_lens. 0 to discard it", + ) + parser.add_argument( "--dynamic-axes", type=int, @@ -205,10 +235,21 @@ def __init__( self.encoder_embed = encoder_embed self.encoder_proj = encoder_proj + def forward2(self, x: torch.Tensor): + x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]).to(torch.int32) + x = x.permute(1, 0, 2) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + return encoder_out + def forward( self, x: torch.Tensor, - ) -> torch.Tensor: + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: """Please see the help information of Zipformer.forward Args: @@ -221,7 +262,6 @@ def forward( - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - encoder_out_lens, A 1-D tensor of shape (N,) """ - x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) x, x_lens = self.encoder_embed(x, x_lens) src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]).to(torch.int32) x = x.permute(1, 0, 2) @@ -230,7 +270,7 @@ def forward( encoder_out = self.encoder_proj(encoder_out) # Now encoder_out is of shape (N, T, joiner_dim) - return encoder_out + return encoder_out, encoder_out_lens class OnnxDecoder(nn.Module): @@ -289,6 +329,7 @@ def export_encoder_model_onnx( max_len: int, dynamic_axes: int, use_int32_inputs: int, + keep_x_lens: int = 1, opset_version: int = 13, ) -> None: """Export the given encoder model to ONNX format. @@ -319,27 +360,39 @@ def export_encoder_model_onnx( x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) else: x_lens = torch.tensor([x.shape[1]], dtype=torch.int64) - print("x_lens", x_lens) - encoder_model = torch.jit.trace(encoder_model, x) - print("x_lens", x_lens) + if keep_x_lens: + inputs = (x, x_lens) + input_names = ["x", "x_lens"] + output_names = ["encoder_out", "encoder_out_lens"] + dynamic_axes_dict = { + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + } + else: + encoder_model.__class__.forward = encoder_model.__class__.forward2 + + inputs = (x,) + input_names = ["x"] + output_names = ["encoder_out"] + dynamic_axes_dict = { + "x": {0: "N", 1: "T"}, + "encoder_out": {0: "N", 1: "T"}, + } + + encoder_model = torch.jit.trace(encoder_model, inputs) torch.onnx.export( encoder_model, - x, + inputs, encoder_filename, verbose=False, opset_version=opset_version, - input_names=["x"], - output_names=["encoder_out"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - } - if dynamic_axes - else {}, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes_dict if dynamic_axes else {}, ) meta_data = { @@ -604,6 +657,7 @@ def main(): dynamic_axes=params.dynamic_axes, use_int32_inputs=params.use_int32_inputs, opset_version=opset_version, + keep_x_lens=params.keep_x_lens, ) logging.info(f"Exported encoder to {encoder_filename}") From 99086d2ce48701c5b6ebef341d4b20f3b42983e5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 5 Jun 2026 15:18:39 +0800 Subject: [PATCH 6/6] Fixes after review --- egs/librispeech/ASR/zipformer/export-onnx-streaming.py | 1 - .../ASR/zipformer/export-streaming-as-non-streaming-onnx.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 693b103c2b..c2d20260e4 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -844,7 +844,6 @@ def main(): use_external_data=params.use_external_data, ) logging.info(f"Exported encoder to {encoder_filename}") - return logging.info("Exporting decoder") decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" diff --git a/egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py b/egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py index caaef4a82f..78804710f5 100755 --- a/egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py @@ -236,7 +236,7 @@ def __init__( self.encoder_proj = encoder_proj def forward2(self, x: torch.Tensor): - x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) + x_lens = torch.tensor([x.shape[1]], dtype=torch.int32, device=x.device) x, x_lens = self.encoder_embed(x, x_lens) src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]).to(torch.int32) x = x.permute(1, 0, 2)