Allow Streaming Zipformer Models to Be Exported as Non-Streaming Models#2086
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
💤 Files with no reviewable changes (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughStandardize ONNX export masking/input dtypes to int32, add a streaming ChangesONNX Export Dtype Standardization and Non-Streaming Export
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a new script export-streaming-as-non-streaming-onnx.py and updates existing export scripts and model components to support exporting streaming Zipformer models as non-streaming ONNX models, including support for int32 inputs. Key feedback highlights several critical issues: debugging leftovers like unconditional return statements in export-onnx-streaming.py and a print(help(...)) statement in the new script should be removed; a potential device mismatch when creating x_lens on CPU should be resolved; hardcoded feature dimensions of 80 should be parameterized; potential runtime errors in NonStreamingChunkCausalDepthwiseConv1d when seq_len < self.kernel_size must be handled; and a redundant elif block in zipformer.py should be simplified.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| logging.info(f"Exported encoder to {encoder_filename}") | ||
| return |
There was a problem hiding this comment.
The unconditional return statement here makes the rest of the main() function unreachable. This prevents the script from exporting the decoder, joiner, and performing FP16 conversion or quantization. This seems to be a debugging leftover and should be removed.
| logging.info(f"Exported encoder to {encoder_filename}") | |
| return | |
| logging.info(f"Exported encoder to {encoder_filename}") |
| logging.info(f"Exported decoder to {decoder_filename}") | ||
| return |
There was a problem hiding this comment.
Another unconditional return statement is added here, which prevents the script from exporting the joiner and performing FP16/INT8 quantization. This should be removed.
| logging.info(f"Exported decoder to {decoder_filename}") | |
| return | |
| logging.info(f"Exported decoder to {decoder_filename}") |
| self.encoder_proj = encoder_proj | ||
|
|
||
| def forward2(self, x: torch.Tensor): | ||
| x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) |
There was a problem hiding this comment.
Creating x_lens unconditionally on CPU will cause a device mismatch error if the input tensor x is on a GPU (e.g., during validation or testing on GPU). It should be created on the same device as x.
| x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) | |
| x_lens = torch.tensor([x.shape[1]], dtype=torch.int32, device=x.device) |
| def export_encoder_model_onnx( | ||
| encoder_model: OnnxEncoder, | ||
| encoder_filename: str, | ||
| max_len: int, | ||
| dynamic_axes: int, | ||
| use_int32_inputs: int, | ||
| keep_x_lens: int = 1, | ||
| opset_version: int = 13, | ||
| ) -> None: |
There was a problem hiding this comment.
Please add feature_dim as an argument to export_encoder_model_onnx to avoid hardcoding the feature dimension to 80 when creating the dummy input tensor x.
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
max_len: int,
dynamic_axes: int,
use_int32_inputs: int,
feature_dim: int = 80,
keep_x_lens: int = 1,
opset_version: int = 13,
) -> None:| if max_len > 0: | ||
| x = torch.zeros(1, max_len, 80, dtype=torch.float32) | ||
| else: | ||
| x = torch.zeros(1, 3000, 80, dtype=torch.float32) |
There was a problem hiding this comment.
The feature dimension 80 is hardcoded here. This will cause a runtime shape mismatch error if the model was trained with a different feature dimension (e.g., 128 for Whisper features or 40 for other configurations). Please use the feature_dim argument here.
| if max_len > 0: | |
| x = torch.zeros(1, max_len, 80, dtype=torch.float32) | |
| else: | |
| x = torch.zeros(1, 3000, 80, dtype=torch.float32) | |
| if max_len > 0: | |
| x = torch.zeros(1, max_len, feature_dim, dtype=torch.float32) | |
| else: | |
| x = torch.zeros(1, 3000, feature_dim, dtype=torch.float32) |
| export_encoder_model_onnx( | ||
| encoder, | ||
| encoder_filename, | ||
| max_len=params.max_len, | ||
| dynamic_axes=params.dynamic_axes, | ||
| use_int32_inputs=params.use_int32_inputs, | ||
| opset_version=opset_version, | ||
| keep_x_lens=params.keep_x_lens, | ||
| ) |
There was a problem hiding this comment.
Pass params.feature_dim to export_encoder_model_onnx to support exporting models trained with non-default feature dimensions.
| export_encoder_model_onnx( | |
| encoder, | |
| encoder_filename, | |
| max_len=params.max_len, | |
| dynamic_axes=params.dynamic_axes, | |
| use_int32_inputs=params.use_int32_inputs, | |
| opset_version=opset_version, | |
| keep_x_lens=params.keep_x_lens, | |
| ) | |
| export_encoder_model_onnx( | |
| encoder, | |
| encoder_filename, | |
| max_len=params.max_len, | |
| dynamic_axes=params.dynamic_axes, | |
| use_int32_inputs=params.use_int32_inputs, | |
| feature_dim=params.feature_dim, | |
| opset_version=opset_version, | |
| keep_x_lens=params.keep_x_lens, | |
| ) |
| # seq_len >= kernel_size in non-streaming mode, so we pad with zeros | ||
| t = seq_len - self.kernel_size | ||
| channels = left_edge.shape[0] | ||
| pad = torch.zeros( | ||
| channels, t, device=left_edge.device, dtype=left_edge.dtype | ||
| ) | ||
| left_edge = torch.cat((left_edge, pad), dim=-1) | ||
| right_edge = torch.cat((pad, right_edge), dim=-1) | ||
| chunk_scale = 1.0 + (left_edge + right_edge) |
There was a problem hiding this comment.
If seq_len < self.kernel_size (which can happen with very short audio inputs or during testing/tracing with small dummy inputs), t = seq_len - self.kernel_size will be negative. This will cause torch.zeros to raise a RuntimeError due to a negative dimension. We should handle t < 0 gracefully by slicing left_edge and right_edge instead of padding them.
| # seq_len >= kernel_size in non-streaming mode, so we pad with zeros | |
| t = seq_len - self.kernel_size | |
| channels = left_edge.shape[0] | |
| pad = torch.zeros( | |
| channels, t, device=left_edge.device, dtype=left_edge.dtype | |
| ) | |
| left_edge = torch.cat((left_edge, pad), dim=-1) | |
| right_edge = torch.cat((pad, right_edge), dim=-1) | |
| chunk_scale = 1.0 + (left_edge + right_edge) | |
| # seq_len >= kernel_size in non-streaming mode, so we pad with zeros | |
| t = seq_len - self.kernel_size | |
| channels = left_edge.shape[0] | |
| if t > 0: | |
| pad = torch.zeros( | |
| channels, t, device=left_edge.device, dtype=left_edge.dtype | |
| ) | |
| left_edge = torch.cat((left_edge, pad), dim=-1) | |
| right_edge = torch.cat((pad, right_edge), dim=-1) | |
| elif t < 0: | |
| left_edge = left_edge[..., :seq_len] | |
| right_edge = right_edge[..., -seq_len:] | |
| chunk_scale = 1.0 + (left_edge + right_edge) |
| opset_version: | ||
| The opset version to use. | ||
| """ | ||
| print(help(encoder_model.forward)) |
| # If we are exporting a streaming model, then we skip the if statement | ||
| if not self.causal or not torch.jit.is_tracing(): | ||
| src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) | ||
| src = torch.cat((src, src_extra), dim=0) | ||
|
|
||
| assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds) | ||
| if pad > 0: | ||
| src_extra = src[src.shape[0] - 1 :].expand( | ||
| pad, src.shape[1], src.shape[2] | ||
| ) | ||
| src = torch.cat((src, src_extra), dim=0) | ||
| elif self.causal and torch.jit.is_scripting(): | ||
| if pad > 0: | ||
| src_extra = src[src.shape[0] - 1 :].expand( | ||
| pad, src.shape[1], src.shape[2] | ||
| ) | ||
| src = torch.cat((src, src_extra), dim=0) | ||
|
|
There was a problem hiding this comment.
The elif self.causal and torch.jit.is_scripting(): block is redundant and unreachable. When scripting a causal model, torch.jit.is_tracing() is False, which means not torch.jit.is_tracing() is True. Therefore, the first if condition not self.causal or not torch.jit.is_tracing() will always evaluate to True, and the elif block will never be reached. You can simplify this by removing the redundant elif block.
| # If we are exporting a streaming model, then we skip the if statement | |
| if not self.causal or not torch.jit.is_tracing(): | |
| src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) | |
| src = torch.cat((src, src_extra), dim=0) | |
| assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds) | |
| if pad > 0: | |
| src_extra = src[src.shape[0] - 1 :].expand( | |
| pad, src.shape[1], src.shape[2] | |
| ) | |
| src = torch.cat((src, src_extra), dim=0) | |
| elif self.causal and torch.jit.is_scripting(): | |
| if pad > 0: | |
| src_extra = src[src.shape[0] - 1 :].expand( | |
| pad, src.shape[1], src.shape[2] | |
| ) | |
| src = torch.cat((src, src_extra), dim=0) | |
| # If we are exporting a streaming model, then we skip the if statement | |
| if not self.causal or not torch.jit.is_tracing(): | |
| if pad > 0: | |
| src_extra = src[src.shape[0] - 1 :].expand( | |
| pad, src.shape[1], src.shape[2] | |
| ) | |
| src = torch.cat((src, src_extra), dim=0) |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
egs/librispeech/ASR/zipformer/zipformer.py (1)
1430-1441:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winScripted positional encoding is now capped at 2000 frames.
Once
forward()stops callingextend_pe()under scripting, scripted/ONNX exports are stuck with the constructor buffer from Line 1430. The new exporter traces a 3000-frame sample by default (--max-len=-1), sopos_embbecomes shorter than2 * seq_len - 1and the later reshape inRelPositionMultiheadAttentionWeights.forward()fails. This blocks the default non-streaming export path for long inputs or large left context.Also applies to: 1503-1512
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@egs/librispeech/ASR/zipformer/zipformer.py` around lines 1430 - 1441, The scripted positional embedding buffer is being capped at the constructor's max_len (2000) so exports that can't call extend_pe() at runtime fail for long inputs; update CompactRelPositionalEncoding to allocate a sufficiently large pe in the constructor (instead of relying on forward to call extend_pe) — e.g., compute the required length from max_len and length_factor (or use the exporter default like 3000 / a safe upper bound) and call extend_pe with a tensor of that computed size; ensure this change affects the initialization path referenced by CompactRelPositionalEncoding.__init__ and prevents RelPositionMultiheadAttentionWeights.forward() from encountering a too-short pe during reshape.egs/librispeech/ASR/zipformer/export-onnx-streaming.py (1)
831-859:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winRemove the early returns in
main().Line 847 exits after the encoder export, so the decoder/joiner exports and all FP16/int8 generation are unreachable. That breaks the script’s advertised behavior and this PR’s end-to-end streaming export flow.
Suggested fix
logging.info(f"Exported encoder to {encoder_filename}") - return logging.info("Exporting decoder") decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" export_decoder_model_onnx( decoder, @@ use_int32_inputs=params.use_int32_inputs, ) logging.info(f"Exported decoder to {decoder_filename}") - return logging.info("Exporting joiner")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@egs/librispeech/ASR/zipformer/export-onnx-streaming.py` around lines 831 - 859, The main() function returns immediately after exporting the encoder which prevents subsequent decoder/joiner exports and FP16/int8 generation; remove the early "return" statements after the encoder export and after the decoder export so execution continues to call export_decoder_model_onnx, export_joiner (if present), and the FP16/int8 generation steps; verify export_encoder_model_onnx and export_decoder_model_onnx are still called with the same identifiers (encoder, encoder_filename, decoder, decoder_filename) and that params.use_external_data handling for encoder_filename remains correct.egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py (1)
644-670:⚠️ Potential issue | 🟠 Major | ⚡ Quick winKeep FP16 export independent of the int8 flag.
Line 645 returns before the FP16 block, so
--fp16 true --enable-int8-quantization 0no longer produces an FP16 model. That makes--enable-int8-quantizationaccidentally gate a separate output format.Suggested fix
- if not params.enable_int8_quantization: - return - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - if params.use_external_data: - model_filename_int8 = f"ctc-{suffix}.int8.onnx" - else: - model_filename_int8 = params.exp_dir / f"ctc-{suffix}.int8.onnx" - - quantize_dynamic( - model_input=model_filename, - model_output=model_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - if params.fp16: if params.use_external_data: model_filename_fp16 = f"ctc-{suffix}.fp16.onnx" export_onnx_fp16_large_2gb(model_filename, model_filename_fp16) else: model_filename_fp16 = params.exp_dir / f"ctc-{suffix}.fp16.onnx" export_onnx_fp16(model_filename, model_filename_fp16) + + if not params.enable_int8_quantization: + return + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + if params.use_external_data: + model_filename_int8 = f"ctc-{suffix}.int8.onnx" + else: + model_filename_int8 = params.exp_dir / f"ctc-{suffix}.int8.onnx" + + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py` around lines 644 - 670, The current early return on params.enable_int8_quantization prevents FP16 export when int8 is disabled; move the FP16 export block out of the int8-gated section so FP16 is handled independently: keep the int8 generation using quantize_dynamic(model_input=model_filename, model_output=model_filename_int8, op_types_to_quantize=["MatMul"], weight_type=QuantType.QInt8) gated by params.enable_int8_quantization, but always check params.fp16 afterwards and call export_onnx_fp16_large_2gb(model_filename, model_filename_fp16) when params.use_external_data is true or export_onnx_fp16(model_filename, model_filename_fp16) otherwise, constructing model_filename_fp16 with suffix/params.exp_dir analogous to model_filename_int8.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py`:
- Around line 238-246: forward2 currently hardcodes a single-length tensor so it
only works for batch size 1; change it to synthesize one length per batch item
(or explicitly enforce batch-1) so the padding mask matches N. Specifically, in
forward2 (and the similar code at the other occurrence) replace x_lens =
torch.tensor([x.shape[1]], dtype=torch.int32) with something that creates a
vector of length x.shape[0] (e.g. torch.full((x.shape[0],), x.shape[1],
dtype=torch.int32) or equivalent), so make_pad_mask(x_lens, x.shape[1]) produces
an (N, T) mask consistent with encoder_embed, encoder and encoder_proj.
In `@egs/librispeech/ASR/zipformer/scaling_converter.py`:
- Around line 75-83: The code fails when seq_len < self.kernel_size because t
becomes negative; update the logic in scaling_converter.py around
left_edge/right_edge/seq_len to match ChunkCausalDepthwiseConv1d's
short-sequence handling: compute pad_len = max(seq_len - self.kernel_size, 0)
and use that for torch.zeros, and when seq_len < self.kernel_size slice/truncate
left_edge and right_edge to seq_len before concatenation so shapes match; ensure
chunk_scale = 1.0 + (left_edge + right_edge) is computed after these
adjustments.
In `@egs/librispeech/ASR/zipformer/zipformer.py`:
- Around line 1351-1363: The padding branch is skipped during the traced causal
export, leaving src too short for the subsequent reshape; change the conditional
so padding runs in all cases except when both causal and tracing are true.
Concretely, replace the separate if/elif blocks around src padding with a single
check like "if pad > 0 and not (self.causal and torch.jit.is_tracing()):" (or
equivalent) so src_extra creation and torch.cat((src, src_extra), dim=0) always
execute when needed; update references to self.causal, torch.jit.is_tracing(),
torch.jit.is_scripting(), pad, and src accordingly to ensure reshape() sees the
correct element count.
---
Outside diff comments:
In `@egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py`:
- Around line 644-670: The current early return on
params.enable_int8_quantization prevents FP16 export when int8 is disabled; move
the FP16 export block out of the int8-gated section so FP16 is handled
independently: keep the int8 generation using
quantize_dynamic(model_input=model_filename, model_output=model_filename_int8,
op_types_to_quantize=["MatMul"], weight_type=QuantType.QInt8) gated by
params.enable_int8_quantization, but always check params.fp16 afterwards and
call export_onnx_fp16_large_2gb(model_filename, model_filename_fp16) when
params.use_external_data is true or export_onnx_fp16(model_filename,
model_filename_fp16) otherwise, constructing model_filename_fp16 with
suffix/params.exp_dir analogous to model_filename_int8.
In `@egs/librispeech/ASR/zipformer/export-onnx-streaming.py`:
- Around line 831-859: The main() function returns immediately after exporting
the encoder which prevents subsequent decoder/joiner exports and FP16/int8
generation; remove the early "return" statements after the encoder export and
after the decoder export so execution continues to call
export_decoder_model_onnx, export_joiner (if present), and the FP16/int8
generation steps; verify export_encoder_model_onnx and export_decoder_model_onnx
are still called with the same identifiers (encoder, encoder_filename, decoder,
decoder_filename) and that params.use_external_data handling for
encoder_filename remains correct.
In `@egs/librispeech/ASR/zipformer/zipformer.py`:
- Around line 1430-1441: The scripted positional embedding buffer is being
capped at the constructor's max_len (2000) so exports that can't call
extend_pe() at runtime fail for long inputs; update CompactRelPositionalEncoding
to allocate a sufficiently large pe in the constructor (instead of relying on
forward to call extend_pe) — e.g., compute the required length from max_len and
length_factor (or use the exporter default like 3000 / a safe upper bound) and
call extend_pe with a tensor of that computed size; ensure this change affects
the initialization path referenced by CompactRelPositionalEncoding.__init__ and
prevents RelPositionMultiheadAttentionWeights.forward() from encountering a
too-short pe during reshape.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 70b7827e-ee99-407f-8f50-6c8a3dde0f86
📒 Files selected for processing (8)
egs/librispeech/ASR/zipformer/export-onnx-ctc.pyegs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.pyegs/librispeech/ASR/zipformer/export-onnx-streaming.pyegs/librispeech/ASR/zipformer/export-onnx.pyegs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.pyegs/librispeech/ASR/zipformer/export.pyegs/librispeech/ASR/zipformer/scaling_converter.pyegs/librispeech/ASR/zipformer/zipformer.py
| def forward2(self, x: torch.Tensor): | ||
| x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) | ||
| x, x_lens = self.encoder_embed(x, x_lens) | ||
| src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]).to(torch.int32) | ||
| x = x.permute(1, 0, 2) | ||
| encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) | ||
| encoder_out = encoder_out.permute(1, 0, 2) | ||
| encoder_out = self.encoder_proj(encoder_out) | ||
| return encoder_out |
There was a problem hiding this comment.
forward2() hardcodes batch size 1 while the exported signature keeps N dynamic.
x_lens = torch.tensor([x.shape[1]], ...) always has shape (1,), but this export mode still declares x/encoder_out with dynamic batch axis N. With --keep-x-lens 0, any N > 1 input will later build a (1, T) padding mask and trip the encoder's batch-size assertions. Either synthesize one length per batch item or make this mode explicitly batch-1-only.
Suggested fix
def forward2(self, x: torch.Tensor):
- x_lens = torch.tensor([x.shape[1]], dtype=torch.int32)
+ x_lens = torch.full(
+ (x.shape[0],),
+ x.shape[1],
+ dtype=torch.int32,
+ device=x.device,
+ )
x, x_lens = self.encoder_embed(x, x_lens)Also applies to: 381-383
🧰 Tools
🪛 Ruff (0.15.15)
[warning] 243-243: Unpacked variable encoder_out_lens is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py`
around lines 238 - 246, forward2 currently hardcodes a single-length tensor so
it only works for batch size 1; change it to synthesize one length per batch
item (or explicitly enforce batch-1) so the padding mask matches N.
Specifically, in forward2 (and the similar code at the other occurrence) replace
x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) with something that
creates a vector of length x.shape[0] (e.g. torch.full((x.shape[0],),
x.shape[1], dtype=torch.int32) or equivalent), so make_pad_mask(x_lens,
x.shape[1]) produces an (N, T) mask consistent with encoder_embed, encoder and
encoder_proj.
| # seq_len >= kernel_size in non-streaming mode, so we pad with zeros | ||
| t = seq_len - self.kernel_size | ||
| channels = left_edge.shape[0] | ||
| pad = torch.zeros( | ||
| channels, t, device=left_edge.device, dtype=left_edge.dtype | ||
| ) | ||
| left_edge = torch.cat((left_edge, pad), dim=-1) | ||
| right_edge = torch.cat((pad, right_edge), dim=-1) | ||
| chunk_scale = 1.0 + (left_edge + right_edge) |
There was a problem hiding this comment.
The ONNX wrapper still breaks on short sequences.
seq_len here is the per-layer time axis after subsampling, so it can legitimately be smaller than kernel_size even in non-streaming inference. In that case t = seq_len - self.kernel_size goes negative and torch.zeros(channels, t, ...) raises. The replacement needs the same short-sequence handling as ChunkCausalDepthwiseConv1d, otherwise exported models reject short utterances.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@egs/librispeech/ASR/zipformer/scaling_converter.py` around lines 75 - 83, The
code fails when seq_len < self.kernel_size because t becomes negative; update
the logic in scaling_converter.py around left_edge/right_edge/seq_len to match
ChunkCausalDepthwiseConv1d's short-sequence handling: compute pad_len =
max(seq_len - self.kernel_size, 0) and use that for torch.zeros, and when
seq_len < self.kernel_size slice/truncate left_edge and right_edge to seq_len
before concatenation so shapes match; ensure chunk_scale = 1.0 + (left_edge +
right_edge) is computed after these adjustments.
| # If we are exporting a streaming model, then we skip the if statement | ||
| if not self.causal or not torch.jit.is_tracing(): | ||
| src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) | ||
| src = torch.cat((src, src_extra), dim=0) | ||
|
|
||
| assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds) | ||
| if pad > 0: | ||
| src_extra = src[src.shape[0] - 1 :].expand( | ||
| pad, src.shape[1], src.shape[2] | ||
| ) | ||
| src = torch.cat((src, src_extra), dim=0) | ||
| elif self.causal and torch.jit.is_scripting(): | ||
| if pad > 0: | ||
| src_extra = src[src.shape[0] - 1 :].expand( | ||
| pad, src.shape[1], src.shape[2] | ||
| ) | ||
| src = torch.cat((src, src_extra), dim=0) |
There was a problem hiding this comment.
Causal tracing skips the padding that reshape() still depends on.
When self.causal and torch.jit.is_tracing(), the first branch is false and the elif can never run, so src is left shorter than d_seq_len * ds. Any traced causal export with a non-multiple seq_len will then hit Line 1365 with the wrong element count.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@egs/librispeech/ASR/zipformer/zipformer.py` around lines 1351 - 1363, The
padding branch is skipped during the traced causal export, leaving src too short
for the subsequent reshape; change the conditional so padding runs in all cases
except when both causal and tracing are true. Concretely, replace the separate
if/elif blocks around src padding with a single check like "if pad > 0 and not
(self.causal and torch.jit.is_tracing()):" (or equivalent) so src_extra creation
and torch.cat((src, src_extra), dim=0) always execute when needed; update
references to self.causal, torch.jit.is_tracing(), torch.jit.is_scripting(),
pad, and src accordingly to ensure reshape() sees the correct element count.
98895fe to
e80ccb6
Compare
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
egs/librispeech/ASR/zipformer/export-onnx-streaming.py (1)
831-847:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winRemove the unconditional
returnafter encoder export.Line 847 exits
main()before the decoder/joiner exports and before fp16/int8 generation, so this script now only emits the encoder ONNX and breaks the normal transducer export flow.Suggested fix
logging.info(f"Exported encoder to {encoder_filename}") - return logging.info("Exporting decoder")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@egs/librispeech/ASR/zipformer/export-onnx-streaming.py` around lines 831 - 847, The unconditional return at the end of the encoder export block causes main() to exit early and prevents subsequent decoder/joiner exports and fp16/int8 generation; remove that return so execution continues past export_encoder_model_onnx (refer to encoder_filename, export_encoder_model_onnx, params and the surrounding encoder export block in main()) allowing the decoder/joiner export and any fp16/int8 steps to run normally.
♻️ Duplicate comments (1)
egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py (1)
238-246:⚠️ Potential issue | 🟠 Major | ⚡ Quick win
--keep-x-lens 0still hardcodes a batch-1 length vector.Line 239 always builds
x_lenswith shape(1,), but Lines 380-383 still exportx/encoder_outwith dynamic batch axisN. With--keep-x-lens 0 --dynamic-axes 1, the padding mask is derived from a single length and no longer matches batched inputs. Either synthesize one length per batch item here, or make this mode explicitly batch-1-only.Suggested fix
def forward2(self, x: torch.Tensor): - x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) + x_lens = torch.full( + (x.shape[0],), + x.shape[1], + dtype=torch.int32, + device=x.device, + ) x, x_lens = self.encoder_embed(x, x_lens) src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]).to(torch.int32)Also applies to: 364-383
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py` around lines 238 - 246, forward2 currently hardcodes x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) which creates a single-length vector and breaks when ONNX exports allow dynamic batch sizes; change it to synthesize one length per batch (e.g., torch.full((x.size(0),), x.size(1), dtype=torch.int32, device=x.device)) so make_pad_mask(...) and src_key_padding_mask match the actual batch size, keep dtype/device consistent, and apply the same fix to the other occurrence (the export block that builds x/encoder_out) so encoder_embed, make_pad_mask and encoder receive per-example x_lens rather than a batch-1 vector.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@egs/librispeech/ASR/zipformer/export-onnx-streaming.py`:
- Around line 831-847: The unconditional return at the end of the encoder export
block causes main() to exit early and prevents subsequent decoder/joiner exports
and fp16/int8 generation; remove that return so execution continues past
export_encoder_model_onnx (refer to encoder_filename, export_encoder_model_onnx,
params and the surrounding encoder export block in main()) allowing the
decoder/joiner export and any fp16/int8 steps to run normally.
---
Duplicate comments:
In `@egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py`:
- Around line 238-246: forward2 currently hardcodes x_lens =
torch.tensor([x.shape[1]], dtype=torch.int32) which creates a single-length
vector and breaks when ONNX exports allow dynamic batch sizes; change it to
synthesize one length per batch (e.g., torch.full((x.size(0),), x.size(1),
dtype=torch.int32, device=x.device)) so make_pad_mask(...) and
src_key_padding_mask match the actual batch size, keep dtype/device consistent,
and apply the same fix to the other occurrence (the export block that builds
x/encoder_out) so encoder_embed, make_pad_mask and encoder receive per-example
x_lens rather than a batch-1 vector.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 09c853e9-510a-41bc-a142-4e357487d3ab
📒 Files selected for processing (2)
egs/librispeech/ASR/zipformer/export-onnx-streaming.pyegs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py
|
@csukuangfj Hey thanks for the great feature. Between, exported model works for long files, but fails for file smaller than 3.0 seconds: Not sure about the reason though |
|
Can you try the following commit If it works, we will create a PR from it. |
Summary by CodeRabbit
New Features
Bug Fixes
Refactor