Skip to content

export a non-stream onnx model from a streaming pytorch model #1576

Description

@1215thebqtic

Hi,

I'm trying to export a non-stream onnx model from a streaming pytorch zipformer2 model. Training a non-stream zipformer2 model from scratch takes long time, so I decide to use "--chunk-size -1 --left-context-frames -1" as a non-stream model.

The streaming model was trained using causal=1.

The script I used to export the non-stream onnx model from a streaming pytorch model:

./zipformer/export-onnx.py \
  --tokens $tokenfile \
  --use-averaged-model 0 \
  --epoch 99 \
  --avg 1 \
  --exp-dir zipformer/exp_L_causal_context_2 \
  --num-encoder-layers "2,2,3,4,3,2" \
  --downsampling-factor "1,2,4,8,4,2" \
  --feedforward-dim "512,768,1024,1536,1024,768" \
  --num-heads "4,4,4,8,4,4" \
  --encoder-dim "192,256,384,512,384,256" \
  --query-head-dim 32 \
  --value-head-dim 12 \
  --pos-head-dim 4 \
  --pos-dim 48 \
  --encoder-unmasked-dim "192,192,256,256,256,192" \
  --cnn-module-kernel "31,31,15,15,15,31" \
  --decoder-dim 512 \
  --joiner-dim 512 \
  --causal True \
  --chunk-size -1 \
  --left-context-frames -1

When I use the following code to decode the onnx model:

./zipformer/onnx_pretrained.py \
  --encoder-model-filename $repo/encoder-epoch-99-avg-1.onnx \
  --decoder-model-filename $repo/decoder-epoch-99-avg-1.onnx \
  --joiner-model-filename $repo/joiner-epoch-99-avg-1.onnx \
  --tokens $tokenfile \
  icefall-asr-zipformer-streaming-wenetspeech-20230615/test_wavs/DEV_T0000000001.wav

An error occured:
broadcasting_error

the error node in netron:
onnx_node

According to the netron and zipformer code, I think it's because of the broadcasting in

x_chunk = x_chunk * chunk_scale

x_chunk's shape is (batch_size, num_channels, chunk_size), chunk_scale's shape is (num_channels, chunk_size).
I noticed that the streaming_forward also has the same code(
x_chunk = x_chunk * chunk_scale
), but there aren't any errors when exporting the streaming onnx model.

I deleted this line of code, and waves can be decoded successfully, the wers on my test dataset differ a little bit: 5.89 (pytorch) versus 5.61 (onnx) (pytorch decoding script: ./zipformer/pretrained.py; onnx decoding script: ./zipformer/onnx_pretrained.py)

And my questions are:

  1. Why does the broadcasting in non-stream mode lead to onnx errors, while no errors in streaming onnx model ?
  2. How do I change this line of code that I can avoid this error, and make the wer is same as the pytorch one?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions