Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/zipformer/export-onnx-ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def forward(
- log_probs_len, a 1-D int64 tensor of shape (N,)
"""
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
src_key_padding_mask = make_pad_mask(x_lens).to(torch.int32)
x = x.permute(1, 0, 2)
encoder_out, log_probs_len = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2)
Expand Down
14 changes: 8 additions & 6 deletions egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,15 @@ def forward(
)
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)

src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)
src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.int32)

# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).to(torch.int32).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
Expand Down Expand Up @@ -641,11 +641,13 @@ def main():
)
logging.info(f"Exported model to {model_filename}")

if params.enable_int8_quantization:
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
if not params.enable_int8_quantization:
return

logging.info("Generate int8 quantization models")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

logging.info("Generate int8 quantization models")

if params.use_external_data:
model_filename_int8 = f"ctc-{suffix}.int8.onnx"
Expand Down
33 changes: 29 additions & 4 deletions egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def get_parser():
help="1 to also export int8 onnx models.",
)

parser.add_argument(
"--use-int32-inputs",
type=int,
default=0,
help="""1 to use int32_t as input types if applicable. 0 to use
int64_t otherwise.""",
)

parser.add_argument(
"--epoch",
type=int,
Expand Down Expand Up @@ -285,15 +293,17 @@ def forward(
)
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)

src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)
src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.int32)

# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
processed_mask = (
(processed_lens.unsqueeze(1) <= processed_mask).to(torch.int32).flip(1)
)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
Expand All @@ -318,7 +328,7 @@ def forward(

new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
new_processed_lens.to(states[-1].dtype),
]

return encoder_out, new_states
Expand Down Expand Up @@ -406,6 +416,7 @@ def export_encoder_model_onnx(
dynamic_batch: bool = True,
use_whisper_features: bool = False,
use_external_data: bool = False,
use_int32_inputs: int = 0,
) -> None:
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
Expand All @@ -418,6 +429,12 @@ def export_encoder_model_onnx(

x = torch.rand(1, T, feature_dim, dtype=torch.float32)
init_state = encoder_model.get_init_states()

if use_int32_inputs:
init_state = [
s if s.dtype != torch.int64 else s.to(torch.int32) for s in init_state
]

num_encoders = len(encoder_model.encoder.encoder_dim)
logging.info(f"num_encoders: {num_encoders}")
logging.info(f"len(init_state): {len(init_state)}")
Expand Down Expand Up @@ -566,6 +583,7 @@ def export_decoder_model_onnx(
decoder_filename: str,
opset_version: int = 11,
dynamic_batch: bool = True,
use_int32_inputs: int = 0,
) -> None:
"""Export the decoder model to ONNX format.

Expand All @@ -588,7 +606,11 @@ def export_decoder_model_onnx(
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(1, context_size, dtype=torch.int64)
if use_int32_inputs:
y = torch.zeros(1, context_size, dtype=torch.int32)
else:
y = torch.zeros(1, context_size, dtype=torch.int64)

decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
Expand Down Expand Up @@ -817,10 +839,12 @@ def main():
opset_version=opset_version,
feature_dim=params.feature_dim,
dynamic_batch=params.dynamic_batch == 1,
use_int32_inputs=params.use_int32_inputs,
use_whisper_features=params.use_whisper_features,
use_external_data=params.use_external_data,
)
logging.info(f"Exported encoder to {encoder_filename}")
return

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The unconditional return statement here makes the rest of the main() function unreachable. This prevents the script from exporting the decoder, joiner, and performing FP16 conversion or quantization. This seems to be a debugging leftover and should be removed.

Suggested change
logging.info(f"Exported encoder to {encoder_filename}")
return
logging.info(f"Exported encoder to {encoder_filename}")


logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
Expand All @@ -829,6 +853,7 @@ def main():
decoder_filename,
opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
use_int32_inputs=params.use_int32_inputs,
)
logging.info(f"Exported decoder to {decoder_filename}")

Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/zipformer/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def forward(
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens, x.shape[1])
src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]).to(torch.int32)
x = x.permute(1, 0, 2)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2)
Expand Down
Loading
Loading