From 2ec4e8590eb29af317adc04ef693502695bf4827 Mon Sep 17 00:00:00 2001 From: Yanli Date: Thu, 13 Jun 2024 16:59:45 +0800 Subject: [PATCH] add interface --- pyproject.toml | 2 +- src/gpt_sovits/infer/__init__.py | 2 + src/gpt_sovits/infer/inference.py | 6 +- src/gpt_sovits/infer/interface.py | 165 +++++++++++++++--------------- 4 files changed, 90 insertions(+), 85 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ef65d6e3..018df0fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "GPT-SoVITS-Infer" -version = "0.2.4" +version = "0.2.5" description = "Inference code for GPT-SoVITS" authors = [ {name = "Yanli",email = "mail@yanli.one"}, diff --git a/src/gpt_sovits/infer/__init__.py b/src/gpt_sovits/infer/__init__.py index 5b4de091..09f5e954 100644 --- a/src/gpt_sovits/infer/__init__.py +++ b/src/gpt_sovits/infer/__init__.py @@ -1,7 +1,9 @@ from gpt_sovits.infer.inference import GPTSoVITSInference from gpt_sovits.infer.inference_pool import GPTSoVITSInferencePool +from gpt_sovits.infer.interface import GPTSoVITSInferenceSimple __all__ = [ "GPTSoVITSInference", "GPTSoVITSInferencePool", + "GPTSoVITSInferenceSimple", ] diff --git a/src/gpt_sovits/infer/inference.py b/src/gpt_sovits/infer/inference.py index 668be3fb..13002115 100644 --- a/src/gpt_sovits/infer/inference.py +++ b/src/gpt_sovits/infer/inference.py @@ -298,12 +298,16 @@ def set_prompt_audio( path=prompt_audio_path, sr=None ) else: - if not prompt_audio_data or not prompt_audio_sr: + if (prompt_audio_data is None) or (prompt_audio_sr is None): raise ValueError( "When prompt_audio_path is not given, prompt_audio_data and prompt_audio_sr must be given." ) self.prompt_audio_sr = cast(int, prompt_audio_sr) self.prompt_audio_data = cast(np.ndarray, prompt_audio_data) + if self.prompt_audio_data.dtype == np.int16: + self.prompt_audio_data = ( + self.prompt_audio_data.astype(self.np_dtype) / 32768 + ) with torch.no_grad(): wav16k = librosa.resample( diff --git a/src/gpt_sovits/infer/interface.py b/src/gpt_sovits/infer/interface.py index 065c86a4..ce96f301 100644 --- a/src/gpt_sovits/infer/interface.py +++ b/src/gpt_sovits/infer/interface.py @@ -1,20 +1,60 @@ from gpt_sovits.infer.inference import GPTSoVITSInference from pydantic import BaseModel -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Dict, Union, TYPE_CHECKING from pathlib import Path from threading import Lock import os, sys +if TYPE_CHECKING: + import numpy as np + class ConfigData(BaseModel): - models: List[str] - prompts: List[str] + bert_path: Path + cnhubert_base_path: Path + models: Dict[str, Tuple[Path, Path]] + prompts: Dict[str, Tuple[Path, Path]] + + +def read_config_data(config_data_base: str) -> ConfigData: + config_data_base = Path(config_data_base) + model_files = os.listdir(config_data_base / "models") + models = { + model_file.split(".")[0]: ( + config_data_base / "models" / model_file, + config_data_base / "models" / model_file.replace(".pth", ".ckpt"), + ) + for model_file in model_files + if model_file.endswith(".pth") + and model_file.replace(".pth", ".ckpt") in model_files + } + prompt_files = os.listdir(Path(config_data_base) / "prompts") + prompts = { + prompt_file.split(".")[0]: ( + config_data_base / "prompts" / prompt_file, + config_data_base / "prompts" / prompt_file.replace(".txt", ".wav"), + ) + for prompt_file in prompt_files + if prompt_file.endswith(".txt") + and prompt_file.replace(".txt", ".wav") in prompt_files + } + cnhubert_base_path = str(Path(config_data_base) / "chinese-hubert-base") + bert_path = str(Path(config_data_base) / "chinese-roberta-wwm-ext-large") + return ConfigData( + bert_path=bert_path, + cnhubert_base_path=cnhubert_base_path, + models=models, + prompts=prompts, + ) + + +PromptType = Union[str, Tuple[str, int, "np.ndarray"]] class GPTSoVITSInferenceSimple: - config_data_base: Path config_data: ConfigData - working_config: Tuple[str, str] + working_model: Optional[str] + working_prompt: Optional[str] inference_worker: GPTSoVITSInference inference_worker_lock: Lock @@ -22,74 +62,64 @@ class GPTSoVITSInferenceSimple: def __init__( self, config_data_base: str, - inference_worker_and_lock: Tuple[GPTSoVITSInference, Lock], - model_name: Optional[str] = None, - prompt_name: Optional[str] = None, + device: Optional[str] = None, + is_half: Optional[bool] = False, ): - self.config_data_base = Path(config_data_base) - model_files = os.listdir(str(self.config_data_base / "models")) - models = [ - model_file.split(".")[0] - for model_file in model_files - if model_file.endswith(".pth") - and model_file.replace(".pth", ".ckpt") in model_files - ] - prompt_files = os.listdir(str(self.config_data_base / "prompts")) - prompts = [ - prompt_file.split(".")[0] - for prompt_file in prompt_files - if prompt_file.endswith(".txt") - and prompt_file.replace(".txt", ".wav") in prompt_files - ] - self.config_data = ConfigData(models=models, prompts=prompts) - - self.inference_worker = inference_worker_and_lock[0] - self.inference_worker_lock = inference_worker_and_lock[1] - self.working_config = ( - model_name if model_name else models[0], - prompt_name if prompt_name else prompts[0], + self.config_data = read_config_data(config_data_base) + self.working_model = None + self.working_prompt = None + self.inference_worker = GPTSoVITSInference( + self.config_data.bert_path, + self.config_data.cnhubert_base_path, + device=device, + is_half=is_half, ) - self._load_model(self.working_config[0]) - self._load_prompt(self.working_config[1]) + self.inference_worker_lock = Lock() def _load_model(self, model_name: str): self.inference_worker.load_sovits( - str(self.config_data_base / "models" / f"{model_name}.pth") + self.config_data.models[model_name][0], ) self.inference_worker.load_gpt( - str(self.config_data_base / "models" / f"{model_name}.ckpt") + self.config_data.models[model_name][1], ) def _load_prompt(self, prompt_name: str): - with open(self.config_data_base / "prompts" / f"{prompt_name}.txt", "r") as f: + with open(self.config_data.prompts[prompt_name][0], "r") as f: prompt_text = f.read().strip() self.inference_worker.set_prompt_audio( prompt_text=prompt_text, - prompt_audio_path=str( - self.config_data_base / "prompts" / f"{prompt_name}.wav" - ), + prompt_audio_path=self.config_data.prompts[prompt_name][1], ) + def _load_things(self, model_name: str, prompt: PromptType): + if model_name != self.working_model: + self._load_model(model_name) + self.working_model = model_name + if isinstance(prompt, str): + if prompt != self.working_model: + self._load_prompt(prompt) + self.working_prompt = prompt + else: + self.inference_worker.set_prompt_audio( + prompt_text=prompt[0], + prompt_audio_data=prompt[2], + prompt_audio_sr=prompt[1], + ) + self.working_prompt = None + def generate( self, + model_name: str, + prompt: PromptType, text: str, text_language="auto", top_k=5, top_p=1, temperature=1, - model_name: Optional[str] = None, - prompt_name: Optional[str] = None, ): - config = ( - model_name if model_name else self.working_config[0], - prompt_name if prompt_name else self.working_config[1], - ) with self.inference_worker_lock: - if config[0] != self.working_config[0]: - self._load_model(config[0]) - if config[1] != self.working_config[1]: - self._load_prompt(config[1]) - self.working_config = config + self._load_things(model_name, prompt) return self.inference_worker.get_tts_wav( text=text, text_language=text_language, @@ -100,24 +130,16 @@ def generate( def generate_stream( self, + model_name: str, + prompt: PromptType, text: str, text_language="auto", top_k=5, top_p=1, temperature=1, - model_name: Optional[str] = None, - prompt_name: Optional[str] = None, ): - config = ( - model_name if model_name else self.working_config[0], - prompt_name if prompt_name else self.working_config[1], - ) with self.inference_worker_lock: - if config[0] != self.working_config[0]: - self._load_model(config[0]) - if config[1] != self.working_config[1]: - self._load_prompt(config[1]) - self.working_config = config + self._load_things(model_name, prompt) for thing in self.inference_worker.get_tts_wav_stream( text=text, text_language=text_language, @@ -128,26 +150,3 @@ def generate_stream( yield thing return None - - -if __name__ == "__main__": - from scipy.io import wavfile - - inference_worker_and_lock = ( - GPTSoVITSInference( - bert_path="pretrained_models/chinese-roberta-wwm-ext-large", - cnhubert_base_path="pretrained_models/chinese-hubert-base", - ), - Lock(), - ) - - api = GPTSoVITSInferenceSimple( - config_data_base="config_data", - inference_worker_and_lock=inference_worker_and_lock, - ) - for idx, (sr, data) in enumerate( - api.generate_stream("鲁迅为什么暴打周树人?这是一个问题") - ): - wavfile.write(f"playground/output/output{idx}.wav", sr, data) - sr, data = api.generate("鲁迅为什么暴打周树人?这是一个问题") - wavfile.write("playground/output.wav", sr, data)