diff --git a/apps/realtime-asr/stream_asr.py b/apps/realtime-asr/stream_asr.py index 91ad227..9a26db4 100644 --- a/apps/realtime-asr/stream_asr.py +++ b/apps/realtime-asr/stream_asr.py @@ -45,6 +45,12 @@ def __init__(self, config: StreamingConfig): self.offset = 0 self.total_frames_processed = 0 self.accumulated_text = "" # Accumulate text across all chunks + + # Transducer predictor caches + self.pred_cache_m = None + self.pred_cache_c = None + self.pred_input = None + self.reset_cache() # Audio capture - prefer PyAudio on macOS for better stability @@ -134,6 +140,16 @@ def reset_cache(self): self.total_frames_processed = 0 self.accumulated_text = "" # Reset accumulated text + # Reset transducer predictor cache if the model has transducer + if hasattr(self.model.model, "predictor"): + predictor = self.model.model.predictor + self.pred_cache_m, self.pred_cache_c = predictor.init_state( + batch_size=1, method="zero", device=self.device + ) + # Initialize with blank token + blank_id = self.model.model.blank + self.pred_input = torch.tensor([blank_id]).reshape(1, 1).to(self.device) + def extract_features(self, audio_chunk: np.ndarray) -> torch.Tensor: """Extract fbank features from audio chunk""" # Convert to torch tensor @@ -188,21 +204,86 @@ def process_chunk(self, audio_chunk: np.ndarray) -> Tuple[torch.Tensor, str]: def decode(self, encoder_out: torch.Tensor) -> str: """Decode encoder output to text""" text: str - if hasattr(self.model.model, "ctc"): + if self.model.config.model == "asr_model": # CTC decoding ctc_probs = self.model.model.ctc.log_softmax(encoder_out) # [B, T, vocab] topk = ctc_probs.argmax(dim=-1) # [B, T] hyps = [hyp.tolist() for hyp in topk] text = str(get_output(hyps, self.model.char_dict, self.model.config.model)[0]) - elif hasattr(self.model, "decoder"): - # Transducer or attention decoder - # Implement appropriate decoding here - text = "[Decoder output]" + elif self.model.config.model == "transducer": + # Transducer decoding using streaming optimized search + hyps = self.decode_transducer_streaming(encoder_out) + text = str(get_output([hyps], self.model.char_dict, self.model.config.model)[0]) else: text = "[Unknown decoder type]" return text + def decode_transducer_streaming(self, encoder_out: torch.Tensor, n_steps: int = 64) -> list: + """ + Streaming transducer decoder based on optimized_search. + + This function processes encoder output frame by frame and maintains + predictor state across chunks for streaming inference. + + Args: + encoder_out: Encoder output tensor [B=1, T, E] + n_steps: Maximum non-blank predictions per frame + + Returns: + List of predicted token IDs (without blanks) + """ + model = self.model.model + blank_id = model.blank + + max_len = encoder_out.size(1) + + # Use persistent predictor cache across chunks + cache_m = self.pred_cache_m + cache_c = self.pred_cache_c + pred_input = self.pred_input + + # Output buffer for this chunk + chunk_hyps = [] + + # Process each frame + for t in range(max_len): + encoder_out_t = encoder_out[:, t : t + 1, :] # [B=1, 1, E] + + # Allow up to n_steps non-blank predictions per frame + for step in range(1, n_steps + 1): + # Forward through predictor + pred_out_step, new_cache = model.predictor.forward_step( + pred_input, (cache_m, cache_c) + ) # [B=1, 1, P] + + # Forward through joint network + joint_out_step = model.joint(encoder_out_t, pred_out_step) # [B=1, 1, V] + joint_out_probs = joint_out_step.log_softmax(dim=-1) + + # Get best prediction + joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() # scalar + if joint_out_max == blank_id: + # Blank prediction - move to next frame + break + else: + # Non-blank prediction + chunk_hyps.append(joint_out_max.item()) + + # Update predictor input and cache for next step + pred_input = joint_out_max.reshape(1, 1) + cache_m, cache_c = new_cache + + # Check if we've reached max steps per frame + if step >= n_steps: + break + + # Update persistent cache for next chunk + self.pred_cache_m = cache_m + self.pred_cache_c = cache_c + self.pred_input = pred_input + return chunk_hyps + def run(self): """Main streaming loop""" print("\n" + "=" * 60) diff --git a/examples/asr/rnnt/conf/chunkformer-rnnt-small-vie-stream-dct.yaml b/examples/asr/rnnt/conf/chunkformer-rnnt-small-vie-stream-dct.yaml new file mode 100644 index 0000000..5827c6e --- /dev/null +++ b/examples/asr/rnnt/conf/chunkformer-rnnt-small-vie-stream-dct.yaml @@ -0,0 +1,137 @@ +# network architecture +# encoder related +encoder: chunkformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: dw_striding # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'chunk_rel_pos' + selfattention_layer_type: 'chunk_rel_seflattn' + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + # Enable the settings below for joint training on full and chunk context + dynamic_conv: true + dynamic_chunk_sizes: [4, 6, 8] # 1 frame 80ms - [320, 480, 640] + # Note that the left context is relative spaned depeding on the number of encoder layer + dynamic_left_context_sizes: [40, 50, 60] + dynamic_right_context_sizes: [0] # No right context for streaming + streaming: true + +joint: transducer_joint +joint_conf: + enc_output_size: 256 + pred_output_size: 256 + join_dim: 512 + prejoin_linear: True + postjoin_linear: false + joint_mode: 'add' + activation: 'tanh' + +predictor: rnn +predictor_conf: + embed_size: 256 + output_size: 256 + embed_dropout: 0.1 + hidden_size: 256 + num_layers: 2 + bias: true + rnn_type: 'lstm' + dropout: 0.1 + +decoder: bitransformer +decoder_conf: + attention_heads: 4 + dropout_rate: 0.1 + linear_units: 2048 + num_blocks: 3 + positional_dropout_rate: 0.1 + r_num_blocks: 3 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + +tokenizer: bpe +tokenizer_conf: + symbol_table_path: 'data/lang_char/train_hf_bpe1024_units.txt' + split_with_space: false + bpe_path: 'data/lang_char/train_hf_bpe1024.model' + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train_hf/global_cmvn' + is_json_cmvn: true + +# hybrid transducer+ctc+attention +model: transducer +model_conf: + transducer_weight: 0.75 + ctc_weight: 0.1 + attention_weight: 0.15 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + reverse_weight: 0.3 + enable_k2: True + +dataset: asr +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 400 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1000 + sort: False + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'dynamic' # static or dynamic + max_frames_in_batch: 300000 + pad_feat: True + + +grad_clip: 5 +accum_grad: 2 +max_epoch: 200 +log_interval: 100 + +optim: adamw +optim_conf: + lr: 0.001 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 15000