Skip to content

Allow Streaming Zipformer Models to Be Exported as Non-Streaming Models#2086

Merged
csukuangfj merged 6 commits into
k2-fsa:masterfrom
csukuangfj:non-streaming-zipformer-qnn
Jun 5, 2026
Merged

Allow Streaming Zipformer Models to Be Exported as Non-Streaming Models#2086
csukuangfj merged 6 commits into
k2-fsa:masterfrom
csukuangfj:non-streaming-zipformer-qnn

Conversation

@csukuangfj

@csukuangfj csukuangfj commented Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Summary by CodeRabbit

  • New Features

    • Standalone non-streaming ONNX export CLI.
    • Opt-in int32-input support for ONNX exports.
    • Optional FP16 and int8 quantized export outputs.
    • Increased default positional length limit (1000 → 2000).
  • Bug Fixes

    • Prevented unintended int8 file generation when quantization is disabled.
  • Refactor

    • Unified masking and tensor-type handling across export and runtime paths.
    • Added ONNX-friendly non-streaming convolution wrapper and adjusted padding/attention behavior for export.

@coderabbitai

coderabbitai Bot commented Jun 5, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a3bd90ed-da26-4e53-848b-0b7a652e81d6

📥 Commits

Reviewing files that changed from the base of the PR and between e80ccb6 and 99086d2.

📒 Files selected for processing (2)
  • egs/librispeech/ASR/zipformer/export-onnx-streaming.py
  • egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py
💤 Files with no reviewable changes (1)
  • egs/librispeech/ASR/zipformer/export-onnx-streaming.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py

📝 Walkthrough

Walkthrough

Standardize ONNX export masking/input dtypes to int32, add a streaming --use-int32-inputs flag that casts streaming states and decoder inputs, introduce a non-streaming ONNX exporter with encoder/decoder/joiner wrappers and quantization helpers, and apply ONNX-oriented Zipformer module adjustments for scripting and mask handling.

Changes

ONNX Export Dtype Standardization and Non-Streaming Export

Layer / File(s) Summary
Zipformer core masking and attention updates
egs/librispeech/ASR/zipformer/zipformer.py
SimpleDownsample padding branches adjusted for tracing/causal modes; CompactRelPositionalEncoding.max_len default increased and scripting path avoids extend_pe(); attention/convolution masked_fill mask handling now casts to torch.bool.
Mask dtype standardization to int32 across export scripts
egs/librispeech/ASR/zipformer/export-onnx-ctc.py, egs/librispeech/ASR/zipformer/export-onnx.py, egs/librispeech/ASR/zipformer/export.py, egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
Export forward paths now convert padding/processed masks to torch.int32 before passing them into Zipformer encoder/on export wrappers.
Streaming ONNX CLI int32 input propagation
egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Adds --use-int32-inputs, propagates it into encoder/decoder export functions to cast init_state/y to int32 and aligns processed_lens dtype when exporting.
Streaming CTC export control flow
egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
Streaming CTC exporter builds int32 masks and adds an early return to skip int8 quantization/output generation when quantization is disabled.
New non-streaming ONNX export script
egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py
New CLI script adding parser, checkpoint averaging, ONNX wrapper modules for encoder/decoder/joiner, export functions with dynamic axes and optional int32 inputs, fp16 conversion, int8 dynamic quantization hooks, and metadata injection.
Scaling converter ONNX wrapper
egs/librispeech/ASR/zipformer/scaling_converter.py
Adds NonStreamingChunkCausalDepthwiseConv1d, imports ChunkCausalDepthwiseConv1d, and updates ONNX replacement scripting to handle grouped modules and the chunk convolution wrapper.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

  • k2-fsa/icefall#2084: Both PRs add/propagate --use-int32-inputs option to ONNX export and cast exported streaming state tensors/decoder inputs from int64 to int32.

Poem

