Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ def get_parser():
You can specify --avg to use more checkpoints for model averaging.""",
)

parser.add_argument(
"--use-int32-inputs",
type=int,
default=0,
help="""1 to use int32_t as input types if applicable. 0 to use
int64_t otherwise.""",
)
Comment on lines +111 to +117

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use str2bool instead of int for --use-int32-inputs to be consistent with other boolean flags in the repository (like --use-averaged-model) and to provide a more user-friendly command-line interface.

Suggested change
parser.add_argument(
"--use-int32-inputs",
type=int,
default=0,
help="""1 to use int32_t as input types if applicable. 0 to use
int64_t otherwise.""",
)
parser.add_argument(
"--use-int32-inputs",
type=str2bool,
default=False,
help="""True to use int32_t as input types if applicable. False to use
int64_t otherwise.""",
)

Comment on lines +111 to +117

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Restrict --use-int32-inputs to the documented values.

type=int accepts any non-zero value here, so --use-int32-inputs 2 silently exports the int32 variant even though the help text documents only 0 and 1.

Suggested fix
     parser.add_argument(
         "--use-int32-inputs",
         type=int,
         default=0,
+        choices=(0, 1),
         help="""1 to use int32_t as input types if applicable. 0 to use
         int64_t otherwise.""",
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parser.add_argument(
"--use-int32-inputs",
type=int,
default=0,
help="""1 to use int32_t as input types if applicable. 0 to use
int64_t otherwise.""",
)
parser.add_argument(
"--use-int32-inputs",
type=int,
default=0,
choices=(0, 1),
help="""1 to use int32_t as input types if applicable. 0 to use
int64_t otherwise.""",
)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py`
around lines 111 - 117, The --use-int32-inputs argument currently uses
parser.add_argument with type=int which accepts any non-zero integer; restrict
it to only the documented values by adding choices=(0,1) (or an equivalent
validation) on the parser.add_argument call for "--use-int32-inputs" so only 0
or 1 are accepted and adjust the help text if needed to match the enforced
choices.


parser.add_argument(
"--iter",
type=int,
Expand Down Expand Up @@ -272,6 +280,7 @@ def export_encoder_model_onnx(
encoder_filename: str,
opset_version: int = 11,
dynamic_batch: bool = True,
use_int32_inputs: int = 0,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Change the type hint of use_int32_inputs to bool to match the recommended str2bool argument type.

Suggested change
use_int32_inputs: int = 0,
use_int32_inputs: bool = False,

) -> None:
"""
Onnx model inputs:
Expand Down Expand Up @@ -307,6 +316,10 @@ def export_encoder_model_onnx(
x = torch.rand(1, T, 80, dtype=torch.float32)

init_state = encoder_model.encoder.get_init_state()
if use_int32_inputs:
init_state = [
s if s.dtype != torch.int64 else s.to(torch.int32) for s in init_state
]

num_encoders = encoder_model.encoder.num_encoders
logging.info(f"num_encoders: {num_encoders}")
Expand Down Expand Up @@ -409,6 +422,7 @@ def export_decoder_model_onnx(
decoder_filename: str,
opset_version: int = 11,
dynamic_batch: bool = True,
use_int32_inputs: int = 0,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Change the type hint of use_int32_inputs to bool to match the recommended str2bool argument type.

Suggested change
use_int32_inputs: int = 0,
use_int32_inputs: bool = False,

) -> None:
"""Export the decoder model to ONNX format.

Expand All @@ -432,7 +446,10 @@ def export_decoder_model_onnx(
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(1, context_size, dtype=torch.int64)
if use_int32_inputs:
y = torch.zeros(1, context_size, dtype=torch.int32)
Comment on lines +449 to +450

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While setting y to dtype=torch.int32 changes the exported ONNX model's input type to int32, Decoder.forward in decoder.py explicitly casts y to torch.int64 via y = y.to(torch.int64). This will introduce a Cast to int64 node and subsequent int64 operations in the exported ONNX graph, which might defeat the purpose of using int32 inputs for hardware accelerators that do not support int64 data types. Consider updating Decoder.forward to avoid casting to int64 if int32 is preferred/supported by the embedding layer in the target ONNX runtime.

else:
y = torch.zeros(1, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
Expand Down Expand Up @@ -655,6 +672,7 @@ def main():
encoder_filename,
opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
use_int32_inputs=params.use_int32_inputs,
)
logging.info(f"Exported encoder to {encoder_filename}")

Expand All @@ -665,6 +683,7 @@ def main():
decoder_filename,
opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
use_int32_inputs=params.use_int32_inputs,
)
logging.info(f"Exported decoder to {decoder_filename}")

Expand Down
Loading