-
Notifications
You must be signed in to change notification settings - Fork 414
Add QNN-compatible ONNX export for non-streaming zipformer transducer. #2088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -89,6 +89,42 @@ def get_parser(): | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||||||
| ) | ||||||
|
|
||||||
| parser.add_argument( | ||||||
| "--max-len", | ||||||
| type=int, | ||||||
| help="Use for export the model for static shapes", | ||||||
| default=-1, | ||||||
| ) | ||||||
|
|
||||||
| parser.add_argument( | ||||||
| "--keep-x-lens", | ||||||
| type=int, | ||||||
| default=-1, | ||||||
| help="1 to keep the encoder input x_lens. 0 to discard it", | ||||||
| ) | ||||||
|
|
||||||
| parser.add_argument( | ||||||
| "--use-int32-inputs", | ||||||
| type=int, | ||||||
| default=0, | ||||||
| help="""1 to use int32_t as input types if applicable. 0 to use | ||||||
| int64_t otherwise.""", | ||||||
| ) | ||||||
|
|
||||||
| parser.add_argument( | ||||||
| "--dynamic-axes", | ||||||
| type=int, | ||||||
| default=1, | ||||||
| help="1 to support dynamic axes. 0 to diable dynamic axes", | ||||||
| ) | ||||||
|
|
||||||
| parser.add_argument( | ||||||
| "--enable-int8-quantization", | ||||||
| type=int, | ||||||
| default=1, | ||||||
| help="1 to also export int8 onnx models.", | ||||||
| ) | ||||||
|
|
||||||
| parser.add_argument( | ||||||
| "--epoch", | ||||||
| type=int, | ||||||
|
|
@@ -208,6 +244,33 @@ def __init__( | |||||
| self.encoder_embed = encoder_embed | ||||||
| self.encoder_proj = encoder_proj | ||||||
|
|
||||||
| def forward2( | ||||||
| self, | ||||||
| x: torch.Tensor, | ||||||
| ) -> torch.Tensor: | ||||||
| """Please see the help information of Zipformer.forward | ||||||
|
|
||||||
| Args: | ||||||
| x: | ||||||
| A 3-D tensor of shape (N, T, C) | ||||||
| x_lens: | ||||||
| A 1-D tensor of shape (N,). Its dtype is torch.int64 | ||||||
| Returns: | ||||||
| Return a tuple containing: | ||||||
| - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) | ||||||
| - encoder_out_lens, A 1-D tensor of shape (N,) | ||||||
| """ | ||||||
|
Comment on lines
+251
to
+262
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring for """Please see the help information of Zipformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
Returns:
Return encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
""" |
||||||
| assert x.shape[0] == 1, x.shape | ||||||
| x_lens = torch.tensor([x.shape[1]], dtype=torch.int32, device=x.device) | ||||||
| x, x_lens = self.encoder_embed(x, x_lens) | ||||||
| x = x.permute(1, 0, 2) | ||||||
| encoder_out, encoder_out_lens = self.encoder(x, x_lens) | ||||||
| encoder_out = encoder_out.permute(1, 0, 2) | ||||||
| encoder_out = self.encoder_proj(encoder_out) | ||||||
| # Now encoder_out is of shape (N, T, joiner_dim) | ||||||
|
|
||||||
| return encoder_out | ||||||
|
|
||||||
| def forward( | ||||||
| self, | ||||||
| x: torch.Tensor, | ||||||
|
|
@@ -289,18 +352,24 @@ def forward( | |||||
| def export_encoder_model_onnx( | ||||||
| encoder_model: OnnxEncoder, | ||||||
| encoder_filename: str, | ||||||
| max_len: int, | ||||||
| dynamic_axes: int, | ||||||
| use_int32_inputs: int, | ||||||
| keep_x_lens: int = 1, | ||||||
| opset_version: int = 11, | ||||||
| ) -> None: | ||||||
| """Export the given encoder model to ONNX format. | ||||||
| The exported model has two inputs: | ||||||
| If keep_x_lens is 1: | ||||||
| The exported model has two inputs | ||||||
|
|
||||||
| - x, a tensor of shape (N, T, C); dtype is torch.float32 | ||||||
| - x_lens, a tensor of shape (N,); dtype is torch.int64 | ||||||
| - x, a tensor of shape (N, T, C); dtype is torch.float32 | ||||||
| - x_lens, a tensor of shape (N,); dtype is torch.int64 if | ||||||
| use_int32_inputs is 0; otherwise its dtype is torch.int32 | ||||||
|
|
||||||
| and it has two outputs: | ||||||
| and it has two outputs: | ||||||
|
|
||||||
| - encoder_out, a tensor of shape (N, T', joiner_dim) | ||||||
| - encoder_out_lens, a tensor of shape (N,) | ||||||
| - encoder_out, a tensor of shape (N, T', joiner_dim) | ||||||
| - encoder_out_lens, a tensor of shape (N,) | ||||||
|
|
||||||
| Args: | ||||||
| encoder_model: | ||||||
|
|
@@ -310,25 +379,48 @@ def export_encoder_model_onnx( | |||||
| opset_version: | ||||||
| The opset version to use. | ||||||
| """ | ||||||
| x = torch.zeros(1, 100, 80, dtype=torch.float32) | ||||||
| x_lens = torch.tensor([100], dtype=torch.int64) | ||||||
| if max_len > 0: | ||||||
| x = torch.zeros(1, max_len, 80, dtype=torch.float32) | ||||||
| else: | ||||||
| x = torch.zeros(1, 300, 80, dtype=torch.float32) | ||||||
|
|
||||||
| encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) | ||||||
| if use_int32_inputs: | ||||||
| x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) | ||||||
| else: | ||||||
| x_lens = torch.tensor([x.shape[1]], dtype=torch.int64) | ||||||
|
|
||||||
| if keep_x_lens: | ||||||
| inputs = (x, x_lens) | ||||||
| input_names = ["x", "x_lens"] | ||||||
| output_names = ["encoder_out", "encoder_out_lens"] | ||||||
| dynamic_axes_dict = { | ||||||
| "x": {0: "N", 1: "T"}, | ||||||
| "x_lens": {0: "N"}, | ||||||
| "encoder_out": {0: "N", 1: "T"}, | ||||||
| "encoder_out_lens": {0: "N"}, | ||||||
| } | ||||||
| else: | ||||||
| encoder_model.__class__.forward = encoder_model.__class__.forward2 | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modifying the class-level
Suggested change
|
||||||
|
|
||||||
| inputs = (x,) | ||||||
| input_names = ["x"] | ||||||
| output_names = ["encoder_out"] | ||||||
| dynamic_axes_dict = { | ||||||
| "x": {0: "N", 1: "T"}, | ||||||
| "encoder_out": {0: "N", 1: "T"}, | ||||||
| } | ||||||
|
|
||||||
| encoder_model = torch.jit.trace(encoder_model, inputs) | ||||||
|
|
||||||
| torch.onnx.export( | ||||||
| encoder_model, | ||||||
| (x, x_lens), | ||||||
| inputs, | ||||||
| encoder_filename, | ||||||
| verbose=False, | ||||||
| opset_version=opset_version, | ||||||
| input_names=["x", "x_lens"], | ||||||
| output_names=["encoder_out", "encoder_out_lens"], | ||||||
| dynamic_axes={ | ||||||
| "x": {0: "N", 1: "T"}, | ||||||
| "x_lens": {0: "N"}, | ||||||
| "encoder_out": {0: "N", 1: "T"}, | ||||||
| "encoder_out_lens": {0: "N"}, | ||||||
| }, | ||||||
| input_names=input_names, | ||||||
| output_names=output_names, | ||||||
| dynamic_axes=dynamic_axes_dict if dynamic_axes else {}, | ||||||
| ) | ||||||
|
|
||||||
| meta_data = { | ||||||
|
|
@@ -345,6 +437,8 @@ def export_encoder_model_onnx( | |||||
| def export_decoder_model_onnx( | ||||||
| decoder_model: OnnxDecoder, | ||||||
| decoder_filename: str, | ||||||
| use_int32_inputs, | ||||||
| dynamic_axes: int, | ||||||
|
Comment on lines
+440
to
+441
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| opset_version: int = 11, | ||||||
| ) -> None: | ||||||
| """Export the decoder model to ONNX format. | ||||||
|
|
@@ -368,7 +462,11 @@ def export_decoder_model_onnx( | |||||
| context_size = decoder_model.decoder.context_size | ||||||
| vocab_size = decoder_model.decoder.vocab_size | ||||||
|
|
||||||
| y = torch.zeros(10, context_size, dtype=torch.int64) | ||||||
| if use_int32_inputs: | ||||||
| y = torch.zeros(1, context_size, dtype=torch.int32) | ||||||
| else: | ||||||
| y = torch.zeros(1, context_size, dtype=torch.int64) | ||||||
|
|
||||||
| decoder_model = torch.jit.script(decoder_model) | ||||||
| torch.onnx.export( | ||||||
| decoder_model, | ||||||
|
|
@@ -381,7 +479,9 @@ def export_decoder_model_onnx( | |||||
| dynamic_axes={ | ||||||
| "y": {0: "N"}, | ||||||
| "decoder_out": {0: "N"}, | ||||||
| }, | ||||||
| } | ||||||
| if dynamic_axes | ||||||
| else {}, | ||||||
| ) | ||||||
|
|
||||||
| meta_data = { | ||||||
|
|
@@ -394,6 +494,7 @@ def export_decoder_model_onnx( | |||||
| def export_joiner_model_onnx( | ||||||
| joiner_model: nn.Module, | ||||||
| joiner_filename: str, | ||||||
| dynamic_axes: int, | ||||||
| opset_version: int = 11, | ||||||
| ) -> None: | ||||||
| """Export the joiner model to ONNX format. | ||||||
|
|
@@ -409,8 +510,8 @@ def export_joiner_model_onnx( | |||||
| joiner_dim = joiner_model.output_linear.weight.shape[1] | ||||||
| logging.info(f"joiner dim: {joiner_dim}") | ||||||
|
|
||||||
| projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) | ||||||
| projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) | ||||||
| projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) | ||||||
| projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) | ||||||
|
|
||||||
| torch.onnx.export( | ||||||
| joiner_model, | ||||||
|
|
@@ -427,7 +528,9 @@ def export_joiner_model_onnx( | |||||
| "encoder_out": {0: "N"}, | ||||||
| "decoder_out": {0: "N"}, | ||||||
| "logit": {0: "N"}, | ||||||
| }, | ||||||
| } | ||||||
| if dynamic_axes | ||||||
| else {}, | ||||||
| ) | ||||||
| meta_data = { | ||||||
| "joiner_dim": str(joiner_dim), | ||||||
|
|
@@ -578,6 +681,10 @@ def main(): | |||||
| export_encoder_model_onnx( | ||||||
| encoder, | ||||||
| encoder_filename, | ||||||
| max_len=params.max_len, | ||||||
| dynamic_axes=params.dynamic_axes, | ||||||
| use_int32_inputs=params.use_int32_inputs, | ||||||
| keep_x_lens=params.keep_x_lens, | ||||||
| opset_version=opset_version, | ||||||
| ) | ||||||
| logging.info(f"Exported encoder to {encoder_filename}") | ||||||
|
|
@@ -587,6 +694,8 @@ def main(): | |||||
| export_decoder_model_onnx( | ||||||
| decoder, | ||||||
| decoder_filename, | ||||||
| dynamic_axes=params.dynamic_axes, | ||||||
| use_int32_inputs=params.use_int32_inputs, | ||||||
| opset_version=opset_version, | ||||||
| ) | ||||||
| logging.info(f"Exported decoder to {decoder_filename}") | ||||||
|
|
@@ -596,6 +705,7 @@ def main(): | |||||
| export_joiner_model_onnx( | ||||||
| joiner, | ||||||
| joiner_filename, | ||||||
| dynamic_axes=params.dynamic_axes, | ||||||
| opset_version=opset_version, | ||||||
| ) | ||||||
| logging.info(f"Exported joiner to {joiner_filename}") | ||||||
|
|
@@ -612,6 +722,9 @@ def main(): | |||||
| joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" | ||||||
| export_onnx_fp16(joiner_filename, joiner_filename_fp16) | ||||||
|
|
||||||
| if not params.enable_int8_quantization: | ||||||
| return | ||||||
|
|
||||||
| # Generate int8 quantization models | ||||||
| # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a typo in the help message for
--dynamic-axes:diableshould bedisable.