From 417b616bd40646eb165c8ebd9210949120e71e8c Mon Sep 17 00:00:00 2001 From: Aidan Pine Date: Sat, 30 Nov 2024 01:16:00 +0000 Subject: [PATCH 1/3] feat: add scoring option to synthesize cli command --- fs2/cli/synthesize.py | 40 ++++++++++++++ fs2/dataset.py | 12 ++++- fs2/loss.py | 58 +++++++++++--------- fs2/prediction_writing_callback.py | 85 ++++++++++++++++++++++++++++++ 4 files changed, 168 insertions(+), 27 deletions(-) diff --git a/fs2/cli/synthesize.py b/fs2/cli/synthesize.py index 5f38b59..93704e4 100644 --- a/fs2/cli/synthesize.py +++ b/fs2/cli/synthesize.py @@ -1,5 +1,6 @@ import sys import textwrap +from collections import Counter from pathlib import Path from typing import Any, Optional @@ -11,6 +12,7 @@ ) from everyvoice.utils import spinner from loguru import logger +from tqdm import tqdm from ..type_definitions import SynthesizeOutputFormats from ..utils import truncate_basename @@ -221,6 +223,7 @@ def synthesize_helper( vocoder_global_step: Optional[int] = None, vocoder_model=None, vocoder_config=None, + return_scores=False, ): """This is a helper to perform synthesis once the model has been loaded. It allows us to use the same command for synthesis via the CLI and @@ -258,6 +261,28 @@ def synthesize_helper( text_representation=text_representation, style_reference=style_reference, ) + if return_scores: + from nltk.util import ngrams + + token_counter = Counter() + trigram_counter = Counter() + for line in tqdm( + data, desc="calculating filelist statistics for score calculation" + ): + tokens = line[f"{text_representation.value[:-1]}_tokens"].split("/") + for t in tokens: + token_counter[t] += 1 + tokens.insert(0, "") + tokens.append("") + for trigram in ngrams(tokens, 3): + trigram_counter[trigram] += 1 + for line in tqdm(data, desc="scoring utterances"): + tokens = line[f"{text_representation.value[:-1]}_tokens"].split("/") + line["phone_coverage_score"] = sum((1 / token_counter[t]) for t in tokens) + line["trigram_coverage_score"] = sum( + (1 / trigram_counter[n]) for n in ngrams(tokens, 3) + ) + from pytorch_lightning import Trainer from ..prediction_writing_callback import get_synthesis_output_callbacks @@ -272,6 +297,7 @@ def synthesize_helper( vocoder_model=vocoder_model, vocoder_config=vocoder_config, vocoder_global_step=vocoder_global_step, + return_scores=return_scores, ) trainer = Trainer( logger=False, # We don't need to log things to tensorboard during inference @@ -284,6 +310,10 @@ def synthesize_helper( teacher_forcing = True model.config.preprocessing.save_dir = teacher_forcing_directory else: + if return_scores: + raise ValueError( + "In order to return the scores, we also need access to the directory containing your ground truth audio. Please pass this in using the --teacher-forcing-directory option. e.g. --teacher-forcing-directory ./preprocessed" + ) teacher_forcing = False # overwrite batch_size and num_workers model.config.training.batch_size = batch_size @@ -401,6 +431,12 @@ def synthesize( # noqa: C901 '**readalong-html**' will generate a single file Offline HTML ReadAlong that can be further edited in the ReadAlong Studio Editor, and opened by itself. Also implies '--output-type wav'. Requires --vocoder-path. """, ), + return_scores: bool = typer.Option( + False, + "--return-scores", + "-R", + help="ADVANCED. Setting this to True will change your batch size to 1 and output a PSV file with the losses for each synthesized audio along with a score of trigram density to measure the phonological importance of the utterance.", + ), teacher_forcing_directory: Path = typer.Option( None, "--teacher-forcing-directory", @@ -466,6 +502,9 @@ def synthesize( # noqa: C901 from ..model import FastSpeech2 + if return_scores: + batch_size = 1 + output_dir.mkdir(exist_ok=True, parents=True) # NOTE: We want to be able to put the vocoder on the proper accelerator for # it to be compatible with the vocoder's input device. @@ -530,4 +569,5 @@ def synthesize( # noqa: C901 vocoder_model=vocoder_model, vocoder_config=vocoder_config, vocoder_global_step=vocoder_global_step, + return_scores=return_scores, ) diff --git a/fs2/dataset.py b/fs2/dataset.py index d356dac..da317c3 100644 --- a/fs2/dataset.py +++ b/fs2/dataset.py @@ -185,7 +185,8 @@ def __getitem__(self, index): else: energy = None pitch = None - return { + + loaded_data = { "mel": mel, "mel_style_reference": mel_style_reference, "duration": duration, @@ -203,6 +204,15 @@ def __getitem__(self, index): "pitch": pitch, } + # used when returning scores + if "phone_coverage_score" in item: + loaded_data['phone_coverage_score'] = item['phone_coverage_score'] + + if "trigram_coverage_score" in item: + loaded_data['trigram_coverage_score'] = item['trigram_coverage_score'] + + return loaded_data + def __len__(self): return len(self.dataset) diff --git a/fs2/loss.py b/fs2/loss.py index ec50f4d..68b6a26 100644 --- a/fs2/loss.py +++ b/fs2/loss.py @@ -37,39 +37,45 @@ def forward(self, output, batch, current_epoch, frozen_components=None): # Don't calculate grad on target duration_target.requires_grad = False - energy_target.requires_grad = False spec_target.requires_grad = False - pitch_target.requires_grad = False losses = {} # Calculate pitch loss - if self.config.model.variance_predictors.pitch.level == "phone": - pitch_mask = src_mask - else: - pitch_mask = tgt_mask - - pitch_prediction = pitch_prediction * pitch_mask - pitch_target = pitch_target * pitch_mask - pitch_loss_fn = self.config.model.variance_predictors.pitch.loss - losses["pitch"] = ( - self.loss_fns[pitch_loss_fn](pitch_prediction, pitch_target) - * self.config.training.pitch_loss_weight - ) + if pitch_target is not None: + + pitch_target.requires_grad = False + + if self.config.model.variance_predictors.pitch.level == "phone": + pitch_mask = src_mask + else: + pitch_mask = tgt_mask + + pitch_prediction = pitch_prediction * pitch_mask + pitch_target = pitch_target * pitch_mask + pitch_loss_fn = self.config.model.variance_predictors.pitch.loss + losses["pitch"] = ( + self.loss_fns[pitch_loss_fn](pitch_prediction, pitch_target) + * self.config.training.pitch_loss_weight + ) # Calculate energy loss - if self.config.model.variance_predictors.energy.level == "phone": - energy_mask = src_mask - else: - energy_mask = tgt_mask - - energy_prediction = energy_prediction * energy_mask - energy_target = energy_target * energy_mask - energy_loss_fn = self.config.model.variance_predictors.energy.loss - losses["energy"] = ( - self.loss_fns[energy_loss_fn](energy_prediction, energy_target) - * self.config.training.energy_loss_weight - ) + if energy_target is not None: + + energy_target.requires_grad = False + + if self.config.model.variance_predictors.energy.level == "phone": + energy_mask = src_mask + else: + energy_mask = tgt_mask + + energy_prediction = energy_prediction * energy_mask + energy_target = energy_target * energy_mask + energy_loss_fn = self.config.model.variance_predictors.energy.loss + losses["energy"] = ( + self.loss_fns[energy_loss_fn](energy_prediction, energy_target) + * self.config.training.energy_loss_weight + ) # Calculate duration loss log_duration_target = torch.log(duration_target.float() + 1) * src_mask diff --git a/fs2/prediction_writing_callback.py b/fs2/prediction_writing_callback.py index 4502768..8d5e5fc 100644 --- a/fs2/prediction_writing_callback.py +++ b/fs2/prediction_writing_callback.py @@ -1,5 +1,6 @@ from __future__ import annotations +from csv import DictWriter from pathlib import Path from typing import Any, Optional, Sequence @@ -32,12 +33,20 @@ def get_synthesis_output_callbacks( vocoder_model: Optional[HiFiGAN] = None, vocoder_config: Optional[HiFiGANConfig] = None, vocoder_global_step: Optional[int] = None, + return_scores=False, ) -> dict[SynthesizeOutputFormats, Callback]: """ Given a list of desired output file formats, return the proper callbacks that will generate those files. """ callbacks: dict[SynthesizeOutputFormats, Callback] = {} + if return_scores: + callbacks['score'] = ScorerCallback( + config=config, + global_step=global_step, + output_dir=output_dir, + output_key=output_key, + ) if ( SynthesizeOutputFormats.wav in output_type or SynthesizeOutputFormats.readalong_html in output_type @@ -135,6 +144,82 @@ def get_filename( return str(path) +class ScorerCallback(Callback): + """ + This callback runs inference on a provided text-to-spec model and saves the resulting losses to disk. + """ + + def __init__( + self, + config: FastSpeech2Config, + global_step: int, + output_dir: Path, + output_key: str, + ): + self.global_step = global_step + self.save_dir = output_dir + self.output_key = output_key + self.config = config + logger.info(f"Saving pytorch output to {self.save_dir}") + self.scores = [] + + def _get_filename(self) -> Path: + path = self.save_dir / f"scores-{self.global_step}.psv" + path.parent.mkdir( + parents=True, exist_ok=True + ) # synthesizing spec allows nested outputs + return path + + def sort_scores(self): + self.scores.sort(key=lambda x: (-x["total"], x["trigram_coverage_score"])) + + def on_predict_epoch_end( + self, + _trainer, + model, + ): + self.sort_scores() + with open(self._get_filename(), "w") as f: + fieldnames = [ + "basename", + "speaker", + "language", + "total", + "trigram_coverage_score", + "duration", + "spec", + "postnet", + "attn_ctc", + "attn_bin", + "raw_text", + "phone_coverage_score", + ] + writer = DictWriter(f, fieldnames=fieldnames, delimiter="|") + writer.writeheader() + for score in self.scores: + writer.writerow(score) + + def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] + self, + _trainer, + model, + outputs: dict[str, torch.Tensor | None], + batch: dict[str, Any], + _batch_idx: int, + _dataloader_idx: int = 0, + ): + with torch.no_grad(): + losses = model.loss(outputs, batch, model.current_epoch) + score = {k: float(v) for k, v in losses.items()} + score["basename"] = batch["basename"][0] + score["speaker"] = batch["speaker"][0] + score["language"] = batch["language"][0] + score["raw_text"] = batch["raw_text"][0] + score["phone_coverage_score"] = batch["phone_coverage_score"][0] + score["trigram_coverage_score"] = batch["trigram_coverage_score"][0] + self.scores.append(score) + + class PredictionWritingSpecCallback(PredictionWritingCallbackBase): """ This callback runs inference on a provided text-to-spec model and saves the resulting Mel spectrograms to disk as pytorch files. These can be used to fine-tune an EveryVoice spec-to-wav model. From 2e0a40beaeb14c6f779670d1fd21ff4a3b4d6ee6 Mon Sep 17 00:00:00 2001 From: Aidan Pine Date: Wed, 22 Jan 2025 03:06:25 +0000 Subject: [PATCH 2/3] refactor: refactor check data to combine with scorer --- fs2/cli/check_data.py | 211 +++++++++++++++++++++++++++++ fs2/cli/check_data_heavy.py | 125 +++++++++++++++++ fs2/cli/synthesize.py | 176 +++++++++++++----------- fs2/prediction_writing_callback.py | 48 +++---- fs2/tests/test_cli.py | 32 ++++- 5 files changed, 487 insertions(+), 105 deletions(-) create mode 100644 fs2/cli/check_data.py create mode 100644 fs2/cli/check_data_heavy.py diff --git a/fs2/cli/check_data.py b/fs2/cli/check_data.py new file mode 100644 index 0000000..7f23a5a --- /dev/null +++ b/fs2/cli/check_data.py @@ -0,0 +1,211 @@ +import json +import sys +from pathlib import Path +from typing import Optional + +import typer +from everyvoice.base_cli.interfaces import complete_path +from everyvoice.config.type_definitions import DatasetTextRepresentation +from everyvoice.utils import generic_psv_filelist_reader, spinner +from loguru import logger + +from .synthesize import get_global_step, synthesize_helper + + +def check_data_command( # noqa: C901 + config_file: Path = typer.Argument( + ..., + exists=True, + dir_okay=False, + file_okay=True, + help="The path to your text-to-spec model configuration file.", + shell_complete=complete_path, + ), + model_path: Optional[Path] = typer.Argument( + ..., + file_okay=True, + exists=True, + dir_okay=False, + help="The path to a trained text-to-spec (i.e., feature prediction) or e2e EveryVoice model.", + shell_complete=complete_path, + ), + output_dir: Path = typer.Option( + "checked_data", + "--output-dir", + "-o", + file_okay=False, + dir_okay=True, + help="The directory where your synthesized audio should be written", + shell_complete=complete_path, + ), + style_reference: Optional[Path] = typer.Option( + None, + "--style-reference", + "-S", + exists=True, + file_okay=True, + dir_okay=False, + help="The path to an audio file containing a style reference. Your text-to-spec must have been trained with the global style token module to use this feature.", + shell_complete=complete_path, + ), + accelerator: str = typer.Option("auto", "--accelerator", "-a"), + devices: str = typer.Option( + "auto", "--devices", "-d", help="The number of GPUs to use" + ), + filelist: Path = typer.Option( + None, + "--filelist", + "-f", + exists=True, + file_okay=True, + dir_okay=False, + help="The path to a file containing a list of utterances (a.k.a filelist). Use --text if you want to just synthesize one sample.", + shell_complete=complete_path, + ), + text_representation: DatasetTextRepresentation = typer.Option( + DatasetTextRepresentation.characters, + help="The representation of the text you are synthesizing. Can be either 'characters', 'phones', or 'arpabet'. The input type must be compatible with your model.", + ), + teacher_forcing_directory: Path = typer.Option( + "preprocessed", + "--preprocessed-directory", + "-p", + help="The path to the folder containing all of your preprocessed data.", + dir_okay=True, + file_okay=False, + shell_complete=complete_path, + ), + num_workers: int = typer.Option( + 4, + "--num-workers", + "-n", + help="Number of workers to process the data.", + ), + calculate_stats: bool = typer.Option( + True, + "--calculate-stats/--no-calculate-stats", + help="Whether to calculate basic statistics on your dataset.", + ), + objective_evaluation: bool = typer.Option( + True, + "--objective-evaluation/--no-objective-evaluation", + help="Whether to perform objective evaluation on your dataset using TorchSquim. This is time-consuming.", + ), + clip_detection: bool = typer.Option( + False, + "--clip-detection/--no-clip-detection", + help="Whether to detect clipping in your audio. This is expensive so we do not do this by default.", + ), +): + """ + Given a filelist and some preprocessed data, check some basic statistics on the data. + If a checkpoint is provided, also calculate the loss for each datapoint with respect to the model. + + Note: this function was written by restricting the synthesize command. + """ + + with spinner(): + from everyvoice.base_cli.helpers import MODEL_CONFIGS, load_unknown_config + from everyvoice.preprocessor import Preprocessor + from everyvoice.utils.heavy import get_device_from_accelerator + + from ..model import FastSpeech2 + from .check_data_heavy import check_data_from_filelist + from .synthesize import load_data_from_filelist + + config = load_unknown_config(config_file) + preprocessor = Preprocessor(config) + if not any((isinstance(config, x) for x in MODEL_CONFIGS)): + print( + "Sorry, your file does not appear to be a valid model configuration. Please choose another model config file." + ) + sys.exit(1) + + output_dir.mkdir(exist_ok=True, parents=True) + + if filelist is None: + training_filelist = generic_psv_filelist_reader( + config.training.training_filelist + ) + val_filelist = generic_psv_filelist_reader(config.training.training_filelist) + combined_filelist_data = training_filelist + val_filelist + else: + combined_filelist_data = generic_psv_filelist_reader(filelist) + + # process stats + if calculate_stats: + stats = check_data_from_filelist( + preprocessor, + combined_filelist_data, + heavy_clip_detection=clip_detection, + heavy_objective_evaluation=objective_evaluation, + ) + if not stats: + print( + f"Sorry, the data at {config.training.training_filelist} and {config.training.validation_filelist} is empty so there is nothing to check." + ) + sys.exit(1) + else: + with open(output_dir / "checked-data.json", "w", encoding="utf8") as f: + json.dump(stats, f) + + if model_path: + # NOTE: We want to be able to put the vocoder on the proper accelerator for + # it to be compatible with the vocoder's input device. + # We could misuse the trainer's API and use the private variable + # trainer._accelerator_connector._accelerator_flag but that value is + # computed when instantiating a trainer and that is exactly when we need + # the information to create the callbacks. + device = get_device_from_accelerator(accelerator) + + # Load checkpoints + print(f"Loading checkpoint from {model_path}", file=sys.stderr) + + from pydantic import ValidationError + + try: + model: FastSpeech2 = FastSpeech2.load_from_checkpoint(model_path).to(device) # type: ignore + except (TypeError, ValidationError) as e: + logger.error(f"Unable to load {model_path}: {e}") + sys.exit(1) + model.eval() + + if filelist is None: + training_filelist = load_data_from_filelist( + config.training.training_filelist, model, text_representation + ) + val_filelist = load_data_from_filelist( + config.training.training_filelist, model, text_representation + ) + combined_filelist_data = training_filelist + val_filelist + else: + combined_filelist_data = None + + # get global step + # We can't just use model.global_step because it gets reset by lightning + global_step = get_global_step(model_path) + + synthesize_helper( + model=model, + texts=None, + style_reference=style_reference, + language=None, + speaker=None, + duration_control=1.0, + global_step=global_step, + output_type=[], + text_representation=text_representation, + accelerator=accelerator, + devices=devices, + device=device, + batch_size=1, + num_workers=num_workers, + filelist=filelist, + filelist_data=combined_filelist_data, + teacher_forcing_directory=teacher_forcing_directory, + output_dir=output_dir, + vocoder_model=None, + vocoder_config=None, + vocoder_global_step=None, + return_scores=True, + ) diff --git a/fs2/cli/check_data_heavy.py b/fs2/cli/check_data_heavy.py new file mode 100644 index 0000000..3c6563d --- /dev/null +++ b/fs2/cli/check_data_heavy.py @@ -0,0 +1,125 @@ +import torch +import torchaudio +from clipdetect import detect_clipping +from tqdm import tqdm + + +def check_datapoint( + item, + preprocessor, + evaluation_model, + word_seg_token=" ", + heavy_clip_detection=False, + heavy_objective_evaluation=False, +): + # speaking rate (words/second, float, scatterplot or bar chart) + # speaking rate (characters/second, float, scatterplot or bar chart) + # articulation level (mean energy/speaking rate) + # unrecognized symbols (bool, list) + # duration (float, box plot) + # clipping (float, box plot) + # silence % (float, box plot) + data_point = {k: v for k, v in item.items()} + characters = item.get("characters") + character_tokens = item.get("character_tokens") + phones = item.get("phones") + phone_tokens = item.get("phone_tokens") + assert ( + characters or phones + ), "Sorry, your data does not have characters or phones available in the filelist, so we can't check the data." + if character_tokens is None and phone_tokens is None: + character_tokens, phone_tokens, _ = preprocessor.process_text( + item, preprocessor.text_processor, use_pfs=False, encode_as_string=True + ) + default_text = phones if phones is not None else characters + n_words = len(default_text.split(word_seg_token)) + n_chars = len(character_tokens.split("/")) if character_tokens is not None else None + n_phones = len(phone_tokens.split("/")) if phone_tokens is not None else None + audio, sr = torchaudio.load( + str( + preprocessor.create_path( + item, "audio", f"audio-{preprocessor.input_sampling_rate}.wav" + ) + ) + ) + + if heavy_objective_evaluation: + # use objective metrics from https://pytorch.org/audio/main/tutorials/squim_tutorial.html + if sr != 16000: + model_audio = torchaudio.functional.resample(audio, sr, 16000) + if len(model_audio.size()) == 1: # must include channel + model_audio = model_audio.unsqueeze(0) + stoi_hyp, pesq_hyp, si_sdr_hyp = evaluation_model(model_audio) + data_point["stoi"] = float(stoi_hyp[0]) + data_point["pesq"] = float(pesq_hyp[0]) + data_point["si_sdr"] = float(si_sdr_hyp[0]) + + assert ( + len(audio.size()) == 1 or audio.size(0) == 1 + ), f"Audio has {audio.size(0)} channels, but should be mono" + audio = audio.squeeze() + + if heavy_clip_detection: + _, total_clipping = detect_clipping(audio) + else: + # this isn't a great way of detecting clipping, + # but it's a lot faster than clipdetect + audio_max = audio.max() + audio_min = audio.min() + total_clipping = ( + audio[audio >= audio_max].size(0) + audio[audio <= audio_min].size(0) - 2 + ) + pitch = torch.load( + preprocessor.create_path(item, "pitch", "pitch.pt"), weights_only=True + ) + energy = torch.load( + preprocessor.create_path(item, "energy", "energy.pt"), weights_only=True + ) + audio_length_s = len(audio) / preprocessor.input_sampling_rate + data_point["total_clipped_samples"] = total_clipping + data_point["pitch_min"] = float(pitch.min()) + data_point["pitch_max"] = float(pitch.max()) + data_point["pitch_mean"] = float(pitch.mean()) + data_point["pitch_std"] = float(pitch.std()) + data_point["energy_min"] = float(energy.min()) + data_point["energy_max"] = float(energy.max()) + data_point["energy_mean"] = float(energy.mean()) + data_point["energy_std"] = float(energy.std()) + data_point["duration"] = audio_length_s + data_point["speaking_rate_words_per_second"] = n_words / audio_length_s + if n_chars is not None: + data_point["speaking_rate_characters_per_second"] = n_chars / audio_length_s + data_point["n_chars"] = n_chars + if n_phones is not None: + data_point["speaking_rate_phones_per_second"] = n_phones / audio_length_s + data_point["n_phones"] = n_phones + data_point["n_missing_symbols"] = len( + preprocessor.text_processor.get_missing_symbols(default_text) + ) + data_point["n_words"] = n_words + return data_point + + +def check_data_from_filelist( + preprocessor, + filelist, + word_seg_token=" ", + heavy_clip_detection=False, + heavy_objective_evaluation=False, +): + data = [] + if heavy_objective_evaluation: + model = torchaudio.pipelines.SQUIM_OBJECTIVE.get_model() + else: + model = None + for item in tqdm(filelist, desc="Checking Data"): + data_point = check_datapoint( + item, + preprocessor, + model, + word_seg_token, + heavy_clip_detection, + heavy_objective_evaluation, + ) + data.append(data_point) + return data diff --git a/fs2/cli/synthesize.py b/fs2/cli/synthesize.py index 93704e4..b2d3dc9 100644 --- a/fs2/cli/synthesize.py +++ b/fs2/cli/synthesize.py @@ -65,8 +65,75 @@ def validate_data_keys_with_model_keys( sys.exit(1) +def load_data_from_filelist( + filelist: Path, + # model is of type ..model.FastSpeech2, but we make it Any to keep the CLI + # fast and enable mocking in unit testing. + model: Any, + text_representation: DatasetTextRepresentation, + language: str | None = None, + speaker: str | None = None, + default_language: str | None = None, + default_speaker: str | None = None, +): + + if default_language is None: + default_language = next(iter(model.lang2id.keys()), None) + if default_speaker is None: + default_speaker = next(iter(model.speaker2id.keys()), None) + + from everyvoice.utils import slugify + + data = model.config.training.filelist_loader(filelist) + try: + data = [ + d + | { + "basename": d.get( + "basename", + truncate_basename(slugify(d[text_representation.value])), + ), # Only truncate the basename if the basename doesn't already exist in the filelist. + "language": language or d.get("language", default_language), + "speaker": speaker or d.get("speaker", default_speaker), + } + for d in data + ] + except KeyError: + # TODO: Errors should have better formatting: + # https://github.com/EveryVoiceTTS/FastSpeech2_lightning/issues/26 + logger.info( + textwrap.dedent( + """ + EveryVoice only accepts filelists in PSV format as in: + + basename|characters|language|speaker + LJ0001|Hello|eng|LJ + + Or in a format where each new line is an utterance: + + This is a sentence. + Here is another sentence. + + Your filelist did not contain the correct keys so we will assume it is in the plain text format. + Text can either be defined as 'characters' or 'phones'. + """ + ) + ) + with open(filelist, encoding="utf8") as f: + data = [ + { + "basename": truncate_basename(slugify(line.strip())), + text_representation.value: line.strip(), + "language": language or default_language, + "speaker": speaker or default_speaker, + } + for line in f + ] + return data + + def prepare_data( - texts: list[str], + texts: Optional[list[str]], language: str | None, speaker: str | None, filelist: Path, @@ -98,51 +165,15 @@ def prepare_data( for text in texts ] else: - data = model.config.training.filelist_loader(filelist) - try: - data = [ - d - | { - "basename": d.get( - "basename", - truncate_basename(slugify(d[text_representation.value])), - ), # Only truncate the basename if the basename doesn't already exist in the filelist. - "language": language or d.get("language", DEFAULT_LANGUAGE), - "speaker": speaker or d.get("speaker", DEFAULT_SPEAKER), - } - for d in data - ] - except KeyError: - # TODO: Errors should have better formatting: - # https://github.com/EveryVoiceTTS/FastSpeech2_lightning/issues/26 - logger.info( - textwrap.dedent( - """ - EveryVoice only accepts filelists in PSV format as in: - - basename|characters|language|speaker - LJ0001|Hello|eng|LJ - - Or in a format where each new line is an utterance: - - This is a sentence. - Here is another sentence. - - Your filelist did not contain the correct keys so we will assume it is in the plain text format. - Text can either be defined as 'characters' or 'phones'. - """ - ) - ) - with open(filelist, encoding="utf8") as f: - data = [ - { - "basename": truncate_basename(slugify(line.strip())), - text_representation.value: line.strip(), - "language": language or DEFAULT_LANGUAGE, - "speaker": speaker or DEFAULT_SPEAKER, - } - for line in f - ] + data = load_data_from_filelist( + filelist, + model, + text_representation, + language, + speaker, + DEFAULT_LANGUAGE, + DEFAULT_SPEAKER, + ) validate_data_keys_with_model_keys( data_keys=set(d["language"] for d in data), @@ -204,7 +235,7 @@ def get_global_step(model_path: Path) -> int: def synthesize_helper( model, - texts: list[str], + texts: Optional[list[str]], style_reference: Optional[Path], language: Optional[str], speaker: Optional[str], @@ -218,6 +249,7 @@ def synthesize_helper( batch_size: int, num_workers: int, filelist: Path, + filelist_data: Optional[list[dict]], output_dir: Path, teacher_forcing_directory: Path, vocoder_global_step: Optional[int] = None, @@ -229,7 +261,6 @@ def synthesize_helper( It allows us to use the same command for synthesis via the CLI and via the gradio demo. """ - from everyvoice.text.phonemizer import AVAILABLE_G2P_ENGINES from ..dataset import FastSpeech2SynthesisDataModule @@ -241,31 +272,25 @@ def synthesize_helper( raise ValueError( f"Your model was trained on {model.config.model.target_text_representation_level} but you provided {text_representation.value} which is incompatible." ) - if ( - model.config.model.target_text_representation_level - != TargetTrainingTextRepresentationLevel.characters - and text_representation == DatasetTextRepresentation.characters - and language not in AVAILABLE_G2P_ENGINES - ): - raise ValueError( - f"Your model was trained on {model.config.model.target_text_representation_level} but you provided {text_representation.value} and there is no available grapheme-to-phoneme engine available for {language}. Please see for more information on how to add one." - ) - data = prepare_data( - texts=texts, - language=language, - speaker=speaker, - duration_control=duration_control if duration_control else 1.0, - filelist=filelist, - model=model, - text_representation=text_representation, - style_reference=style_reference, - ) + if filelist_data is None: + data: list[dict[Any, Any]] = prepare_data( + texts=texts, + language=language, + speaker=speaker, + duration_control=duration_control if duration_control else 1.0, + filelist=filelist, + model=model, + text_representation=text_representation, + style_reference=style_reference, + ) + else: + data = filelist_data if return_scores: from nltk.util import ngrams - token_counter = Counter() - trigram_counter = Counter() + token_counter: Counter = Counter() + trigram_counter: Counter = Counter() for line in tqdm( data, desc="calculating filelist statistics for score calculation" ): @@ -312,7 +337,7 @@ def synthesize_helper( else: if return_scores: raise ValueError( - "In order to return the scores, we also need access to the directory containing your ground truth audio. Please pass this in using the --teacher-forcing-directory option. e.g. --teacher-forcing-directory ./preprocessed" + "In order to return the scores, we also need access to the directory containing your ground truth audio and preprocessed data. Please pass this in using the --teacher-forcing-directory option. e.g. --teacher-forcing-directory ./preprocessed" ) teacher_forcing = False # overwrite batch_size and num_workers @@ -431,12 +456,6 @@ def synthesize( # noqa: C901 '**readalong-html**' will generate a single file Offline HTML ReadAlong that can be further edited in the ReadAlong Studio Editor, and opened by itself. Also implies '--output-type wav'. Requires --vocoder-path. """, ), - return_scores: bool = typer.Option( - False, - "--return-scores", - "-R", - help="ADVANCED. Setting this to True will change your batch size to 1 and output a PSV file with the losses for each synthesized audio along with a score of trigram density to measure the phonological importance of the utterance.", - ), teacher_forcing_directory: Path = typer.Option( None, "--teacher-forcing-directory", @@ -502,9 +521,6 @@ def synthesize( # noqa: C901 from ..model import FastSpeech2 - if return_scores: - batch_size = 1 - output_dir.mkdir(exist_ok=True, parents=True) # NOTE: We want to be able to put the vocoder on the proper accelerator for # it to be compatible with the vocoder's input device. @@ -564,10 +580,10 @@ def synthesize( # noqa: C901 batch_size=batch_size, num_workers=num_workers, filelist=filelist, + filelist_data=None, teacher_forcing_directory=teacher_forcing_directory, output_dir=output_dir, vocoder_model=vocoder_model, vocoder_config=vocoder_config, vocoder_global_step=vocoder_global_step, - return_scores=return_scores, ) diff --git a/fs2/prediction_writing_callback.py b/fs2/prediction_writing_callback.py index 8d5e5fc..7233ddd 100644 --- a/fs2/prediction_writing_callback.py +++ b/fs2/prediction_writing_callback.py @@ -34,19 +34,19 @@ def get_synthesis_output_callbacks( vocoder_config: Optional[HiFiGANConfig] = None, vocoder_global_step: Optional[int] = None, return_scores=False, -) -> dict[SynthesizeOutputFormats, Callback]: +) -> dict[SynthesizeOutputFormats | str, Callback]: """ Given a list of desired output file formats, return the proper callbacks that will generate those files. """ - callbacks: dict[SynthesizeOutputFormats, Callback] = {} + callbacks: dict[SynthesizeOutputFormats | str, Callback] = {} if return_scores: - callbacks['score'] = ScorerCallback( - config=config, - global_step=global_step, - output_dir=output_dir, - output_key=output_key, - ) + callbacks["score"] = ScorerCallback( + config=config, + global_step=global_step, + output_dir=output_dir, + output_key=output_key, + ) if ( SynthesizeOutputFormats.wav in output_type or SynthesizeOutputFormats.readalong_html in output_type @@ -84,25 +84,25 @@ def get_synthesis_output_callbacks( output_key=output_key, ) if SynthesizeOutputFormats.readalong_xml in output_type: - callbacks[SynthesizeOutputFormats.readalong_xml] = ( - PredictionWritingReadAlongCallback( - config=config, - global_step=global_step, - output_dir=output_dir, - output_key=output_key, - ) + callbacks[ + SynthesizeOutputFormats.readalong_xml + ] = PredictionWritingReadAlongCallback( + config=config, + global_step=global_step, + output_dir=output_dir, + output_key=output_key, ) if SynthesizeOutputFormats.readalong_html in output_type: wav_callback = callbacks[SynthesizeOutputFormats.wav] assert isinstance(wav_callback, PredictionWritingWavCallback) - callbacks[SynthesizeOutputFormats.readalong_html] = ( - PredictionWritingOfflineRASCallback( - config=config, - global_step=global_step, - output_dir=output_dir, - output_key=output_key, - wav_callback=wav_callback, - ) + callbacks[ + SynthesizeOutputFormats.readalong_html + ] = PredictionWritingOfflineRASCallback( + config=config, + global_step=global_step, + output_dir=output_dir, + output_key=output_key, + wav_callback=wav_callback, ) return callbacks @@ -161,7 +161,7 @@ def __init__( self.output_key = output_key self.config = config logger.info(f"Saving pytorch output to {self.save_dir}") - self.scores = [] + self.scores: list[dict] = [] def _get_filename(self) -> Path: path = self.save_dir / f"scores-{self.global_step}.psv" diff --git a/fs2/tests/test_cli.py b/fs2/tests/test_cli.py index dee96e4..247809e 100644 --- a/fs2/tests/test_cli.py +++ b/fs2/tests/test_cli.py @@ -14,10 +14,12 @@ DatasetTextRepresentation, TargetTrainingTextRepresentationLevel, ) -from everyvoice.tests.stubs import silence_c_stderr +from everyvoice.tests.basic_test_case import BasicTestCase +from everyvoice.tests.stubs import capture_stderr, silence_c_stderr from everyvoice.utils import generic_psv_filelist_reader from typer.testing import CliRunner +from ..cli.check_data_heavy import check_data_from_filelist from ..cli.cli import app from ..cli.synthesize import prepare_data as prepare_synthesize_data from ..cli.synthesize import validate_data_keys_with_model_keys @@ -308,6 +310,34 @@ def setUp(self) -> None: "train", ) + def test_check_data(self): + filelist = generic_psv_filelist_reader(BasicTestCase.data_dir / "metadata.psv") + with capture_stderr(): + checked_data = check_data_from_filelist( + filelist, heavy_objective_evaluation=True + ) + self.assertIn("pesq", checked_data[0]) + self.assertIn("stoi", checked_data[0]) + self.assertIn("si_sdr", checked_data[0]) + self.assertGreater(checked_data[0]["pesq"], 3.0) + self.assertLess(checked_data[0]["pesq"], 5.0) + self.assertAlmostEqual(checked_data[0]["duration"], 5.17, 2) + + # def test_compute_stats(self): + # feat_prediction_config = EveryVoiceConfig.load_config_from_path().feature_prediction + # preprocessor = Preprocessor(feat_prediction_config) + # preprocessor.compute_stats() + # self.assertEqual( + # self.preprocessor.config["preprocessing"]["audio"]["mel_mean"], + # -4.018, + # places=3, + # ) + # self.assertEqual( + # self.preprocessor.config["preprocessing"]["audio"]["mel_std"], + # 4.017, + # places=3, + # ) + def test_commands_present(self): """ Each subcommand is present in the the command's help message. From ba627a7f798256c3446df73489d4d78e821a2096 Mon Sep 17 00:00:00 2001 From: Aidan Pine Date: Wed, 22 Jan 2025 04:30:37 +0000 Subject: [PATCH 3/3] fix(test): fix check data test --- fs2/tests/test_cli.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fs2/tests/test_cli.py b/fs2/tests/test_cli.py index 247809e..21241b1 100644 --- a/fs2/tests/test_cli.py +++ b/fs2/tests/test_cli.py @@ -15,6 +15,7 @@ TargetTrainingTextRepresentationLevel, ) from everyvoice.tests.basic_test_case import BasicTestCase +from everyvoice.tests.preprocessed_audio_fixture import PreprocessedAudioFixture from everyvoice.tests.stubs import capture_stderr, silence_c_stderr from everyvoice.utils import generic_psv_filelist_reader from typer.testing import CliRunner @@ -296,12 +297,13 @@ def test_not_multispeaker_with_speaker(self): ) -class CLITest(TestCase): +class CLITest(PreprocessedAudioFixture, BasicTestCase): """ Validate that all subcommands are accessible. """ def setUp(self) -> None: + super().setUp() self.runner = CliRunner() self.subcommands = ( "benchmark", @@ -314,7 +316,7 @@ def test_check_data(self): filelist = generic_psv_filelist_reader(BasicTestCase.data_dir / "metadata.psv") with capture_stderr(): checked_data = check_data_from_filelist( - filelist, heavy_objective_evaluation=True + self.preprocessor, filelist, heavy_objective_evaluation=True ) self.assertIn("pesq", checked_data[0]) self.assertIn("stoi", checked_data[0])