🐰 Hopping through tensors, masks swap to three,
Scripts export smoother, ONNX sings with glee.
Encoders and joiners now stride in a row,
A rabbit’s small dance for the models we know,
Cheers to clean exports — hop, skip, and go!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main objective: enabling streaming Zipformer models to be exported as non-streaming ONNX models, which is demonstrated by the new export-streaming-as-non-streaming-onnx.py script and related changes across multiple export files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new script export-streaming-as-non-streaming-onnx.py and updates existing export scripts and model components to support exporting streaming Zipformer models as non-streaming ONNX models, including support for int32 inputs. Key feedback highlights several critical issues: debugging leftovers like unconditional return statements in export-onnx-streaming.py and a print(help(...)) statement in the new script should be removed; a potential device mismatch when creating x_lens on CPU should be resolved; hardcoded feature dimensions of 80 should be parameterized; potential runtime errors in NonStreamingChunkCausalDepthwiseConv1d when seq_len < self.kernel_size must be handled; and a redundant elif block in zipformer.py should be simplified.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +846 to +847
logging.info(f"Exported encoder to {encoder_filename}")
return

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The unconditional return statement here makes the rest of the main() function unreachable. This prevents the script from exporting the decoder, joiner, and performing FP16 conversion or quantization. This seems to be a debugging leftover and should be removed.

Suggested change
logging.info(f"Exported encoder to {encoder_filename}")
return
logging.info(f"Exported encoder to {encoder_filename}")

Comment on lines +858 to +859
logging.info(f"Exported decoder to {decoder_filename}")
return

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Another unconditional return statement is added here, which prevents the script from exporting the joiner and performing FP16/INT8 quantization. This should be removed.

Suggested change
logging.info(f"Exported decoder to {decoder_filename}")
return
logging.info(f"Exported decoder to {decoder_filename}")

self.encoder_proj = encoder_proj

def forward2(self, x: torch.Tensor):
x_lens = torch.tensor([x.shape[1]], dtype=torch.int32)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Creating x_lens unconditionally on CPU will cause a device mismatch error if the input tensor x is on a GPU (e.g., during validation or testing on GPU). It should be created on the same device as x.

Suggested change
x_lens = torch.tensor([x.shape[1]], dtype=torch.int32)
x_lens = torch.tensor([x.shape[1]], dtype=torch.int32, device=x.device)

Comment on lines +326 to +334
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
max_len: int,
dynamic_axes: int,
use_int32_inputs: int,
keep_x_lens: int = 1,
opset_version: int = 13,
) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Please add feature_dim as an argument to export_encoder_model_onnx to avoid hardcoding the feature dimension to 80 when creating the dummy input tensor x.

def export_encoder_model_onnx(
    encoder_model: OnnxEncoder,
    encoder_filename: str,
    max_len: int,
    dynamic_axes: int,
    use_int32_inputs: int,
    feature_dim: int = 80,
    keep_x_lens: int = 1,
    opset_version: int = 13,
) -> None:

Comment on lines +355 to +358
if max_len > 0:
x = torch.zeros(1, max_len, 80, dtype=torch.float32)
else:
x = torch.zeros(1, 3000, 80, dtype=torch.float32)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The feature dimension 80 is hardcoded here. This will cause a runtime shape mismatch error if the model was trained with a different feature dimension (e.g., 128 for Whisper features or 40 for other configurations). Please use the feature_dim argument here.

Suggested change
if max_len > 0:
x = torch.zeros(1, max_len, 80, dtype=torch.float32)
else:
x = torch.zeros(1, 3000, 80, dtype=torch.float32)
if max_len > 0:
x = torch.zeros(1, max_len, feature_dim, dtype=torch.float32)
else:
x = torch.zeros(1, 3000, feature_dim, dtype=torch.float32)

