Support specifying input types in ONNX export script.#2084
Conversation
📝 WalkthroughWalkthroughAdded a CLI flag ChangesONNX int32 input dtype flag
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds a --use-int32-inputs option to the ONNX export script to support exporting models with int32 input types instead of int64. The reviewer suggests using str2bool instead of int for the command-line argument and updating the function type hints to bool for consistency. Additionally, the reviewer points out that Decoder.forward explicitly casts inputs to int64, which may introduce unwanted Cast nodes in the exported ONNX graph and defeat the purpose of using int32 inputs.
| if use_int32_inputs: | ||
| y = torch.zeros(1, context_size, dtype=torch.int32) |
There was a problem hiding this comment.
While setting y to dtype=torch.int32 changes the exported ONNX model's input type to int32, Decoder.forward in decoder.py explicitly casts y to torch.int64 via y = y.to(torch.int64). This will introduce a Cast to int64 node and subsequent int64 operations in the exported ONNX graph, which might defeat the purpose of using int32 inputs for hardware accelerators that do not support int64 data types. Consider updating Decoder.forward to avoid casting to int64 if int32 is preferred/supported by the embedding layer in the target ONNX runtime.
| parser.add_argument( | ||
| "--use-int32-inputs", | ||
| type=int, | ||
| default=0, | ||
| help="""1 to use int32_t as input types if applicable. 0 to use | ||
| int64_t otherwise.""", | ||
| ) |
There was a problem hiding this comment.
Use str2bool instead of int for --use-int32-inputs to be consistent with other boolean flags in the repository (like --use-averaged-model) and to provide a more user-friendly command-line interface.
| parser.add_argument( | |
| "--use-int32-inputs", | |
| type=int, | |
| default=0, | |
| help="""1 to use int32_t as input types if applicable. 0 to use | |
| int64_t otherwise.""", | |
| ) | |
| parser.add_argument( | |
| "--use-int32-inputs", | |
| type=str2bool, | |
| default=False, | |
| help="""True to use int32_t as input types if applicable. False to use | |
| int64_t otherwise.""", | |
| ) |
| encoder_filename: str, | ||
| opset_version: int = 11, | ||
| dynamic_batch: bool = True, | ||
| use_int32_inputs: int = 0, |
| decoder_filename: str, | ||
| opset_version: int = 11, | ||
| dynamic_batch: bool = True, | ||
| use_int32_inputs: int = 0, |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py (1)
427-452:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate the decoder contract docs for the int32 path.
Lines 431-432 still say
yis alwaystorch.int64, but this branch now exportsyastorch.int32when the flag is enabled. That leaves the script documenting the wrong ONNX input type.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py` around lines 427 - 452, The docstring for the decoder ONNX export incorrectly states the input 'y' is torch.int64 even though the code conditionally creates 'y' as torch.int32 when use_int32_inputs is true; update the decoder contract text in the function docstring to reflect that 'y' may be either torch.int32 or torch.int64 depending on the use_int32_inputs flag (mention use_int32_inputs, variable y, and decoder_model.decoder.context_size/context_size) so the documented ONNX input type matches the actual exported type.
🧹 Nitpick comments (1)
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py (1)
668-687: ⚡ Quick winEncode the selected input dtype in the exported artifact names.
--use-int32-inputschanges the public ONNX input schema, but the filenames stay identical, so an int32 export overwrites the default int64 bundle and the saved artifacts are indistinguishable without opening the model.Suggested fix
suffix += f"-avg-{params.avg}" if params.use_averaged_model: suffix += "-with-averaged-model" + suffix += "-int32-inputs" if params.use_int32_inputs else "-int64-inputs"🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py` around lines 668 - 687, The exported filenames don’t reflect the input dtype so int32 exports can overwrite int64 ones; update the logic that builds encoder_filename and decoder_filename (used when calling export_encoder_model_onnx and export_decoder_model_onnx) to incorporate params.use_int32_inputs into the suffix (for example append a "-int32" marker when params.use_int32_inputs is true) so both encoder_filename and decoder_filename uniquely encode the selected input dtype.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In
`@egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py`:
- Around line 111-117: The --use-int32-inputs argument currently uses
parser.add_argument with type=int which accepts any non-zero integer; restrict
it to only the documented values by adding choices=(0,1) (or an equivalent
validation) on the parser.add_argument call for "--use-int32-inputs" so only 0
or 1 are accepted and adjust the help text if needed to match the enforced
choices.
---
Outside diff comments:
In
`@egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py`:
- Around line 427-452: The docstring for the decoder ONNX export incorrectly
states the input 'y' is torch.int64 even though the code conditionally creates
'y' as torch.int32 when use_int32_inputs is true; update the decoder contract
text in the function docstring to reflect that 'y' may be either torch.int32 or
torch.int64 depending on the use_int32_inputs flag (mention use_int32_inputs,
variable y, and decoder_model.decoder.context_size/context_size) so the
documented ONNX input type matches the actual exported type.
---
Nitpick comments:
In
`@egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py`:
- Around line 668-687: The exported filenames don’t reflect the input dtype so
int32 exports can overwrite int64 ones; update the logic that builds
encoder_filename and decoder_filename (used when calling
export_encoder_model_onnx and export_decoder_model_onnx) to incorporate
params.use_int32_inputs into the suffix (for example append a "-int32" marker
when params.use_int32_inputs is true) so both encoder_filename and
decoder_filename uniquely encode the selected input dtype.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 397bddb5-6818-4a1c-8ae0-8cc032efd96a
📒 Files selected for processing (1)
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py
| parser.add_argument( | ||
| "--use-int32-inputs", | ||
| type=int, | ||
| default=0, | ||
| help="""1 to use int32_t as input types if applicable. 0 to use | ||
| int64_t otherwise.""", | ||
| ) |
There was a problem hiding this comment.
Restrict --use-int32-inputs to the documented values.
type=int accepts any non-zero value here, so --use-int32-inputs 2 silently exports the int32 variant even though the help text documents only 0 and 1.
Suggested fix
parser.add_argument(
"--use-int32-inputs",
type=int,
default=0,
+ choices=(0, 1),
help="""1 to use int32_t as input types if applicable. 0 to use
int64_t otherwise.""",
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| parser.add_argument( | |
| "--use-int32-inputs", | |
| type=int, | |
| default=0, | |
| help="""1 to use int32_t as input types if applicable. 0 to use | |
| int64_t otherwise.""", | |
| ) | |
| parser.add_argument( | |
| "--use-int32-inputs", | |
| type=int, | |
| default=0, | |
| choices=(0, 1), | |
| help="""1 to use int32_t as input types if applicable. 0 to use | |
| int64_t otherwise.""", | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py`
around lines 111 - 117, The --use-int32-inputs argument currently uses
parser.add_argument with type=int which accepts any non-zero integer; restrict
it to only the documented values by adding choices=(0,1) (or an equivalent
validation) on the parser.add_argument call for "--use-int32-inputs" so only 0
or 1 are accepted and adjust the help text if needed to match the enforced
choices.
It previously used
int64_tfor all integer inputs. This PR adds support for bothint32_tandint64_t, allowing the integer type to be selected as needed.Summary by CodeRabbit