diff --git a/recipe/README.md b/recipe/README.md
index 29fb403841a..fb006a0bc9d 100644
--- a/recipe/README.md
+++ b/recipe/README.md
@@ -24,3 +24,4 @@ The help the community reproduce experiments, verl team provides a snapshot of t
- [cognitive-behaviors](https://github.com/kanishkg/cognitive-behaviors): Cognitive Behaviors that Enable Self-Improving Reasoners, or, Four Habits of Highly Effective STaRs 
- [deepscaler](https://github.com/agentica-project/rllm/tree/deepscaler): iterative context scaling with GRPO 
- [DAPO](https://dapo-sia.github.io/): the fully open source SOTA RL algorithm that beats DeepSeek-R1-zero-32B 
+- [CosyVoice-TTS-GRPO](https://github.com/FunAudioLLM/CosyVoice/tree/main): Cosyvoice TTS GRPO fine-tuning recipe 
diff --git a/recipe/cosyvoice_tts/README.md b/recipe/cosyvoice_tts/README.md
new file mode 100644
index 00000000000..41cce1c1049
--- /dev/null
+++ b/recipe/cosyvoice_tts/README.md
@@ -0,0 +1,142 @@
+# CosyVoice2 LLM Reinforcement Learning Recipe
+
+This recipe shows how to train the **CosyVoice2** large language model with reinforcement learning algorithms such as **GRPO** using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the Seed-TTS test_zh set from 1.81% to 1.06%.
+
+We initialize the model from a Supervised Fine-Tuned (SFT) version of Qwen2-0.5B-Instruct and then continue training with reinforcement learning. Given an input sentence, the model predicts the corresponding CosyVoice2 speech tokens. For the SFT training recipe please refer to [PR #1887](https://github.com/k2-fsa/icefall/pull/1887).
+
+## Table of Contents
+
+- [Environment Setup](#environment-setup)
+- [Data Preparation](#data-preparation)
+- [Reward Function & ASR Server](#reward-function--asr-server)
+- [Training](#training)
+- [Evaluation](#evaluation)
+- [Single-Utterance Inference](#single-utterance-inference)
+- [Results](#results)
+- [Acknowledgement](#acknowledgement)
+
+## Environment Setup
+
+Stage `-1` of `run.sh` installs all required dependencies:
+
+```bash
+bash run.sh -1 -1 # run only stage -1
+```
+
+The script performs the following tasks:
+
+1. Clones and installs **veRL** (without Megatron).
+2. Checks out the **CosyVoice** source code to `/workspace/CosyVoice` and installs the Python packages from `requirements-cosyvoice.txt`.
+3. Downloads the TTS codec model `iic/CosyVoice2-0.5B` from **ModelScope** into `/workspace/CosyVoice2-0.5B`.
+4. Installs **PytritonSensevoice** together with **Pytriton**.
+5. Downloads the SFT-finetuned CosyVoice2-0.5B LLM whose vocabulary was extended on Emilia-Zh data.
+
+> [!TIP]
+> The **veRL** repository evolves quickly. To reproduce our results you can checkout this [specific commit](https://github.com/yuekaizhang/verl/tree/thread).
+
+## Data Preparation
+
+`prepare_data.py` expects a JSON/JSONL file with at least the following schema:
+
+```jsonc
+{
+ "text": "An example sentence to be synthesized."
+}
+```
+You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face.
+
+Stage `0` converts raw JSONL files into the parquet format expected by veRL:
+
+```bash
+bash run.sh 0 0
+```
+Create two JSONL files – `train.jsonl` and `test.jsonl`.
+The script will generate two parquet files:
+
+```
+data/parquet_tiny/train.parquet
+data/parquet_tiny/test.parquet
+```
+
+Each sample is automatically wrapped into a chat-style prompt with the special system token `<|SPEECH_GENERATION_START|>` so that the LLM learns to output CosyVoice2 speech tokens.
+
+> [!TIP]
+> For the `prompt_template` we recommend using the same configuration as during SFT training. See the corresponding setup [here](https://github.com/yuekaizhang/icefall/blob/emilia/egs/emilia/TTS/llasa_cosyvoice2_token/train.py#L84).
+
+## Reward Function & ASR Server
+
+To compute rewards we run a lightweight server that:
+
+1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model.
+2. Transcribes the waveform with **SenseVoice** ASR.
+3. Calculates the pinyin-level error rate against the ground-truth text and maps it to a score in the range \[0 … 1\].
+
+Start the server (stage `1`) in a dedicated terminal / GPU:
+
+```bash
+bash run.sh 1 1
+# Triton server listens on ports 8000/8001/8002
+```
+
+The custom reward implementation lives in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score.
+
+## Training
+
+Run stage `2` to start GRPO training:
+
+```bash
+bash run.sh 2 2
+```
+
+Key CLI arguments passed to `verl.trainer.main_ppo`:
+
+* `algorithm.adv_estimator=grpo` – use GRPO instead of PPO.
+* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet`
+* `actor_rollout_ref.model.path=/workspace/rl/llasa_cosyvoice2_token_qwen_0.5b/checkpoint-885000` – path to the pretrained CosyVoice2 LLM.
+* `custom_reward_function.path=reward_tts.py` – custom reward function described above.
+* `trainer.total_epochs=1` – number of training epochs (adjust as needed).
+
+Tune `CUDA_VISIBLE_DEVICES`, batch sizes and other hyper-parameters according to your hardware.
+
+## Evaluation
+
+After training finishes we gather the sharded FSDP weights and export a HuggingFace-style checkpoint (stage `3`):
+
+```bash
+bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model
+```
+
+We can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`):
+
+```bash
+bash run.sh 4 4
+```
+
+This command launches distributed inference via `infer_dist.py` and computes WER with `scripts/compute_wer.sh`.
+
+> [!TIP]
+> The script also supports the Seed-TTS test set by setting `dataset=test_zh`.
+
+## Single-Utterance Inference
+
+For a quick demo run stage `5`:
+
+```bash
+bash run.sh 5 5
+```
+
+The script synthesizes a tongue-twister using the merged checkpoint and prints the path of the generated audio file.
+
+## Results
+
+| Model | Seed-TTS `test_zh` CER | Cosyvoice3 `zero_shot_zh` |Comment |
+|-|------------------------------------------------------|------------------------|--------------------------------------------------------------------------------|
+| Official CosyVoice2 LLM | 1.45 % |4.08%| See the [paper](https://arxiv.org/abs/2412.10117) |
+| SFT (initialized from Qwen2-0.5B-Instruct) | 1.81 % |4.83%| See [PR #1887](https://github.com/k2-fsa/icefall/pull/1887) |
+| GRPO (this work, trained on AIShell-3) | **1.06 %** |4.03%| |
+
+## Acknowledgement
+
+This work is inspired by the implementation in
+https://github.com/channel-io/ch-tts-llasa-rl-grpo
+
diff --git a/recipe/cosyvoice_tts/assets/prompt_audio.wav b/recipe/cosyvoice_tts/assets/prompt_audio.wav
new file mode 100644
index 00000000000..a6481ea611b
Binary files /dev/null and b/recipe/cosyvoice_tts/assets/prompt_audio.wav differ
diff --git a/recipe/cosyvoice_tts/infer.py b/recipe/cosyvoice_tts/infer.py
new file mode 100644
index 00000000000..c9d9d4164ed
--- /dev/null
+++ b/recipe/cosyvoice_tts/infer.py
@@ -0,0 +1,185 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from transformers import AutoTokenizer, AutoModelForCausalLM
+import torch
+import soundfile as sf
+from cosyvoice.cli.cosyvoice import CosyVoice2
+from cosyvoice.utils.file_utils import load_wav
+from argparse import ArgumentParser
+import sys
+
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"
+
+def get_args():
+ parser = ArgumentParser()
+
+ parser.add_argument(
+ "--token2wav-path",
+ type=str,
+ default=None,
+ help="Token2Wav path, default to %(default)r",
+ )
+ parser.add_argument(
+ "--prompt-text",
+ type=str,
+ default="Romeo and Juliet might be the most famous act of William Shakespeare.",
+ help="The prompt text",
+ )
+
+ parser.add_argument(
+ "--prompt-speech-path",
+ type=str,
+ default="./assets/common_voice_en_2586258.wav",
+ help="The path to the prompt speech",
+ )
+ parser.add_argument(
+ "--input-text",
+ type=str,
+ default='突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"',
+ help="The input text",
+ )
+ parser.add_argument(
+ "--model-path",
+ type=str,
+ default='/workspace/rl/llasa_cosyvoice2_token_qwen_0.5b/checkpoint-885000',
+ help="The path to the model",
+ )
+ args = parser.parse_args()
+ return args
+
+args = get_args()
+
+def audio_decode_cosyvoice2(
+ audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
+):
+ """
+ Generate audio from tokens with optional tone and prompt embedding.
+
+ Args:
+ audio_tokens (list): List of audio tokens to be processed.
+ model_config: Configuration object containing vocab settings.
+ codec_decoder: Codec decoder for generating audio.
+ tone_dir (str): The tone directory or setting.
+ audio_prompt_path (str, optional): Path to the audio prompt file. Required when tone_dir is not "default_tone".
+ code_layer (int, optional): Number of code layers. Defaults to 1.
+ num_latency_tokens (int, optional): Number of latency tokens to ignore. Defaults to 0.
+ speed (float, optional): Speed factor for audio generation. Defaults to 1.0.
+
+ Returns:
+ torch.Tensor: Generated audio waveform.
+ """
+ model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
+ "empty", prompt_text, prompt_speech_16k, 24000
+ )
+ tts_mel, _ = codec_decoder.model.flow.inference(
+ token=audio_tokens.to(codec_decoder.model.device),
+ token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
+ codec_decoder.model.device
+ ),
+ prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
+ codec_decoder.model.device
+ ),
+ prompt_token_len=torch.tensor(
+ [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
+ ).to(codec_decoder.model.device),
+ prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
+ codec_decoder.model.device
+ ),
+ prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
+ codec_decoder.model.device
+ ),
+ embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
+ finalize=True,
+ )
+
+ audio_hat, _ = codec_decoder.model.hift.inference(
+ speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+ )
+
+ return audio_hat
+
+def extract_speech_ids(speech_tokens_str):
+
+ speech_ids = []
+ for token_str in speech_tokens_str:
+ if token_str.startswith('<|s_') and token_str.endswith('|>'):
+ num_str = token_str[4:-2]
+
+ num = int(num_str)
+ speech_ids.append(num)
+ else:
+ print(f"Unexpected token: {token_str}")
+ return speech_ids
+
+
+
+tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+model = AutoModelForCausalLM.from_pretrained(args.model_path)
+model.eval()
+model.to('cuda')
+
+token2wav_model = CosyVoice2(
+ args.token2wav_path, load_jit=False, load_trt=False, fp16=False
+)
+
+prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
+
+with torch.no_grad():
+ # Tokenize the text
+ chat = [
+ {"role": "user", "content": f"{args.input_text}"},
+ {"role": "assistant", "content": ""}
+ ]
+ if 'system' in tokenizer.chat_template:
+ tokenizer.chat_template = TEMPLATE
+ input_ids = tokenizer.apply_chat_template(
+ chat,
+ tokenize=True,
+ return_tensors='pt',
+ continue_final_message=True
+ )
+ input_ids = input_ids.to('cuda')
+
+ # Generate the speech autoregressively
+ outputs = model.generate(
+ input_ids,
+ max_length=2048, # We trained our model with a max length of 2048
+ do_sample=True,
+ top_p=1, # Adjusts the diversity of generated content
+ temperature=0.8, # Controls randomness in output
+ )
+ # Extract the speech tokens
+ generated_ids = outputs[0][input_ids.shape[1]:-1]
+
+ speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+
+ # Convert token <|s_23456|> to int 23456
+ speech_tokens = extract_speech_ids(speech_tokens)
+
+ speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0)
+
+
+ audio_hat = audio_decode_cosyvoice2(
+ speech_tokens,
+ args.prompt_text,
+ prompt_speech_16k,
+ token2wav_model,
+ )
+
+ audio = audio_hat.squeeze(0).cpu().numpy()
+
+
+sf.write("gen.wav", audio, 24000)
diff --git a/recipe/cosyvoice_tts/infer_dataset.py b/recipe/cosyvoice_tts/infer_dataset.py
new file mode 100644
index 00000000000..40c968d27af
--- /dev/null
+++ b/recipe/cosyvoice_tts/infer_dataset.py
@@ -0,0 +1,399 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Example Usage
+dataset=zero_shot_zh
+output_dir=./outputs_rl_aishell3_step${step}_${dataset}_jit_trt_fp16_reward_tts
+
+token2wav_path=/workspace/CosyVoice2-0.5B
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+torchrun --nproc_per_node=8 \
+ infer_dataset.py \
+ --output-dir $output_dir \
+ --llm-model-name-or-path $llm_path/merged_hf_model \
+ --token2wav-path $token2wav_path \
+ --split-name ${dataset} || exit 1
+"""
+
+import argparse
+import json
+import os
+import sys
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import torchaudio
+from cosyvoice.cli.cosyvoice import CosyVoice2
+from cosyvoice.utils.file_utils import load_wav
+from datasets import load_dataset
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from torch.utils.data import DataLoader, Dataset, DistributedSampler
+from tqdm import tqdm
+import soundfile as sf
+import s3tokenizer
+from functools import partial
+
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+try:
+ torch.multiprocessing.set_start_method("spawn")
+except RuntimeError:
+ pass
+
+
+TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"
+
+
+def audio_decode_cosyvoice2(
+ audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
+):
+ """
+ Generate audio from tokens with optional tone and prompt embedding.
+ """
+ model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
+ "empty", prompt_text, prompt_speech_16k, 24000
+ )
+ tts_mel, _ = codec_decoder.model.flow.inference(
+ token=audio_tokens.to(codec_decoder.model.device),
+ token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
+ codec_decoder.model.device
+ ),
+ prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
+ codec_decoder.model.device
+ ),
+ prompt_token_len=torch.tensor(
+ [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
+ ).to(codec_decoder.model.device),
+ prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
+ codec_decoder.model.device
+ ),
+ prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
+ codec_decoder.model.device
+ ),
+ embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
+ finalize=True,
+ )
+
+ audio_hat, _ = codec_decoder.model.hift.inference(
+ speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+ )
+
+ return audio_hat
+
+
+def extract_speech_ids(speech_tokens_str):
+ """Extract speech IDs from token strings like <|s_23456|>"""
+ speech_ids = []
+ for token_str in speech_tokens_str:
+ if token_str.startswith('<|s_') and token_str.endswith('|>'):
+ num_str = token_str[4:-2]
+ num = int(num_str)
+ speech_ids.append(num)
+ else:
+ print(f"Unexpected token: {token_str}")
+ return speech_ids
+
+def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
+ """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
+ speech_id_str = ""
+ for token in cosy2_tokens:
+ speech_id_str += f"<|s_{token}|>"
+ return speech_id_str
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
+ parser.add_argument(
+ "--split-name",
+ type=str,
+ default="wenetspeech4tts",
+ help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
+ )
+ parser.add_argument(
+ "--output-dir", required=True, type=str, help="dir to save result"
+ )
+ parser.add_argument(
+ "--batch-size",
+ default=1,
+ type=int,
+ help="batch size (per-device) for inference",
+ )
+ parser.add_argument(
+ "--num-workers", type=int, default=1, help="workers for dataloader"
+ )
+ parser.add_argument(
+ "--prefetch", type=int, default=5, help="prefetch for dataloader"
+ )
+ parser.add_argument(
+ "--llm-model-name-or-path",
+ required=True,
+ type=str,
+ help="LLM model path (includes both model and tokenizer)",
+ )
+ parser.add_argument(
+ "--token2wav-path",
+ required=True,
+ type=str,
+ help="CosyVoice2 token2wav model path",
+ )
+ parser.add_argument(
+ "--prompt-text",
+ type=str,
+ default=None,
+ help="The prompt text for CosyVoice2",
+ )
+ parser.add_argument(
+ "--prompt-speech-path",
+ type=str,
+ default=None,
+ help="The path to the prompt speech for CosyVoice2",
+ )
+ parser.add_argument(
+ "--top-p",
+ type=float,
+ default=0.95,
+ help="top p for sampling",
+ )
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ default=0.8,
+ help="temperature for sampling",
+ )
+ parser.add_argument(
+ "--top-k",
+ type=int,
+ default=50,
+ help="top k for sampling",
+ )
+ args = parser.parse_args()
+ return args
+
+
+
+def data_collator(batch, tokenizer, s3_tokenizer):
+ """Simplified data collator for batch_size=1 processing"""
+ target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
+ device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
+ input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
+ mels, prompt_audio_cosy2tokens_list = [], []
+ for i, item in enumerate(batch):
+ prompt_text, target_text = (
+ item["prompt_text"],
+ item["target_text"],
+ )
+ prompt_text_list.append(prompt_text)
+ # Combine prompt and target text
+ full_text = prompt_text + target_text
+
+ # get prompt audio for CosyVoice2 (convert to 16kHz)
+ ref_audio_org, ref_sr = (
+ item["prompt_audio"]["array"],
+ item["prompt_audio"]["sampling_rate"],
+ )
+ ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
+ # ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
+ print(ref_audio_org.shape)
+
+ if ref_sr != target_sample_rate:
+ resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
+ ref_audio = resampler(ref_audio_org)
+ else:
+ ref_audio = ref_audio_org
+
+ prompt_audio_list.append(ref_audio)
+
+ if "prompt_audio_cosy2_tokens" in item:
+ prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
+ prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
+ else:
+ # convert to float first
+ mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
+
+ if len(mels) > 0:
+ mels, mels_lens = s3tokenizer.padding(mels)
+ codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
+ for i in range(len(codes)):
+ prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
+ for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
+ prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
+ # Create chat template for LLM generation
+ chat = [
+ {"role": "user", "content": full_text},
+ {"role": "assistant", "content": prompt_audio_cosy2_id_str}
+ ]
+ if 'system' in tokenizer.chat_template:
+ tokenizer.chat_template = TEMPLATE
+ input_ids = tokenizer.apply_chat_template(
+ chat,
+ tokenize=True,
+ return_tensors='pt',
+ continue_final_message=True
+ )
+ input_ids_list.append(input_ids.squeeze(0))
+
+
+ # For batch_size=1, no need to pad
+ if len(input_ids_list) == 1:
+ input_ids = input_ids_list[0].unsqueeze(0)
+ else:
+ # Handle batch > 1 if needed
+ max_len = max([len(input_ids) for input_ids in input_ids_list])
+ input_ids_list = [
+ torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
+ for input_ids in input_ids_list
+ ]
+ input_ids = torch.stack(input_ids_list)
+
+ ids = [item["id"] for item in batch]
+
+ return {
+ "input_ids": input_ids,
+ "ids": ids,
+ "prompt_text": prompt_text_list,
+ "prompt_audio_list": prompt_audio_list,
+ }
+
+
+def init_distributed():
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
+ rank = int(os.environ.get("RANK", 0))
+ print(
+ "Inference on multiple gpus, this gpu {}".format(local_rank)
+ + ", rank {}, world_size {}".format(rank, world_size)
+ )
+ torch.cuda.set_device(local_rank)
+ dist.init_process_group("nccl")
+ return world_size, local_rank, rank
+
+
+def main():
+ args = get_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ assert torch.cuda.is_available()
+ world_size, local_rank, rank = init_distributed()
+ device = torch.device(f"cuda:{local_rank}")
+
+ # Load LLM model and tokenizer directly
+ tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
+ model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
+ model.eval()
+ model.to(device)
+
+ cosyvoice_codec = CosyVoice2(
+ args.token2wav_path, load_jit=True, load_trt=True, fp16=True
+ )
+ if args.prompt_speech_path:
+ prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
+ else:
+ prompt_speech_16k = None
+ s3_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").to(device) if 'zero' in args.split_name else None
+ dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
+ dataset = load_dataset(
+ dataset_name,
+ split=args.split_name,
+ trust_remote_code=True,
+ )
+
+ sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ sampler=sampler,
+ shuffle=False,
+ num_workers=args.num_workers,
+ prefetch_factor=args.prefetch,
+ collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
+ )
+
+ total_steps = len(dataset)
+
+ if rank == 0:
+ progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
+
+ for batch in dataloader:
+ with torch.no_grad():
+ input_ids = batch["input_ids"].to(device)
+
+ # Generate speech tokens using LLM
+ outputs = model.generate(
+ input_ids,
+ max_new_tokens=2048, # Max length for generation
+ do_sample=True,
+ top_p=args.top_p,
+ temperature=args.temperature,
+ top_k=args.top_k,
+ )
+
+ # Process each sample in the batch
+ for i in range(len(batch["ids"])):
+ # Extract generated tokens (excluding input)
+ input_length = input_ids[i].shape[0]
+ generated_ids = outputs[i][input_length:-1] # Remove last token if needed
+ speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+
+ # Extract speech IDs from token strings like <|s_23456|>
+ speech_ids = extract_speech_ids(speech_tokens_str)
+
+ if len(speech_ids) == 0:
+ print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
+ continue
+
+ # Convert to tensor for CosyVoice2
+ audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
+
+ if args.prompt_text is not None:
+ current_prompt_text = args.prompt_text
+ current_prompt_audio = prompt_speech_16k
+ else:
+ current_prompt_text = batch["prompt_text"][i]
+ current_prompt_audio = batch["prompt_audio_list"][i]
+
+ if current_prompt_audio is not None:
+ # Generate audio using CosyVoice2
+ audio_hat = audio_decode_cosyvoice2(
+ audio_tokens,
+ current_prompt_text,
+ current_prompt_audio,
+ cosyvoice_codec,
+ )
+
+ # Convert to numpy and save
+ generated_wave = audio_hat.squeeze(0).cpu().numpy()
+ target_sample_rate = 24000
+
+ utt = batch["ids"][i]
+ sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
+
+ print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
+ else:
+ print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
+
+
+ if rank == 0:
+ progress_bar.update(world_size * len(batch["ids"]))
+
+ if rank == 0:
+ progress_bar.close()
+
+ dist.barrier()
+ dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/recipe/cosyvoice_tts/prepare_data.py b/recipe/cosyvoice_tts/prepare_data.py
new file mode 100644
index 00000000000..e63ae47f7da
--- /dev/null
+++ b/recipe/cosyvoice_tts/prepare_data.py
@@ -0,0 +1,88 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Preprocess the Text to Speech dataset to parquet format
+"""
+
+import argparse
+import os
+import re
+
+import datasets
+
+from verl.utils.hdfs_io import copy, makedirs
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")
+ parser.add_argument("--test_file", required=True, help="Path to test JSON/JSONL file")
+ parser.add_argument("--local_dir", default=None, required=True)
+ parser.add_argument("--hdfs_dir", default=None)
+
+ args = parser.parse_args()
+
+ # Load datasets from local JSON files
+ train_dataset = datasets.load_dataset("json", data_files=args.train_file)['train']
+ test_dataset = datasets.load_dataset("json", data_files=args.test_file)['train']
+
+ # add a row to each data item that represents a unique id
+ def make_map_fn(split):
+ def process_fn(example, idx):
+ text = example.pop("text")
+
+ # use cosyvoice2 official huggingface compatible checkpoint template
+ question = text
+ answer = ""
+
+ data = {
+ "data_source": f"{args.train_file}_{args.test_file}", # Use file names as data source
+ "prompt": [
+ {
+ "role": "user",
+ "content": question,
+ },
+ {
+ "role": "assistant",
+ "content": answer,
+ },
+ ],
+ "ability": "text-to-speech",
+ "reward_model": {"style": "rule", "ground_truth": text},
+ "extra_info": {
+ "split": split,
+ "index": idx,
+ "text": text,
+ },
+ }
+ return data
+
+ return process_fn
+
+ train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
+ test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
+
+ local_dir = args.local_dir
+ hdfs_dir = args.hdfs_dir
+
+ print(train_dataset)
+ print(test_dataset)
+ train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
+ test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
+
+ if hdfs_dir is not None:
+ makedirs(hdfs_dir)
+
+ copy(src=local_dir, dst=hdfs_dir)
diff --git a/recipe/cosyvoice_tts/pretrained_to_huggingface.py b/recipe/cosyvoice_tts/pretrained_to_huggingface.py
new file mode 100644
index 00000000000..5034f5626ef
--- /dev/null
+++ b/recipe/cosyvoice_tts/pretrained_to_huggingface.py
@@ -0,0 +1,124 @@
+
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage: Instruct TTS
+ python3 infer.py \
+ --token2wav-path /workspace/CosyVoice2-0.5B \
+ --prompt-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
+ --prompt-speech-path ./assets/prompt_audio.wav \
+ --model-path ./transformers_cosyvoice2_llm \
+ --input-text "用四川话说<|endofprompt|>扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。"
+"""
+from cosyvoice.cli.cosyvoice import CosyVoice2
+import sys
+from argparse import ArgumentParser
+from transformers import AutoTokenizer, AutoModelForCausalLM
+import torch
+
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+
+
+
+def get_args():
+ parser = ArgumentParser()
+
+ parser.add_argument(
+ "--pretrained-cosyvoice2-path",
+ type=str,
+ default="/workspace/CosyVoice2-0.5B",
+ help="Token2Wav path, default to %(default)r",
+ )
+ parser.add_argument(
+ "--save-path",
+ type=str,
+ default='./transformers_cosyvoice2_llm',
+ help="The path to save the model",
+ )
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+ args = get_args()
+ cosy2_model = CosyVoice2(
+ args.pretrained_cosyvoice2_path, load_jit=False, load_trt=False, fp16=False
+ )
+
+ llm = cosy2_model.model.llm.llm.model
+
+ speech_embedding = cosy2_model.model.llm.speech_embedding
+ llm_decoder = cosy2_model.model.llm.llm_decoder
+ llm_embedding = cosy2_model.model.llm.llm_embedding
+
+ tokenizer = AutoTokenizer.from_pretrained(f"{args.pretrained_cosyvoice2_path}/CosyVoice-BlankEN")
+ special_tokens = {
+ 'eos_token': '<|endoftext|>',
+ 'pad_token': '<|endoftext|>',
+ 'additional_special_tokens': [
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
+ '[breath]', '', '', '[noise]',
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
+ '[quick_breath]',
+ "", "",
+ "[hissing]", "[sigh]", "[vocalized-noise]",
+ "[lipsmack]", "[mn]"
+ ]
+ }
+ tokenizer.add_special_tokens(special_tokens)
+
+ original_tokenizer_vocab_size = len(tokenizer)
+ cosyvoice2_token_size = 6561
+ new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
+ "<|eos1|>", "<|eos2|>", "<|eos3|>", "<|sos|>", "<|task_id|>"
+ ]
+ num_added_tokens = tokenizer.add_tokens(new_tokens)
+
+ llm.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128)
+ vocab_size = llm.get_input_embeddings().weight.shape[0]
+
+ feature_size = speech_embedding.embedding_dim
+ new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=vocab_size, bias=True)
+
+ with torch.no_grad():
+ # set the weight and bias of the new lm_head to 0
+ new_lm_head.weight.data.zero_()
+ new_lm_head.bias.data.zero_()
+ new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.weight
+ new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.bias
+
+ llm.lm_head = new_lm_head
+ input_embeddings = llm.get_input_embeddings()
+
+ with torch.no_grad():
+ input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = speech_embedding.weight
+ input_embeddings.weight[original_tokenizer_vocab_size+cosyvoice2_token_size+3:original_tokenizer_vocab_size+cosyvoice2_token_size+3+2] = llm_embedding.weight
+
+ eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size, original_tokenizer_vocab_size + cosyvoice2_token_size + 1, original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
+ llm.generation_config.eos_token_id = eos_token_ids
+ llm.generation_config.temperature = 1.0
+ llm.generation_config.top_p = 0.8
+ llm.generation_config.top_k = 25
+
+ llm.config.eos_token_id = original_tokenizer_vocab_size + cosyvoice2_token_size
+ llm.config.vocab_size = vocab_size
+ llm.config.tie_word_embeddings = False
+ llm.config.use_bias = True
+ llm.to(torch.bfloat16)
+ llm.save_pretrained(args.save_path)
+
+ TEMPLATE = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- '<|sos|>' + message['content'] + '<|task_id|>' }}{%- elif message['role'] == 'assistant' %}{{- message['content']}}{%- endif %}{%- endfor %}"
+ tokenizer.chat_template = TEMPLATE
+ tokenizer.save_pretrained(args.save_path)
\ No newline at end of file
diff --git a/recipe/cosyvoice_tts/requirements-cosyvoice.txt b/recipe/cosyvoice_tts/requirements-cosyvoice.txt
new file mode 100644
index 00000000000..73a449dd26a
--- /dev/null
+++ b/recipe/cosyvoice_tts/requirements-cosyvoice.txt
@@ -0,0 +1,29 @@
+conformer==0.3.2
+diffusers==0.29.0
+gdown==5.1.0
+gradio
+hydra-core==1.3.2
+HyperPyYAML==1.2.2
+inflect==7.3.1
+librosa==0.10.2
+lightning==2.2.4
+matplotlib==3.7.5
+modelscope==1.15.0
+networkx==3.1
+omegaconf==2.3.0
+onnx==1.16.0
+onnxruntime-gpu==1.18.0
+protobuf==4.25
+pydantic==2.7.0
+pyworld==0.3.4
+rich==13.7.1
+soundfile==0.12.1
+tensorboard==2.14.0
+wget==3.2
+WeTextProcessing==1.0.3
+s3tokenizer
+tensorrt
+sherpa_onnx
+jiwer
+zhon
+numpy==1.25.2
\ No newline at end of file
diff --git a/recipe/cosyvoice_tts/reward_tts.py b/recipe/cosyvoice_tts/reward_tts.py
new file mode 100644
index 00000000000..f49dc6d0dce
--- /dev/null
+++ b/recipe/cosyvoice_tts/reward_tts.py
@@ -0,0 +1,230 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Reward calculation for CosyVoice2-0.5B.
+"""
+
+from __future__ import annotations
+
+import os, re, warnings, json, time, argparse
+from typing import List
+
+import numpy as np
+import requests
+
+
+REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
+
+
+def _parse_ids(token_str: str) -> List[int]:
+ return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
+
+def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
+ """Send token IDs and ground-truth text to the Triton server and get reward."""
+
+ tokens_arr = np.array(tokens, dtype=np.int32).reshape(1, -1)
+ lens_arr = np.array([[tokens_arr.shape[1]]], dtype=np.int32)
+
+ gt_arr = np.array([ground_truth.encode("utf-8")], dtype=object)
+
+ payload = {
+ "inputs": [
+ {
+ "name": "TOKENS",
+ "shape": list(tokens_arr.shape),
+ "datatype": "INT32",
+ "data": tokens_arr.tolist(),
+ },
+ {
+ "name": "TOKEN_LENS",
+ "shape": list(lens_arr.shape),
+ "datatype": "INT32",
+ "data": lens_arr.tolist(),
+ },
+ {
+ "name": "GT_TEXT",
+ "shape": [1, 1],
+ "datatype": "BYTES",
+ "data": [ground_truth],
+ },
+ ]
+ }
+ rsp = requests.post(
+ REWARD_SERVER_URL,
+ headers={"Content-Type": "application/json"},
+ json=payload,
+ timeout=timeout,
+ verify=False,
+ params={"request_id": "0"},
+ )
+ rsp.raise_for_status()
+ result = rsp.json()
+
+ try:
+ # Reward is returned as the first output
+ return float(result["outputs"][0]["data"][0])
+ except (KeyError, IndexError, TypeError):
+ return 0.0
+
+
+def compute_score(
+ data_source: str,
+ solution_str: str,
+ ground_truth: str,
+ extra_info: dict | None = None,
+ *,
+ debug_dump: bool = False,
+) -> float:
+ """Return reward in [0, 1] using the Triton ASR service.
+
+ The reward is based on the pinyin-level WER between the ASR transcript
+ produced from *solution_str* and the provided *ground_truth* text.
+ """
+
+ # Decode token IDs
+ ids = _parse_ids(solution_str)
+
+ # Query remote server for reward
+ try:
+ reward = _remote_reward(ids, ground_truth)
+ except Exception as e:
+ warnings.warn(f"Remote reward server error: {e}; returning 0.0")
+ reward = 0.0
+
+ if debug_dump:
+ print(
+ f"\033[92m[{data_source}] Remote reward: {reward:.4f}\033[0m"
+ )
+
+ return reward
+
+# CLI quick test
+if __name__ == "__main__":
+ import sys
+
+ def get_args():
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Test TTS CER scoring with data from JSONL file",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--input", "-i",
+ type=str,
+ default="data/emilia_zh-cosy-tiny-test.jsonl",
+ help="Path to input JSONL file"
+ )
+
+ parser.add_argument(
+ "--max-samples", "-n",
+ type=int,
+ default=None,
+ help="Maximum number of samples to process (default: all)"
+ )
+
+ parser.add_argument(
+ "--no-interactive",
+ action="store_true",
+ help="Run in non-interactive mode (process all samples without prompts)"
+ )
+
+
+ parser.add_argument(
+ "--debug",
+ action="store_true",
+ help="Enable debug mode"
+ )
+
+ return parser.parse_args()
+
+ def load_jsonl(file_path: str):
+ """Load data from jsonl file."""
+ data = []
+ with open(file_path, 'r', encoding='utf-8') as f:
+ for line in f:
+ data.append(json.loads(line.strip()))
+ return data
+
+ def code_to_solution_str(code_list: List[int]) -> str:
+ """Convert code list to solution string format."""
+ return ''.join([f"<|s_{code}|>" for code in code_list])
+
+ # Parse command line arguments
+ args = get_args()
+
+ try:
+ # Load data from jsonl file
+ print(f"Loading data from: {args.input}")
+ data_list = load_jsonl(args.input)
+ print(f"Loaded {len(data_list)} samples")
+
+ # Limit samples if specified
+ if args.max_samples is not None:
+ data_list = data_list[:args.max_samples]
+ print(f"Processing first {len(data_list)} samples (limited by --max-samples)")
+
+ # Process each sample
+ begin_time = time.time()
+ for i, sample in enumerate(data_list):
+ print(f"\n--- Sample {i+1}/{len(data_list)} ---")
+ print(f"Index: {sample.get('index', 'unknown')}")
+ print(f"Text: {sample['text']}")
+
+ # Extract required fields
+ code_list = sample['code']
+ ground_truth = sample['text']
+ data_source = sample.get('index', f'sample_{i}') # Use index as data_source
+
+ # Convert code list to solution string
+ solution_str = code_to_solution_str(code_list)
+ print(f"Solution tokens: {len(code_list)} tokens")
+ if args.debug:
+ print(f"Solution string: {solution_str}")
+ else:
+ print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}")
+
+ # Call compute_score function
+ try:
+ score = compute_score(
+ data_source=data_source,
+ solution_str=solution_str,
+ ground_truth=ground_truth,
+ extra_info=None,
+ debug_dump=args.debug
+ )
+ print(f"Final Score: {score:.4f}")
+ except Exception as e:
+ print(f"Error computing score: {e}")
+
+ # Ask user if they want to continue (for interactive mode)
+ if not args.no_interactive and i < len(data_list) - 1:
+ try:
+ response = input("\nPress Enter to continue or 'q' to quit: ").strip().lower()
+ if response == 'q':
+ break
+ except KeyboardInterrupt:
+ print("\nStopped by user")
+ break
+
+ print(f"\nProcessed {min(i+1, len(data_list))} samples")
+ end_time = time.time()
+ print(f"Time taken: {end_time - begin_time} seconds")
+ except FileNotFoundError:
+ print(f"Error: File not found - {args.input}")
+ print("Please check the file path or use --input to specify correct path")
+ print("Run with --help for usage information")
+ except Exception as e:
+ print(f"Error: {e}")
diff --git a/recipe/cosyvoice_tts/run.sh b/recipe/cosyvoice_tts/run.sh
new file mode 100644
index 00000000000..27ed6206fe8
--- /dev/null
+++ b/recipe/cosyvoice_tts/run.sh
@@ -0,0 +1,170 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=$1
+stop_stage=$2
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+
+export PYTHONPATH=/workspace/CosyVoice
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+ log "stage -1: install vllm and CosyVoice"
+ # install verl
+ git clone https://github.com/volcengine/verl.git
+ cd verl
+ USE_MEGATRON=0 USE_SGLANG=0 bash scripts/install_vllm_sglang_mcore.sh
+ pip install -r requirements.txt
+ pip install --no-deps -e .
+
+ # install CosyVoice
+ git clone https://github.com/FunAudioLLM/CosyVoice.git /workspace/CosyVoice
+ pip install -r ./requirements-cosyvoice.txt
+
+ # download CosyVoice2-0.5B for token2wav
+ modelscope download --model iic/CosyVoice2-0.5B --local-dir /workspace/CosyVoice2-0.5B
+
+ # install PytritonSenseVoice
+ git clone https://github.com/yuekaizhang/PytritonSenseVoice.git /workspace/PytritonSenseVoice
+ cd /workspace/PytritonSenseVoice
+ pip install -e .
+
+ # install Pytriton
+ pip install -U nvidia-pytriton
+
+ # download custom CosyVoice2-0.5B LLM
+ huggingface-cli download --local-dir /workspace/llasa_cosyvoice2_token_qwen_0.5b yuekai/llasa_cosyvoice2_token_qwen_0.5b
+
+ # download official CosyVoice2-0.5B LLM
+ # First, obtained the huggingface compatible checkpoint. You could directly download the checkpoint from yuekai/cosyvoice2_llm
+ huggingface-cli download --local-dir ./transformers_cosyvoice2_llm yuekai/cosyvoice2_llm
+ # Or, you could use the following command to convert the pretrained model to huggingface compatible checkpoint
+ # python3 pretrained_to_huggingface.py \
+ # --pretrained-cosyvoice2-path /workspace/CosyVoice2-0.5B \
+ # --save-path ./transformers_cosyvoice2_llm
+ # If you would like to use the official CosyVoice2-0.5B LLM and do RL training, please see run_official.sh
+fi
+
+data_dir=data/parquet_aishell3
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "stage 0: prepare data into verl format"
+ # You could download the aishell3 data from https://huggingface.co/datasets/SparkAudio/voxbox/blob/main/metadata/aishell-3.jsonl
+ mkdir -p $data_dir
+ tail -n 80000 data/aishell-3.jsonl > data/train.jsonl
+ tail -n 100 data/aishell-3.jsonl > data/test.jsonl
+ python prepare_data.py \
+ --train_file data/train.jsonl \
+ --test_file data/test.jsonl \
+ --local_dir $data_dir
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "stage 1: start token2wav asr server for reward function"
+ python3 token2wav_asr_server.py --number-of-devices 8
+
+ # log "Test the reward server"
+ # python3 reward_tts.py \
+ # --input data/emilia_zh-cosy-tiny-test.jsonl \
+ # --no-interactive --debug
+
+ # async test the reward server
+ # python3 token2wav_asr_client.py
+fi
+
+sft_model_path=/workspace/rl/llasa_cosyvoice2_token_qwen_0.5b/checkpoint-885000
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "stage 2: grpo train"
+ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+ export MKL_SERVICE_FORCE_INTEL=TRUE
+ n_gpus_per_node=8
+ micro_batch_size=4
+ train_batch_size=32
+ python3 -m verl.trainer.main_ppo \
+ algorithm.adv_estimator=grpo \
+ data.train_files=$data_dir/train.parquet \
+ data.val_files=$data_dir/test.parquet \
+ data.train_batch_size=$train_batch_size \
+ data.max_prompt_length=1024 \
+ data.max_response_length=1024 \
+ data.truncation='error' \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.model.path=$sft_model_path \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.ppo_mini_batch_size=16 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_batch_size \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_batch_size \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+ actor_rollout_ref.rollout.do_sample=true \
+ actor_rollout_ref.rollout.temperature=0.8 \
+ actor_rollout_ref.rollout.top_p=0.9 \
+ actor_rollout_ref.rollout.n=4 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=true \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.9 \
+ reward_model.reward_manager=prime \
+ custom_reward_function.path=reward_tts.py \
+ custom_reward_function.name=compute_score \
+ trainer.project_name='llasa_tts_grpo' \
+ trainer.experiment_name='aishell3' \
+ trainer.logger=['console','wandb'] \
+ trainer.n_gpus_per_node=$n_gpus_per_node \
+ trainer.nnodes=1 \
+ trainer.save_freq=100 \
+ trainer.test_freq=100 \
+ trainer.resume_mode='auto' \
+ trainer.total_epochs=1 \
+ trainer.val_before_train=False
+fi
+
+step=1600
+llm_path=./checkpoints/llasa_tts_grpo/aishell3/global_step_${step}
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "stage 3: merge the model"
+ python -m verl.model_merger merge \
+ --backend fsdp \
+ --local_dir $llm_path/actor \
+ --target_dir $llm_path/merged_hf_model || exit 1
+
+fi
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "stage 4: Test the model"
+ datasets=(zero_shot_zh test_zh)
+ for dataset in ${datasets[@]}; do
+ output_dir=./outputs_rl_emilia_zh_step${step}_${dataset}
+
+ token2wav_path=/workspace/CosyVoice2-0.5B
+ model_path=$llm_path/merged_hf_model
+
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ torchrun --nproc_per_node=8 \
+ infer_dataset.py \
+ --output-dir $output_dir \
+ --llm-model-name-or-path $model_path \
+ --token2wav-path $token2wav_path \
+ --split-name ${dataset} || exit 1
+ bash scripts/compute_wer.sh $output_dir ${dataset}
+ done
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "stage 5: Infer with single case"
+ python3 infer.py \
+ --token2wav-path /workspace/CosyVoice2-0.5B \
+ --prompt-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
+ --prompt-speech-path ./assets/prompt_audio.wav \
+ --model-path $llm_path/merged_hf_model \
+ --input-text "扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。"
+fi
diff --git a/recipe/cosyvoice_tts/run_official.sh b/recipe/cosyvoice_tts/run_official.sh
new file mode 100644
index 00000000000..d7f23acff74
--- /dev/null
+++ b/recipe/cosyvoice_tts/run_official.sh
@@ -0,0 +1,120 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=$1
+stop_stage=$2
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+
+export PYTHONPATH=/workspace/CosyVoice
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "stage 0: prepare data into verl format"
+ # yuekai/llasa_cosyvoice2_token_qwen_0.5b is the emilia zh trained model, please set use_custom_template=True to use the custom template
+ # yuekai/cosyvoice2_llm is the official cosyvoice2 llm model, please set use_custom_template=False to use the official template
+ python prepare_data.py \
+ --train_file data/aishell-3-cosy.jsonl \
+ --test_file data/emilia_test.jsonl \
+ --local_dir data/parquet_aishell3
+fi
+
+sft_model_path=./transformers_cosyvoice2_llm
+exp_name=official_llm_aishell3_reward_tts_prime
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "stage 2: grpo train"
+ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+ export MKL_SERVICE_FORCE_INTEL=TRUE
+ n_gpus_per_node=8
+ micro_batch_size=4
+ train_batch_size=32
+ python3 -m verl.trainer.main_ppo \
+ algorithm.adv_estimator=grpo \
+ data.train_files=data/parquet_aishell3/train.parquet \
+ data.val_files=data/parquet_aishell3/test.parquet \
+ data.train_batch_size=$train_batch_size \
+ data.max_prompt_length=1024 \
+ data.max_response_length=512 \
+ data.truncation='error' \
+ actor_rollout_ref.model.use_remove_padding=False \
+ actor_rollout_ref.model.path=$sft_model_path \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.ppo_mini_batch_size=32 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_batch_size \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_batch_size \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+ actor_rollout_ref.rollout.do_sample=true \
+ actor_rollout_ref.rollout.temperature=0.8 \
+ actor_rollout_ref.rollout.top_p=0.95 \
+ actor_rollout_ref.rollout.top_k=25 \
+ actor_rollout_ref.rollout.n=4 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=true \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
+ actor_rollout_ref.rollout.val_kwargs.top_k=25 \
+ reward_model.reward_manager=prime \
+ custom_reward_function.path=reward_tts.py \
+ custom_reward_function.name=compute_score \
+ trainer.project_name='llasa_tts_grpo' \
+ trainer.experiment_name=$exp_name \
+ trainer.logger=['console','wandb'] \
+ trainer.n_gpus_per_node=$n_gpus_per_node \
+ trainer.nnodes=1 \
+ trainer.save_freq=100 \
+ trainer.test_freq=100 \
+ trainer.resume_mode='auto' \
+ trainer.total_epochs=1 \
+ trainer.val_before_train=False
+fi
+
+step=100
+llm_path=./checkpoints/llasa_tts_grpo/$exp_name/global_step_${step}
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "stage 3: merge the model"
+ python -m verl.model_merger merge \
+ --backend fsdp \
+ --local_dir $llm_path/actor \
+ --target_dir $llm_path/merged_hf_model || exit 1
+
+fi
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "stage 4: Test the model"
+ dataset=zero_shot_zh
+ # dataset=test_zh
+ output_dir=./outputs_${exp_name}_${step}_${dataset}
+
+ token2wav_path=/workspace/CosyVoice2-0.5B
+ model_path=$llm_path/merged_hf_model
+
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ torchrun --nproc_per_node=8 \
+ infer_dataset.py \
+ --output-dir $output_dir \
+ --llm-model-name-or-path $model_path \
+ --token2wav-path $token2wav_path \
+ --split-name ${dataset} || exit 1
+
+ bash scripts/compute_wer.sh $output_dir ${dataset}
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "stage 5: Infer with single case"
+ python3 infer.py \
+ --token2wav-path /workspace/CosyVoice2-0.5B \
+ --prompt-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
+ --prompt-speech-path ./assets/prompt_audio.wav \
+ --model-path ./transformers_cosyvoice2_llm \
+ --input-text "用四川话说<|endofprompt|>扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。"
+fi
\ No newline at end of file
diff --git a/recipe/cosyvoice_tts/scripts/compute_wer.sh b/recipe/cosyvoice_tts/scripts/compute_wer.sh
new file mode 100644
index 00000000000..55ae1a73780
--- /dev/null
+++ b/recipe/cosyvoice_tts/scripts/compute_wer.sh
@@ -0,0 +1,32 @@
+wav_dir=$1
+wav_files=$(ls $wav_dir/*.wav)
+# if wav_files is empty, then exit
+if [ -z "$wav_files" ]; then
+ exit 1
+fi
+split_name=$2
+model_path=models/sherpa-onnx-paraformer-zh-2023-09-14
+
+if [ ! -d $model_path ]; then
+ pip install sherpa-onnx
+ wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
+ tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C models
+fi
+
+python3 scripts/offline-decode-files.py \
+ --tokens=$model_path/tokens.txt \
+ --paraformer=$model_path/model.int8.onnx \
+ --num-threads=2 \
+ --decoding-method=greedy_search \
+ --debug=false \
+ --sample-rate=24000 \
+ --log-dir $wav_dir \
+ --feature-dim=80 \
+ --split-name $split_name \
+ --name sherpa_onnx \
+ $wav_files
+
+# python3 scripts/paraformer-pytriton-client.py \
+# --log-dir $wav_dir \
+# --split-name $split_name \
+# $wav_files
\ No newline at end of file
diff --git a/recipe/cosyvoice_tts/scripts/offline-decode-files.py b/recipe/cosyvoice_tts/scripts/offline-decode-files.py
new file mode 100644
index 00000000000..35fc03da1e5
--- /dev/null
+++ b/recipe/cosyvoice_tts/scripts/offline-decode-files.py
@@ -0,0 +1,753 @@
+#!/usr/bin/env python3
+#
+# Copyright (c) 2023 by manyeyes
+# Copyright (c) 2023 Xiaomi Corporation
+
+"""
+This file demonstrates how to use sherpa-onnx Python API to transcribe
+file(s) with a non-streaming model.
+
+(1) For paraformer
+
+ ./python-api-examples/offline-decode-files.py \
+ --tokens=/path/to/tokens.txt \
+ --paraformer=/path/to/paraformer.onnx \
+ --num-threads=2 \
+ --decoding-method=greedy_search \
+ --debug=false \
+ --sample-rate=16000 \
+ --feature-dim=80 \
+ /path/to/0.wav \
+ /path/to/1.wav
+
+(2) For transducer models from icefall
+
+ ./python-api-examples/offline-decode-files.py \
+ --tokens=/path/to/tokens.txt \
+ --encoder=/path/to/encoder.onnx \
+ --decoder=/path/to/decoder.onnx \
+ --joiner=/path/to/joiner.onnx \
+ --num-threads=2 \
+ --decoding-method=greedy_search \
+ --debug=false \
+ --sample-rate=16000 \
+ --feature-dim=80 \
+ /path/to/0.wav \
+ /path/to/1.wav
+
+(3) For CTC models from NeMo
+
+python3 ./python-api-examples/offline-decode-files.py \
+ --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
+ --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
+ --num-threads=2 \
+ --decoding-method=greedy_search \
+ --debug=false \
+ ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
+ ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
+ ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
+
+(4) For Whisper models
+
+python3 ./python-api-examples/offline-decode-files.py \
+ --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
+ --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
+ --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
+ --whisper-task=transcribe \
+ --num-threads=1 \
+ ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
+ ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
+ ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
+
+(5) For CTC models from WeNet
+
+python3 ./python-api-examples/offline-decode-files.py \
+ --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
+ --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
+ ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
+ ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
+ ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
+
+(6) For tdnn models of the yesno recipe from icefall
+
+python3 ./python-api-examples/offline-decode-files.py \
+ --sample-rate=8000 \
+ --feature-dim=23 \
+ --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
+ --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
+ ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
+ ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
+ ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
+
+Please refer to
+https://k2-fsa.github.io/sherpa/onnx/index.html
+to install sherpa-onnx and to download non-streaming pre-trained models
+used in this file.
+"""
+import argparse
+import time
+import wave
+from pathlib import Path
+from typing import List, Tuple, Dict, Iterable, TextIO, Union
+
+import numpy as np
+import sherpa_onnx
+import soundfile as sf
+from datasets import load_dataset
+import logging
+from collections import defaultdict
+import kaldialign
+from zhon.hanzi import punctuation
+import string
+punctuation_all = punctuation + string.punctuation
+Pathlike = Union[str, Path]
+
+def remove_punctuation(text: str) -> str:
+ for x in punctuation_all:
+ if x == '\'':
+ continue
+ text = text.replace(x, '')
+ return text
+
+def store_transcripts(
+ filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
+) -> None:
+ """Save predicted results and reference transcripts to a file.
+
+ Args:
+ filename:
+ File to save the results to.
+ texts:
+ An iterable of tuples. The first element is the cur_id, the second is
+ the reference transcript and the third element is the predicted result.
+ If it is a multi-talker ASR system, the ref and hyp may also be lists of
+ strings.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf8") as f:
+ for cut_id, ref, hyp in texts:
+ if char_level:
+ ref = list("".join(ref))
+ hyp = list("".join(hyp))
+ print(f"{cut_id}:\tref={ref}", file=f)
+ print(f"{cut_id}:\thyp={hyp}", file=f)
+
+
+def write_error_stats(
+ f: TextIO,
+ test_set_name: str,
+ results: List[Tuple[str, str]],
+ enable_log: bool = True,
+ compute_CER: bool = False,
+ sclite_mode: bool = False,
+) -> float:
+ """Write statistics based on predicted results and reference transcripts.
+
+ It will write the following to the given file:
+
+ - WER
+ - number of insertions, deletions, substitutions, corrects and total
+ reference words. For example::
+
+ Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
+ reference words (2337 correct)
+
+ - The difference between the reference transcript and predicted result.
+ An instance is given below::
+
+ THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
+
+ The above example shows that the reference word is `EDISON`,
+ but it is predicted to `ADDISON` (a substitution error).
+
+ Another example is::
+
+ FOR THE FIRST DAY (SIR->*) I THINK
+
+ The reference word `SIR` is missing in the predicted
+ results (a deletion error).
+ results:
+ An iterable of tuples. The first element is the cut_id, the second is
+ the reference transcript and the third element is the predicted result.
+ enable_log:
+ If True, also print detailed WER to the console.
+ Otherwise, it is written only to the given file.
+ Returns:
+ Return None.
+ """
+ subs: Dict[Tuple[str, str], int] = defaultdict(int)
+ ins: Dict[str, int] = defaultdict(int)
+ dels: Dict[str, int] = defaultdict(int)
+
+ # `words` stores counts per word, as follows:
+ # corr, ref_sub, hyp_sub, ins, dels
+ words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
+ num_corr = 0
+ ERR = "*"
+
+ if compute_CER:
+ for i, res in enumerate(results):
+ cut_id, ref, hyp = res
+ ref = list("".join(ref))
+ hyp = list("".join(hyp))
+ results[i] = (cut_id, ref, hyp)
+
+ for cut_id, ref, hyp in results:
+ ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
+ for ref_word, hyp_word in ali:
+ if ref_word == ERR:
+ ins[hyp_word] += 1
+ words[hyp_word][3] += 1
+ elif hyp_word == ERR:
+ dels[ref_word] += 1
+ words[ref_word][4] += 1
+ elif hyp_word != ref_word:
+ subs[(ref_word, hyp_word)] += 1
+ words[ref_word][1] += 1
+ words[hyp_word][2] += 1
+ else:
+ words[ref_word][0] += 1
+ num_corr += 1
+ ref_len = sum([len(r) for _, r, _ in results])
+ sub_errs = sum(subs.values())
+ ins_errs = sum(ins.values())
+ del_errs = sum(dels.values())
+ tot_errs = sub_errs + ins_errs + del_errs
+ tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
+
+ if enable_log:
+ logging.info(
+ f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
+ f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
+ f"{del_errs} del, {sub_errs} sub ]"
+ )
+
+ print(f"%WER = {tot_err_rate}", file=f)
+ print(
+ f"Errors: {ins_errs} insertions, {del_errs} deletions, "
+ f"{sub_errs} substitutions, over {ref_len} reference "
+ f"words ({num_corr} correct)",
+ file=f,
+ )
+ print(
+ "Search below for sections starting with PER-UTT DETAILS:, "
+ "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
+ file=f,
+ )
+
+ print("", file=f)
+ print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
+ for cut_id, ref, hyp in results:
+ ali = kaldialign.align(ref, hyp, ERR)
+ combine_successive_errors = True
+ if combine_successive_errors:
+ ali = [[[x], [y]] for x, y in ali]
+ for i in range(len(ali) - 1):
+ if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
+ ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
+ ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
+ ali[i] = [[], []]
+ ali = [
+ [
+ list(filter(lambda a: a != ERR, x)),
+ list(filter(lambda a: a != ERR, y)),
+ ]
+ for x, y in ali
+ ]
+ ali = list(filter(lambda x: x != [[], []], ali))
+ ali = [
+ [
+ ERR if x == [] else " ".join(x),
+ ERR if y == [] else " ".join(y),
+ ]
+ for x, y in ali
+ ]
+
+ print(
+ f"{cut_id}:\t"
+ + " ".join(
+ (
+ ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
+ for ref_word, hyp_word in ali
+ )
+ ),
+ file=f,
+ )
+
+ print("", file=f)
+ print("SUBSTITUTIONS: count ref -> hyp", file=f)
+
+ for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
+ print(f"{count} {ref} -> {hyp}", file=f)
+
+ print("", file=f)
+ print("DELETIONS: count ref", file=f)
+ for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
+ print(f"{count} {ref}", file=f)
+
+ print("", file=f)
+ print("INSERTIONS: count hyp", file=f)
+ for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
+ print(f"{count} {hyp}", file=f)
+
+ print("", file=f)
+ print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
+ for _, word, counts in sorted(
+ [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
+ ):
+ (corr, ref_sub, hyp_sub, ins, dels) = counts
+ tot_errs = ref_sub + hyp_sub + ins + dels
+ ref_count = corr + ref_sub + dels
+ hyp_count = corr + hyp_sub + ins
+
+ print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
+ return float(tot_err_rate)
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ help="Path to tokens.txt",
+ )
+
+ parser.add_argument(
+ "--hotwords-file",
+ type=str,
+ default="",
+ help="""
+ The file containing hotwords, one words/phrases per line, like
+ HELLO WORLD
+ 你好世界
+ """,
+ )
+
+ parser.add_argument(
+ "--hotwords-score",
+ type=float,
+ default=1.5,
+ help="""
+ The hotword score of each token for biasing word/phrase. Used only if
+ --hotwords-file is given.
+ """,
+ )
+
+ parser.add_argument(
+ "--modeling-unit",
+ type=str,
+ default="",
+ help="""
+ The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
+ Used only when hotwords-file is given.
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-vocab",
+ type=str,
+ default="",
+ help="""
+ The path to the bpe vocabulary, the bpe vocabulary is generated by
+ sentencepiece, you can also export the bpe vocabulary through a bpe model
+ by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
+ and modeling-unit is bpe or cjkchar+bpe.
+ """,
+ )
+
+ parser.add_argument(
+ "--encoder",
+ default="",
+ type=str,
+ help="Path to the encoder model",
+ )
+
+ parser.add_argument(
+ "--decoder",
+ default="",
+ type=str,
+ help="Path to the decoder model",
+ )
+
+ parser.add_argument(
+ "--joiner",
+ default="",
+ type=str,
+ help="Path to the joiner model",
+ )
+
+ parser.add_argument(
+ "--paraformer",
+ default="",
+ type=str,
+ help="Path to the model.onnx from Paraformer",
+ )
+
+ parser.add_argument(
+ "--nemo-ctc",
+ default="",
+ type=str,
+ help="Path to the model.onnx from NeMo CTC",
+ )
+
+ parser.add_argument(
+ "--wenet-ctc",
+ default="",
+ type=str,
+ help="Path to the model.onnx from WeNet CTC",
+ )
+
+ parser.add_argument(
+ "--tdnn-model",
+ default="",
+ type=str,
+ help="Path to the model.onnx for the tdnn model of the yesno recipe",
+ )
+
+ parser.add_argument(
+ "--num-threads",
+ type=int,
+ default=1,
+ help="Number of threads for neural network computation",
+ )
+
+ parser.add_argument(
+ "--whisper-encoder",
+ default="",
+ type=str,
+ help="Path to whisper encoder model",
+ )
+
+ parser.add_argument(
+ "--whisper-decoder",
+ default="",
+ type=str,
+ help="Path to whisper decoder model",
+ )
+
+ parser.add_argument(
+ "--whisper-language",
+ default="",
+ type=str,
+ help="""It specifies the spoken language in the input audio file.
+ Example values: en, fr, de, zh, jp.
+ Available languages for multilingual models can be found at
+ https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
+ If not specified, we infer the language from the input audio file.
+ """,
+ )
+
+ parser.add_argument(
+ "--whisper-task",
+ default="transcribe",
+ choices=["transcribe", "translate"],
+ type=str,
+ help="""For multilingual models, if you specify translate, the output
+ will be in English.
+ """,
+ )
+
+ parser.add_argument(
+ "--whisper-tail-paddings",
+ default=-1,
+ type=int,
+ help="""Number of tail padding frames.
+ We have removed the 30-second constraint from whisper, so you need to
+ choose the amount of tail padding frames by yourself.
+ Use -1 to use a default value for tail padding.
+ """,
+ )
+
+ parser.add_argument(
+ "--blank-penalty",
+ type=float,
+ default=0.0,
+ help="""
+ The penalty applied on blank symbol during decoding.
+ Note: It is a positive value that would be applied to logits like
+ this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
+ [batch_size, vocab] and blank id is 0).
+ """,
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="Valid values are greedy_search and modified_beam_search",
+ )
+ parser.add_argument(
+ "--debug",
+ type=bool,
+ default=False,
+ help="True to show debug messages",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="""Sample rate of the feature extractor. Must match the one
+ expected by the model. Note: The input sound files can have a
+ different sample rate from this argument.""",
+ )
+
+ parser.add_argument(
+ "--feature-dim",
+ type=int,
+ default=80,
+ help="Feature dimension. Must match the one expected by the model",
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to decode. Each file must be of WAVE"
+ "format with a single channel, and each sample has 16-bit, "
+ "i.e., int16_t. "
+ "The sample rate of the file can be arbitrary and does not need to "
+ "be 16 kHz",
+ )
+
+ parser.add_argument(
+ "--name",
+ type=str,
+ default="",
+ help="The directory containing the input sound files to decode",
+ )
+
+ parser.add_argument(
+ "--log-dir",
+ type=str,
+ default="",
+ help="The directory containing the input sound files to decode",
+ )
+
+ parser.add_argument(
+ "--label",
+ type=str,
+ default=None,
+ help="wav_base_name label",
+ )
+
+ # Dataset related arguments for loading labels when label file is not provided
+ parser.add_argument(
+ "--dataset-name",
+ type=str,
+ default="yuekai/seed_tts_cosy2",
+ help="Huggingface dataset name for loading labels",
+ )
+
+ parser.add_argument(
+ "--split-name",
+ type=str,
+ default="wenetspeech4tts",
+ help="Dataset split name for loading labels",
+ )
+
+ return parser.parse_args()
+
+
+def assert_file_exists(filename: str):
+ assert Path(filename).is_file(), (
+ f"{filename} does not exist!\n"
+ "Please refer to "
+ "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
+ )
+
+
+def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
+ """
+ Args:
+ wave_filename:
+ Path to a wave file. It should be single channel and can be of type
+ 32-bit floating point PCM. Its sample rate does not need to be 24kHz.
+
+ Returns:
+ Return a tuple containing:
+ - A 1-D array of dtype np.float32 containing the samples,
+ which are normalized to the range [-1, 1].
+ - Sample rate of the wave file.
+ """
+
+ samples, sample_rate = sf.read(wave_filename, dtype="float32")
+ assert (
+ samples.ndim == 1
+ ), f"Expected single channel, but got {samples.ndim} channels."
+
+ samples_float32 = samples.astype(np.float32)
+
+ return samples_float32, sample_rate
+
+
+def normalize_text_alimeeting(text: str) -> str:
+ """
+ Text normalization similar to M2MeT challenge baseline.
+ See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
+ """
+ import re
+ text = text.replace('\u00A0', '') # test_hard
+ text = text.replace(" ", "")
+ text = text.replace("", "")
+ text = text.replace("<%>", "")
+ text = text.replace("<->", "")
+ text = text.replace("<$>", "")
+ text = text.replace("<#>", "")
+ text = text.replace("<_>", "")
+ text = text.replace("", "")
+ text = text.replace("`", "")
+ text = text.replace("&", "")
+ text = text.replace(",", "")
+ if re.search("[a-zA-Z]", text):
+ text = text.upper()
+ text = text.replace("A", "A")
+ text = text.replace("a", "A")
+ text = text.replace("b", "B")
+ text = text.replace("c", "C")
+ text = text.replace("k", "K")
+ text = text.replace("t", "T")
+ text = text.replace(",", "")
+ text = text.replace("丶", "")
+ text = text.replace("。", "")
+ text = text.replace("、", "")
+ text = text.replace("?", "")
+ text = remove_punctuation(text)
+ return text
+
+
+def main():
+ args = get_args()
+ assert_file_exists(args.tokens)
+ assert args.num_threads > 0, args.num_threads
+
+ assert len(args.nemo_ctc) == 0, args.nemo_ctc
+ assert len(args.wenet_ctc) == 0, args.wenet_ctc
+ assert len(args.whisper_encoder) == 0, args.whisper_encoder
+ assert len(args.whisper_decoder) == 0, args.whisper_decoder
+ assert len(args.tdnn_model) == 0, args.tdnn_model
+
+ assert_file_exists(args.paraformer)
+
+ recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
+ paraformer=args.paraformer,
+ tokens=args.tokens,
+ num_threads=args.num_threads,
+ sample_rate=args.sample_rate,
+ feature_dim=args.feature_dim,
+ decoding_method=args.decoding_method,
+ debug=args.debug,
+ )
+
+ print("Started!")
+ start_time = time.time()
+
+ streams, results = [], []
+ total_duration = 0
+
+ for i, wave_filename in enumerate(args.sound_files):
+ assert_file_exists(wave_filename)
+ samples, sample_rate = read_wave(wave_filename)
+ duration = len(samples) / sample_rate
+ total_duration += duration
+ s = recognizer.create_stream()
+ s.accept_waveform(sample_rate, samples)
+
+ streams.append(s)
+ if i % 10 == 0:
+ recognizer.decode_streams(streams)
+ results += [s.result.text for s in streams]
+ streams = []
+ print(f"Processed {i} files")
+ # process the last batch
+ if streams:
+ recognizer.decode_streams(streams)
+ results += [s.result.text for s in streams]
+ end_time = time.time()
+ print("Done!")
+
+ results_dict = {}
+ for wave_filename, result in zip(args.sound_files, results):
+ print(f"{wave_filename}\n{result}")
+ print("-" * 10)
+ wave_basename = Path(wave_filename).stem
+ results_dict[wave_basename] = result
+
+ elapsed_seconds = end_time - start_time
+ rtf = elapsed_seconds / total_duration
+ print(f"num_threads: {args.num_threads}")
+ print(f"decoding_method: {args.decoding_method}")
+ print(f"Wave duration: {total_duration:.3f} s")
+ print(f"Elapsed time: {elapsed_seconds:.3f} s")
+ print(
+ f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
+ )
+
+ # Load labels either from file or from dataset
+ labels_dict = {}
+
+ if args.label:
+ # Load labels from file (original functionality)
+ print(f"Loading labels from file: {args.label}")
+ with open(args.label, "r") as f:
+ for line in f:
+ # fields = line.strip().split(" ")
+ # fields = [item for item in fields if item]
+ # assert len(fields) == 4
+ # prompt_text, prompt_audio, text, audio_path = fields
+
+ fields = line.strip().split("|")
+ fields = [item for item in fields if item]
+ assert len(fields) == 4
+ audio_path, prompt_text, prompt_audio, text = fields
+ labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
+ else:
+ # Load labels from dataset (new functionality)
+ print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
+ if 'zero' in args.split_name:
+ dataset_name = "yuekai/CV3-Eval"
+ else:
+ dataset_name = "yuekai/seed_tts_cosy2"
+ dataset = load_dataset(
+ dataset_name,
+ split=args.split_name,
+ trust_remote_code=True,
+ )
+
+ for item in dataset:
+ audio_id = item["id"]
+ labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
+
+ print(f"Loaded {len(labels_dict)} labels from dataset")
+
+ # Perform evaluation if labels are available
+ if labels_dict:
+
+ final_results = []
+ for key, value in results_dict.items():
+ if key in labels_dict:
+ final_results.append((key, labels_dict[key], value))
+ else:
+ print(f"Warning: No label found for {key}, skipping...")
+
+ if final_results:
+ store_transcripts(
+ filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
+ )
+ with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
+ write_error_stats(f, "test-set", final_results, enable_log=True)
+
+ with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
+ print(f.readline()) # WER
+ print(f.readline()) # Detailed errors
+ else:
+ print("No matching labels found for evaluation")
+ else:
+ print("No labels available for evaluation")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/recipe/cosyvoice_tts/token2wav_asr_client.py b/recipe/cosyvoice_tts/token2wav_asr_client.py
new file mode 100644
index 00000000000..9a8033746de
--- /dev/null
+++ b/recipe/cosyvoice_tts/token2wav_asr_client.py
@@ -0,0 +1,164 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import requests
+import soundfile as sf
+import json
+import numpy as np
+import argparse
+import time
+import asyncio
+import aiohttp
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--server-url",
+ type=str,
+ default="localhost:8000",
+ help="Address of the server",
+ )
+
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="token2wav_asr",
+ choices=[
+ "token2wav_asr"
+ ],
+ help="triton model_repo module name to request",
+ )
+
+ parser.add_argument(
+ "--concurrent-job",
+ type=int,
+ default=10,
+ help="Number of concurrent requests to send in parallel",
+ )
+
+ parser.add_argument(
+ "--data-path",
+ type=str,
+ default="./data/emilia_zh-cosy-tiny-test.jsonl",
+ help="Path to the data file",
+ )
+ return parser.parse_args()
+
+def prepare_request(tokens, token_lens, gt_text):
+ """Construct HTTP/JSON inference request body."""
+
+ data = {
+ "inputs": [
+ {
+ "name": "TOKENS",
+ "shape": list(tokens.shape),
+ "datatype": "INT32",
+ "data": tokens.tolist(),
+ },
+ {
+ "name": "TOKEN_LENS",
+ "shape": list(token_lens.shape),
+ "datatype": "INT32",
+ "data": token_lens.tolist(),
+ },
+ {
+ "name": "GT_TEXT",
+ "shape": [1, 1],
+ "datatype": "BYTES",
+ "data": [gt_text],
+ },
+ ]
+ }
+
+ return data
+
+def load_jsonl(file_path: str):
+ """Load data from jsonl file."""
+ data = []
+ with open(file_path, 'r', encoding='utf-8') as f:
+ for line in f:
+ data.append(json.loads(line.strip()))
+ return data
+
+
+async def process_sample(idx, total, sample, session, url, semaphore):
+ """Send a single request to the inference server and log the response."""
+ async with semaphore:
+ # Prepare request body
+ code_list = sample["code"]
+ tokens = np.array(code_list, dtype=np.int32).reshape(1, -1)
+ token_lens = np.array([[len(tokens[0])]], dtype=np.int32)
+ gt_text = sample["text"]
+ data = prepare_request(tokens, token_lens, gt_text)
+
+ # Send HTTP POST
+ async with session.post(
+ url,
+ headers={"Content-Type": "application/json"},
+ json=data,
+ params={"request_id": "0"},
+ ) as rsp:
+ result = await rsp.json()
+
+ # Parse outputs (order: REWARDS, TRANSCRIPTS)
+ rewards = None
+ transcripts = None
+ for out in result.get("outputs", []):
+ if out["name"] == "REWARDS":
+ rewards = out["data"][0]
+ elif out["name"] == "TRANSCRIPTS":
+ transcripts = out["data"][0]
+
+ # Output summary (prints may interleave across tasks)
+ print(f"\n--- Sample {idx}/{total} ---")
+ print(f"GT Text: {gt_text}")
+ print(f"Tokens shape: {tokens.shape}, Token_lens shape: {token_lens.shape}")
+ print(f"Transcript: {transcripts}")
+ print(f"Reward: {rewards}")
+
+
+async def main_async():
+ args = get_args()
+
+ server_url = args.server_url
+ if not server_url.startswith(("http://", "https://")):
+ server_url = f"http://{server_url}"
+
+ url = f"{server_url}/v2/models/{args.model_name}/infer"
+
+ # Load dataset
+ data_list = load_jsonl(args.data_path)
+
+ # Concurrency primitives
+ semaphore = asyncio.Semaphore(max(1, args.concurrent_job))
+ connector = aiohttp.TCPConnector(ssl=False)
+
+ start_time = time.time()
+ async with aiohttp.ClientSession(connector=connector) as session:
+ tasks = [
+ asyncio.create_task(
+ process_sample(i + 1, len(data_list), sample, session, url, semaphore)
+ )
+ for i, sample in enumerate(data_list)
+ ]
+ await asyncio.gather(*tasks)
+ end_time = time.time()
+ print(f"Time taken: {end_time - start_time} seconds")
+
+
+if __name__ == "__main__":
+ asyncio.run(main_async())
\ No newline at end of file
diff --git a/recipe/cosyvoice_tts/token2wav_asr_server.py b/recipe/cosyvoice_tts/token2wav_asr_server.py
new file mode 100644
index 00000000000..1273c188ce0
--- /dev/null
+++ b/recipe/cosyvoice_tts/token2wav_asr_server.py
@@ -0,0 +1,348 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Pytriton server for token2wav conversion and ASR"""
+
+import argparse
+import io
+import logging
+from typing import Any, List
+import numpy as np
+import torch
+from scipy.signal import resample
+import sys
+import random
+import re
+from jiwer import wer
+from pypinyin import lazy_pinyin, Style
+from tn.chinese.normalizer import Normalizer as ZhNormalizer
+
+# Chinese text normalizer (cached globally)
+zh_tn_model = ZhNormalizer(
+ cache_dir="./cache",
+ remove_erhua=False,
+ remove_interjections=False,
+ remove_puncts=True,
+ overwrite_cache=True,
+)
+
+from pytriton.decorators import batch
+from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
+from pytriton.triton import Triton, TritonConfig
+from pytriton.proxy.types import Request
+
+from omnisense.models import OmniSenseVoiceSmall
+from cosyvoice.cli.cosyvoice import CosyVoice2
+
+from datasets import load_dataset
+
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+
+logger = logging.getLogger("token2wav_asr_server")
+
+
+class _ASR_Server:
+ """Wraps a single OmniSenseVoiceSmall model instance for Triton."""
+
+ def __init__(self, device_id: int):
+ self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
+
+ @batch
+ def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray):
+ """
+ WAV: np.ndarray, WAV_LENS: np.ndarray
+ LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
+ See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
+ """
+ logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape)
+ wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))]
+
+ results = self._model.transcribe_single_batch(
+ wavs,
+ language="zh",
+ textnorm="woitn",
+ )
+ texts = [result.text for result in results]
+ transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
+ return {"TRANSCRIPTS": transcripts}
+
+
+
+def audio_decode_cosyvoice2(
+ audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
+):
+ """
+ Generate audio from tokens with optional tone and prompt embedding.
+ """
+ model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
+ "empty", prompt_text, prompt_speech_16k, 24000
+ )
+ tts_mel, _ = codec_decoder.model.flow.inference(
+ token=audio_tokens.to(codec_decoder.model.device),
+ token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
+ codec_decoder.model.device
+ ),
+ prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
+ codec_decoder.model.device
+ ),
+ prompt_token_len=torch.tensor(
+ [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
+ ).to(codec_decoder.model.device),
+ prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
+ codec_decoder.model.device
+ ),
+ prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
+ codec_decoder.model.device
+ ),
+ embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
+ finalize=True,
+ )
+
+ audio_hat, _ = codec_decoder.model.hift.inference(
+ speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+ )
+
+ return audio_hat
+
+
+def get_random_prompt_from_dataset(dataset):
+ """
+ Get random prompt text and speech from the pre-loaded dataset.
+ Returns (prompt_text, prompt_speech_16k)
+ """
+ random_idx = random.randint(0, len(dataset) - 1)
+ sample = dataset[random_idx]
+
+ # Extract audio data
+ audio_data = sample["audio"]
+ audio_array = audio_data["array"]
+ sample_rate = audio_data["sampling_rate"]
+
+ # Convert audio to 16kHz if needed
+ if sample_rate != 16000:
+ num_samples = int(len(audio_array) * (16000 / sample_rate))
+ audio_array = resample(audio_array, num_samples)
+
+ # Convert to torch tensor
+ prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0)
+ prompt_text = sample["text"]
+ # remove space in prompt_text
+ prompt_text = prompt_text.replace(" ", "")
+ return prompt_text, prompt_speech_16k
+
+class _Token2Wav_ASR:
+ """Wraps a single OmniSenseVoiceSmall model instance for Triton."""
+
+ def __init__(self, device_id: int):
+ self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
+ self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"]
+
+ # Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model
+ # CosyVoice2 internally uses generic "cuda" device, so we first switch the
+ # current CUDA context to the desired card before the object is created.
+ # Afterwards, all parameters loaded with the generic "cuda" device will
+ # reside on this GPU. We keep the selected id in `self.device_id` and
+ # will set the context again for every forward call to avoid race
+ # conditions when several instances are used in the same process.
+
+ self.device_id = device_id
+
+ # Construct the TTS codec decoder under the correct CUDA device context
+ with torch.cuda.device(self.device_id):
+ self.codec_decoder = CosyVoice2(
+ "/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
+ )
+ @batch
+ def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
+ """
+ WAV: np.ndarray, WAV_LENS: np.ndarray
+ LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
+ See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
+ """
+ # Ensure the default CUDA device is set correctly for this invocation
+ torch.cuda.set_device(self.device_id)
+
+ if self.device_id == 0:
+ print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}")
+
+ tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))]
+
+ # Decode ground-truth text strings (BYTES → str)
+ if GT_TEXT.ndim == 2:
+ gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))]
+ else:
+ gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))]
+
+ wavs = []
+ for tokens in tokens_list:
+ prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset)
+ audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0)
+ audio_hat = audio_decode_cosyvoice2(
+ audio_tokens,
+ prompt_text,
+ prompt_speech_16k,
+ self.codec_decoder,
+ )
+ # resample to 16000 using soundfile
+ audio_hat = audio_hat.squeeze(0).float().cpu()
+ audio_hat = audio_hat.numpy()
+ num_samples = int(len(audio_hat) * (16000 / 24000))
+ audio_hat = resample(audio_hat, num_samples)
+ wavs.append(audio_hat)
+
+ results = self.asr_model.transcribe_single_batch(
+ wavs,
+ language="zh",
+ textnorm="woitn",
+ )
+ texts = [result.text for result in results]
+
+ # ---------------- Reward computation ----------------
+ rewards = []
+ for gt_text, hyp_text in zip(gt_texts, texts):
+ gt_norm = zh_tn_model.normalize(gt_text).lower()
+ hyp_norm = zh_tn_model.normalize(hyp_text).lower()
+
+ gt_pinyin = lazy_pinyin(
+ gt_norm,
+ style=Style.TONE3,
+ tone_sandhi=True,
+ neutral_tone_with_five=True,
+ )
+ hyp_pinyin = lazy_pinyin(
+ hyp_norm,
+ style=Style.TONE3,
+ tone_sandhi=True,
+ neutral_tone_with_five=True,
+ )
+
+ c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin)))
+ reward_val = 1.0 - np.tanh(3.0 * c)
+ reward_val = max(0.0, min(1.0, reward_val))
+ rewards.append(reward_val)
+ print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}")
+
+ transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
+ rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
+
+
+ return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
+
+
+def _infer_function_factory(device_ids: List[int], model_name: str):
+ """Creates a list of inference functions, one for each requested device ID."""
+ infer_funcs = []
+ for device_id in device_ids:
+ if model_name == "sensevoice":
+ infer_funcs.append(_ASR_Server(device_id=device_id))
+ else:
+ infer_funcs.append(_Token2Wav_ASR(device_id=device_id))
+ return infer_funcs
+
+
+def main():
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--max-batch-size",
+ type=int,
+ default=32,
+ help="Batch size of request.",
+ required=False,
+ )
+ parser.add_argument(
+ "--verbose",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--number-of-instances-per-device",
+ type=int,
+ default=1,
+ help="Number of model instances to load.",
+ required=False,
+ )
+ parser.add_argument(
+ "--number-of-devices",
+ type=int,
+ default=8,
+ help="Number of devices to use.",
+ )
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="token2wav_asr",
+ choices=["token2wav_asr", "sensevoice"],
+ help="Model name.",
+ )
+
+ args = parser.parse_args()
+
+ log_level = logging.DEBUG if args.verbose else logging.INFO
+ logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
+
+ triton_config = TritonConfig(
+ http_port=8000,
+ grpc_port=8001,
+ metrics_port=8002,
+ )
+
+ device_ids = [i for i in range(args.number_of_devices)]
+ device_ids = device_ids * args.number_of_instances_per_device
+
+ with Triton(config=triton_config) as triton:
+ logger.info("Loading SenseVoice model on device ids: %s", device_ids)
+ if args.model_name == "sensevoice":
+ triton.bind(
+ model_name="sensevoice",
+ infer_func=_infer_function_factory(device_ids, args.model_name),
+ inputs=[
+ Tensor(name="WAV", dtype=np.float32, shape=(-1,)),
+ Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)),
+ Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)),
+ Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)),
+ ],
+ outputs=[
+ Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
+ ],
+ config=ModelConfig(
+ max_batch_size=args.max_batch_size,
+ batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
+ ),
+ strict=True,
+ )
+ else:
+ triton.bind(
+ model_name="token2wav_asr",
+ infer_func=_infer_function_factory(device_ids, args.model_name),
+ inputs=[
+ Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)),
+ Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)),
+ Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)),
+ ],
+ outputs=[
+ Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)),
+ Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
+ ],
+ config=ModelConfig(
+ max_batch_size=args.max_batch_size,
+ batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
+ ),
+ strict=True,
+ )
+ logger.info("Serving inference")
+ triton.serve()
+
+
+if __name__ == "__main__":
+ main()