Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions chunkformer/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
from torch import nn

from chunkformer.utils.common import unfold_with_loop


class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Expand Down Expand Up @@ -265,6 +267,20 @@ def rel_shift(self, x, left_context_size: int = 0, right_context_size: int = 0):
storage_offset=n_stride * (time1 - 1),
)

def rel_shift_export(self, x, left_context_size: int = 0, right_context_size: int = 0):
(batch_size, num_heads, time1, n) = x.size()
time2 = time1 + left_context_size + right_context_size

rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(time2)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols

x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes.to(x.device))
x = x.reshape(batch_size, num_heads, time1, time2)
return x

def forward(
self,
query: torch.Tensor,
Expand All @@ -276,6 +292,7 @@ def forward(
chunk_size: int = 0,
left_context_size: int = 0,
right_context_size: int = 0,
export: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
Expand Down Expand Up @@ -337,7 +354,10 @@ def forward(
n_frames_pad = n_frames_pad % chunk_size
q = torch.nn.functional.pad(q, (0, 0, 0, 0, 0, n_frames_pad))
# [B, n_chunks, head, d_k, q_size]
q = q.unfold(1, size=chunk_size, step=chunk_size)
if export:
q = unfold_with_loop(q, 1, chunk_size, chunk_size)
else:
q = q.unfold(1, size=chunk_size, step=chunk_size)
# [B * n_chunks, head, d_k, q_size]
q = q.reshape(-1, q.size(2), q.size(3), q.size(4))
# [B * n_chunks,q_size, head, d_k]
Expand All @@ -350,9 +370,14 @@ def forward(
kv, (0, 0, left_context_size, n_frames_pad + right_context_size)
)
# [B, head, n_chunks, d_k * 2, l + c + r]
kv = kv.unfold(
2, size=left_context_size + chunk_size + right_context_size, step=chunk_size
)
if export:
kv = unfold_with_loop(
kv, 2, left_context_size + chunk_size + right_context_size, chunk_size
)
else:
kv = kv.unfold(
2, size=left_context_size + chunk_size + right_context_size, step=chunk_size
)
# [B, n_chunks, head, l + c + r, d_k * 2]
kv = kv.permute(0, 2, 1, 4, 3)
# [B * n_chunks, head, l + c + r, d_k * 2]
Expand All @@ -363,7 +388,10 @@ def forward(
# [B, 1, T + n_frames_pad]
mask_q = torch.nn.functional.pad(mask, (0, n_frames_pad))
# [B, 1, n_chunks, chunk_size]
mask_q = mask_q.unfold(-1, size=chunk_size, step=chunk_size)
if export:
mask_q = unfold_with_loop(mask_q, -1, chunk_size, chunk_size)
else:
mask_q = mask_q.unfold(-1, size=chunk_size, step=chunk_size)
# [B *n_chunks, chunk_size]
mask_q = mask_q.reshape(-1, mask_q.size(-1))

Expand All @@ -372,9 +400,14 @@ def forward(
mask, (left_context_size, n_frames_pad + right_context_size)
)
# [B, 1, n_chunks, chunk_size]
mask_kv = mask_kv.unfold(
-1, size=left_context_size + chunk_size + right_context_size, step=chunk_size
)
if export:
mask_kv = unfold_with_loop(
mask_kv, -1, left_context_size + chunk_size + right_context_size, chunk_size
)
else:
mask_kv = mask_kv.unfold(
-1, size=left_context_size + chunk_size + right_context_size, step=chunk_size
)
# [B, * n_chunks, chunk_size]
mask_kv = mask_kv.reshape(-1, mask_kv.size(3))

Expand Down Expand Up @@ -405,7 +438,10 @@ def forward(
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# Add relative shift with left and right context inclusion, it can stream
matrix_bd = self.rel_shift(matrix_bd, left_context_size, right_context_size)
if export:
matrix_bd = self.rel_shift_export(matrix_bd, left_context_size, right_context_size)
else:
matrix_bd = self.rel_shift(matrix_bd, left_context_size, right_context_size)

scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)

Expand Down
8 changes: 7 additions & 1 deletion chunkformer/modules/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch
from torch import nn

from chunkformer.utils.common import unfold_with_loop


class ChunkConvolutionModule(nn.Module):
"""ConvolutionModule in ChunkFormer model."""
Expand Down Expand Up @@ -104,6 +106,7 @@ def forward(
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
cache: torch.Tensor = torch.zeros((0, 0, 0)),
chunk_size: int = 0,
export: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
Expand Down Expand Up @@ -157,7 +160,10 @@ def forward(

n_chunks = ((x.size(2) - size) // step) + 1
# [B, C, n_chunks, size]
x = x.unfold(-1, size=size, step=step)
if export:
x = unfold_with_loop(x, -1, size=size, step=step)
else:
x = x.unfold(-1, size=size, step=step)
# [B, n_chunks, C, size]
x = x.transpose(1, 2)
# [B * n_chunks, C, size]
Expand Down
4 changes: 4 additions & 0 deletions chunkformer/modules/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def forward_encoder(
chunk_size: int = 0,
left_context_size: int = 0,
right_context_size: int = 0,
export: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.

Expand Down Expand Up @@ -258,6 +259,7 @@ def forward_encoder(
chunk_size=chunk_size,
left_context_size=left_context_size,
right_context_size=right_context_size,
export=export,
)
if self.normalize_before and self.final_norm:
xs = self.after_norm(xs)
Expand All @@ -275,6 +277,7 @@ def forward_layers(
chunk_size: int = 0,
left_context_size: int = 0,
right_context_size: int = 0,
export: bool = False,
) -> torch.Tensor:
for idx, layer in enumerate(self.encoders):
xs, chunk_masks, _, _ = layer(
Expand All @@ -285,6 +288,7 @@ def forward_layers(
chunk_size=chunk_size,
left_context_size=left_context_size,
right_context_size=right_context_size,
export=export,
)
return xs

Expand Down
6 changes: 5 additions & 1 deletion chunkformer/modules/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def forward(
chunk_size: int = 0,
left_context_size: int = 0,
right_context_size: int = 0,
export: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.

Expand Down Expand Up @@ -118,6 +119,7 @@ def forward(
chunk_size=chunk_size,
left_context_size=left_context_size,
right_context_size=right_context_size,
export=export,
)
x = residual + self.dropout(x_att)
if not self.normalize_before:
Expand All @@ -131,7 +133,9 @@ def forward(
if self.normalize_before:
x = self.norm_conv(x)

x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache, chunk_size=chunk_size)
x, new_cnn_cache = self.conv_module(
x, mask_pad, cnn_cache, chunk_size=chunk_size, export=export
)
x = residual + self.dropout(x)

if not self.normalize_before:
Expand Down
36 changes: 35 additions & 1 deletion chunkformer/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,48 @@

import math
import time
from typing import List, Tuple
from typing import List, Tuple, Union

import torch
from torch.nn.utils.rnn import pad_sequence

IGNORE_ID = -1


def unfold_with_loop(x: torch.Tensor, dimension: int, size: int, step: int) -> torch.Tensor:
"""
Replicates the functionality of torch.Tensor.unfold using slicing and a for-loop.

Args:
x: The input tensor to unfold (e.g., shape (B, H, L, D)).
dimension: The dimension along which to unfold (e.g., 2 for L).
size: The size of the sliding window (W).
step: The step between adjacent windows (S).

Returns:
The unfolded tensor. If original shape was (..., L, ...), the new shape
will be (..., num_windows, size, ...).
"""
dimension = dimension if dimension >= 0 else dimension + x.ndim
permute_lst = list(range(len(x.shape))) + [dimension]
permute_lst.pop(dimension)
x = x.permute(permute_lst)

if step <= 0 or size <= 0:
raise ValueError("Size and step must be positive integers.")

unfolded_chunks: List[torch.Tensor] = []
for start_index in range(0, x.size(-1) - size + 1, step):
end_index = start_index + size
slices: List[Union[slice, int]] = [slice(None)] * x.ndim
slices[-1] = slice(start_index, end_index)
chunk = x[slices]
unfolded_chunks.append(chunk)

unfolded_tensor = torch.stack(unfolded_chunks, dim=dimension)
return unfolded_tensor


def pad_list(xs: List[torch.Tensor], pad_value: int):
"""Perform padding for the list of tensors.

Expand Down
2 changes: 1 addition & 1 deletion chunkformer/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
[0, 0, 1, 1, 1]]
"""
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
max_len = max_len if max_len > 0 else lengths.max()
seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
Expand Down
19 changes: 14 additions & 5 deletions examples/asr/ctc/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ decode_checkpoint=$dir/final.pt
# maybe you can try to adjust it if you can not get close results as README.md
average_num=75
decode_modes="ctc_greedy_search"
# Specify decoding_chunk_size if it's a unified dynamic chunk trained model
# -1 for full chunk
decoding_chunk_size=64
num_decoding_left_chunks=128 # left_context_size = right_context_size

# bpemode (unigram or bpe)
nbpe=1024
Expand All @@ -51,7 +55,7 @@ bpemode=bpe
# To enable upload and set these variables:
hf_token="hf_xxxxxxxxxxxxxxxxxxxxxxxxx" # Your Hugging Face token
hf_repo_id="username/chunkformer-model" # Your repository ID

onnx=false
set -e
set -u
set -o pipefail
Expand Down Expand Up @@ -152,10 +156,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--src_path $dir \
--num ${average_num}
fi
# Specify decoding_chunk_size if it's a unified dynamic chunk trained model
# -1 for full chunk
decoding_chunk_size=64
num_decoding_left_chunks=128

ctc_weight=0.3
for test in $recog_set; do
Expand Down Expand Up @@ -266,6 +266,15 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Warning: $vocab not found for tokenizer folder."
fi

# Step 6: Export ONNX model if requested
if [ ${onnx} == true ]; then
echo "Exporting ONNX model using export_onnx.py..."
python chunkformer/tools/export_onnx.py "$inference_model_dir" \
--chunk_size ${decoding_chunk_size} \
--left_context_size ${num_decoding_left_chunks} \
--right_context_size ${num_decoding_left_chunks}
fi

echo "Model setup completed. Directory structure:"
ls -la $inference_model_dir
echo ""
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ dependencies = [
"deepspeed>=0.14.0",
"librosa",
"langid",
"onnx==1.19.0",
"onnx-ir==0.1.9",
"onnxruntime==1.23.0",
"onnxscript==0.5.2",
]
keywords = [
"speech-recognition",
Expand Down
Loading
Loading