Comment on lines +654 to +662
export_encoder_model_onnx(
encoder,
encoder_filename,
max_len=params.max_len,
dynamic_axes=params.dynamic_axes,
use_int32_inputs=params.use_int32_inputs,
opset_version=opset_version,
keep_x_lens=params.keep_x_lens,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Pass params.feature_dim to export_encoder_model_onnx to support exporting models trained with non-default feature dimensions.

Suggested change
export_encoder_model_onnx(
encoder,
encoder_filename,
max_len=params.max_len,
dynamic_axes=params.dynamic_axes,
use_int32_inputs=params.use_int32_inputs,
opset_version=opset_version,
keep_x_lens=params.keep_x_lens,
)
export_encoder_model_onnx(
encoder,
encoder_filename,
max_len=params.max_len,
dynamic_axes=params.dynamic_axes,
use_int32_inputs=params.use_int32_inputs,
feature_dim=params.feature_dim,
opset_version=opset_version,
keep_x_lens=params.keep_x_lens,
)

Comment on lines +75 to +83
# seq_len >= kernel_size in non-streaming mode, so we pad with zeros
t = seq_len - self.kernel_size
channels = left_edge.shape[0]
pad = torch.zeros(
channels, t, device=left_edge.device, dtype=left_edge.dtype
)
left_edge = torch.cat((left_edge, pad), dim=-1)
right_edge = torch.cat((pad, right_edge), dim=-1)
chunk_scale = 1.0 + (left_edge + right_edge)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If seq_len < self.kernel_size (which can happen with very short audio inputs or during testing/tracing with small dummy inputs), t = seq_len - self.kernel_size will be negative. This will cause torch.zeros to raise a RuntimeError due to a negative dimension. We should handle t < 0 gracefully by slicing left_edge and right_edge instead of padding them.

Suggested change
# seq_len >= kernel_size in non-streaming mode, so we pad with zeros
t = seq_len - self.kernel_size
channels = left_edge.shape[0]
pad = torch.zeros(
channels, t, device=left_edge.device, dtype=left_edge.dtype
)
left_edge = torch.cat((left_edge, pad), dim=-1)
right_edge = torch.cat((pad, right_edge), dim=-1)
chunk_scale = 1.0 + (left_edge + right_edge)
# seq_len >= kernel_size in non-streaming mode, so we pad with zeros
t = seq_len - self.kernel_size
channels = left_edge.shape[0]
if t > 0:
pad = torch.zeros(
channels, t, device=left_edge.device, dtype=left_edge.dtype
)
left_edge = torch.cat((left_edge, pad), dim=-1)
right_edge = torch.cat((pad, right_edge), dim=-1)
elif t < 0:
left_edge = left_edge[..., :seq_len]
right_edge = right_edge[..., -seq_len:]
chunk_scale = 1.0 + (left_edge + right_edge)

opset_version:
The opset version to use.
"""
print(help(encoder_model.forward))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print(help(...)) statement appears to be a debugging leftover and should be removed to keep the export logs clean.

Comment on lines 1351 to 1364
# If we are exporting a streaming model, then we skip the if statement
if not self.causal or not torch.jit.is_tracing():
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
src = torch.cat((src, src_extra), dim=0)

assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds)
if pad > 0:
src_extra = src[src.shape[0] - 1 :].expand(
pad, src.shape[1], src.shape[2]
)
src = torch.cat((src, src_extra), dim=0)
elif self.causal and torch.jit.is_scripting():
if pad > 0:
src_extra = src[src.shape[0] - 1 :].expand(
pad, src.shape[1], src.shape[2]
)
src = torch.cat((src, src_extra), dim=0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The elif self.causal and torch.jit.is_scripting(): block is redundant and unreachable. When scripting a causal model, torch.jit.is_tracing() is False, which means not torch.jit.is_tracing() is True. Therefore, the first if condition not self.causal or not torch.jit.is_tracing() will always evaluate to True, and the elif block will never be reached. You can simplify this by removing the redundant elif block.

Suggested change
# If we are exporting a streaming model, then we skip the if statement
if not self.causal or not torch.jit.is_tracing():
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
src = torch.cat((src, src_extra), dim=0)
assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds)
if pad > 0:
src_extra = src[src.shape[0] - 1 :].expand(
pad, src.shape[1], src.shape[2]
)
src = torch.cat((src, src_extra), dim=0)
elif self.causal and torch.jit.is_scripting():
if pad > 0:
src_extra = src[src.shape[0] - 1 :].expand(
pad, src.shape[1], src.shape[2]
)
src = torch.cat((src, src_extra), dim=0)
# If we are exporting a streaming model, then we skip the if statement
if not self.causal or not torch.jit.is_tracing():
if pad > 0:
src_extra = src[src.shape[0] - 1 :].expand(
pad, src.shape[1], src.shape[2]
)
src = torch.cat((src, src_extra), dim=0)

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
egs/librispeech/ASR/zipformer/zipformer.py (1)

1430-1441: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Scripted positional encoding is now capped at 2000 frames.

Once forward() stops calling extend_pe() under scripting, scripted/ONNX exports are stuck with the constructor buffer from Line 1430. The new exporter traces a 3000-frame sample by default (--max-len=-1), so pos_emb becomes shorter than 2 * seq_len - 1 and the later reshape in RelPositionMultiheadAttentionWeights.forward() fails. This blocks the default non-streaming export path for long inputs or large left context.

Also applies to: 1503-1512

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@egs/librispeech/ASR/zipformer/zipformer.py` around lines 1430 - 1441, The
scripted positional embedding buffer is being capped at the constructor's
max_len (2000) so exports that can't call extend_pe() at runtime fail for long
inputs; update CompactRelPositionalEncoding to allocate a sufficiently large pe
in the constructor (instead of relying on forward to call extend_pe) — e.g.,
compute the required length from max_len and length_factor (or use the exporter
default like 3000 / a safe upper bound) and call extend_pe with a tensor of that
computed size; ensure this change affects the initialization path referenced by
CompactRelPositionalEncoding.__init__ and prevents
RelPositionMultiheadAttentionWeights.forward() from encountering a too-short pe
during reshape.
egs/librispeech/ASR/zipformer/export-onnx-streaming.py (1)

831-859: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Remove the early returns in main().

Line 847 exits after the encoder export, so the decoder/joiner exports and all FP16/int8 generation are unreachable. That breaks the script’s advertised behavior and this PR’s end-to-end streaming export flow.

Suggested fix
     logging.info(f"Exported encoder to {encoder_filename}")
-    return

     logging.info("Exporting decoder")
     decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
     export_decoder_model_onnx(
         decoder,
@@
         use_int32_inputs=params.use_int32_inputs,
     )
     logging.info(f"Exported decoder to {decoder_filename}")
-    return

     logging.info("Exporting joiner")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@egs/librispeech/ASR/zipformer/export-onnx-streaming.py` around lines 831 -
859, The main() function returns immediately after exporting the encoder which
prevents subsequent decoder/joiner exports and FP16/int8 generation; remove the
early "return" statements after the encoder export and after the decoder export
so execution continues to call export_decoder_model_onnx, export_joiner (if
present), and the FP16/int8 generation steps; verify export_encoder_model_onnx
and export_decoder_model_onnx are still called with the same identifiers
(encoder, encoder_filename, decoder, decoder_filename) and that
params.use_external_data handling for encoder_filename remains correct.
egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py (1)

644-670: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Keep FP16 export independent of the int8 flag.

Line 645 returns before the FP16 block, so --fp16 true --enable-int8-quantization 0 no longer produces an FP16 model. That makes --enable-int8-quantization accidentally gate a separate output format.

Suggested fix
-    if not params.enable_int8_quantization:
-        return
-
-    # Generate int8 quantization models
-    # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
-
-    logging.info("Generate int8 quantization models")
-
-    if params.use_external_data:
-        model_filename_int8 = f"ctc-{suffix}.int8.onnx"
-    else:
-        model_filename_int8 = params.exp_dir / f"ctc-{suffix}.int8.onnx"
-
-    quantize_dynamic(
-        model_input=model_filename,
-        model_output=model_filename_int8,
-        op_types_to_quantize=["MatMul"],
-        weight_type=QuantType.QInt8,
-    )
-
     if params.fp16:
         if params.use_external_data:
             model_filename_fp16 = f"ctc-{suffix}.fp16.onnx"
             export_onnx_fp16_large_2gb(model_filename, model_filename_fp16)
         else:
             model_filename_fp16 = params.exp_dir / f"ctc-{suffix}.fp16.onnx"
             export_onnx_fp16(model_filename, model_filename_fp16)
+
+    if not params.enable_int8_quantization:
+        return
+
+    # Generate int8 quantization models
+    # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
+
+    logging.info("Generate int8 quantization models")
+
+    if params.use_external_data:
+        model_filename_int8 = f"ctc-{suffix}.int8.onnx"
+    else:
+        model_filename_int8 = params.exp_dir / f"ctc-{suffix}.int8.onnx"
+
+    quantize_dynamic(
+        model_input=model_filename,
+        model_output=model_filename_int8,
+        op_types_to_quantize=["MatMul"],
+        weight_type=QuantType.QInt8,
+    )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py` around lines 644
- 670, The current early return on params.enable_int8_quantization prevents FP16
export when int8 is disabled; move the FP16 export block out of the int8-gated
section so FP16 is handled independently: keep the int8 generation using
quantize_dynamic(model_input=model_filename, model_output=model_filename_int8,
op_types_to_quantize=["MatMul"], weight_type=QuantType.QInt8) gated by
params.enable_int8_quantization, but always check params.fp16 afterwards and
call export_onnx_fp16_large_2gb(model_filename, model_filename_fp16) when
params.use_external_data is true or export_onnx_fp16(model_filename,
model_filename_fp16) otherwise, constructing model_filename_fp16 with
suffix/params.exp_dir analogous to model_filename_int8.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py`:
- Around line 238-246: forward2 currently hardcodes a single-length tensor so it
only works for batch size 1; change it to synthesize one length per batch item
(or explicitly enforce batch-1) so the padding mask matches N. Specifically, in
forward2 (and the similar code at the other occurrence) replace x_lens =
torch.tensor([x.shape[1]], dtype=torch.int32) with something that creates a
vector of length x.shape[0] (e.g. torch.full((x.shape[0],), x.shape[1],
dtype=torch.int32) or equivalent), so make_pad_mask(x_lens, x.shape[1]) produces
an (N, T) mask consistent with encoder_embed, encoder and encoder_proj.

In `@egs/librispeech/ASR/zipformer/scaling_converter.py`:
- Around line 75-83: The code fails when seq_len < self.kernel_size because t
becomes negative; update the logic in scaling_converter.py around
left_edge/right_edge/seq_len to match ChunkCausalDepthwiseConv1d's
short-sequence handling: compute pad_len = max(seq_len - self.kernel_size, 0)
and use that for torch.zeros, and when seq_len < self.kernel_size slice/truncate
left_edge and right_edge to seq_len before concatenation so shapes match; ensure
chunk_scale = 1.0 + (left_edge + right_edge) is computed after these
adjustments.

In `@egs/librispeech/ASR/zipformer/zipformer.py`:
- Around line 1351-1363: The padding branch is skipped during the traced causal
export, leaving src too short for the subsequent reshape; change the conditional
so padding runs in all cases except when both causal and tracing are true.
Concretely, replace the separate if/elif blocks around src padding with a single
check like "if pad > 0 and not (self.causal and torch.jit.is_tracing()):" (or
equivalent) so src_extra creation and torch.cat((src, src_extra), dim=0) always
execute when needed; update references to self.causal, torch.jit.is_tracing(),
torch.jit.is_scripting(), pad, and src accordingly to ensure reshape() sees the
correct element count.

---

Outside diff comments:
In `@egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py`:
- Around line 644-670: The current early return on
params.enable_int8_quantization prevents FP16 export when int8 is disabled; move
the FP16 export block out of the int8-gated section so FP16 is handled
independently: keep the int8 generation using
quantize_dynamic(model_input=model_filename, model_output=model_filename_int8,
op_types_to_quantize=["MatMul"], weight_type=QuantType.QInt8) gated by
params.enable_int8_quantization, but always check params.fp16 afterwards and
call export_onnx_fp16_large_2gb(model_filename, model_filename_fp16) when
params.use_external_data is true or export_onnx_fp16(model_filename,
model_filename_fp16) otherwise, constructing model_filename_fp16 with
suffix/params.exp_dir analogous to model_filename_int8.

In `@egs/librispeech/ASR/zipformer/export-onnx-streaming.py`:
- Around line 831-859: The main() function returns immediately after exporting
the encoder which prevents subsequent decoder/joiner exports and FP16/int8
generation; remove the early "return" statements after the encoder export and
after the decoder export so execution continues to call
export_decoder_model_onnx, export_joiner (if present), and the FP16/int8
generation steps; verify export_encoder_model_onnx and export_decoder_model_onnx
are still called with the same identifiers (encoder, encoder_filename, decoder,
decoder_filename) and that params.use_external_data handling for
encoder_filename remains correct.

In `@egs/librispeech/ASR/zipformer/zipformer.py`:
- Around line 1430-1441: The scripted positional embedding buffer is being
capped at the constructor's max_len (2000) so exports that can't call
extend_pe() at runtime fail for long inputs; update CompactRelPositionalEncoding
to allocate a sufficiently large pe in the constructor (instead of relying on
forward to call extend_pe) — e.g., compute the required length from max_len and
length_factor (or use the exporter default like 3000 / a safe upper bound) and
call extend_pe with a tensor of that computed size; ensure this change affects
the initialization path referenced by CompactRelPositionalEncoding.__init__ and
prevents RelPositionMultiheadAttentionWeights.forward() from encountering a
too-short pe during reshape.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 70b7827e-ee99-407f-8f50-6c8a3dde0f86

📥 Commits

Reviewing files that changed from the base of the PR and between bf478d3 and 98895fe.

📒 Files selected for processing (8)
  • egs/librispeech/ASR/zipformer/export-onnx-ctc.py
  • egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
  • egs/librispeech/ASR/zipformer/export-onnx-streaming.py
  • egs/librispeech/ASR/zipformer/export-onnx.py
  • egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py
  • egs/librispeech/ASR/zipformer/export.py
  • egs/librispeech/ASR/zipformer/scaling_converter.py
  • egs/librispeech/ASR/zipformer/zipformer.py

Comment on lines +238 to +246
def forward2(self, x: torch.Tensor):
x_lens = torch.tensor([x.shape[1]], dtype=torch.int32)
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]).to(torch.int32)
x = x.permute(1, 0, 2)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2)
encoder_out = self.encoder_proj(encoder_out)
return encoder_out

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

forward2() hardcodes batch size 1 while the exported signature keeps N dynamic.

x_lens = torch.tensor([x.shape[1]], ...) always has shape (1,), but this export mode still declares x/encoder_out with dynamic batch axis N. With --keep-x-lens 0, any N > 1 input will later build a (1, T) padding mask and trip the encoder's batch-size assertions. Either synthesize one length per batch item or make this mode explicitly batch-1-only.

Suggested fix
     def forward2(self, x: torch.Tensor):
-        x_lens = torch.tensor([x.shape[1]], dtype=torch.int32)
+        x_lens = torch.full(
+            (x.shape[0],),
+            x.shape[1],
+            dtype=torch.int32,
+            device=x.device,
+        )
         x, x_lens = self.encoder_embed(x, x_lens)

Also applies to: 381-383

🧰 Tools
🪛 Ruff (0.15.15)

[warning] 243-243: Unpacked variable encoder_out_lens is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py`
around lines 238 - 246, forward2 currently hardcodes a single-length tensor so
it only works for batch size 1; change it to synthesize one length per batch
item (or explicitly enforce batch-1) so the padding mask matches N.
Specifically, in forward2 (and the similar code at the other occurrence) replace
x_lens = torch.tensor([x.shape[1]], dtype=torch.int32) with something that
creates a vector of length x.shape[0] (e.g. torch.full((x.shape[0],),
x.shape[1], dtype=torch.int32) or equivalent), so make_pad_mask(x_lens,
x.shape[1]) produces an (N, T) mask consistent with encoder_embed, encoder and
encoder_proj.

Comment on lines +75 to +83
# seq_len >= kernel_size in non-streaming mode, so we pad with zeros
t = seq_len - self.kernel_size
channels = left_edge.shape[0]
pad = torch.zeros(
channels, t, device=left_edge.device, dtype=left_edge.dtype
)
left_edge = torch.cat((left_edge, pad), dim=-1)
right_edge = torch.cat((pad, right_edge), dim=-1)
chunk_scale = 1.0 + (left_edge + right_edge)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

The ONNX wrapper still breaks on short sequences.

seq_len here is the per-layer time axis after subsampling, so it can legitimately be smaller than kernel_size even in non-streaming inference. In that case t = seq_len - self.kernel_size goes negative and torch.zeros(channels, t, ...) raises. The replacement needs the same short-sequence handling as ChunkCausalDepthwiseConv1d, otherwise exported models reject short utterances.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@egs/librispeech/ASR/zipformer/scaling_converter.py` around lines 75 - 83, The
code fails when seq_len < self.kernel_size because t becomes negative; update
the logic in scaling_converter.py around left_edge/right_edge/seq_len to match
ChunkCausalDepthwiseConv1d's short-sequence handling: compute pad_len =
max(seq_len - self.kernel_size, 0) and use that for torch.zeros, and when
seq_len < self.kernel_size slice/truncate left_edge and right_edge to seq_len
before concatenation so shapes match; ensure chunk_scale = 1.0 + (left_edge +
right_edge) is computed after these adjustments.

Comment on lines 1351 to +1363
# If we are exporting a streaming model, then we skip the if statement
if not self.causal or not torch.jit.is_tracing():
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
src = torch.cat((src, src_extra), dim=0)

assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds)
if pad > 0:
src_extra = src[src.shape[0] - 1 :].expand(
pad, src.shape[1], src.shape[2]
)
src = torch.cat((src, src_extra), dim=0)
elif self.causal and torch.jit.is_scripting():
if pad > 0:
src_extra = src[src.shape[0] - 1 :].expand(
pad, src.shape[1], src.shape[2]
)
src = torch.cat((src, src_extra), dim=0)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Causal tracing skips the padding that reshape() still depends on.

When self.causal and torch.jit.is_tracing(), the first branch is false and the elif can never run, so src is left shorter than d_seq_len * ds. Any traced causal export with a non-multiple seq_len will then hit Line 1365 with the wrong element count.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@egs/librispeech/ASR/zipformer/zipformer.py` around lines 1351 - 1363, The
padding branch is skipped during the traced causal export, leaving src too short
for the subsequent reshape; change the conditional so padding runs in all cases
except when both causal and tracing are true. Concretely, replace the separate
if/elif blocks around src padding with a single check like "if pad > 0 and not
(self.causal and torch.jit.is_tracing()):" (or equivalent) so src_extra creation
and torch.cat((src, src_extra), dim=0) always execute when needed; update
references to self.causal, torch.jit.is_tracing(), torch.jit.is_scripting(),
pad, and src accordingly to ensure reshape() sees the correct element count.

@csukuangfj csukuangfj force-pushed the non-streaming-zipformer-qnn branch from 98895fe to e80ccb6 Compare June 5, 2026 07:05

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
egs/librispeech/ASR/zipformer/export-onnx-streaming.py (1)

831-847: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Remove the unconditional return after encoder export.

Line 847 exits main() before the decoder/joiner exports and before fp16/int8 generation, so this script now only emits the encoder ONNX and breaks the normal transducer export flow.

Suggested fix
     logging.info(f"Exported encoder to {encoder_filename}")
-    return
 
     logging.info("Exporting decoder")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@egs/librispeech/ASR/zipformer/export-onnx-streaming.py` around lines 831 -
847, The unconditional return at the end of the encoder export block causes
main() to exit early and prevents subsequent decoder/joiner exports and
fp16/int8 generation; remove that return so execution continues past
export_encoder_model_onnx (refer to encoder_filename, export_encoder_model_onnx,
params and the surrounding encoder export block in main()) allowing the
decoder/joiner export and any fp16/int8 steps to run normally.
♻️ Duplicate comments (1)
egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py (1)

238-246: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

--keep-x-lens 0 still hardcodes a batch-1 length vector.

Line 239 always builds x_lens with shape (1,), but Lines 380-383 still export x/encoder_out with dynamic batch axis N. With --keep-x-lens 0 --dynamic-axes 1, the padding mask is derived from a single length and no longer matches batched inputs. Either synthesize one length per batch item here, or make this mode explicitly batch-1-only.

Suggested fix
     def forward2(self, x: torch.Tensor):
-        x_lens = torch.tensor([x.shape[1]], dtype=torch.int32)
+        x_lens = torch.full(
+            (x.shape[0],),
+            x.shape[1],
+            dtype=torch.int32,
+            device=x.device,
+        )
         x, x_lens = self.encoder_embed(x, x_lens)
         src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]).to(torch.int32)

Also applies to: 364-383

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py`
around lines 238 - 246, forward2 currently hardcodes x_lens =
torch.tensor([x.shape[1]], dtype=torch.int32) which creates a single-length
vector and breaks when ONNX exports allow dynamic batch sizes; change it to
synthesize one length per batch (e.g., torch.full((x.size(0),), x.size(1),
dtype=torch.int32, device=x.device)) so make_pad_mask(...) and
src_key_padding_mask match the actual batch size, keep dtype/device consistent,
and apply the same fix to the other occurrence (the export block that builds
x/encoder_out) so encoder_embed, make_pad_mask and encoder receive per-example
x_lens rather than a batch-1 vector.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@egs/librispeech/ASR/zipformer/export-onnx-streaming.py`:
- Around line 831-847: The unconditional return at the end of the encoder export
block causes main() to exit early and prevents subsequent decoder/joiner exports
and fp16/int8 generation; remove that return so execution continues past
export_encoder_model_onnx (refer to encoder_filename, export_encoder_model_onnx,
params and the surrounding encoder export block in main()) allowing the
decoder/joiner export and any fp16/int8 steps to run normally.

---

Duplicate comments:
In `@egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py`:
- Around line 238-246: forward2 currently hardcodes x_lens =
torch.tensor([x.shape[1]], dtype=torch.int32) which creates a single-length
vector and breaks when ONNX exports allow dynamic batch sizes; change it to
synthesize one length per batch (e.g., torch.full((x.size(0),), x.size(1),
dtype=torch.int32, device=x.device)) so make_pad_mask(...) and
src_key_padding_mask match the actual batch size, keep dtype/device consistent,
and apply the same fix to the other occurrence (the export block that builds
x/encoder_out) so encoder_embed, make_pad_mask and encoder receive per-example
x_lens rather than a batch-1 vector.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 09c853e9-510a-41bc-a142-4e357487d3ab

📥 Commits

Reviewing files that changed from the base of the PR and between 98895fe and e80ccb6.

📒 Files selected for processing (2)
  • egs/librispeech/ASR/zipformer/export-onnx-streaming.py
  • egs/librispeech/ASR/zipformer/export-streaming-as-non-streaming-onnx.py

@csukuangfj csukuangfj merged commit 27e14b3 into k2-fsa:master Jun 5, 2026
9 of 33 checks passed
@csukuangfj csukuangfj deleted the non-streaming-zipformer-qnn branch June 5, 2026 07:31
@nshmyrev

Copy link
Copy Markdown
Contributor

@csukuangfj Hey thanks for the great feature. Between, exported model works for long files, but fails for file smaller than 3.0 seconds:

2026-06-15 15:49:57.413157084 [E:onnxruntime:, sequential_executor.cc:572 ExecuteKernel] Non-zero status code returned while running ConstantOfShape node. Name:'/encoder/3/encoder/0/conv_module1/depthwise_conv/ConstantOfShape_1' Status Message: /onnxruntime_src/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. tensor.cc:57 CalculateTensorStorageSize Tensor shape.Size() must be >= 0

Traceback (most recent call last):
  File "/root/models/vosk-model-ar-0.63/./decode.py", line 45, in <module>
    main()
    ~~~~^^
  File "/root/models/vosk-model-ar-0.63/./decode.py", line 39, in main
    recognizer.decode_stream(s)
    ~~~~~~~~~~~~~~~~~~~~~~~~^^^
  File "/root/venvs/k2-venv/lib/python3.13/site-packages/sherpa_onnx/offline_recognizer.py", line 1781, in decode_stream
    self.recognizer.decode_stream(s)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
RuntimeError: Non-zero status code returned while running ConstantOfShape node. Name:'/encoder/3/encoder/0/conv_module1/depthwise_conv/ConstantOfShape_1' Status Message: /onnxruntime_src/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorS

Not sure about the reason though

@csukuangfj

Copy link
Copy Markdown
Collaborator Author

@nshmyrev

Can you try the following commit
csukuangfj@bf4992a

If it works, we will create a PR from it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants