diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index d569f7878d..a3970d2485 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -89,6 +89,42 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--max-len", + type=int, + help="Use for export the model for static shapes", + default=-1, + ) + + parser.add_argument( + "--keep-x-lens", + type=int, + default=-1, + help="1 to keep the encoder input x_lens. 0 to discard it", + ) + + parser.add_argument( + "--use-int32-inputs", + type=int, + default=0, + help="""1 to use int32_t as input types if applicable. 0 to use + int64_t otherwise.""", + ) + + parser.add_argument( + "--dynamic-axes", + type=int, + default=1, + help="1 to support dynamic axes. 0 to diable dynamic axes", + ) + + parser.add_argument( + "--enable-int8-quantization", + type=int, + default=1, + help="1 to also export int8 onnx models.", + ) + parser.add_argument( "--epoch", type=int, @@ -208,6 +244,33 @@ def __init__( self.encoder_embed = encoder_embed self.encoder_proj = encoder_proj + def forward2( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + assert x.shape[0] == 1, x.shape + x_lens = torch.tensor([x.shape[1]], dtype=torch.int32, device=x.device) + x, x_lens = self.encoder_embed(x, x_lens) + x = x.permute(1, 0, 2) + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out + def forward( self, x: torch.Tensor, @@ -289,18 +352,24 @@ def forward( def export_encoder_model_onnx( encoder_model: OnnxEncoder, encoder_filename: str, + max_len: int, + dynamic_axes: int, + use_int32_inputs: int, + keep_x_lens: int = 1, opset_version: int = 11, ) -> None: """Export the given encoder model to ONNX format. - The exported model has two inputs: + If keep_x_lens is 1: + The exported model has two inputs - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 if + use_int32_inputs is 0; otherwise its dtype is torch.int32 - and it has two outputs: + and it has two outputs: - - encoder_out, a tensor of shape (N, T', joiner_dim) - - encoder_out_lens, a tensor of shape (N,) + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) Args: encoder_model: @@ -310,25 +379,48 @@ def export_encoder_model_onnx( opset_version: The opset version to use. """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) + if max_len > 0: + x = torch.zeros(1, max_len, 80, dtype=torch.float32) + else: + x = torch.zeros(1, 300, 80, dtype=torch.float32) - encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) + if use_int32_inputs: + x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) + else: + x_lens = torch.tensor([x.shape[1]], dtype=torch.int64) + + if keep_x_lens: + inputs = (x, x_lens) + input_names = ["x", "x_lens"] + output_names = ["encoder_out", "encoder_out_lens"] + dynamic_axes_dict = { + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + } + else: + encoder_model.__class__.forward = encoder_model.__class__.forward2 + + inputs = (x,) + input_names = ["x"] + output_names = ["encoder_out"] + dynamic_axes_dict = { + "x": {0: "N", 1: "T"}, + "encoder_out": {0: "N", 1: "T"}, + } + + encoder_model = torch.jit.trace(encoder_model, inputs) torch.onnx.export( encoder_model, - (x, x_lens), + inputs, encoder_filename, verbose=False, opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes_dict if dynamic_axes else {}, ) meta_data = { @@ -345,6 +437,8 @@ def export_encoder_model_onnx( def export_decoder_model_onnx( decoder_model: OnnxDecoder, decoder_filename: str, + use_int32_inputs, + dynamic_axes: int, opset_version: int = 11, ) -> None: """Export the decoder model to ONNX format. @@ -368,7 +462,11 @@ def export_decoder_model_onnx( context_size = decoder_model.decoder.context_size vocab_size = decoder_model.decoder.vocab_size - y = torch.zeros(10, context_size, dtype=torch.int64) + if use_int32_inputs: + y = torch.zeros(1, context_size, dtype=torch.int32) + else: + y = torch.zeros(1, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, @@ -381,7 +479,9 @@ def export_decoder_model_onnx( dynamic_axes={ "y": {0: "N"}, "decoder_out": {0: "N"}, - }, + } + if dynamic_axes + else {}, ) meta_data = { @@ -394,6 +494,7 @@ def export_decoder_model_onnx( def export_joiner_model_onnx( joiner_model: nn.Module, joiner_filename: str, + dynamic_axes: int, opset_version: int = 11, ) -> None: """Export the joiner model to ONNX format. @@ -409,8 +510,8 @@ def export_joiner_model_onnx( joiner_dim = joiner_model.output_linear.weight.shape[1] logging.info(f"joiner dim: {joiner_dim}") - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) torch.onnx.export( joiner_model, @@ -427,7 +528,9 @@ def export_joiner_model_onnx( "encoder_out": {0: "N"}, "decoder_out": {0: "N"}, "logit": {0: "N"}, - }, + } + if dynamic_axes + else {}, ) meta_data = { "joiner_dim": str(joiner_dim), @@ -578,6 +681,10 @@ def main(): export_encoder_model_onnx( encoder, encoder_filename, + max_len=params.max_len, + dynamic_axes=params.dynamic_axes, + use_int32_inputs=params.use_int32_inputs, + keep_x_lens=params.keep_x_lens, opset_version=opset_version, ) logging.info(f"Exported encoder to {encoder_filename}") @@ -587,6 +694,8 @@ def main(): export_decoder_model_onnx( decoder, decoder_filename, + dynamic_axes=params.dynamic_axes, + use_int32_inputs=params.use_int32_inputs, opset_version=opset_version, ) logging.info(f"Exported decoder to {decoder_filename}") @@ -596,6 +705,7 @@ def main(): export_joiner_model_onnx( joiner, joiner_filename, + dynamic_axes=params.dynamic_axes, opset_version=opset_version, ) logging.info(f"Exported joiner to {joiner_filename}") @@ -612,6 +722,9 @@ def main(): joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" export_onnx_fp16(joiner_filename, joiner_filename_fp16) + if not params.enable_int8_quantization: + return + # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py index 662392b5fe..083c6c5dd1 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py @@ -140,6 +140,7 @@ def __init__( session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 4 + session_opts.log_severity_level = 3 self.session_opts = session_opts