diff --git a/benchmarks/DASB/music4all/music4all_prepare.py b/benchmarks/DASB/music4all/music4all_prepare.py new file mode 100644 index 000000000..5a79a923a --- /dev/null +++ b/benchmarks/DASB/music4all/music4all_prepare.py @@ -0,0 +1,233 @@ +""" +Music4All data preparation. +Download: https://sites.google.com/view/contact4music4all + +Authors + * Pooneh Mousavi 2024 +""" + +import os +import csv +import json +import random +import logging +from types import SimpleNamespace +from tqdm import tqdm +import os +import csv +import random +import json +from speechbrain.dataio.dataio import ( + read_audio_info, +) + + +logger = logging.getLogger(__name__) +METADATA_CSV = "id_metadata.csv" +WAVS = "audios" +DURATIONS = "durations" +FROZEN_SPLIT="frozen_split.json" + +logger = logging.getLogger(__name__) + + +def prepare_music4all( + data_folder, + save_folder, + splits=["train", "valid","test"], + split_ratio=[80, 10, 10], + seed=1234, + skip_prep=False, + frozen_split_path=None, + device="cpu", +): + """ + Prepares the csv files for the Music4All datasets. + + Arguments + --------- + data_folder : str + Path to the folder where the original LJspeech dataset is stored + save_folder : str + The directory where to store the csv/json files + splits : list + List of dataset splits to prepare + split_ratio : list + Proportion for dataset splits + seed : int + Random seed + skip_prep : bool + If True, skip preparation + frozen_split_path : str | path-like + The path to the frozen split file (used to standardize multiple + experiments) + device : str + Device for to be used for computation (used as required) + + Returns + ------- + None + + Example + ------- + >>> data_folder = 'data/music4all/' + >>> save_folder = 'save/' + >>> splits = ['train', 'valid','test'] + >>> split_ratio = [80, 10, 10] + >>> seed = 1234 + >>> prepare_music4all(data_folder, save_folder, splits, split_ratio, seed) + """ + # Sets seeds for reproducible code + random.seed(seed) + + if skip_prep: + return + + # Check if this phase is already done (if so, skip it) + if skip(splits, save_folder): + logger.info("Skipping preparation, completed in previous run.") + return + + + if not os.path.exists(save_folder): + os.makedirs(save_folder) + + # Setting ouput files + meta_csv = os.path.join(data_folder, METADATA_CSV) + wavs_folder = os.path.join(data_folder, WAVS) + if frozen_split_path is None: + frozen_split_path = os.path.join(save_folder, FROZEN_SPLIT) + + + # Additional check to make sure metadata.csv and wavs folder exists + assert os.path.exists(meta_csv), "metadata.csv does not exist" + assert os.path.exists(wavs_folder), "wavs/ folder does not exist" + + # Prepare data splits + msg = "Creating csv file for music4all Dataset.." + logger.info(msg) + # Get the splits + splits_data = split_sets(meta_csv, splits, split_ratio, frozen_split_path) + + # Dynamically create CSV files for each split + for split in splits_data: + logger.info(f"Start processing {split} data.") + save_json_path = os.path.join(save_folder, f"{split}.json") # Dynamic filename + create_csv(splits_data[split], wavs_folder, save_json_path) + logger.info(f"Saved {split} data to {save_json_path}") + + + +def skip(splits, save_folder): + """ + Detects if the ljspeech data_preparation has been already done. + If the preparation has been done, we can skip it. + + Returns + ------- + bool + if True, the preparation phase can be skipped. + if False, it must be done. + """ + # Checking json files + skip = True + + + for split in splits: + if not os.path.isfile(os.path.join(save_folder, f"{split}.json")): + skip = False + return skip + + + +# Function to split IDs based on ratios +def split_sets(meta_file, splits, split_ratio, frozen_split_path): + """ + Splits data into train, valid, and test sets based on given ratios. + Checks if frozen splits already exist and uses them if available. + + Parameters: + data_folder (str): Path to the folder containing `meta.csv`. + splits (list): List of split names, e.g., ["train", "valid", "test"]. + split_ratio (list): Ratios for train, valid, and test splits. + frozen_split_path (str): Path to save/load the frozen splits JSON file. + + Returns: + dict: A dictionary with keys as split names containing the split IDs. + """ + # Check if frozen splits already exist + if os.path.exists(frozen_split_path): + logger.info(f"Loading frozen splits from {frozen_split_path}") + with open(frozen_split_path, 'r') as f: + splits_data = json.load(f) + return splits_data + + with open(meta_file, 'r') as f: + reader = csv.DictReader(f, delimiter='\t') # Using '\t' as delimiter for tab-separated file + # Extract all IDs + all_ids = [row['id'] for row in reader] + + # Shuffle the IDs and calculate split indices + random.shuffle(all_ids) + total = len(all_ids) + train_end = int(total * (split_ratio[0] / 100)) + valid_end = train_end + int(total * (split_ratio[1] / 100)) + + # Generate new splits + splits_data = { + splits[0]: all_ids[:train_end], # Train + splits[1]: all_ids[train_end:valid_end], # Valid + splits[2]: all_ids[valid_end:] # Test + } + + # Save the splits to frozen_split_path + logger.info(f"Saving splits to {frozen_split_path}") + with open(frozen_split_path, 'w') as f: + json.dump(splits_data, f, indent=4) + + return splits_data + +# Function to create CSV files +def create_csv(ids, audio_folder, output_json): + """ + Creates a CSV file with `id`, `audio_path`, and `duration` for the given IDs. + + Parameters: + ids (list): List of IDs to include in the CSV. + audio_folder (str): Folder containing the audio files. + output_csv (str): Path to the output CSV file. + """ + json_dict = {} + for file_id in tqdm(ids, desc="Processing split", unit="file_id"): + audio_path = os.path.join(audio_folder, f"{file_id}.mp3") + if os.path.exists(audio_path): + try: + # Get the duration of the audio file + info = read_audio_info(audio_path) + duration = info.num_frames / info.sample_rate + json_dict[file_id] = { + "uttid": file_id, + "wav": audio_path, + "duration": duration + } + except Exception as e: + logger.warn(f"Error processing file {audio_path}") + continue + + # Writing the dictionary to the json file + with open(output_json, mode="w") as json_f: + json.dump(json_dict, json_f, indent=2) + + logger.info(f"{output_json} successfully created!") + +# # Example Usage +# if __name__ == "__main__": +# data_folder = '/home/ubuntu/music4all' +# save_folder = 'save/' +# splits = ["train", "valid", "test"] +# split_ratio = [1, 10, 80] # Train, Valid, Test percentages +# seed = 1234 +# prepare_music4all(data_folder, save_folder, splits, split_ratio, seed) + + + diff --git a/benchmarks/DASB/music4all/quantization/README.md b/benchmarks/DASB/music4all/quantization/README.md new file mode 100644 index 000000000..5f371ddc0 --- /dev/null +++ b/benchmarks/DASB/music4all/quantization/README.md @@ -0,0 +1,65 @@ +# Quantization + +This folder contains recipes for training K-means quantizers on the music4all dataset. +The quantizer maps self-supervised representations from MERT. into discrete representations. +These discrete representations can then be used as input features for downstream tasks. + +You can download LJSpeech from https://sites.google.com/view/contact4music4all. + +--------------------------------------------------------------------------------------------------------- + +## Installing Extra Dependencies + +Before proceeding, ensure you have installed the necessary additional dependencies. +To do so, simply run the following command in your terminal: + +```shell +pip install -r extra_requirements.txt +``` + +--------------------------------------------------------------------------------------------------------- + +## Running an Experiment + +```shell +python train.py hparams/train_discrete_ssl.yaml --data_folder +--n_clusters 1000 \ +--layer_id 7 \ +--experiment_name mert_K1000_L7 +``` +--------------------------------------------------------------------------------------------------------- + +## About SpeechBrain + +- Website: https://speechbrain.github.io/ +- Code: https://github.com/speechbrain/speechbrain/ +- HuggingFace: https://huggingface.co/speechbrain/ + +--------------------------------------------------------------------------------------------------------- + +## Citing SpeechBrain + +Please, cite SpeechBrain if you use it for your research or business. + +```bibtex +@article{speechbrainV1, + author = {Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca {Della Libera} and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Ha Nguyen and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Ga{{\"e}}lle Laperri{{\`e}}re and Mickael Rouvier and Renato De Mori and Yannick Est{{\`e}}ve}, + title = {Open-Source Conversational {AI} with {SpeechBrain} 1.0}, + journal = {Journal of Machine Learning Research}, + year = {2024}, + volume = {25}, + number = {333}, + pages = {1--11}, + url = {http://jmlr.org/papers/v25/24-0991.html} +} +``` + +```bibtex +@article{ravanelli2021speechbrain, + author = {Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, + title = {{SpeechBrain}: A General-Purpose Speech Toolkit}, + journal = {arXiv preprint arXiv:2106.04624}, + year = {2021}, + url = {https://arxiv.org/abs/2106.04624}, +} +``` diff --git a/benchmarks/DASB/music4all/quantization/extra_requirements.txt b/benchmarks/DASB/music4all/quantization/extra_requirements.txt new file mode 100644 index 000000000..bd6f06aa6 --- /dev/null +++ b/benchmarks/DASB/music4all/quantization/extra_requirements.txt @@ -0,0 +1,3 @@ +scikit-learn +tgt +unidecode diff --git a/benchmarks/DASB/music4all/quantization/hparams/train_discrete_ssl.yaml b/benchmarks/DASB/music4all/quantization/hparams/train_discrete_ssl.yaml new file mode 100644 index 000000000..b66160378 --- /dev/null +++ b/benchmarks/DASB/music4all/quantization/hparams/train_discrete_ssl.yaml @@ -0,0 +1,103 @@ +# ########################################################################################### +# Model: K-means applied to SSL model +# Authors: Pooneh Mousavi 2024 +# Adapted from: https://github.com/speechbrain/speechbrain/blob/v1.0.2/recipes/LJSpeech/quantization/hparams/train_discrete_ssl.yaml +# ########################################################################################### + +experiment_name: mert_K1000_L7 + +# Seed needs to be set at top of YAML +seed: 1986 +__set_seed: !apply:speechbrain.utils.seed_everything [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +splits: [train, valid, test] +split_ratio: [80, 10, 10] +skip_prep: False +data_cache_folder: data_cache/ +train_json: !ref /train.json +valid_json: !ref /valid.json +test_json: !ref /test.json + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 1 +train_batch_size: 8 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +precision: fp32 +ckpt_interval_steps: 4000 +keep_checkpoints: 2 + +# SSL model parameters +ssl_hub: m-a-p/MERT-v1-330M +sample_rate: 24000 # NOTE: must match the SSL model sample rate +layer_id: 1 + +# Quantizer parameters +n_clusters: 1000 +init: k-means++ +max_iter: 100 +kmeans_batch_size: 10000 # Should be >= num_clusters +tol: 0.0 +max_no_improvement: 100 +n_init: 20 +reassignment_ratio: 0.0 + +# Modules +ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.mert.MERT + source: !ref + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + +quantizer: !new:speechbrain.lobes.models.kmeans.MiniBatchKMeansSklearn + n_clusters: !ref + init: !ref + max_iter: !ref + batch_size: !ref + tol: !ref + max_no_improvement: !ref + n_init: !ref + reassignment_ratio: !ref + random_state: !ref + verbose: 1 + compute_labels: True + init_size: null + +modules: + ssl_model: !ref + quantizer: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + quantizer: !ref + counter: !ref + custom_load_hooks: + quantizer: !name:speechbrain.lobes.models.kmeans.MiniBatchKMeansSklearn.load + custom_save_hooks: + quantizer: !name:speechbrain.lobes.models.kmeans.MiniBatchKMeansSklearn.save + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt + precision: 3 diff --git a/benchmarks/DASB/music4all/quantization/music4all_prepare.py b/benchmarks/DASB/music4all/quantization/music4all_prepare.py new file mode 120000 index 000000000..7bebef740 --- /dev/null +++ b/benchmarks/DASB/music4all/quantization/music4all_prepare.py @@ -0,0 +1 @@ +../music4all_prepare.py \ No newline at end of file diff --git a/benchmarks/DASB/music4all/quantization/run_kmean.sh b/benchmarks/DASB/music4all/quantization/run_kmean.sh new file mode 100644 index 000000000..52e0b2388 --- /dev/null +++ b/benchmarks/DASB/music4all/quantization/run_kmean.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Ensure the script is executable with: chmod +x run_experiments.sh + +# Number of iterations +iterations=24 + +# Base command to run the experiment +base_command="python train.py hparams/train_discrete_ssl.yaml --data_folder /home/ubuntu/music4all --data_cache_folder /home/ubuntu/music4all --n_clusters 1000" + +# Loop over the iterations +for ((i=0; i python train.py hparams/train_discrete_ssl.yaml + +Authors + * Pooneh Mousavi 2024 +""" + +# Adapted from: +# https://github.com/speechbrain/speechbrain/blob/v1.0.2/recipes/LJSpeech/quantization/train.py + +import sys + +import torch +import torchaudio +from hyperpyyaml import load_hyperpyyaml + +import speechbrain as sb +from speechbrain.utils.distributed import if_main_process + + +class Quantization(sb.Brain): + def compute_forward(self, batch, stage): + """Forward pass.""" + batch = batch.to(self.device) + sig, lens = batch.sig # [B, T] + + # Extract features + with torch.no_grad(): + self.modules.ssl_model.eval() + feats = self.modules.ssl_model(sig, lens) # [K, B, N, H] + feats = feats[self.hparams.layer_id] # [B, N, H] + + return feats + + def compute_objectives(self, predictions, batch, stage): + """Computes the objectives.""" + feats = predictions # [B, N, H] + + if stage != sb.Stage.TRAIN: + # For K-means the validation/test loss is the inertia + # The lower the inertia, the better should be the clustering + # It is useful to monitor progress across epochs + # However, when saving checkpoints we always keep the last one (i.e. max_keys=["epoch"]) + # to keep backward compatibility + loss = self.hparams.quantizer.inertia(feats) + return loss + + # If training, accumulate features (batch size used for K-means training + # should be much larger than batch size used for feature extraction) + feats = feats.flatten(end_dim=-2) # [BN, H] + self.curr_feats.append(feats) + self.curr_batch_size += len(feats) + if self.curr_batch_size < self.hparams.kmeans_batch_size: + # If not enough features, leave average loss unchanged and go to next batch + # avg_loss is computed as: (avg_loss - avg_loss / self.step) + float(loss) / self.step + # If we set loss = avg_loss, avg_loss stays unchanged + loss = torch.tensor(self.avg_train_loss) + # Keep compatibility with standard supervised training + # (SpeechBrain expects a tensor with gradient) + loss.requires_grad_() + return loss + self.curr_feats = torch.cat(self.curr_feats) + feats = self.curr_feats[: self.hparams.kmeans_batch_size] + + # Keep remaining features for next iteration + self.curr_feats = [self.curr_feats[self.hparams.kmeans_batch_size :]] + self.curr_batch_size = len(self.curr_feats[0]) + + # Retrieve current centroids + old_cluster_centers = self.hparams.quantizer.cluster_centers + + # Partial fit on current batch + self.hparams.quantizer.partial_fit(feats) + + # For K-means the training loss is the drift between current centroids and old centroids + # If close to 0, it means that the training has converged + curr_cluster_centers = self.hparams.quantizer.cluster_centers + loss = (curr_cluster_centers - old_cluster_centers).norm() + + # Keep compatibility with standard supervised training + # (SpeechBrain expects a tensor with gradient) + loss.requires_grad_() + self.optimizer_step += 1 + assert self.optimizer_step == self.modules.quantizer.n_steps, ( + f"optimizer_step: {self.optimizer_step}", + f"quantizer.n_steps: {self.modules.quantizer.n_steps}", + ) + + return loss + + def on_stage_start(self, stage, epoch=None): + """Gets called at the beginning of each epoch.""" + if stage == sb.Stage.TRAIN: + # NOTE: not included in intra-epoch checkpoints + self.curr_feats = [] + self.curr_batch_size = 0 + + def on_stage_end(self, stage, stage_loss, epoch=None): + """Gets called at the end of each epoch.""" + # Compute/store important stats + current_epoch = self.hparams.epoch_counter.current + stage_stats = {"loss": stage_loss} + + if stage == sb.Stage.TRAIN: + self.avg_train_loss = 0.0 + self.train_stats = stage_stats + self.stats_meta = {"epoch": epoch, "steps": self.optimizer_step} + if if_main_process(): + self.checkpointer.save_and_keep_only( + meta={"loss": stage_stats["loss"], "epoch": epoch}, + max_keys=["epoch"], + num_to_keep=self.hparams.keep_checkpoints, + ) + self.hparams.train_logger.log_stats( + stats_meta=self.stats_meta, + train_stats=self.train_stats, + ) + + # Perform end-of-iteration operations, like annealing, logging, etc. + elif stage == sb.Stage.VALID: + self.hparams.train_logger.log_stats( + stats_meta=self.stats_meta, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": current_epoch}, + test_stats=stage_stats, + ) + + +def dataio_prepare(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions. + + """ + train_data = sb.dataio.dataset.DynamicItemDataset.from_json( + json_path=hparams["train_json"], + replacements={"DATA_ROOT": hparams["data_folder"]}, + ) + # Sort training data to speed up training + train_data = train_data.filtered_sorted( + sort_key="duration", + reverse=hparams["sorting"] == "descending", + key_max_value={"duration": hparams["train_remove_if_longer"]}, + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_json( + json_path=hparams["valid_json"], + replacements={"DATA_ROOT": hparams["data_folder"]}, + ) + # Sort validation data to speed up validation + valid_data = valid_data.filtered_sorted( + sort_key="duration", + reverse=True, + key_max_value={"duration": hparams["valid_remove_if_longer"]}, + ) + + test_data = sb.dataio.dataset.DynamicItemDataset.from_json( + json_path=hparams["test_json"], + replacements={"DATA_ROOT": hparams["data_folder"]}, + ) + # Sort the test data to speed up testing + test_data = test_data.filtered_sorted( + sort_key="duration", + reverse=True, + key_max_value={"duration": hparams["test_remove_if_longer"]}, + ) + + datasets = [train_data, valid_data, test_data] + + # Define audio pipeline + takes = ["wav"] + provides = ["sig"] + + def audio_pipeline(wav): + original_sample_rate = sb.dataio.dataio.read_audio_info(wav).sample_rate + sig = sb.dataio.dataio.read_audio(wav) + sig = torchaudio.functional.resample( + sig, original_sample_rate, hparams["sample_rate"] + ) + if sig.ndim > 1: + sig=sig[:,0] + + yield sig.squeeze() + + sb.dataio.dataset.add_dynamic_item( + datasets, audio_pipeline, takes, provides + ) + + # Set output + sb.dataio.dataset.set_output_keys(datasets, ["id"] + provides) + + return datasets + + +if __name__ == "__main__": + # Command-line interface + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + with open(hparams_file, encoding="utf-8") as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # If --distributed_launch then create ddp_init_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Prepare data + from music4all_prepare import prepare_music4all + + kwargs = { + "data_folder": hparams["data_folder"], + "save_folder": hparams["data_cache_folder"], + "splits": hparams["splits"], + "split_ratio": hparams["split_ratio"], + "seed": hparams["seed"], + "skip_prep": hparams["skip_prep"], + } + prepare_music4all(**kwargs) + + # Create the datasets objects + train_data, valid_data, test_data = dataio_prepare(hparams) + + # Trainer initialization + brain = Quantization( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # Train + brain.fit( + brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=dict( + num_workers=hparams["dataloader_workers"], + batch_size=hparams["train_batch_size"], + shuffle=hparams["sorting"] == "random", + pin_memory=run_opts.get("device", "cpu") != "cpu", + ), + valid_loader_kwargs=dict( + num_workers=hparams["dataloader_workers"], + batch_size=hparams["valid_batch_size"], + pin_memory=run_opts.get("device", "cpu") != "cpu", + ), + ) + + # Test + brain.evaluate( + test_data, + max_key="epoch", + test_loader_kwargs=dict( + num_workers=hparams["dataloader_workers"], + batch_size=hparams["test_batch_size"], + pin_memory=run_opts.get("device", "cpu") != "cpu", + ), + )