diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c7a30f9..67d7a1f 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -26,7 +26,7 @@ repos:
rev: 6.0.0
hooks:
- id: flake8
- args: ['--max-line-length=100', '--extend-ignore=E203,W503,B008,C416,EXE001,E741']
+ args: ['--max-line-length=100', '--extend-ignore=E203,W503,B008,C416,EXE001,E741,E731']
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
diff --git a/README.md b/README.md
index d44552f..94d9fce 100644
--- a/README.md
+++ b/README.md
@@ -96,8 +96,29 @@ feature, feature_len = model.encode(
print("feature: ", feature.shape)
print("feature_len: ", feature_len)
```
+### Python API
+#### Classification
+
+ChunkFormer also supports speech classification tasks (e.g., gender, dialect, emotion, age recognition).
-### Python API Transcription
+```python
+from chunkformer import ChunkFormerModel
+
+# Load a pre-trained classification model from Hugging Face or local directory
+model = ChunkFormerModel.from_pretrained("path/to/classification/model")
+
+# Single audio classification
+result = model.classify_audio(
+ audio_path="path/to/audio.wav",
+ chunk_size=-1, # -1 for full attention
+ left_context_size=-1,
+ right_context_size=-1,
+)
+
+print(result)
+```
+
+#### Transcription
```python
from chunkformer import ChunkFormerModel
@@ -130,13 +151,13 @@ for i, transcription in enumerate(transcriptions):
```
-### Command Line Transcription
-#### Long-Form Audio Testing
+### Command Line
+#### Long-Form Audio Transcription
To test the model with a single [long-form audio file](samples/audios/audio_1.wav). Audio file extensions ".mp3", ".wav", ".flac", ".m4a", ".aac" are accepted:
```bash
chunkformer-decode \
--model_checkpoint path/to/hf/checkpoint/repo \
- --long_form_audio path/to/audio.wav \
+ --audio_file path/to/audio.wav \
--total_batch_duration 14400 \
--chunk_size 64 \
--left_context_size 128 \
@@ -148,7 +169,7 @@ Example Output:
[00:00:02.500] - [00:00:03.700]: testing the long-form audio
```
-#### Batch Transcription Testing
+#### Batch Audio Transcription
The [data.tsv](samples/data.tsv) file must have at least one column named **wav**. Optionally, a column named **txt** can be included to compute the **Word Error Rate (WER)**. Output will be saved to the same file.
```bash
@@ -165,6 +186,14 @@ Example Output:
WER: 0.1234
```
+#### Classification
+To classify a single audio file:
+```bash
+chunkformer-decode \
+ --model_checkpoint path/to/classification/model \
+ --audio_file path/to/audio.wav
+```
+
---
diff --git a/chunkformer/bin/classify.py b/chunkformer/bin/classify.py
new file mode 100755
index 0000000..71c30bf
--- /dev/null
+++ b/chunkformer/bin/classify.py
@@ -0,0 +1,185 @@
+#!/usr/bin/env python3
+# Copyright (c) 2024 ChunkFormer Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Inference script for speech classification tasks."""
+
+import argparse
+import json
+import logging
+import os
+from contextlib import nullcontext
+
+import torch
+import yaml
+from tqdm import tqdm
+
+from chunkformer.dataset.dataset import Dataset
+from chunkformer.utils.checkpoint import load_checkpoint
+from chunkformer.utils.init_model import init_speech_model
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="Classification inference")
+ parser.add_argument("--gpu", type=int, default=0, help="GPU id, -1 for CPU")
+ parser.add_argument("--config", required=True, help="Config file")
+ parser.add_argument("--data_type", default="raw", choices=["raw", "shard"], help="Data type")
+ parser.add_argument("--test_data", required=True, help="Test data list file")
+ parser.add_argument("--checkpoint", required=True, help="Model checkpoint")
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
+ parser.add_argument("--result_dir", required=True, help="Result directory")
+ parser.add_argument("--chunk_size", type=int, default=-1, help="Chunk size for encoder")
+ parser.add_argument("--left_context_size", type=int, default=-1, help="Left context size")
+ parser.add_argument("--right_context_size", type=int, default=-1, help="Right context size")
+ parser.add_argument(
+ "--dtype", default="fp32", choices=["fp16", "fp32"], help="Data type for inference"
+ )
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
+
+ # Setup device
+ if args.gpu < 0:
+ device = torch.device("cpu")
+ else:
+ device = torch.device(f"cuda:{args.gpu}")
+
+ # Load config
+ with open(args.config, "r") as f:
+ configs = yaml.load(f, Loader=yaml.FullLoader)
+
+ # Get tasks
+ tasks = list(configs.get("model_conf", {}).get("tasks", {}).keys())
+ if not tasks:
+ logging.error("No tasks defined in config")
+ return
+
+ logging.info(f"Tasks: {tasks}")
+
+ # Initialize model
+ model, _ = init_speech_model(args, configs)
+ load_checkpoint(model, args.checkpoint)
+ model = model.to(device)
+ model.eval()
+
+ logging.info(f"Model loaded from {args.checkpoint}")
+
+ # Setup dataset
+ dataset_conf = configs.get("dataset_conf", {})
+ dataset_conf["shuffle"] = False
+ dataset_conf["sort"] = False
+ dataset_conf["batch_size"] = args.batch_size
+ dataset_conf["batch_type"] = "static"
+
+ test_dataset = Dataset(
+ args.data_type,
+ args.test_data,
+ tokenizer=None,
+ conf=dataset_conf,
+ partition=False,
+ )
+
+ # Dataset already handles batching internally via padding function
+ test_data_loader = torch.utils.data.DataLoader(
+ test_dataset,
+ batch_size=None,
+ num_workers=0,
+ )
+
+ # Create output directory
+ os.makedirs(args.result_dir, exist_ok=True)
+
+ # Output files
+ pred_file = os.path.join(args.result_dir, "predictions.tsv")
+ detail_file = os.path.join(args.result_dir, "predictions_detail.jsonl")
+
+ # Setup dtype
+ if args.dtype == "fp16":
+ dtype = torch.float16
+ autocast_context = torch.cuda.amp.autocast
+ else:
+ dtype = torch.float32
+ autocast_context = nullcontext
+
+ # Run inference
+ logging.info("Starting inference...")
+ all_predictions = []
+
+ with torch.no_grad(), autocast_context():
+ for batch_idx, batch in enumerate(tqdm(test_data_loader)):
+ # Get keys from batch
+ keys = batch.get(
+ "keys", [f"utt_{batch_idx}_{i}" for i in range(batch["feats"].size(0))]
+ )
+
+ # Move to device
+ feats = batch["feats"].to(device, dtype=dtype)
+ feats_lengths = batch["feats_lengths"].to(device)
+
+ # Forward pass
+ results = model.classify(
+ feats,
+ feats_lengths,
+ chunk_size=args.chunk_size,
+ left_context_size=args.left_context_size,
+ right_context_size=args.right_context_size,
+ )
+
+ # Process results
+ batch_size = feats.size(0)
+ for i in range(batch_size):
+ key = keys[i] if i < len(keys) else f"utt_{batch_idx}_{i}"
+
+ pred_dict = {"key": key}
+
+ for task in tasks:
+ pred_key = f"{task}_prediction"
+ prob_key = f"{task}_probability"
+
+ prediction = results[pred_key][i].item()
+ pred_dict[task] = prediction
+
+ probability = results[prob_key][i].item()
+ pred_dict[f"{task}_prob"] = probability
+
+ all_predictions.append(pred_dict)
+
+ # Save predictions in TSV format
+ logging.info(f"Saving predictions to {pred_file}")
+ with open(pred_file, "w", encoding="utf-8") as f:
+ # Write header
+ header = ["key"] + tasks
+ f.write("\t".join(header) + "\n")
+
+ # Write predictions
+ for pred in all_predictions:
+ row = [pred["key"]] + [str(pred.get(task, "-1")) for task in tasks]
+ f.write("\t".join(row) + "\n")
+
+ # Save detailed predictions in JSONL format
+ logging.info(f"Saving detailed predictions to {detail_file}")
+ with open(detail_file, "w", encoding="utf-8") as f:
+ for pred in all_predictions:
+ f.write(json.dumps(pred, ensure_ascii=False) + "\n")
+
+ logging.info("Inference complete!")
+ logging.info(f"Total samples: {len(all_predictions)}")
+ logging.info(f"Results saved to: {args.result_dir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/chunkformer/bin/recognize.py b/chunkformer/bin/recognize.py
index 56b20b8..aed3039 100644
--- a/chunkformer/bin/recognize.py
+++ b/chunkformer/bin/recognize.py
@@ -278,7 +278,6 @@ def main():
else:
feats = batch["feats"]
feats_lengths = batch["feats_lengths"].to(device)
- infos = {"tasks": batch["tasks"], "langs": batch["langs"]}
results = model.decode(
args.modes,
feats,
@@ -294,7 +293,6 @@ def main():
blank_id=blank_id,
blank_penalty=args.blank_penalty,
length_penalty=args.length_penalty,
- infos=infos,
)
for i, key in enumerate(keys):
for mode, hyps in results.items():
diff --git a/chunkformer/bin/train.py b/chunkformer/bin/train.py
index edacf5e..d769777 100644
--- a/chunkformer/bin/train.py
+++ b/chunkformer/bin/train.py
@@ -100,19 +100,22 @@ def main():
if len(args.override_config) > 0:
configs = override_config(configs, args.override_config)
- # init tokenizer
- tokenizer = init_tokenizer(configs)
+ model_type = configs.get("model", "asr_model")
+ # init tokenizer (not needed for classification)
+ tokenizer = None if model_type == "classification" else init_tokenizer(configs)
# Init env for ddp OR deepspeed
_, _, rank = init_distributed(args)
- # Get dataset & dataloader
+ # Get dataset & dataloader (unified for both ASR and classification)
train_dataset, cv_dataset, train_data_loader, cv_data_loader = init_dataset_and_dataloader(
args, configs, tokenizer
)
-
# Do some sanity checks and save config to arsg.model_dir
- configs = check_modify_and_save_config(args, configs, tokenizer.symbol_table)
+ if model_type == "classification":
+ configs = check_modify_and_save_config(args, configs, symbol_table=None)
+ else:
+ configs = check_modify_and_save_config(args, configs, tokenizer.symbol_table)
# Init asr model from configs
model, configs = init_model(args, configs)
diff --git a/chunkformer/chunkformer_model.py b/chunkformer/chunkformer_model.py
index e65abb2..c3b693f 100644
--- a/chunkformer/chunkformer_model.py
+++ b/chunkformer/chunkformer_model.py
@@ -3,6 +3,7 @@
"""
import argparse
+import json
import os
from contextlib import nullcontext
from typing import List, Optional, Union
@@ -19,14 +20,13 @@
from transformers import PretrainedConfig, PreTrainedModel
from transformers.utils import logging
+from chunkformer.modules.classification_model import SpeechClassificationModel
from chunkformer.transducer.search.greedy_search import batch_greedy_search, optimized_search
from chunkformer.utils.checkpoint import load_checkpoint
from chunkformer.utils.file_utils import read_symbol_table
from chunkformer.utils.init_model import init_speech_model
from chunkformer.utils.model_utils import get_output, get_output_with_timestamps
-# Import ChunkFormer modules
-
logger = logging.get_logger(__name__)
@@ -76,6 +76,8 @@ def __init__(self, config):
# Initialize the model components directly (avoiding file path dependencies)
self.model = self._init_model_from_config()
self.char_dict = None # Will be set when loading symbol table
+ self.label_mapping = None # Will be set when loading label_mapping.json
+ self.is_classification = isinstance(self.model, SpeechClassificationModel)
# Post-init
self.post_init()
@@ -195,6 +197,13 @@ def from_pretrained(
symbol_table = read_symbol_table(vocab_path)
model.char_dict = {v: k for k, v in symbol_table.items()} # type: ignore[assignment]
+ # Load label mapping for classification models
+ label_mapping_path = os.path.join(pretrained_model_name_or_path, "label_mapping.json")
+ if os.path.exists(label_mapping_path):
+ with open(label_mapping_path, "r") as f:
+ model.label_mapping = json.load(f)
+ logger.info(f"Loaded label mapping from: {label_mapping_path}")
+
return model
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): # type: ignore
@@ -232,6 +241,18 @@ def get_ctc(self):
"""Get the CTC module."""
return self.model.ctc
+ def get_classification_heads(self):
+ """Get the classification heads (for classification models only)."""
+ if self.is_classification and hasattr(self.model, "classification_heads"):
+ return self.model.classification_heads
+ return None
+
+ def get_tasks(self):
+ """Get classification tasks (for classification models only)."""
+ if self.is_classification and hasattr(self.model, "tasks"):
+ return self.model.tasks
+ return None
+
def encode(
self,
xs: torch.Tensor,
@@ -530,6 +551,94 @@ def batch_decode(
return decodes
+ @torch.no_grad()
+ def classify_audio(
+ self,
+ audio_path: str,
+ chunk_size: Optional[int] = -1,
+ left_context_size: Optional[int] = -1,
+ right_context_size: Optional[int] = -1,
+ ):
+ """
+ Perform classification on a single audio file.
+
+ Args:
+ audio_path: Path to audio file
+ chunk_size: Chunk size for processing (-1 for full attention)
+ left_context_size: Left context size
+ right_context_size: Right context size
+
+ Returns:
+ Dictionary containing predictions for each task in the format:
+ {
+ task_name: {
+ "label": str, # Human-readable label name
+ "label_id": int, # Numeric label ID
+ "prob": float # Probability of predicted class
+ }
+ }
+
+ Example:
+ {
+ "gender": {
+ "label": "female",
+ "label_id": 0,
+ "prob": 0.95
+ },
+ "emotion": {
+ "label": "neutral",
+ "label_id": 5,
+ "prob": 0.80
+ }
+ }
+ """
+ if not self.is_classification:
+ raise ValueError(
+ "This model is not a classification model. Use ASR decoding methods instead."
+ )
+
+ device = next(self.parameters()).device
+
+ # Load audio and extract features
+ xs, xs_len = self._load_audio_and_extract_features(audio_path)
+ xs = xs.unsqueeze(0).to(device)
+ xs_lens = torch.tensor([xs_len], dtype=torch.long, device=device)
+
+ # Classify
+ results = self.model.classify(
+ speech=xs,
+ speech_lengths=xs_lens,
+ chunk_size=chunk_size,
+ left_context_size=left_context_size,
+ right_context_size=right_context_size,
+ )
+
+ # Convert to desired format with label names
+ output = {}
+ for key, value in results.items():
+ if not key.endswith("_prediction"):
+ continue
+
+ task_name = key.replace("_prediction", "")
+ label_id = int(value.item())
+
+ # Get label name from label_mapping if available
+ label_name = str(label_id) # Default to label_id as string
+
+ if self.label_mapping and task_name in self.label_mapping:
+ # Direct lookup: label_mapping is already {id: label}
+ label_name = self.label_mapping[task_name].get(str(label_id), str(label_id))
+
+ # Get probability
+ prob_key = f"{task_name}_probability"
+ probability = 0.0
+ if prob_key in results:
+ probability = results[prob_key].item()
+
+ output[task_name] = {"label": label_name, "label_id": label_id, "prob": probability}
+
+ return output
+
# Register the configuration and model
ChunkFormerConfig.register_for_auto_class()
@@ -540,7 +649,7 @@ def main():
"""Main function for command line interface."""
# Create argument parser
parser = argparse.ArgumentParser(
- description="ChunkFormer ASR inference with command line interface."
+ description="ChunkFormer ASR and Classification inference with command line interface."
)
# Add arguments with default values
@@ -552,10 +661,13 @@ def main():
type=int,
default=1800,
help="The total audio duration (in second) in a batch \
- that your GPU memory can handle at once. Default is 1800s",
+ that your GPU memory can handle at once. Default is 1800s (ASR only)",
)
parser.add_argument(
- "--chunk_size", type=int, default=64, help="Size of the chunks (default: 64)"
+ "--chunk_size",
+ type=int,
+ default=64,
+ help="Size of the chunks (default: 64, -1 for full attention)",
)
parser.add_argument(
"--left_context_size", type=int, default=128, help="Size of the left context (default: 128)"
@@ -567,17 +679,17 @@ def main():
help="Size of the right context (default: 128)",
)
parser.add_argument(
- "--long_form_audio",
+ "--audio_file",
type=str,
default=None,
- help="Path to the long audio file (default: None)",
+ help="Path to a single audio file (for both ASR long-form and classification)",
)
parser.add_argument(
"--audio_list",
type=str,
default=None,
required=False,
- help="Path to the TSV file containing the audio list. \
+ help="Path to the TSV file containing the audio list (ASR only). \
The TSV file must have one column named 'wav'. \
If 'txt' column is provided, Word Error Rate (WER) is computed",
)
@@ -615,13 +727,11 @@ def main():
print(f"Chunk Size: {args.chunk_size}")
print(f"Left Context Size: {args.left_context_size}")
print(f"Right Context Size: {args.right_context_size}")
- print(f"Long Form Audio Path: {args.long_form_audio}")
+ print(f"Audio File: {args.audio_file}")
print(f"Audio List Path: {args.audio_list}")
assert args.model_checkpoint is not None, "You must specify the path to the model"
- assert (
- args.long_form_audio or args.audio_list
- ), "`long_form_audio` or `audio_list` must be activated"
+ assert args.audio_file or args.audio_list, "`long_form_audio` or `audio_list` must be activated"
# Load model using HuggingFace interface
print("Loading model using HuggingFace interface...")
@@ -631,38 +741,75 @@ def main():
# Perform inference
with torch.autocast(device.type, dtype) if dtype is not None else nullcontext():
- if args.long_form_audio:
- decode = model.endless_decode(
- args.long_form_audio,
- chunk_size=args.chunk_size,
- left_context_size=args.left_context_size,
- right_context_size=args.right_context_size,
- total_batch_duration=args.total_batch_duration,
- )
- for item in decode:
- start = f"{Fore.RED}{item['start']}{Style.RESET_ALL}"
- end = f"{Fore.RED}{item['end']}{Style.RESET_ALL}"
- print(f"{start} - {end}: {item['decode']}")
+ if not model.is_classification:
+ # ASR model
+ if args.audio_file:
+ # Long-form audio decoding
+ decode = model.endless_decode(
+ args.audio_file,
+ chunk_size=args.chunk_size,
+ left_context_size=args.left_context_size,
+ right_context_size=args.right_context_size,
+ total_batch_duration=args.total_batch_duration,
+ )
+ for item in decode:
+ start = f"{Fore.RED}{item['start']}{Style.RESET_ALL}"
+ end = f"{Fore.RED}{item['end']}{Style.RESET_ALL}"
+ print(f"{start} - {end}: {item['decode']}")
+ else:
+ # Batch decode using audio list
+ df = pd.read_csv(args.audio_list, sep="\t")
+ audio_paths = df["wav"].to_list()
+
+ decodes = model.batch_decode(
+ audio_paths,
+ chunk_size=args.chunk_size,
+ left_context_size=args.left_context_size,
+ right_context_size=args.right_context_size,
+ total_batch_duration=args.total_batch_duration,
+ )
+ df["decode"] = decodes
+ if "txt" in df.columns:
+ wer = jiwer.wer(df["txt"].to_list(), decodes)
+ print(f"Word Error Rate (WER): {wer:.4f}")
+
+ # Save results
+ df.to_csv(args.audio_list, sep="\t", index=False)
+ print(f"Results saved to {args.audio_list}")
+
else:
- # Batch decode using HF model interface
- df = pd.read_csv(args.audio_list, sep="\t")
- audio_paths = df["wav"].to_list()
+ # Classification model
+ assert args.audio_file is not None, "`audio_file` must be provided for classification"
- decodes = model.batch_decode(
- audio_paths,
+ print(f"Audio File: {args.audio_file}")
+
+ # Get tasks
+ tasks = model.get_tasks()
+ print(f"Classification tasks: {list(tasks.keys())}")
+
+ # Classify single audio file
+ result = model.classify_audio(
+ args.audio_file,
chunk_size=args.chunk_size,
left_context_size=args.left_context_size,
right_context_size=args.right_context_size,
- total_batch_duration=args.total_batch_duration,
)
- df["decode"] = decodes
- if "txt" in df.columns:
- wer = jiwer.wer(df["txt"].to_list(), decodes)
- print(f"Word Error Rate (WER): {wer:.4f}") # noqa: E231
-
- # Save results
- df.to_csv(args.audio_list, sep="\t", index=False)
- print(f"Results saved to {args.audio_list}")
+
+ # Print results
+ print(f"\nClassification Results for: {args.audio_file}")
+ print("=" * 70)
+ for task_name, task_result in result.items():
+ label = task_result.get("label", "N/A")
+ label_id = task_result.get("label_id", -1)
+ prob = task_result.get("prob")
+
+ print(f"{task_name.capitalize()}:")
+ print(f" Label: {label}")
+ print(f" Label ID: {label_id}")
+ if prob is not None:
+ print(f" Probability: {prob:.4f}")
+ print()
+ print("=" * 70)
if __name__ == "__main__":
diff --git a/chunkformer/dataset/dataset.py b/chunkformer/dataset/dataset.py
index b112de1..ceb6eac 100644
--- a/chunkformer/dataset/dataset.py
+++ b/chunkformer/dataset/dataset.py
@@ -78,6 +78,14 @@ def Dataset(
if tokenizer is not None:
dataset = dataset.map(partial(processor.tokenize, tokenizer=tokenizer))
+ # Classification-specific processing
+ dataset_type = conf.get("dataset_type", "asr")
+ if dataset_type == "classification" or "tasks" in conf:
+ tasks = conf.get("tasks", [])
+ if tasks:
+ # Parse classification labels (convert strings to integers)
+ dataset = dataset.map(partial(processor.parse_classification_labels, tasks=tasks))
+
filter_conf = conf.get("filter_conf", {})
dataset = dataset.filter(partial(processor.filter, **filter_conf))
@@ -114,10 +122,6 @@ def Dataset(
spec_trim_conf = conf.get("spec_trim_conf", {})
dataset = dataset.map(partial(processor.spec_trim, **spec_trim_conf))
- language_conf = conf.get("language_conf", {"limited_langs": ["zh", "en"]})
- dataset = dataset.map(partial(processor.detect_language, **language_conf))
- dataset = dataset.map(processor.detect_task)
-
shuffle = conf.get("shuffle", True)
if shuffle:
shuffle_conf = conf.get("shuffle_conf", {})
@@ -132,13 +136,12 @@ def Dataset(
batch_type = batch_conf.get("batch_type", "static")
pad_feat = batch_conf.get("pad_feat", "True")
+ wrapper_func = lambda batch: processor.padding(batch, pad_feat=pad_feat)
assert batch_type in ["static", "bucket", "dynamic"]
if batch_type == "static":
assert "batch_size" in batch_conf
batch_size = batch_conf.get("batch_size", 16)
- dataset = dataset.batch(
- batch_size, wrapper_class=lambda batch: processor.padding(batch, pad_feat)
- )
+ dataset = dataset.batch(batch_size, wrapper_class=wrapper_func)
elif batch_type == "bucket":
assert "bucket_boundaries" in batch_conf
assert "bucket_batch_sizes" in batch_conf
@@ -146,13 +149,13 @@ def Dataset(
processor.feats_length_fn,
batch_conf["bucket_boundaries"],
batch_conf["bucket_batch_sizes"],
- wrapper_class=lambda batch: processor.padding(batch, pad_feat),
+ wrapper_class=wrapper_func,
)
else:
max_frames_in_batch = batch_conf.get("max_frames_in_batch", 12000)
dataset = dataset.dynamic_batch(
processor.DynamicBatchWindow(max_frames_in_batch),
- wrapper_class=lambda batch: processor.padding(batch, pad_feat),
+ wrapper_class=wrapper_func,
)
return dataset
diff --git a/chunkformer/dataset/processor.py b/chunkformer/dataset/processor.py
index 22bc8ea..e1577cb 100644
--- a/chunkformer/dataset/processor.py
+++ b/chunkformer/dataset/processor.py
@@ -26,7 +26,6 @@
import torch.nn.functional as F
import torchaudio
import torchaudio.compliance.kaldi as kaldi
-from langid.langid import LanguageIdentifier, model
from torch.nn.utils.rnn import pad_sequence
from chunkformer.text.base_tokenizer import BaseTokenizer
@@ -35,7 +34,6 @@
AUDIO_FORMAT_SETS = set(["flac", "mp3", "m4a", "ogg", "opus", "wav", "wma"])
-lid = LanguageIdentifier.from_modelstring(model, norm_probs=True)
logging.getLogger("langid").setLevel(logging.INFO)
@@ -103,28 +101,6 @@ def parse_speaker(sample, speaker_dict):
return sample
-def detect_language(sample, limited_langs):
- assert "txt" in sample
- # NOTE(xcsong): Because language classification may not be very accurate
- # (for example, Chinese being classified as Japanese), our workaround,
- # given we know for certain that the training data only consists of
- # Chinese and English, is to limit the classification results to reduce
- # the impact of misclassification.
- lid.set_languages(limited_langs)
- # i.e., ('zh', 0.9999999909903544)
- sample["lang"] = lid.classify(sample["txt"])[0]
- return sample
-
-
-def detect_task(sample):
- # TODO(xcsong): Currently, the task is hard-coded to 'transcribe'.
- # In the future, we could dynamically determine the task based on
- # the contents of sample. For instance, if a sample contains both
- # 'txt_en' and 'txt_zh', the task should be set to 'translate'.
- sample["task"] = "transcribe"
- return sample
-
-
def decode_wav(sample):
"""Parse key/wav/txt from json line
@@ -533,43 +509,69 @@ def spec_trim(sample, max_t=20):
def padding(data, pad_feat=True):
"""Padding the data into training data
+ Automatically detects and supports both ASR and classification tasks.
+ - ASR data: has "label" key (token IDs)
+ - Classification data: has "{task}_label" keys (e.g., "gender_label", "emotion_label")
+
Args:
- data: List[{key, feat, label}
+ data: List[{key, feat, label, ...}] for ASR or
+ List[{key, feat, {task}_label, ...}] for classification
+ pad_feat: Whether to pad features
Returns:
- Tuple(keys, feats, labels, feats lengths, label lengths)
+ Batched dictionary for ASR or classification
"""
sample = data
assert isinstance(sample, list)
+ assert len(sample) > 0, "Empty batch"
+
feats_length = torch.tensor([x["feat"].size(0) for x in sample], dtype=torch.int32)
order = torch.argsort(feats_length, descending=True)
feats_lengths = torch.tensor([sample[i]["feat"].size(0) for i in order], dtype=torch.int32)
sorted_feats = [sample[i]["feat"] for i in order]
sorted_keys = [sample[i]["key"] for i in order]
- sorted_labels = [torch.tensor(sample[i]["label"], dtype=torch.int64) for i in order]
- sorted_wavs = [sample[i]["wav"].squeeze(0) for i in order]
- langs = [sample[i]["lang"] for i in order]
- tasks = [sample[i]["task"] for i in order]
- label_lengths = torch.tensor([x.size(0) for x in sorted_labels], dtype=torch.int32)
- wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs], dtype=torch.int32)
padded_feats = pad_sequence(sorted_feats, batch_first=True, padding_value=0)
- padding_labels = pad_sequence(sorted_labels, batch_first=True, padding_value=-1)
- padded_wavs = pad_sequence(sorted_wavs, batch_first=True, padding_value=0)
-
- batch = {
- "keys": sorted_keys,
- "feats": padded_feats if pad_feat else sorted_feats,
- "target": padding_labels,
- "feats_lengths": feats_lengths,
- "target_lengths": label_lengths,
- "pcm": padded_wavs,
- "pcm_length": wav_lengths,
- "langs": langs,
- "tasks": tasks,
- }
- if "speaker" in sample[0]:
- speaker = torch.tensor([sample[i]["speaker"] for i in order], dtype=torch.int32)
- batch["speaker"] = speaker
+
+ # Detect data type: ASR has "label" key, classification has "*_label" keys
+ is_asr = "label" in sample[0]
+
+ if is_asr:
+ # ASR-specific batching
+ sorted_labels = [torch.tensor(sample[i]["label"], dtype=torch.int64) for i in order]
+ sorted_wavs = [sample[i]["wav"].squeeze(0) for i in order]
+ label_lengths = torch.tensor([x.size(0) for x in sorted_labels], dtype=torch.int32)
+ wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs], dtype=torch.int32)
+ padding_labels = pad_sequence(sorted_labels, batch_first=True, padding_value=-1)
+ padded_wavs = pad_sequence(sorted_wavs, batch_first=True, padding_value=0)
+
+ batch = {
+ "keys": sorted_keys,
+ "feats": padded_feats if pad_feat else sorted_feats,
+ "target": padding_labels,
+ "feats_lengths": feats_lengths,
+ "target_lengths": label_lengths,
+ "pcm": padded_wavs,
+ "pcm_length": wav_lengths,
+ }
+ if "speaker" in sample[0]:
+ speaker = torch.tensor([sample[i]["speaker"] for i in order], dtype=torch.int32)
+ batch["speaker"] = speaker
+ else:
+ # Classification-specific batching
+ # Automatically detect all *_label keys
+ batch = {
+ "keys": sorted_keys,
+ "feats": padded_feats if pad_feat else sorted_feats,
+ "feats_lengths": feats_lengths,
+ }
+
+ # Find all classification label keys (e.g., gender_label, emotion_label, etc.)
+ label_keys = [k for k in sample[0].keys() if k.endswith("_label")]
+
+ # Add each classification task's labels to the batch
+ for label_key in label_keys:
+ labels = torch.tensor([sample[i][label_key] for i in order], dtype=torch.int64)
+ batch[label_key] = labels
return batch
@@ -590,3 +592,28 @@ def __call__(self, sample, buffer_size):
self.longest_frames = new_sample_frames
return True
return False
+
+
+def parse_classification_labels(sample: dict, tasks: list) -> dict:
+ """Parse classification labels from sample.
+
+ Args:
+ sample: Dictionary containing label information
+ tasks: List of task names to parse (e.g., ['gender', 'emotion', 'region'])
+
+ Returns:
+ Updated sample with parsed labels
+
+ Example:
+ Input sample: {'key': 'utt1', 'gender_label': '0', 'emotion_label': '3'}
+ tasks: ['gender', 'emotion']
+ Output: {'key': 'utt1', 'gender_label': 0, 'emotion_label': 3}
+ """
+ for task in tasks:
+ label_key = f"{task}_label"
+ if label_key in sample:
+ sample[label_key] = int(sample[label_key])
+ else:
+ # raise error
+ raise KeyError(f"Label {label_key} not found in sample {sample}")
+ return sample
diff --git a/chunkformer/modules/asr_model.py b/chunkformer/modules/asr_model.py
index a85a937..6a83200 100644
--- a/chunkformer/modules/asr_model.py
+++ b/chunkformer/modules/asr_model.py
@@ -112,7 +112,6 @@ def forward(
encoder_mask,
text,
text_lengths,
- {"langs": batch["langs"], "tasks": batch["tasks"]},
)
else:
loss_att = None
diff --git a/chunkformer/modules/classification_model.py b/chunkformer/modules/classification_model.py
new file mode 100644
index 0000000..0a1a11b
--- /dev/null
+++ b/chunkformer/modules/classification_model.py
@@ -0,0 +1,291 @@
+# Copyright (c) 2024 ChunkFormer Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Speech Classification Model using ChunkFormer encoder."""
+
+from typing import Dict, Tuple
+
+import torch
+import torch.nn.functional as F
+
+from chunkformer.modules.encoder import ChunkFormerEncoder
+
+
+class ClassificationHead(torch.nn.Module):
+ """Simple linear classification head with optional dropout."""
+
+ def __init__(
+ self,
+ input_size: int,
+ num_classes: int,
+ dropout_rate: float = 0.1,
+ ):
+ """
+ Args:
+ input_size: Input feature dimension
+ num_classes: Number of output classes
+ dropout_rate: Dropout rate before classification layer
+ """
+ super().__init__()
+ self.dropout = torch.nn.Dropout(dropout_rate)
+ self.linear = torch.nn.Linear(input_size, num_classes)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x: (batch, input_size)
+ Returns:
+ logits: (batch, num_classes)
+ """
+ x = self.dropout(x)
+ return self.linear(x)
+
+
+class SpeechClassificationModel(torch.nn.Module):
+ """Speech Classification Model using ChunkFormer encoder.
+
+ Supports both single-task and multi-task classification.
+ Uses average pooling over encoder outputs for feature aggregation.
+ """
+
+ def __init__(
+ self,
+ encoder: ChunkFormerEncoder,
+ tasks: Dict[str, int],
+ dropout_rate: float = 0.1,
+ label_smoothing: float = 0.0,
+ ):
+ """
+ Args:
+ encoder: ChunkFormer encoder
+ tasks: Dictionary mapping task names to number of classes
+ e.g., {'gender': 2, 'emotion': 7, 'region': 5}
+ For single task: {'gender': 2}
+ dropout_rate: Dropout rate for classification heads
+ label_smoothing: Label smoothing factor (0.0 = no smoothing, typically 0.1)
+ """
+ super().__init__()
+
+ if not tasks:
+ raise ValueError("At least one classification task must be defined")
+
+ self.encoder = encoder
+ self.tasks = tasks
+ self.task_names = list(tasks.keys())
+ self.num_tasks = len(tasks)
+ self.label_smoothing = label_smoothing
+
+ # Create classification head for each task
+ encoder_output_size = encoder.output_size()
+ self.classification_heads = torch.nn.ModuleDict(
+ {
+ task_name: ClassificationHead(
+ input_size=encoder_output_size,
+ num_classes=num_classes,
+ dropout_rate=dropout_rate,
+ )
+ for task_name, num_classes in tasks.items()
+ }
+ )
+
+ def forward(
+ self,
+ batch: dict,
+ device: torch.device,
+ ) -> Dict[str, torch.Tensor]:
+ """Forward pass for training.
+
+ Args:
+ batch: Dictionary containing:
+ - feats: (batch, time, feat_dim)
+ - feats_lengths: (batch,)
+ - {task_name}_label: (batch,) for each task
+ device: Device to run on
+
+ Returns:
+ Dictionary containing:
+ - loss: Total loss (averaged across tasks)
+ - loss_{task_name}: Loss for each task
+ - acc_{task_name}: Accuracy for each task
+ - logits_{task_name}: Logits for each task
+ """
+ speech = batch["feats"].to(device)
+ speech_lengths = batch["feats_lengths"].to(device)
+
+ # 1. Encoder forward
+ encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
+ # encoder_out: (batch, time, encoder_dim)
+ # encoder_mask: (batch, 1, time)
+
+ # 2. Average pooling over time dimension
+ pooled_features = self._average_pooling(encoder_out, encoder_mask)
+ # pooled_features: (batch, encoder_dim)
+
+ # 3. Classification for each task
+ outputs = {}
+ total_loss = 0.0
+ num_valid_tasks = 0
+
+ for task_name in self.task_names:
+ label_key = f"{task_name}_label"
+
+ # Skip task if labels not provided (useful for multi-task with partial labels)
+ if label_key not in batch:
+ continue
+
+ labels = batch[label_key].to(device)
+
+ # Get logits from classification head
+ logits = self.classification_heads[task_name](pooled_features)
+
+ # Compute cross-entropy loss with label smoothing
+ loss = F.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
+
+ # Compute accuracy
+ predictions = torch.argmax(logits, dim=-1)
+ accuracy = (predictions == labels).float().mean()
+
+ # Store outputs
+ outputs[f"loss_{task_name}"] = loss
+ outputs[f"acc_{task_name}"] = accuracy
+
+ total_loss += loss
+ num_valid_tasks += 1
+
+ # Average loss across all tasks
+ if num_valid_tasks > 0:
+ outputs["loss"] = total_loss / num_valid_tasks
+ else:
+ outputs["loss"] = torch.tensor(0.0, device=device)
+
+ return outputs
+
+ def _average_pooling(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ """Average pooling over time dimension, considering padding mask.
+
+ Args:
+ encoder_out: (batch, time, dim)
+ encoder_mask: (batch, 1, time) - True for valid frames, False for padding
+
+ Returns:
+ pooled: (batch, dim)
+ """
+ # encoder_mask: (batch, 1, time) -> (batch, time, 1)
+ mask = encoder_mask.transpose(1, 2).float()
+
+ # Sum over valid frames
+ masked_sum = (encoder_out * mask).sum(dim=1) # (batch, dim)
+
+ # Count valid frames
+ valid_counts = mask.sum(dim=1) # (batch, 1)
+
+ # Average
+ pooled = masked_sum / (valid_counts + 1e-10) # (batch, dim)
+
+ return pooled
+
+ def encode(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ chunk_size: int = -1,
+ left_context_size: int = -1,
+ right_context_size: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode audio to features.
+
+ Args:
+ speech: (batch, time, feat_dim)
+ speech_lengths: (batch,)
+ chunk_size: Chunk size for chunked processing
+ left_context_size: Left context size
+ right_context_size: Right context size
+
+ Returns:
+ encoder_out: (batch, time, encoder_dim)
+ encoder_mask: (batch, 1, time)
+ """
+ encoder_out, encoder_mask = self.encoder(
+ speech,
+ speech_lengths,
+ chunk_size=chunk_size,
+ left_context_size=left_context_size,
+ right_context_size=right_context_size,
+ )
+ return encoder_out, encoder_mask
+
+ def classify(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ chunk_size: int = -1,
+ left_context_size: int = -1,
+ right_context_size: int = -1,
+ ) -> Dict[str, torch.Tensor]:
+ """Inference: classify audio samples.
+
+ Args:
+ speech: (batch, time, feat_dim)
+ speech_lengths: (batch,)
+ chunk_size: Chunk size for chunked processing
+ left_context_size: Left context size
+ right_context_size: Right context size
+
+ Returns:
+ Dictionary containing for each task:
+ - {task_name}_prediction: (batch,) predicted class indices
+ - {task_name}_probability: (batch,) probability of predicted class
+ """
+ # Encode
+ encoder_out, encoder_mask = self.encode(
+ speech,
+ speech_lengths,
+ chunk_size,
+ left_context_size,
+ right_context_size,
+ )
+
+ # Pool
+ pooled_features = self._average_pooling(encoder_out, encoder_mask)
+
+ # Classify for each task
+ results = {}
+ for task_name in self.task_names:
+ logits = self.classification_heads[task_name](pooled_features)
+
+ # Get predictions
+ predictions = torch.argmax(logits, dim=-1)
+ results[f"{task_name}_prediction"] = predictions
+
+ # Get probabilities and extract only the predicted class probability
+ probabilities = F.softmax(logits, dim=-1)
+ # Get probability of the predicted class for each sample in batch
+ predicted_probs = probabilities.gather(1, predictions.unsqueeze(1)).squeeze(1)
+ results[f"{task_name}_probability"] = predicted_probs
+
+ return results
+
+ def get_num_classes(self, task_name: str) -> int:
+ """Get number of classes for a specific task."""
+ if task_name not in self.tasks:
+ raise ValueError(f"Task '{task_name}' not found. Available tasks: {self.task_names}")
+ return self.tasks[task_name]
+
+ def is_multi_task(self) -> bool:
+ """Check if model is multi-task."""
+ return self.num_tasks > 1
diff --git a/chunkformer/utils/executor.py b/chunkformer/utils/executor.py
index b73c06c..c535837 100644
--- a/chunkformer/utils/executor.py
+++ b/chunkformer/utils/executor.py
@@ -78,7 +78,7 @@ def train(
if wenet_join(group_join, info_dict):
break
- if batch_dict["target_lengths"].size(0) == 0:
+ if "target_lengths" in batch_dict and batch_dict["target_lengths"].size(0) == 0:
continue
context = None
@@ -138,6 +138,7 @@ def cv(self, model, cv_data_loader, configs):
model.eval()
info_dict = copy.deepcopy(configs)
num_seen_utts, loss_dict, total_acc = 1, {}, [] # avoid division by 0
+ acc_dict = {} # For accumulating accuracies (classification or ASR)
with torch.no_grad():
for batch_idx, batch_dict in enumerate(cv_data_loader):
info_dict["tag"] = "CV"
@@ -145,7 +146,7 @@ def cv(self, model, cv_data_loader, configs):
info_dict["batch_idx"] = batch_idx
info_dict["cv_step"] = batch_idx
- num_utts = batch_dict["target_lengths"].size(0)
+ num_utts = batch_dict["feats"].size(0)
if num_utts == 0:
continue
@@ -158,17 +159,32 @@ def cv(self, model, cv_data_loader, configs):
if _dict.get("th_accuracy", None) is not None
else 0.0
)
- for loss_name, loss_value in _dict.items():
- if (
- loss_value is not None
- and "loss" in loss_name
- and torch.isfinite(loss_value)
- ):
- loss_value = loss_value.item()
- loss_dict[loss_name] = loss_dict.get(loss_name, 0) + loss_value * num_utts
+
+ # Accumulate all losses and accuracies
+ for key, value in _dict.items():
+ if value is None or not torch.isfinite(value):
+ continue
+
+ value_item = value.item()
+
+ # Accumulate losses
+ if "loss" in key:
+ loss_dict[key] = loss_dict.get(key, 0) + value_item * num_utts
+
+ # Accumulate accuracies (for classification: acc_gender, acc_emotion, etc.)
+ if "acc" in key:
+ acc_dict[key] = acc_dict.get(key, 0) + value_item * num_utts
+
# write cv: log
log_per_step(writer=None, info_dict=info_dict, timer=self.cv_step_timer)
- for loss_name, loss_value in loss_dict.items():
+
+ # Average all accumulated losses
+ for loss_name in loss_dict:
loss_dict[loss_name] = loss_dict[loss_name] / num_seen_utts
+
+ # Average all accumulated accuracies
+ for acc_name in acc_dict:
+ loss_dict[acc_name] = acc_dict[acc_name] / num_seen_utts
+
loss_dict["acc"] = sum(total_acc) / len(total_acc)
return loss_dict
diff --git a/chunkformer/utils/init_dataset.py b/chunkformer/utils/init_dataset.py
index 08658c7..711fe6c 100644
--- a/chunkformer/utils/init_dataset.py
+++ b/chunkformer/utils/init_dataset.py
@@ -20,7 +20,7 @@ def init_dataset(
partition=True,
split="train",
):
- assert dataset_type in ["asr", "ssl"]
+ assert dataset_type in ["asr", "ssl", "classification"]
if split != "train":
cv_conf = copy.deepcopy(conf)
@@ -33,8 +33,17 @@ def init_dataset(
cv_conf["list_shuffle"] = False
conf = cv_conf
+ # Add dataset_type to conf so Dataset function knows how to batch
+ conf = copy.deepcopy(conf)
+ conf["dataset_type"] = dataset_type
+
if dataset_type == "asr":
return init_asr_dataset(data_type, data_list_file, tokenizer, conf, partition)
+ elif dataset_type == "classification":
+ # Classification uses the same Dataset class but without tokenizer
+ return init_asr_dataset(
+ data_type, data_list_file, tokenizer=None, conf=conf, partition=partition
+ )
else:
from chunkformer.ssl.init_dataset import init_dataset as init_ssl_dataset
diff --git a/chunkformer/utils/init_model.py b/chunkformer/utils/init_model.py
index a5e5e02..b702c23 100644
--- a/chunkformer/utils/init_model.py
+++ b/chunkformer/utils/init_model.py
@@ -17,6 +17,7 @@
import torch
from ..modules.asr_model import ASRModel
+from ..modules.classification_model import SpeechClassificationModel
from ..modules.cmvn import GlobalCMVN
from ..modules.ctc import CTC
from ..modules.decoder import BiTransformerDecoder, TransformerDecoder
@@ -53,6 +54,7 @@
CHUNKFORMER_MODEL_CLASSES = {
"asr_model": ASRModel,
"transducer": Transducer,
+ "classification": SpeechClassificationModel,
}
@@ -67,7 +69,12 @@ def init_speech_model(args, configs):
global_cmvn = None
input_dim = configs["input_dim"]
- vocab_size = configs["output_dim"]
+
+ # Get model type early to determine what components to create
+ model_type = configs.get("model", "asr_model")
+
+ # vocab_size is only needed for ASR models
+ vocab_size = configs.get("output_dim", 0) if model_type != "classification" else 0
# ChunkFormer only supports chunkformer encoder
encoder_type = configs.get("encoder", "chunkformer")
@@ -79,21 +86,36 @@ def init_speech_model(args, configs):
input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
)
- # Create decoder
- decoder = CHUNKFORMER_DECODER_CLASSES[decoder_type](
- vocab_size, encoder.output_size(), **configs["decoder_conf"]
- )
+ # Create decoder and CTC only for ASR models
+ decoder = None
+ ctc = None
- # Create CTC
- ctc = CHUNKFORMER_CTC_CLASSES[ctc_type](
- vocab_size,
- encoder.output_size(),
- blank_id=configs["ctc_conf"]["ctc_blank_id"] if "ctc_conf" in configs else 0,
- )
+ if model_type != "classification":
+ # Create decoder
+ decoder = CHUNKFORMER_DECODER_CLASSES[decoder_type](
+ vocab_size, encoder.output_size(), **configs["decoder_conf"]
+ )
+
+ # Create CTC
+ ctc = CHUNKFORMER_CTC_CLASSES[ctc_type](
+ vocab_size,
+ encoder.output_size(),
+ blank_id=configs["ctc_conf"]["ctc_blank_id"] if "ctc_conf" in configs else 0,
+ )
# Create model based on type
- model_type = configs.get("model", "asr_model")
- if model_type == "transducer":
+ if model_type == "classification":
+ # Classification model only needs encoder
+ tasks = configs["model_conf"].get("tasks", {})
+ if not tasks:
+ raise ValueError("Classification model requires 'tasks' in model_conf")
+
+ model = CHUNKFORMER_MODEL_CLASSES[model_type](
+ encoder=encoder,
+ tasks=tasks,
+ **{k: v for k, v in configs["model_conf"].items() if k != "tasks"},
+ )
+ elif model_type == "transducer":
predictor_type = configs.get("predictor", "rnn")
joint_type = configs.get("joint", "transducer_joint")
predictor = CHUNKFORMER_PREDICTOR_CLASSES[predictor_type](
diff --git a/chunkformer/utils/train_utils.py b/chunkformer/utils/train_utils.py
index 5f6facb..9236ee3 100644
--- a/chunkformer/utils/train_utils.py
+++ b/chunkformer/utils/train_utils.py
@@ -328,21 +328,21 @@ def check_modify_and_save_config(args, configs, symbol_table):
configs["lora_conf"]["lora_alpha"] = args.lora_alpha
configs["lora_conf"]["lora_dropout"] = args.lora_dropout
- if configs["model"] == "asr_model" or configs["model"] == "transducer":
- if "input_dim" not in configs:
- if "fbank_conf" in configs["dataset_conf"]:
- input_dim = configs["dataset_conf"]["fbank_conf"]["num_mel_bins"]
- elif "log_mel_spectrogram_conf" in configs["dataset_conf"]:
- input_dim = configs["dataset_conf"]["log_mel_spectrogram_conf"]["num_mel_bins"]
- else:
- input_dim = configs["dataset_conf"]["mfcc_conf"]["num_mel_bins"]
+ # Set input_dim if not present (for both ASR and classification models)
+ if "input_dim" not in configs:
+ if "fbank_conf" in configs["dataset_conf"]:
+ configs["input_dim"] = configs["dataset_conf"]["fbank_conf"]["num_mel_bins"]
+ elif "log_mel_spectrogram_conf" in configs["dataset_conf"]:
+ configs["input_dim"] = configs["dataset_conf"]["log_mel_spectrogram_conf"][
+ "num_mel_bins"
+ ]
else:
- input_dim = configs["input_dim"]
-
- configs["input_dim"] = input_dim
+ configs["input_dim"] = configs["dataset_conf"]["mfcc_conf"]["num_mel_bins"]
- configs, _ = get_blank_id(configs, symbol_table)
- configs["output_dim"] = configs["vocab_size"]
+ # Only process symbol_table for ASR models
+ if symbol_table is not None:
+ configs, _ = get_blank_id(configs, symbol_table)
+ configs["output_dim"] = configs["vocab_size"]
configs["train_engine"] = args.train_engine
configs["use_amp"] = args.use_amp
@@ -373,7 +373,11 @@ def init_dataset_and_dataloader(args, configs, tokenizer, seed=777):
configs["dataset_conf"]["cycle"] = configs.get("max_epoch", 100)
conf = configs["dataset_conf"]
dataset_type = configs.get("dataset", "asr")
- configs["vocab_size"] = tokenizer.vocab_size()
+
+ # Set vocab_size for ASR models
+ if dataset_type != "classification":
+ configs["vocab_size"] = tokenizer.vocab_size()
+
train_dataset = init_dataset(
dataset_type, args.data_type, args.train_data, tokenizer, conf, True, split="train"
)
@@ -855,14 +859,31 @@ def log_per_epoch(writer, info_dict):
lrs = info_dict["lrs"]
rank = int(os.environ.get("RANK", 0))
step = info_dict["step"]
+
+ # Get accuracy: either general "acc" or compute average from task-specific accuracies
+ if "acc" in loss_dict:
+ acc_value = tensor_to_scalar(loss_dict["acc"])
+ acc_str = f"acc {acc_value}"
+ else:
+ # For classification: collect all acc_* keys and average them
+ acc_keys = [k for k in loss_dict.keys() if k.startswith("acc_")]
+ if acc_keys:
+ avg_acc = sum(tensor_to_scalar(loss_dict[k]) for k in acc_keys) / len(acc_keys)
+ acc_str = f"acc {avg_acc}"
+ # Also add individual task accuracies
+ task_accs = " ".join(f"{k} {tensor_to_scalar(loss_dict[k]):.4f}" for k in acc_keys)
+ acc_str = f"{acc_str} ({task_accs})"
+ else:
+ acc_str = ""
+
logging.info(
- "Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}".format(
+ "Epoch {} Step {} CV info lr {} cv_loss {} rank {} {}".format(
epoch,
step,
lrs_to_str(lrs),
tensor_to_scalar(loss_dict["loss"]),
rank,
- tensor_to_scalar(loss_dict["acc"]),
+ acc_str,
)
)
diff --git a/examples/classification/README.md b/examples/classification/README.md
new file mode 100644
index 0000000..d098f59
--- /dev/null
+++ b/examples/classification/README.md
@@ -0,0 +1,176 @@
+# Speech Classification with ChunkFormer
+
+Complete implementation for speech classification tasks using ChunkFormer encoder. Supports both **single-task** and **multi-task** classification (gender, emotion, region, accent, age, etc.).
+
+```
+Audio → Fbank → ChunkFormer Encoder → Average Pooling → Classification Heads
+```
+
+## Quick Start
+
+### 1. Installation
+
+```bash
+cd /path/to/chunkformer
+pip install -e .
+```
+
+### 2. Data Format
+
+Create `data.tsv` files:
+
+```tsv
+key wav gender_label
+utt001 /path/to/audio1.wav 0
+utt002 /path/to/audio2.wav 1
+```
+
+For multi-task, add more label columns:
+```tsv
+key wav gender_label emotion_label region_label
+utt001 /path/to/audio1.wav 0 1 2
+utt002 /path/to/audio2.wav 1 3 0
+```
+
+**Note**:
+- Labels must be integers starting from 0
+- `key` column is optional (auto-generated from wav path if missing)
+
+### 3. Directory Structure
+
+```
+examples/classification/
+├── data/
+│ ├── train/data.tsv
+│ ├── dev/data.tsv
+│ └── test/data.tsv
+├── conf/
+│ ├── single_task.yaml
+│ └── multi_task.yaml
+└── run.sh
+```
+
+### 4. Configuration
+
+**Single-task** (`conf/single_task.yaml`):
+```yaml
+model: classification
+model_conf:
+ tasks:
+ gender: 2 # 2 classes
+ dropout_rate: 0.1
+ label_smoothing: 0.1 # Optional
+
+dataset: classification
+dataset_conf:
+ tasks: ['gender']
+ batch_conf:
+ batch_type: static
+ batch_size: 8
+```
+
+**Multi-task** (`conf/multi_task.yaml`):
+```yaml
+model_conf:
+ tasks:
+ gender: 2
+ emotion: 7
+ region: 5
+dataset_conf:
+ tasks: ['gender', 'emotion', 'region']
+```
+
+### 5. Training Pipeline
+
+```bash
+cd examples/classification
+
+# Full pipeline (data prep + training)
+./run.sh --stage 0 --stop-stage 3
+
+# Or run individual stages:
+./run.sh --stage 0 --stop-stage 0 # Convert TSV to list format
+./run.sh --stage 1 --stop-stage 1 # Compute CMVN
+./run.sh --stage 2 --stop-stage 2 # Analyze labels
+./run.sh --stage 3 --stop-stage 3 # Train model
+```
+
+**Multi-GPU training**:
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+./run.sh --stage 3 --stop-stage 3
+```
+
+**Resume training**:
+```bash
+./run.sh --stage 3 --stop-stage 3 --checkpoint exp/classification_v1/10.pt
+```
+
+**Transfer learning from ASR**:
+```bash
+# Load encoder weights from pre-trained ASR model
+./run.sh --stage 3 --stop-stage 3 --checkpoint /path/to/asr_model.pt
+```
+Note: Only encoder weights are loaded; classification heads are randomly initialized.
+
+**Monitor training**:
+```bash
+tensorboard --logdir tensorboard/classification_v1 --port 6006
+```
+
+### 6. Evaluation
+
+```bash
+# Inference on test set
+./run.sh --stage 4 --stop-stage 4
+```
+
+Results in `exp/classification_v1/test/metrics.txt`:
+```
+======================================================================
+Classification Metrics
+======================================================================
+
+GENDER:
+ Samples: 4457
+
+ Classification Report:
+ precision recall f1-score support
+
+ 0 0.98 0.99 0.98 2243
+ 1 0.99 0.98 0.98 2214
+
+ accuracy 0.98 4457
+ macro avg 0.98 0.98 0.98 4457
+ weighted avg 0.98 0.98 0.98 4457
+
+ Confusion Matrix:
+ 0: [2213, 30]
+ 1: [41, 2173]
+```
+
+## Tools
+
+### Convert Text Labels to Integers
+
+```bash
+# Basic usage - auto-detects tasks from columns ending with '_label'
+python tools/convert_text_labels_to_int.py \
+ --input data.tsv
+
+# Or specify tasks explicitly
+python tools/convert_text_labels_to_int.py \
+ --input data.tsv \
+ --tasks gender emotion region
+```
+
+### Split Train/Test
+
+```bash
+python tools/split_classification_data.py \
+ --input data/train/data.tsv \
+ --train_output data/train/data.tsv \
+ --test_output data/test/data.tsv \
+ --test_ratio 0.2 \
+ --seed 42
+```
diff --git a/examples/classification/RESULTS.md b/examples/classification/RESULTS.md
new file mode 100644
index 0000000..bfcbdf3
--- /dev/null
+++ b/examples/classification/RESULTS.md
@@ -0,0 +1,196 @@
+# Classification Results
+
+Comparison of multi-task classification performance with and without transfer learning from pre-trained ASR model.
+
+> **⚠️ Important Note**: The dataset exhibits extreme class imbalance across Age, Dialect, and Emotion tasks. This severe imbalance causes highly unstable training when training from scratch, as the model tends to collapse to predicting only the majority class. Further investigation is needed to determine whether this instability is primarily due to:
+> 1. **Data imbalance** - insufficient samples for minority classes
+> 2. **Model architecture** - inadequate capacity or regularization for imbalanced learning
+> 3. **Training strategy** - need for class weighting, focal loss, or other imbalance-handling techniques
+>
+> The dramatic improvement with transfer learning suggests that pre-trained representations provide crucial initialization that prevents majority class collapse.
+
+## Experimental Setup
+
+- **Dataset**: [LSVSC](doof-ferb/LSVSC) Vietnamese speech dataset
+- **Training Set Size**: 40,102 samples
+- **Dev/Test Set Size**: 4,457 samples each
+- **Tasks**: 4 classification tasks (Gender, Age, Dialect, Emotion)
+- **Model**: ChunkFormer encoder
+- **Configuration**: `conf/multi_task.yaml`
+- **Training**:
+ - **From Pretrain**: Initialized with pre-trained ASR encoder (`khanhld/chunkformer-rnnt-large-vie`)
+ - **From Scratch**: Randomly initialized encoder
+- **Checkpoint**: [](https://huggingface.co/khanhld/chunkformer-gender-emotion-dialect-age-classification)
+
+## Dataset Statistics
+
+### Training Set (40,102 samples)
+
+| Task | Classes | Distribution | Balance |
+|------|---------|--------------|---------|
+| **Gender** | 2 | Male: 49.65% / Female: 50.35% | ✅ Balanced |
+| **Age** | 5 | 0: 0.08% / 1: 42.62% / 2: 5.30% / 3: 0.62% / 4: 51.38% | ❌ Highly Imbalanced |
+| **Dialect** | 5 | 0: 3.39% / 1: 0.70% / 2: 0.05% / 3: 88.10% / 4: 7.76% | ❌ Highly Imbalanced |
+| **Emotion** | 8 | 0: 0.21% / 1: 0.03% / 2: 0.08% / 3: 0.06% / 4: 0.66% / 5: 98.57% / 6: 0.35% / 7: 0.04% | ❌ Highly Imbalanced |
+
+
+### Test Set (4,457 samples)
+
+| Task | Classes | Distribution | Balance |
+|------|---------|--------------|---------|
+| **Gender** | 2 | Male: 50.33% / Female: 49.67% | ✅ Balanced |
+| **Age** | 5 | 0: 0.07% / 1: 42.00% / 2: 5.12% / 3: 0.45% / 4: 52.37% | ❌ Highly Imbalanced |
+| **Dialect** | 5 | 0: 3.46% / 1: 0.47% / 2: 0.07% / 3: 88.45% / 4: 7.56% | ❌ Highly Imbalanced |
+| **Emotion** | 8 | 0: 0.04% / 1: 0.04% / 2: 0.04% / 3: 0.07% / 4: 0.56% / 5: 98.83% / 6: 0.38% / 7: 0.02% | ❌ Highly Imbalanced |
+
+
+
+## Overall Results Summary
+
+| Task | Metric | From Pretrain | From Scratch | Improvement |
+|------|--------|--------------|--------------|-------------|
+| **Gender** | Accuracy | **98.4%** | 51.5% | +46.9% |
+| | Weighted F1 | **0.98** | 0.51 | +0.47 |
+| **Age** | Accuracy | **80.5%** | 52.4% | +28.1% |
+| | Weighted F1 | **0.80** | 0.36 | +0.44 |
+| **Dialect** | Accuracy | **95.5%** | 88.5% | +7.0% |
+| | Weighted F1 | **0.95** | 0.83 | +0.12 |
+| **Emotion** | Accuracy | **98.9%** | 98.8% | +0.1% |
+| | Weighted F1 | **0.99** | 0.98 | +0.01 |
+
+## Detailed Results by Task
+
+### 1. Gender Classification (2 classes: Male/Female)
+
+**From Pretrain** (Best Performance):
+```
+ precision recall f1-score support
+
+ 0 0.98 0.99 0.98 2243
+ 1 0.99 0.98 0.98 2214
+
+ accuracy 0.98 4457
+ macro avg 0.98 0.98 0.98 4457
+ weighted avg 0.98 0.98 0.98 4457
+```
+
+**From Scratch**:
+```
+ precision recall f1-score support
+
+ 0 0.52 0.56 0.54 2243
+ 1 0.51 0.47 0.49 2214
+
+ accuracy 0.51 4457
+ macro avg 0.51 0.51 0.51 4457
+ weighted avg 0.51 0.51 0.51 4457
+```
+
+---
+
+### 2. Age Classification (5 classes)
+
+**From Pretrain**:
+```
+ precision recall f1-score support
+
+ 0 0.75 1.00 0.86 3
+ 1 0.78 0.75 0.76 1872
+ 2 0.88 0.71 0.78 228
+ 3 0.86 0.90 0.88 20
+ 4 0.81 0.85 0.83 2334
+
+ accuracy 0.80 4457
+ macro avg 0.81 0.84 0.82 4457
+ weighted avg 0.80 0.80 0.80 4457
+```
+
+**From Scratch**:
+```
+ precision recall f1-score support
+
+ 0 0.00 0.00 0.00 3
+ 1 0.00 0.00 0.00 1872
+ 2 0.00 0.00 0.00 228
+ 3 0.00 0.00 0.00 20
+ 4 0.52 1.00 0.69 2334
+
+ accuracy 0.52 4457
+ macro avg 0.10 0.20 0.14 4457
+ weighted avg 0.27 0.52 0.36 4457
+```
+
+---
+
+### 3. Dialect Classification (5 classes)
+
+**From Pretrain**:
+```
+ precision recall f1-score support
+
+ 0.0 0.71 0.56 0.63 154
+ 1.0 0.65 0.52 0.58 21
+ 2.0 1.00 0.67 0.80 3
+ 3.0 0.97 0.99 0.98 3942
+ 4.0 0.84 0.81 0.82 337
+
+ accuracy 0.96 4457
+ macro avg 0.83 0.71 0.76 4457
+ weighted avg 0.95 0.96 0.95 4457
+```
+
+**From Scratch**:
+```
+ precision recall f1-score support
+
+ 0.0 0.00 0.00 0.00 154
+ 1.0 0.00 0.00 0.00 21
+ 2.0 0.00 0.00 0.00 3
+ 3.0 0.88 1.00 0.94 3942
+ 4.0 0.00 0.00 0.00 337
+
+ accuracy 0.88 4457
+ macro avg 0.18 0.20 0.19 4457
+ weighted avg 0.78 0.88 0.83 4457
+```
+---
+
+### 4. Emotion Classification (8 classes)
+
+**From Pretrain**:
+```
+ precision recall f1-score support
+
+ 0.0 0.00 0.00 0.00 2
+ 1.0 0.00 0.00 0.00 2
+ 2.0 0.00 0.00 0.00 2
+ 3.0 0.00 0.00 0.00 3
+ 4.0 0.76 0.52 0.62 25
+ 5.0 0.99 1.00 0.99 4405
+ 6.0 0.27 0.18 0.21 17
+ 7.0 1.00 1.00 1.00 1
+
+ accuracy 0.99 4457
+ macro avg 0.38 0.34 0.35 4457
+ weighted avg 0.99 0.99 0.99 4457
+```
+
+**From Scratch**:
+```
+ precision recall f1-score support
+
+ 0.0 0.00 0.00 0.00 2
+ 1.0 0.00 0.00 0.00 2
+ 2.0 0.00 0.00 0.00 2
+ 3.0 0.00 0.00 0.00 3
+ 4.0 0.00 0.00 0.00 25
+ 5.0 0.99 1.00 0.99 4405
+ 6.0 0.00 0.00 0.00 17
+ 7.0 0.00 0.00 0.00 1
+
+ accuracy 0.99 4457
+ macro avg 0.12 0.12 0.12 4457
+ weighted avg 0.98 0.99 0.98 4457
+```
+
+---
diff --git a/examples/classification/conf/multi_task.yaml b/examples/classification/conf/multi_task.yaml
new file mode 100644
index 0000000..e8678eb
--- /dev/null
+++ b/examples/classification/conf/multi_task.yaml
@@ -0,0 +1,101 @@
+# ChunkFormer Speech Classification Configuration
+# Multi-task example: Gender + Emotion + Region classification
+
+# Encoder configuration
+encoder: chunkformer
+encoder_conf:
+ output_size: 512 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.1
+ input_layer: dw_striding
+ normalize_before: true
+ cnn_module_kernel: 15
+ use_cnn_module: True
+ activation_type: 'swish'
+ pos_enc_layer_type: 'chunk_rel_pos'
+ selfattention_layer_type: 'chunk_rel_seflattn'
+ cnn_module_norm: 'layer_norm'
+ # Enable dynamic chunking for varied context training
+ dynamic_conv: true
+ dynamic_chunk_sizes: [-1, -1, 64, 128, 256]
+ dynamic_left_context_sizes: [64, 128, 256]
+ dynamic_right_context_sizes: [64, 128, 256]
+
+# Model configuration
+model: classification
+model_conf:
+ # Classification tasks: task_name -> num_classes
+ tasks:
+ gender: 2
+ emotion: 8
+ dialect: 5
+ age: 5
+ dropout_rate: 0.1
+ label_smoothing: 0.2
+
+# CMVN (Cepstral Mean and Variance Normalization)
+cmvn: global_cmvn
+cmvn_conf:
+ cmvn_file: 'data/train/global_cmvn'
+ is_json_cmvn: true
+
+# Dataset configuration
+dataset: classification
+dataset_conf:
+ # Task names (must match model_conf.tasks)
+ tasks: ['gender', 'emotion', 'dialect', 'age']
+
+ filter_conf:
+ max_length: 40960 # Maximum audio length in samples
+ min_length: 0 # Minimum audio length in samples
+
+ resample_conf:
+ resample_rate: 16000
+
+ fbank_conf:
+ num_mel_bins: 80
+ frame_shift: 10
+ frame_length: 25
+ dither: 1.0
+
+ speed_perturb: true # Speed perturbation
+ spec_aug: true # SpecAugment
+ spec_aug_conf:
+ num_t_mask: 2
+ num_f_mask: 2
+ max_t: 50
+ max_f: 10
+
+ shuffle: true
+ shuffle_conf:
+ shuffle_size: 1000
+
+ sort: false
+ sort_conf:
+ sort_size: 500
+
+ batch_conf:
+ batch_type: 'dynamic' # static or dynamic
+ max_frames_in_batch: 80000
+ batch_size: 4
+ pad_feat: True
+
+# Training configuration
+grad_clip: 5.0
+accum_grad: 1
+max_epoch: 100
+log_interval: 100
+
+# Optimizer
+optim: adamw
+optim_conf:
+ lr: 0.001
+
+# Scheduler
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 5000
diff --git a/examples/classification/conf/single_task.yaml b/examples/classification/conf/single_task.yaml
new file mode 100644
index 0000000..7c0a24f
--- /dev/null
+++ b/examples/classification/conf/single_task.yaml
@@ -0,0 +1,98 @@
+# ChunkFormer Speech Classification Configuration
+# Single-task example: Gender classification
+
+# Encoder configuration
+encoder: chunkformer
+encoder_conf:
+ output_size: 512 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.1
+ input_layer: dw_striding
+ normalize_before: true
+ cnn_module_kernel: 15
+ use_cnn_module: True
+ activation_type: 'swish'
+ pos_enc_layer_type: 'chunk_rel_pos'
+ selfattention_layer_type: 'chunk_rel_seflattn'
+ cnn_module_norm: 'layer_norm'
+ # Enable dynamic chunking for varied context training
+ dynamic_conv: true
+ dynamic_chunk_sizes: [-1, -1, 64, 128, 256]
+ dynamic_left_context_sizes: [64, 128, 256]
+ dynamic_right_context_sizes: [64, 128, 256]
+
+# Model configuration
+model: classification
+model_conf:
+ # Classification tasks: task_name -> num_classes
+ tasks:
+ gender: 2
+ dropout_rate: 0.1
+ label_smoothing: 0.2
+
+# CMVN (Cepstral Mean and Variance Normalization)
+cmvn: global_cmvn
+cmvn_conf:
+ cmvn_file: 'data/train/global_cmvn'
+ is_json_cmvn: true
+
+# Dataset configuration
+dataset: classification
+dataset_conf:
+ # Task names (must match model_conf.tasks)
+ tasks: ['gender']
+
+ filter_conf:
+ max_length: 40960 # Maximum audio length in samples
+ min_length: 0 # Minimum audio length in samples
+
+ resample_conf:
+ resample_rate: 16000
+
+ fbank_conf:
+ num_mel_bins: 80
+ frame_shift: 10
+ frame_length: 25
+ dither: 1.0
+
+ speed_perturb: true # Speed perturbation
+ spec_aug: true # SpecAugment
+ spec_aug_conf:
+ num_t_mask: 2
+ num_f_mask: 2
+ max_t: 50
+ max_f: 10
+
+ shuffle: true
+ shuffle_conf:
+ shuffle_size: 1000
+
+ sort: false
+ sort_conf:
+ sort_size: 500
+
+ batch_conf:
+ batch_type: 'dynamic' # static or dynamic
+ max_frames_in_batch: 80000
+ batch_size: 4
+ pad_feat: True
+
+# Training configuration
+grad_clip: 5.0
+accum_grad: 1
+max_epoch: 100
+log_interval: 100
+
+# Optimizer
+optim: adamw
+optim_conf:
+ lr: 0.001
+
+# Scheduler
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 5000
diff --git a/examples/classification/path.sh b/examples/classification/path.sh
new file mode 100755
index 0000000..1bc37ac
--- /dev/null
+++ b/examples/classification/path.sh
@@ -0,0 +1,8 @@
+export CHUNKFORMER_DIR=$PWD/../..
+export BUILD_DIR=${CHUNKFORMER_DIR}/runtime/libtorch/build
+export OPENFST_BIN=${BUILD_DIR}/../fc_base/openfst-build/src
+export PATH=$PWD:${BUILD_DIR}/bin:${BUILD_DIR}/kaldi:${OPENFST_BIN}/bin:$PATH
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PYTHONPATH=${CHUNKFORMER_DIR}:$PYTHONPATH
diff --git a/examples/classification/run.sh b/examples/classification/run.sh
new file mode 100755
index 0000000..677b881
--- /dev/null
+++ b/examples/classification/run.sh
@@ -0,0 +1,333 @@
+#!/bin/bash
+
+# Speech Classification Training Pipeline
+# Stages:
+# Stage 0: Data Format Conversion (TSV to list format)
+# Stage 1: Feature Generation (CMVN computation)
+# Stage 2: Label Statistics and Validation
+# Stage 3: Training
+# Stage 4: Evaluation
+# Stage 5: Export Model for Inference
+# Stage 6: Push Model to Hugging Face Hub (optional)
+
+. ./path.sh || exit 1;
+
+# GPU Configuration
+export CUDA_VISIBLE_DEVICES="0"
+echo "CUDA_VISIBLE_DEVICES is ${CUDA_VISIBLE_DEVICES}"
+
+# Stage control
+stage=0
+stop_stage=6
+
+# Multi-machine training settings
+HOST_NODE_ADDR="localhost:0"
+num_nodes=1
+job_id=2024
+
+# Data directory
+wave_data=data
+data_type=raw
+
+# Training configuration
+# Choose one of:
+# - conf/multi_task.yaml: Single-task gender classification
+# - conf/multi_task.yaml: Multi-task (gender + emotion + region)
+train_config=conf/multi_task.yaml
+
+# Training settings
+checkpoint=
+num_workers=4
+dir=exp/multi_task
+tensorboard_dir=tensorboard
+
+# Model averaging
+average_checkpoint=true
+decode_checkpoint=$dir/final.pt
+average_num=10
+
+# Hugging Face Hub upload settings (optional)
+# To enable upload and set these variables:
+hf_token="hf_xxxxxxxxxxxxxxxxxxxxxxxxx" # Your Hugging Face token
+hf_repo_id="username/chunkformer-model" # Your repository ID
+
+# Dataset names (folder names under data/)
+train_set=train
+dev_set=dev
+test_set=test
+
+# Training engine
+train_engine=torch_ddp
+
+set -e
+set -u
+set -o pipefail
+
+. tools/parse_options.sh || exit 1;
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "============================================"
+ echo "Stage 0: Data Format Conversion"
+ echo "============================================"
+
+ # Convert data.tsv files to required format for training
+ for dataset in $train_set $dev_set $test_set; do
+ if [ -f "$wave_data/$dataset/data.tsv" ]; then
+ echo "Converting $wave_data/$dataset/data.tsv"
+ python tools/tsv_to_list.py $wave_data/$dataset/data.tsv
+ else
+ echo "Warning: $wave_data/$dataset/data.tsv not found, skipping..."
+ fi
+ done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "============================================"
+ echo "Stage 1: Feature Generation (CMVN)"
+ echo "============================================"
+
+ tools/compute_cmvn_stats.py \
+ --num_workers 16 \
+ --train_config $train_config \
+ --in_scp $wave_data/$train_set/wav.scp \
+ --out_cmvn $wave_data/$train_set/global_cmvn
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "============================================"
+ echo "Stage 2: Label Statistics and Validation"
+ echo "============================================"
+
+ python tools/compute_label_stats.py \
+ --config $train_config \
+ --train_data $wave_data/$train_set/data.list \
+ --dev_data $wave_data/$dev_set/data.list \
+ --test_data $wave_data/$test_set/data.list \
+ --output_dir $wave_data
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "============================================"
+ echo "Stage 3: Model Training"
+ echo "============================================"
+
+ mkdir -p $dir
+ num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+ dist_backend="nccl"
+
+ echo "Number of nodes: $num_nodes, GPUs per node: $num_gpus"
+
+ torchrun --nnodes=$num_nodes --nproc_per_node=$num_gpus \
+ --rdzv_endpoint=$HOST_NODE_ADDR \
+ --rdzv_id=$job_id --rdzv_backend="c10d" \
+ ${CHUNKFORMER_DIR}/chunkformer/bin/train.py \
+ --use_amp \
+ --train_engine ${train_engine} \
+ --config $train_config \
+ --data_type ${data_type} \
+ --train_data $wave_data/$train_set/data.list \
+ --cv_data $wave_data/$dev_set/data.list \
+ ${checkpoint:+--checkpoint $checkpoint} \
+ --model_dir $dir \
+ --tensorboard_dir ${tensorboard_dir} \
+ --ddp.dist_backend $dist_backend \
+ --num_workers ${num_workers} \
+ --pin_memory
+fi
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "============================================"
+ echo "Stage 4: Model Evaluation"
+ echo "============================================"
+
+ mkdir -p $dir/test
+
+ # Model averaging
+ if [ ${average_checkpoint} == true ]; then
+ decode_checkpoint=$dir/avg_${average_num}.pt
+ echo "Averaging last $average_num checkpoints -> $decode_checkpoint"
+ python ${CHUNKFORMER_DIR}/chunkformer/bin/average_model.py \
+ --dst_model $decode_checkpoint \
+ --src_path $dir \
+ --num ${average_num}
+ fi
+
+ # Chunking settings (optional)
+ chunk_size=
+ left_context_size=
+ right_context_size=
+
+ # Evaluate on test set
+ for test in $test_set; do
+ result_dir=$dir/${test}
+ mkdir -p $result_dir
+
+ echo "Evaluating on $test set..."
+ python ${CHUNKFORMER_DIR}/chunkformer/bin/classify.py \
+ --gpu 0 \
+ --config $dir/train.yaml \
+ --data_type raw \
+ --test_data $wave_data/$test/data.list \
+ --checkpoint $decode_checkpoint \
+ --batch_size 32 \
+ --result_dir $result_dir \
+ ${chunk_size:+--chunk_size $chunk_size} \
+ ${left_context_size:+--left_context_size $left_context_size} \
+ ${right_context_size:+--right_context_size $right_context_size}
+
+ # Compute metrics
+ python tools/compute_classification_metrics.py \
+ --config $dir/train.yaml \
+ --predictions $result_dir/predictions.tsv \
+ --labels $wave_data/$test/data.list \
+ --output $result_dir/metrics.txt
+
+ echo "Results saved to $result_dir"
+ echo "Metrics:"
+ cat $result_dir/metrics.txt
+ done
+fi
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "============================================"
+ echo "Stage 5: Export Model for Inference"
+ echo "============================================"
+
+ # Recreate average checkpoint if needed
+ if [ ${average_checkpoint} == true ]; then
+ decode_checkpoint=$dir/avg_${average_num}.pt
+ if [ ! -f "$decode_checkpoint" ]; then
+ echo "Creating averaged checkpoint..."
+ python ${CHUNKFORMER_DIR}/chunkformer/bin/average_model.py \
+ --dst_model $decode_checkpoint \
+ --src_path $dir \
+ --num ${average_num}
+ fi
+ checkpoint_name="avg_${average_num}"
+ else
+ decode_checkpoint=$dir/final.pt
+ checkpoint_name="final"
+ fi
+
+ # Create inference model directory
+ inference_model_dir=$dir/model_checkpoint_${checkpoint_name}
+ mkdir -p $inference_model_dir
+
+ echo "Creating inference model directory: $inference_model_dir"
+
+ # Copy model checkpoint
+ if [ -f "$decode_checkpoint" ]; then
+ cp $decode_checkpoint $inference_model_dir/pytorch_model.pt
+ echo "✓ Copied model checkpoint"
+ else
+ echo "✗ Warning: Model checkpoint not found at $decode_checkpoint"
+ fi
+
+ # Copy training configuration
+ if [ -f "$dir/train.yaml" ]; then
+ cp $dir/train.yaml $inference_model_dir/config.yaml
+ echo "✓ Copied training config"
+ else
+ echo "✗ Warning: Training config not found"
+ fi
+
+ # Copy CMVN statistics
+ if [ -f "$wave_data/$train_set/global_cmvn" ]; then
+ cp $wave_data/$train_set/global_cmvn $inference_model_dir/global_cmvn
+ echo "✓ Copied CMVN stats"
+ else
+ echo "✗ Warning: CMVN statistics not found"
+ fi
+
+ # Copy label mapping JSON
+ if [ -f "$wave_data/$train_set/label_mapping.json" ]; then
+ cp $wave_data/$train_set/label_mapping.json $inference_model_dir/label_mapping.json
+ echo "✓ Copied label_mapping.json"
+ else
+ echo "✗ Warning: label_mapping.json not found in $wave_data/$train_set"
+ fi
+
+ echo ""
+ echo "============================================"
+ echo "Model Export Complete!"
+ echo "============================================"
+ echo "Model directory: $inference_model_dir"
+ echo ""
+ echo "Directory contents:"
+ ls -lh $inference_model_dir
+ echo ""
+ echo "You can now use this model for inference:"
+ echo " python chunkformer/bin/classify.py --checkpoint $inference_model_dir ..."
+fi
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "============================================"
+ echo "Stage 6: Push Model to Hugging Face Hub"
+ echo "============================================"
+
+ # Determine inference model directory (in case stage 5 was skipped)
+ if [ ${average_checkpoint} == true ]; then
+ checkpoint_name="avg_${average_num}"
+ else
+ checkpoint_name="final"
+ fi
+ inference_model_dir=$dir/model_checkpoint_${checkpoint_name}
+
+ # Check if Hugging Face token and repo_id are provided
+ if [ -z "$hf_token" ] || [ -z "$hf_repo_id" ]; then
+ echo "Skipping Hugging Face upload: hf_token or hf_repo_id not provided"
+ echo ""
+ echo "To enable upload, set the following variables in this script:"
+ echo " hf_token=\"your_huggingface_token\""
+ echo " hf_repo_id=\"username/repository-name\""
+ echo ""
+ echo "You can also upload manually later using:"
+ echo " cd ../../.." # Go to chunkformer root
+ echo " python tools/push_model_hf.py \\"
+ echo " --model_dir $inference_model_dir \\"
+ echo " --repo_id username/repo-name \\"
+ echo " --token your_token"
+ else
+ echo "Uploading classification model to Hugging Face Hub..."
+ echo "Repository: $hf_repo_id"
+ echo "Model directory: $inference_model_dir"
+
+ # Run the upload script
+ python tools/push_model_hf.py \
+ --model_dir "$inference_model_dir" \
+ --repo_id "$hf_repo_id" \
+ --token "$hf_token" \
+ --commit_message "Upload ChunkFormer Classification Model"
+
+ upload_status=$?
+
+ if [ $upload_status -eq 0 ]; then
+ echo ""
+ echo "🎉 Classification model successfully uploaded to Hugging Face Hub!"
+ echo "Model URL: https://huggingface.co/$hf_repo_id"
+ echo ""
+ echo "You can now load your model from anywhere with:"
+ echo "from chunkformer import ChunkFormerModel"
+ echo "model = ChunkFormerModel.from_pretrained('$hf_repo_id')"
+ echo ""
+ echo "Example usage:"
+ echo "result = model.classify_audio("
+ echo " audio_path='path/to/audio.wav'"
+ echo ")"
+ else
+ echo "❌ Failed to upload model to Hugging Face Hub"
+ echo ""
+ echo "You can try uploading manually with:"
+ echo " cd ../../.."
+ echo " python tools/push_model_hf.py \\"
+ echo " --model_dir $inference_model_dir \\"
+ echo " --repo_id $hf_repo_id \\"
+ echo " --token $hf_token"
+ fi
+ fi
+fi
+
+echo ""
+echo "============================================"
+echo "Training Pipeline Complete!"
+echo "============================================"
diff --git a/pyproject.toml b/pyproject.toml
index dbe9994..90062d9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -115,7 +115,7 @@ filterwarnings = [
[tool.flake8]
max-line-length = 100
-extend-ignore = ["E203", "W503", "B008", "C416", "EXE001", "E741"]
+extend-ignore = ["E203", "W503", "B008", "C416", "EXE001", "E741", "E731"]
exclude = [
".git",
"__pycache__",
diff --git a/tests/test_classification_output.py b/tests/test_classification_output.py
new file mode 100644
index 0000000..c8ea36f
--- /dev/null
+++ b/tests/test_classification_output.py
@@ -0,0 +1,242 @@
+#!/usr/bin/env python3
+"""
+Test classification output format for ChunkFormer classification models.
+"""
+
+from pathlib import Path
+
+import pytest
+
+from chunkformer import ChunkFormerModel
+
+
+class TestClassificationOutput:
+ """Test cases for classification output format validation."""
+
+ @classmethod
+ def setup_class(cls):
+ """Set up test fixtures."""
+ cls.model_name = "khanhld/chunkformer-gender-emotion-dialect-age-classification"
+ cls.model = ChunkFormerModel.from_pretrained(cls.model_name)
+ cls.data_dir = Path(__file__).parent.parent
+
+ # Use sample audio from ASR tests if available
+ cls.sample_audio_dir = cls.data_dir / "samples" / "audios"
+
+ # Find first available audio file
+ cls.test_audio_path = None
+ if cls.sample_audio_dir.exists():
+ for audio_file in cls.sample_audio_dir.glob("*.wav"):
+ if audio_file.exists():
+ cls.test_audio_path = str(audio_file)
+ break
+
+ # If not found, try samples directory directly
+ if cls.test_audio_path is None:
+ sample_dir_alt = cls.data_dir / "samples"
+ if sample_dir_alt.exists():
+ for audio_file in sample_dir_alt.glob("*.wav"):
+ if audio_file.exists():
+ cls.test_audio_path = str(audio_file)
+ break
+
+ # Assert that we found an audio file
+ assert (
+ cls.test_audio_path is not None
+ ), f"No test audio file found in {cls.sample_audio_dir} or {cls.data_dir / 'samples'}"
+
+ print(f"Using test audio: {cls.test_audio_path}")
+
+ def test_model_is_classification(self):
+ """Test that the model is correctly identified as a classification model."""
+ assert self.model.is_classification, "Model should be identified as classification model"
+
+ def test_model_has_tasks(self):
+ """Test that the model has classification tasks defined."""
+ tasks = self.model.get_tasks()
+ assert tasks is not None, "Model should have tasks defined"
+ assert len(tasks) > 0, "Model should have at least one task"
+
+ # Expected tasks for this model
+ expected_tasks = ["gender", "emotion", "dialect", "age"]
+ for task in expected_tasks:
+ assert task in tasks, f"Task '{task}' should be in model tasks"
+
+ print(f"Model tasks: {list(tasks.keys())}")
+ for task_name, num_classes in tasks.items():
+ print(f" {task_name}: {num_classes} classes")
+
+ def test_model_has_label_mapping(self):
+ """Test that the model has label mapping loaded."""
+ assert self.model.label_mapping is not None, "Model should have label_mapping loaded"
+
+ # Check that label mapping has all tasks
+ tasks = self.model.get_tasks()
+ for task_name in tasks.keys():
+ assert (
+ task_name in self.model.label_mapping
+ ), f"Task '{task_name}' should be in label_mapping"
+
+ # Check that label mapping uses id:label format
+ task_mapping = self.model.label_mapping[task_name]
+ assert isinstance(task_mapping, dict), f"Label mapping for {task_name} should be a dict"
+
+ # Check that keys are string IDs
+ for key in task_mapping.keys():
+ assert isinstance(
+ key, str
+ ), f"Label mapping keys should be strings, got {type(key)}"
+ assert key.isdigit(), f"Label mapping keys should be numeric strings, got '{key}'"
+
+ print("Label mapping loaded successfully")
+ for task_name, mapping in self.model.label_mapping.items():
+ print(f" {task_name}: {len(mapping)} labels")
+
+ def test_classification_output_format(self):
+ """Test that classification output has the correct format."""
+ print(f"Testing with audio: {self.test_audio_path}")
+
+ # Perform classification
+ result = self.model.classify_audio(
+ audio_path=self.test_audio_path,
+ chunk_size=-1, # Full attention
+ left_context_size=-1,
+ right_context_size=-1,
+ )
+
+ # Verify result is a dictionary
+ assert isinstance(result, dict), f"Result should be a dict, got {type(result)}"
+
+ # Verify each task is in the result
+ tasks = self.model.get_tasks()
+ for task_name in tasks.keys():
+ assert task_name in result, f"Task '{task_name}' should be in result"
+
+ task_result = result[task_name]
+
+ # Verify task result structure
+ assert isinstance(
+ task_result, dict
+ ), f"Result for {task_name} should be a dict, got {type(task_result)}"
+
+ # Verify required keys
+ required_keys = ["label", "label_id", "prob"]
+ for key in required_keys:
+ assert key in task_result, f"Key '{key}' should be in result for task '{task_name}'"
+
+ # Verify types
+ assert isinstance(
+ task_result["label"], str
+ ), f"label should be str, got {type(task_result['label'])}"
+ assert isinstance(
+ task_result["label_id"], int
+ ), f"label_id should be int, got {type(task_result['label_id'])}"
+ assert isinstance(
+ task_result["prob"], float
+ ), f"prob should be float, got {type(task_result['prob'])}"
+
+ # Verify probability is in valid range
+ assert (
+ 0.0 <= task_result["prob"] <= 1.0
+ ), f"prob should be in [0, 1], got {task_result['prob']}"
+
+ # Verify label_id matches label_mapping
+ label_id_str = str(task_result["label_id"])
+ assert (
+ label_id_str in self.model.label_mapping[task_name]
+ ), f"label_id {label_id_str} not found in label_mapping for {task_name}"
+
+ expected_label = self.model.label_mapping[task_name][label_id_str]
+ assert (
+ task_result["label"] == expected_label
+ ), f"label mismatch: got '{task_result['label']}', expected '{expected_label}'"
+
+ # Print results
+ print("\nClassification Results:")
+ print("=" * 70)
+ for task_name, task_result in result.items():
+ print(f"{task_name.capitalize()}:")
+ print(f" Label: {task_result['label']}")
+ print(f" Label ID: {task_result['label_id']}")
+ print(f" Probability: {task_result['prob']:.4f}")
+
+ def test_classification_with_chunking(self):
+ """Test classification with chunking enabled."""
+ # Perform classification with chunking
+ result = self.model.classify_audio(
+ audio_path=self.test_audio_path,
+ chunk_size=64,
+ left_context_size=128,
+ right_context_size=128,
+ )
+
+ # Verify output format is still correct
+ assert isinstance(result, dict), "Result should be a dict"
+
+ tasks = self.model.get_tasks()
+ for task_name in tasks.keys():
+ assert task_name in result, f"Task '{task_name}' should be in result"
+ task_result = result[task_name]
+
+ # Verify structure
+ assert "label" in task_result
+ assert "label_id" in task_result
+ assert "prob" in task_result
+
+ # Verify types and ranges
+ assert isinstance(task_result["label"], str)
+ assert isinstance(task_result["label_id"], int)
+ assert isinstance(task_result["prob"], float)
+ assert 0.0 <= task_result["prob"] <= 1.0
+
+ print("\nClassification with chunking completed successfully")
+
+ def test_probabilities_always_returned(self):
+ """Test that probabilities are always returned (no parameter needed)."""
+ # Call classify_audio without any probability parameter
+ result = self.model.classify_audio(
+ audio_path=self.test_audio_path,
+ chunk_size=-1,
+ )
+
+ # Verify probabilities are present
+ for task_name, task_result in result.items():
+ assert (
+ "prob" in task_result
+ ), f"Probability should always be present for task '{task_name}'"
+ assert (
+ task_result["prob"] > 0.0
+ ), f"Probability should be > 0 for predicted class in task '{task_name}'"
+
+ print("\nProbabilities are always returned by default")
+
+ def test_multiple_audio_consistency(self):
+ """Test that classification is consistent across multiple calls."""
+ # Classify the same audio twice
+ result1 = self.model.classify_audio(
+ audio_path=self.test_audio_path,
+ chunk_size=-1,
+ )
+
+ result2 = self.model.classify_audio(
+ audio_path=self.test_audio_path,
+ chunk_size=-1,
+ )
+
+ # Results should be identical
+ assert result1.keys() == result2.keys(), "Tasks should be identical"
+
+ for task_name in result1.keys():
+ assert (
+ result1[task_name]["label_id"] == result2[task_name]["label_id"]
+ ), f"Predicted label_id should be consistent for task '{task_name}'"
+
+ # Probabilities should be very close (allowing for minor floating point differences)
+ prob_diff = abs(result1[task_name]["prob"] - result2[task_name]["prob"])
+ assert prob_diff < 1e-6, f"Probabilities should be consistent for task '{task_name}'"
+
+ print("\nClassification is consistent across multiple calls")
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v", "-s"])
diff --git a/tests/test_wer_ctc_performance.py b/tests/test_wer_ctc_performance.py
index 4b0a9bf..48ebfd1 100644
--- a/tests/test_wer_ctc_performance.py
+++ b/tests/test_wer_ctc_performance.py
@@ -328,7 +328,7 @@ def test_command_line_long_form_audio(self):
"chunkformer-decode",
"--model_checkpoint",
self.model_name,
- "--long_form_audio",
+ "--audio_file",
test_audio_path,
"--total_batch_duration",
"14400",
diff --git a/tests/test_wer_rnnt_performance.py b/tests/test_wer_rnnt_performance.py
index 7b443a1..e4ce228 100644
--- a/tests/test_wer_rnnt_performance.py
+++ b/tests/test_wer_rnnt_performance.py
@@ -328,7 +328,7 @@ def test_command_line_long_form_audio(self):
"chunkformer-decode",
"--model_checkpoint",
self.model_name,
- "--long_form_audio",
+ "--audio_file",
test_audio_path,
"--total_batch_duration",
"14400",
diff --git a/tools/compute_classification_metrics.py b/tools/compute_classification_metrics.py
new file mode 100755
index 0000000..3cc8c98
--- /dev/null
+++ b/tools/compute_classification_metrics.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python3
+# Copyright (c) 2024 ChunkFormer Authors
+
+"""
+Compute classification metrics (accuracy, precision, recall, F1).
+"""
+
+import argparse
+import json
+
+import numpy as np
+import yaml
+from sklearn.metrics import classification_report, confusion_matrix
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="Compute classification metrics")
+ parser.add_argument("--config", required=True, help="Training config YAML")
+ parser.add_argument("--predictions", required=True, help="Predictions TSV file")
+ parser.add_argument("--labels", required=True, help="Ground truth labels (data.list)")
+ parser.add_argument("--output", required=True, help="Output metrics file")
+ return parser.parse_args()
+
+
+def load_ground_truth(labels_file, tasks):
+ """Load ground truth labels."""
+ gt = {}
+ with open(labels_file, "r", encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+
+ try:
+ data = json.loads(line)
+ except json.JSONDecodeError:
+ continue
+
+ key = data["key"]
+ gt[key] = {}
+ for task in tasks:
+ label_key = f"{task}_label"
+ if label_key in data:
+ label = data[label_key]
+ if isinstance(label, str):
+ label = int(label)
+ gt[key][task] = label
+
+ return gt
+
+
+def load_predictions(pred_file, tasks):
+ """Load predictions."""
+ predictions = {}
+ with open(pred_file, "r", encoding="utf-8") as f:
+ # Skip header if exists
+ f.readline()
+
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+
+ parts = line.split("\t")
+ if len(parts) < 1 + len(tasks):
+ continue
+
+ key = parts[0]
+ predictions[key] = {}
+
+ for i, task in enumerate(tasks):
+ pred = int(parts[i + 1])
+ predictions[key][task] = pred
+
+ return predictions
+
+
+def compute_metrics(gt, predictions, tasks):
+ """Compute metrics for each task."""
+ results = {}
+
+ for task in tasks:
+ y_true = []
+ y_pred = []
+
+ # Collect predictions and ground truth
+ for key in gt:
+ if key not in predictions:
+ continue
+
+ if task not in gt[key] or task not in predictions[key]:
+ continue
+
+ gt_label = gt[key][task]
+ pred_label = predictions[key][task]
+
+ if gt_label < 0: # Skip invalid labels
+ continue
+
+ y_true.append(gt_label)
+ y_pred.append(pred_label)
+
+ if not y_true:
+ print(f"Warning: No valid samples for task {task}")
+ continue
+
+ # Classification report (includes all metrics)
+ class_report = classification_report(y_true, y_pred, zero_division=0)
+
+ # Confusion matrix
+ cm = confusion_matrix(y_true, y_pred)
+
+ results[task] = {
+ "classification_report": class_report,
+ "confusion_matrix": cm.tolist(),
+ "num_samples": len(y_true),
+ }
+
+ return results
+
+
+def print_and_save_results(results, tasks, output_file):
+ """Print and save results."""
+ lines = []
+
+ lines.append("=" * 70)
+ lines.append("Classification Metrics")
+ lines.append("=" * 70)
+
+ for task in tasks:
+ if task not in results:
+ continue
+
+ r = results[task]
+ lines.append(f"\n{task.upper()}:")
+ lines.append(f" Samples: {r['num_samples']}")
+
+ lines.append("\n Classification Report:")
+ # Indent the classification report
+ report_lines = r["classification_report"].strip().split("\n")
+ for report_line in report_lines:
+ lines.append(f" {report_line}")
+
+ lines.append("\n Confusion Matrix:")
+ cm = np.array(r["confusion_matrix"])
+ for i, row in enumerate(cm):
+ lines.append(f" {i}: {row.tolist()}")
+
+ lines.append("\n" + "=" * 70)
+
+ # Print to console
+ for line in lines:
+ print(line)
+
+ # Save to file
+ with open(output_file, "w", encoding="utf-8") as f:
+ f.write("\n".join(lines))
+
+ print(f"\nMetrics saved to: {output_file}")
+
+
+def main():
+ args = get_args()
+
+ # Load config
+ with open(args.config, "r") as f:
+ config = yaml.load(f, Loader=yaml.FullLoader)
+
+ tasks = list(config.get("model_conf", {}).get("tasks", {}).keys())
+ if not tasks:
+ print("Error: No tasks defined in config")
+ return
+
+ print(f"Tasks: {tasks}")
+
+ # Load data
+ print(f"Loading ground truth from: {args.labels}")
+ gt = load_ground_truth(args.labels, tasks)
+
+ print(f"Loading predictions from: {args.predictions}")
+ predictions = load_predictions(args.predictions, tasks)
+
+ # Compute metrics
+ print("\nComputing metrics...")
+ results = compute_metrics(gt, predictions, tasks)
+
+ # Print and save
+ print_and_save_results(results, tasks, args.output)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/compute_label_stats.py b/tools/compute_label_stats.py
new file mode 100755
index 0000000..c279fb6
--- /dev/null
+++ b/tools/compute_label_stats.py
@@ -0,0 +1,154 @@
+#!/usr/bin/env python3
+# Copyright (c) 2024 ChunkFormer Authors
+
+"""
+Compute label statistics for classification tasks.
+"""
+
+import argparse
+import json
+import sys
+from collections import Counter
+
+import yaml
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="Compute label statistics")
+ parser.add_argument("--config", required=True, help="Training config")
+ parser.add_argument("--train_data", required=True, help="Training data list")
+ parser.add_argument("--dev_data", help="Development data list")
+ parser.add_argument("--test_data", help="Test data list")
+ parser.add_argument("--output_dir", required=True, help="Output directory")
+ return parser.parse_args()
+
+
+def load_data(data_file, tasks):
+ """Load data and collect label statistics."""
+ label_counts = {task: Counter() for task in tasks}
+ total_samples = 0
+
+ with open(data_file, "r", encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+
+ try:
+ data = json.loads(line)
+ except json.JSONDecodeError:
+ continue
+
+ total_samples += 1
+
+ for task in tasks:
+ label_key = f"{task}_label"
+ if label_key in data:
+ label = data[label_key]
+ if isinstance(label, str):
+ label = int(label)
+ if label >= 0: # Valid label
+ label_counts[task][label] += 1
+
+ return label_counts, total_samples
+
+
+def print_statistics(name, label_counts, total_samples, tasks):
+ """Print label statistics."""
+ print(f"\n{'='*60}")
+ print(f"{name} Statistics")
+ print(f"{'='*60}")
+ print(f"Total samples: {total_samples}")
+
+ for task in tasks:
+ print(f"\n{task.upper()}:")
+ counts = label_counts[task]
+
+ if not counts:
+ print(" No valid labels found")
+ continue
+
+ total_labeled = sum(counts.values())
+ print(f" Labeled samples: {total_labeled}")
+ print(" Label distribution:")
+
+ for label in sorted(counts.keys()):
+ count = counts[label]
+ percentage = (count / total_labeled) * 100
+ print(f" Label {label}: {count:6d} ({percentage:5.2f}%)")
+
+
+def save_label_mappings(output_dir, label_counts, tasks):
+ """Save label mappings to files."""
+ for task in tasks:
+ counts = label_counts[task]
+ if not counts:
+ continue
+
+ mapping_file = f"{output_dir}/{task}_labels.txt"
+ with open(mapping_file, "w", encoding="utf-8") as f:
+ for label in sorted(counts.keys()):
+ # Format: label_id label_name (placeholder, user should edit)
+ f.write(f"{label} class_{label}\n")
+
+ print(f"\nCreated label mapping: {mapping_file}")
+ print(" (Please edit this file to add meaningful class names)")
+
+
+def main():
+ args = get_args()
+
+ # Load config
+ with open(args.config, "r") as f:
+ config = yaml.load(f, Loader=yaml.FullLoader)
+
+ tasks = config.get("dataset_conf", {}).get("tasks", [])
+ if not tasks:
+ print("Error: No tasks defined in config")
+ sys.exit(1)
+
+ # Process training data
+ print(f"Processing training data: {args.train_data}")
+ train_counts, train_total = load_data(args.train_data, tasks)
+ print_statistics("Training", train_counts, train_total, tasks)
+
+ # Process dev data
+ if args.dev_data:
+ print(f"\nProcessing dev data: {args.dev_data}")
+ dev_counts, dev_total = load_data(args.dev_data, tasks)
+ print_statistics("Development", dev_counts, dev_total, tasks)
+
+ # Process test data
+ if args.test_data:
+ print(f"\nProcessing test data: {args.test_data}")
+ test_counts, test_total = load_data(args.test_data, tasks)
+ print_statistics("Test", test_counts, test_total, tasks)
+
+ # Save label mappings
+ save_label_mappings(args.output_dir, train_counts, tasks)
+
+ # Check for label consistency
+ print(f"\n{'='*60}")
+ print("Label Consistency Check")
+ print(f"{'='*60}")
+
+ for task in tasks:
+ train_labels = set(train_counts[task].keys())
+
+ if args.dev_data:
+ dev_labels = set(dev_counts[task].keys())
+ dev_only = dev_labels - train_labels
+ if dev_only:
+ print(f"\nWarning ({task}): Dev set has labels not in train: {dev_only}")
+
+ if args.test_data:
+ test_labels = set(test_counts[task].keys())
+ test_only = test_labels - train_labels
+ if test_only:
+ print(f"\nWarning ({task}): Test set has labels not in train: {test_only}")
+
+ print("\nStatistics computation complete!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/convert_text_labels_to_int.py b/tools/convert_text_labels_to_int.py
new file mode 100755
index 0000000..04912fc
--- /dev/null
+++ b/tools/convert_text_labels_to_int.py
@@ -0,0 +1,263 @@
+#!/usr/bin/env python3
+# Copyright (c) 2024 ChunkFormer Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Convert text classification labels to integer labels.
+
+This script converts classification labels from text format to integer format
+and automatically creates a label_mapping.json file.
+
+Example:
+ Input TSV:
+ key wav gender_label emotion_label
+ utt1 a.wav male happy
+ utt2 b.wav female sad
+
+ Output TSV (always data.tsv):
+ key wav gender_label emotion_label
+ utt1 a.wav 0 0
+ utt2 b.wav 1 1
+
+ Note: If input file is named "data.tsv", it will be renamed to "data_original.tsv"
+
+ Label mappings (label_mapping.json, automatically created):
+ {
+ "gender": {
+ "0": "male",
+ "1": "female"
+ },
+ "emotion": {
+ "0": "happy",
+ "1": "sad"
+ }
+ }
+"""
+
+import argparse
+import json
+import os
+from collections import defaultdict
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Convert text classification labels to integer labels"
+ )
+ parser.add_argument("--input", "-i", required=True, help="Input TSV file with text labels")
+ parser.add_argument(
+ "--tasks",
+ "-t",
+ nargs="+",
+ help="List of task names (e.g., gender emotion region). "
+ "If not specified, will auto-detect from column names ending with '_label'",
+ )
+ return parser.parse_args()
+
+
+def read_tsv(input_file):
+ """Read TSV file and return header and rows."""
+ with open(input_file, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+
+ if not lines:
+ raise ValueError(f"Empty input file: {input_file}")
+
+ # Parse header
+ header = lines[0].strip().split("\t")
+
+ # Parse data rows
+ rows = []
+ for line in lines[1:]:
+ line = line.strip()
+ if not line:
+ continue
+ parts = line.split("\t")
+ if len(parts) != len(header):
+ print(f"Warning: Skipping malformed line: {line}")
+ continue
+ rows.append(dict(zip(header, parts)))
+
+ return header, rows
+
+
+def detect_tasks(header):
+ """Detect task names from header columns ending with '_label'."""
+ tasks = []
+ for col in header:
+ if col.endswith("_label"):
+ task_name = col[:-6] # Remove '_label' suffix
+ tasks.append(task_name)
+ return tasks
+
+
+def load_label_mapping(label_mapping_file):
+ """Load existing label mapping from JSON file."""
+ if not os.path.exists(label_mapping_file):
+ return None
+
+ with open(label_mapping_file, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+
+def save_label_mapping(label_mapping_file, all_mappings):
+ """Save all task label mappings to a single JSON file."""
+ label_dir = os.path.dirname(label_mapping_file)
+ os.makedirs(label_dir, exist_ok=True)
+
+ with open(label_mapping_file, "w", encoding="utf-8") as f:
+ json.dump(all_mappings, f, indent=2, ensure_ascii=False)
+
+
+def create_label_mapping(rows, task):
+ """Create label mapping from data by collecting all unique labels."""
+ label_key = f"{task}_label"
+ unique_labels = set()
+
+ for row in rows:
+ if label_key in row:
+ label = row[label_key].strip()
+ if label: # Skip empty labels
+ unique_labels.add(label)
+
+ # Sort labels alphabetically for consistency
+ sorted_labels = sorted(unique_labels)
+
+ # Create mapping as id: label (reversed from before)
+ mapping = {str(idx): label for idx, label in enumerate(sorted_labels)}
+
+ return mapping
+
+
+def convert_labels(rows, tasks, label_mappings):
+ """Convert text labels to integer labels."""
+ converted_rows = []
+ missing_labels = defaultdict(set)
+
+ for row in rows:
+ new_row = row.copy()
+ for task in tasks:
+ label_key = f"{task}_label"
+ if label_key not in row:
+ print(f"Warning: Missing {label_key} in row with key {row.get('key', 'unknown')}")
+ continue
+
+ text_label = row[label_key].strip()
+ if not text_label:
+ print(f"Warning: Empty {label_key} in row with key {row.get('key', 'unknown')}")
+ continue
+
+ # Now label_mappings[task] is {id: label}, so we need to reverse lookup
+ # Create reverse mapping: label -> id
+ reverse_mapping = {
+ label: int(label_id) for label_id, label in label_mappings[task].items()
+ }
+
+ if text_label not in reverse_mapping:
+ missing_labels[task].add(text_label)
+ continue
+
+ new_row[label_key] = str(reverse_mapping[text_label])
+
+ converted_rows.append(new_row)
+
+ # Report missing labels
+ if missing_labels:
+ print("\nWarning: Some text labels not found in mappings:")
+ for task, labels in missing_labels.items():
+ print(f" Task '{task}': {', '.join(sorted(labels))}")
+
+ return converted_rows
+
+
+def write_tsv(output_file, header, rows):
+ """Write TSV file."""
+ output_dir = os.path.dirname(output_file)
+ os.makedirs(output_dir, exist_ok=True)
+
+ with open(output_file, "w", encoding="utf-8") as f:
+ # Write header
+ f.write("\t".join(header) + "\n")
+
+ # Write data rows
+ for row in rows:
+ values = [row.get(col, "") for col in header]
+ f.write("\t".join(values) + "\n")
+
+
+def main():
+ args = parse_args()
+
+ # Generate output filenames based on input
+ input_dir = os.path.dirname(args.input) or "."
+ input_basename = os.path.basename(args.input)
+
+ # Output file is always named "data.tsv"
+ output_file = os.path.join(input_dir, "data.tsv")
+ label_mapping_file = os.path.join(input_dir, "label_mapping.json")
+
+ # If input is already "data.tsv", rename it to "data_original.tsv"
+ if input_basename == "data.tsv":
+ original_input_file = os.path.join(input_dir, "data_original.tsv")
+ print(f"Input file is 'data.tsv', renaming to: {original_input_file}")
+ os.rename(args.input, original_input_file)
+ args.input = original_input_file
+
+ print(f"Reading input file: {args.input}")
+ header, rows = read_tsv(args.input)
+ print(f" Found {len(rows)} rows")
+
+ # Detect or use specified tasks
+ if args.tasks:
+ tasks = args.tasks
+ else:
+ tasks = detect_tasks(header)
+
+ if not tasks:
+ raise ValueError(
+ "No tasks found. Please specify --tasks or ensure columns end with '_label'"
+ )
+
+ print(f"Tasks to process: {', '.join(tasks)}")
+
+ # Create label mappings from data
+ all_label_mappings = {}
+
+ print("\nCreating label mappings for all tasks...")
+ for task in tasks:
+ print(f" Task '{task}'...")
+ all_label_mappings[task] = create_label_mapping(rows, task)
+ labels_list = [
+ f"{label_id}: {label}" for label_id, label in all_label_mappings[task].items()
+ ]
+ print(f" Found {len(all_label_mappings[task])} unique labels: {labels_list}")
+
+ save_label_mapping(label_mapping_file, all_label_mappings)
+ print(f"\nSaved all label mappings to: {label_mapping_file}")
+
+ # Convert labels
+ print("\nConverting labels...")
+ converted_rows = convert_labels(rows, tasks, all_label_mappings)
+
+ # Write output (always named "data.tsv")
+ print(f"Writing output file: {output_file}")
+ write_tsv(output_file, header, converted_rows)
+ print(f" Wrote {len(converted_rows)} rows")
+
+ print("\nDone!")
+ print(f" Output TSV: {output_file}")
+ print(f" Label mapping: {label_mapping_file}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/push_model_hf.py b/tools/push_model_hf.py
index f53a3b9..182407c 100755
--- a/tools/push_model_hf.py
+++ b/tools/push_model_hf.py
@@ -10,6 +10,7 @@
import sys
from typing import Optional
+import yaml
from huggingface_hub import HfApi, create_repo, upload_folder
from huggingface_hub.utils import RepositoryNotFoundError
@@ -27,18 +28,49 @@ def __init__(self, token: Optional[str] = None):
self.api = HfApi(token=token)
self.token = token
- def create_model_card(self, repo_id: str) -> str:
+ def detect_model_type(self, model_dir: str) -> tuple[str, dict]:
"""
- Create a model card for the ChunkFormer model.
+ Detect whether the model is ASR or Classification based on config.
Args:
model_dir: Directory containing the model files
- repo_id: Repository ID on Hugging Face
- config: ChunkFormer configuration
Returns:
- Model card content as string
+ Tuple of (model_type, tasks_info)
+ - model_type: "asr" or "classification"
+ - tasks_info: Dictionary with task information (for classification)
"""
+ config_path = os.path.join(model_dir, "config.yaml")
+
+ if not os.path.exists(config_path):
+ print(f"Warning: config.yaml not found in {model_dir}, assuming ASR model")
+ return "asr", {}
+
+ try:
+ with open(config_path, "r") as f:
+ config = yaml.load(f, Loader=yaml.FullLoader)
+
+ # Check if it's a classification model
+ model_type_str = config.get("model", "asr_model")
+
+ if "classification" in model_type_str.lower():
+ # Extract task information
+ tasks_info = {}
+ if "model_conf" in config:
+ tasks_conf = config["model_conf"].get("tasks", {})
+ for task_name, num_classes in tasks_conf.items():
+ tasks_info[task_name] = num_classes
+
+ return "classification", tasks_info
+ else:
+ return "asr", {}
+
+ except Exception as e:
+ print(f"Warning: Error reading config.yaml: {e}, assuming ASR model")
+ return "asr", {}
+
+ def create_asr_model_card(self, repo_id: str) -> str:
+ """Create model card for ASR model."""
model_card = f"""---
tags:
- speech-recognition
@@ -55,7 +87,7 @@ def create_model_card(self, repo_id: str) -> str:
pipeline_tag: automatic-speech-recognition
---
-# ChunkFormer Model
+# ChunkFormer ASR Model
+[](https://github.com/khanld/chunkformer)
+[](https://arxiv.org/abs/2502.14673)
+
+This model performs speech classification tasks such as gender recognition, dialect identification, emotion detection, and age classification.
+{tasks_desc}
+
+## Usage
+
+Install the package:
+
+```bash
+pip install chunkformer
+```
+
+### Single Audio Classification
+
+```python
+from chunkformer import ChunkFormerModel
+
+# Load the model
+model = ChunkFormerModel.from_pretrained("{repo_id}")
+
+# Classify a single audio file
+result = model.classify_audio(
+ audio_path="path/to/your/audio.wav",
+ chunk_size=-1, # -1 for full attention
+ left_context_size=-1,
+ right_context_size=-1
+)
+
+print(result)
+# Output example:
+# {{
+# 'gender': {{
+# 'label': 'female',
+# 'label_id': 0,
+# 'prob': 0.95
+# }},
+# 'dialect': {{
+# 'label': 'northern dialect',
+# 'label_id': 3,
+# 'prob': 0.70
+# }},
+# 'emotion': {{
+# 'label': 'neutral',
+# 'label_id': 5,
+# 'prob': 0.80
+# }}
+# }}
+```
+
+### Command Line Usage
+
+```bash
+chunkformer-decode \\
+ --model_checkpoint {repo_id} \\
+ --audio_file path/to/audio.wav
```
## Training
@@ -124,6 +284,30 @@ def create_model_card(self, repo_id: str) -> str:
""" # noqa: E501
return model_card
+ def create_model_card(self, model_dir: str, repo_id: str) -> str:
+ """
+ Create a model card for the ChunkFormer model (ASR or Classification).
+
+ Args:
+ model_dir: Directory containing the model files
+ repo_id: Repository ID on Hugging Face
+
+ Returns:
+ Model card content as string
+ """
+ # Detect model type
+ model_type, tasks_info = self.detect_model_type(model_dir)
+
+ print(f"Detected model type: {model_type}")
+ if tasks_info:
+ print(f"Classification tasks: {tasks_info}")
+
+ # Generate appropriate model card
+ if model_type == "classification":
+ return self.create_classification_model_card(repo_id, tasks_info)
+ else:
+ return self.create_asr_model_card(repo_id)
+
def create_repository(self, repo_id: str, private: bool = False) -> bool:
"""
Create a new repository on Hugging Face Hub.
@@ -170,7 +354,7 @@ def upload_model(
"""
try:
# Create model card
- model_card_content = self.create_model_card(repo_id)
+ model_card_content = self.create_model_card(model_dir, repo_id)
model_card_path = os.path.join(model_dir, "README.md")
with open(model_card_path, "w", encoding="utf-8") as f:
f.write(model_card_content)
diff --git a/tools/split_train_test.py b/tools/split_train_test.py
new file mode 100755
index 0000000..2bc40eb
--- /dev/null
+++ b/tools/split_train_test.py
@@ -0,0 +1,157 @@
+#!/usr/bin/env python3
+# Copyright (c) 2024 ChunkFormer Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Split TSV file into train and test sets randomly."""
+
+import argparse
+import random
+import sys
+from pathlib import Path
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Split TSV file into train and test sets randomly")
+ parser.add_argument("-i", "--input", required=True, type=str, help="Input TSV file path")
+ parser.add_argument(
+ "-o",
+ "--output-dir",
+ required=True,
+ type=str,
+ help="Output directory for train and test files",
+ )
+ parser.add_argument(
+ "--test-ratio",
+ type=float,
+ default=0.2,
+ help="Ratio of test set (default: 0.2 for 20%%)",
+ )
+ parser.add_argument(
+ "--dev-ratio",
+ type=float,
+ default=0.0,
+ help="Ratio of dev set (default: 0.0, no dev set)",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
+ parser.add_argument("--shuffle", action="store_true", help="Shuffle the data before splitting")
+
+ args = parser.parse_args()
+
+ # Validate ratios
+ if args.test_ratio < 0 or args.test_ratio >= 1:
+ print("Error: test_ratio must be between 0 and 1")
+ sys.exit(1)
+ if args.dev_ratio < 0 or args.dev_ratio >= 1:
+ print("Error: dev_ratio must be between 0 and 1")
+ sys.exit(1)
+ if args.test_ratio + args.dev_ratio >= 1:
+ print("Error: test_ratio + dev_ratio must be less than 1")
+ sys.exit(1)
+
+ # Set random seed
+ random.seed(args.seed)
+
+ # Read input file
+ input_path = Path(args.input)
+ if not input_path.exists():
+ print(f"Error: Input file {args.input} does not exist")
+ sys.exit(1)
+
+ print(f"Reading data from: {args.input}")
+ with open(input_path, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+
+ if len(lines) < 2:
+ print("Error: Input file must have at least a header and one data line")
+ sys.exit(1)
+
+ # Separate header and data
+ header = lines[0]
+ data_lines = lines[1:]
+
+ print(f"Total samples: {len(data_lines)}")
+
+ # Shuffle if requested
+ if args.shuffle:
+ print("Shuffling data...")
+ random.shuffle(data_lines)
+
+ # Calculate split sizes
+ total_samples = len(data_lines)
+ test_size = int(total_samples * args.test_ratio)
+ dev_size = int(total_samples * args.dev_ratio)
+ train_size = total_samples - test_size - dev_size
+
+ print(f"Train samples: {train_size} ({train_size/total_samples*100:.1f}%)")
+ if dev_size > 0:
+ print(f"Dev samples: {dev_size} ({dev_size/total_samples*100:.1f}%)")
+ print(f"Test samples: {test_size} ({test_size/total_samples*100:.1f}%)")
+
+ # Split the data
+ train_lines = data_lines[:train_size]
+ if dev_size > 0:
+ dev_lines = data_lines[train_size : train_size + dev_size]
+ test_lines = data_lines[train_size + dev_size :]
+ else:
+ dev_lines = []
+ test_lines = data_lines[train_size:]
+
+ # Create output directory structure
+ output_dir = Path(args.output_dir)
+ train_dir = output_dir / "train"
+ test_dir = output_dir / "test"
+ train_dir.mkdir(parents=True, exist_ok=True)
+ test_dir.mkdir(parents=True, exist_ok=True)
+
+ if dev_size > 0:
+ dev_dir = output_dir / "dev"
+ dev_dir.mkdir(parents=True, exist_ok=True)
+
+ # Write train file
+ train_output = train_dir / "data.tsv"
+ print(f"Writing train data to: {train_output}")
+ with open(train_output, "w", encoding="utf-8") as f:
+ f.write(header)
+ f.writelines(train_lines)
+
+ # Write dev file if needed
+ if dev_size > 0:
+ dev_output = dev_dir / "data.tsv"
+ print(f"Writing dev data to: {dev_output}")
+ with open(dev_output, "w", encoding="utf-8") as f:
+ f.write(header)
+ f.writelines(dev_lines)
+
+ # Write test file
+ test_output = test_dir / "data.tsv"
+ print(f"Writing test data to: {test_output}")
+ with open(test_output, "w", encoding="utf-8") as f:
+ f.write(header)
+ f.writelines(test_lines)
+
+ print("Split completed successfully!")
+
+ # Print statistics
+ print("\n=== Summary ===")
+ print(f"Input: {args.input}")
+ print(f"Output directory: {args.output_dir}")
+ print(f"Train: {train_output} ({len(train_lines)} samples)")
+ if dev_size > 0:
+ print(f"Dev: {dev_output} ({len(dev_lines)} samples)")
+ print(f"Test: {test_output} ({len(test_lines)} samples)")
+ print(f"Random seed: {args.seed}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/tsv_to_list.py b/tools/tsv_to_list.py
index ae7ace8..2564eb4 100644
--- a/tools/tsv_to_list.py
+++ b/tools/tsv_to_list.py
@@ -1,3 +1,17 @@
+#!/usr/bin/env python3
+# Copyright (c) 2024 ChunkFormer Authors
+
+"""
+Convert TSV format to list format.
+
+Supports both ASR and classification tasks:
+- ASR: key, wav, txt columns
+- Classification: key, wav, and any number of label columns (e.g., gender_label, emotion_label)
+
+Usage:
+ python tsv_to_list.py
+"""
+
import json
import os
import sys
@@ -14,33 +28,51 @@ def main():
base_dir = os.path.dirname(input_file)
base_name = os.path.splitext(os.path.basename(input_file))[0]
list_file = os.path.join(base_dir, f"{base_name}.list")
- text_file = os.path.join(base_dir, "text")
wav_scp_file = os.path.join(base_dir, "wav.scp")
# Read the .tsv file into a pandas DataFrame
- df = pd.read_csv(input_file, sep="\t")
+ df = pd.read_csv(input_file, sep="\t", comment="#")
df = df.dropna()
- # Generate the "key" and "wav" columns
- df["key"] = df["wav"]
+ print(f"Read {len(df)} samples from {input_file}")
+ print(f"Columns: {list(df.columns)}")
+
+ # Check if this is ASR or classification
+ has_txt = "txt" in df.columns
+ has_key = "key" in df.columns
+ has_wav = "wav" in df.columns
- # Write the .list file
+ if not has_wav:
+ print("Error: 'wav' column is required")
+ sys.exit(1)
+
+ # Generate the "key" column if not present
+ if not has_key:
+ df["key"] = df["wav"]
+ print("Generated 'key' column from 'wav'")
+
+ # Write the .list file (JSON format with all columns)
with open(list_file, "w", encoding="utf-8") as list_out:
for _, row in df.iterrows():
- row_dict = {"key": row["key"], "wav": row["wav"], "txt": row["txt"]}
+ row_dict = row.to_dict()
list_out.write(json.dumps(row_dict, ensure_ascii=False) + "\n")
- # Write the text file (key txt)
- df["txt"] = [str(txt).strip() for txt in df["txt"]]
- with open(text_file, "w", encoding="utf-8") as text_out:
- for _, row in df.iterrows():
- text_out.write(f"{row['key']} {row['txt']}\n")
+ # Write the text file (only for ASR with 'txt' column)
+ if has_txt:
+ text_file = os.path.join(base_dir, "text")
+ df["txt"] = [str(txt).strip() for txt in df["txt"]]
+ with open(text_file, "w", encoding="utf-8") as text_out:
+ for _, row in df.iterrows():
+ text_out.write(f"{row['key']} {row['txt']}\n")
+ print(f"Output written to {list_file}, {text_file}, and {wav_scp_file}")
+ else:
+ print(f"Output written to {list_file} and {wav_scp_file}")
+ print("(No 'txt' column found, skipped text file generation)")
# Write the wav.scp file (key wav)
with open(wav_scp_file, "w", encoding="utf-8") as wav_out:
for _, row in df.iterrows():
wav_out.write(f"{row['key']} {row['wav']}\n")
- print(f"Output written to {list_file}, {text_file}, and {wav_scp_file}")
if __name__ == "__main__":