Skip to content

Commit

Permalink
add interface
Browse files Browse the repository at this point in the history
  • Loading branch information
BeautyyuYanli committed Jun 13, 2024
1 parent 0f3a367 commit 2ec4e85
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 85 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "GPT-SoVITS-Infer"
version = "0.2.4"
version = "0.2.5"
description = "Inference code for GPT-SoVITS"
authors = [
{name = "Yanli",email = "[email protected]"},
Expand Down
2 changes: 2 additions & 0 deletions src/gpt_sovits/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from gpt_sovits.infer.inference import GPTSoVITSInference
from gpt_sovits.infer.inference_pool import GPTSoVITSInferencePool
from gpt_sovits.infer.interface import GPTSoVITSInferenceSimple

__all__ = [
"GPTSoVITSInference",
"GPTSoVITSInferencePool",
"GPTSoVITSInferenceSimple",
]
6 changes: 5 additions & 1 deletion src/gpt_sovits/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,16 @@ def set_prompt_audio(
path=prompt_audio_path, sr=None
)
else:
if not prompt_audio_data or not prompt_audio_sr:
if (prompt_audio_data is None) or (prompt_audio_sr is None):
raise ValueError(
"When prompt_audio_path is not given, prompt_audio_data and prompt_audio_sr must be given."
)
self.prompt_audio_sr = cast(int, prompt_audio_sr)
self.prompt_audio_data = cast(np.ndarray, prompt_audio_data)
if self.prompt_audio_data.dtype == np.int16:
self.prompt_audio_data = (
self.prompt_audio_data.astype(self.np_dtype) / 32768
)

with torch.no_grad():
wav16k = librosa.resample(
Expand Down
165 changes: 82 additions & 83 deletions src/gpt_sovits/infer/interface.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,125 @@
from gpt_sovits.infer.inference import GPTSoVITSInference
from pydantic import BaseModel
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Dict, Union, TYPE_CHECKING
from pathlib import Path
from threading import Lock
import os, sys

if TYPE_CHECKING:
import numpy as np


class ConfigData(BaseModel):
models: List[str]
prompts: List[str]
bert_path: Path
cnhubert_base_path: Path
models: Dict[str, Tuple[Path, Path]]
prompts: Dict[str, Tuple[Path, Path]]


def read_config_data(config_data_base: str) -> ConfigData:
config_data_base = Path(config_data_base)
model_files = os.listdir(config_data_base / "models")
models = {
model_file.split(".")[0]: (
config_data_base / "models" / model_file,
config_data_base / "models" / model_file.replace(".pth", ".ckpt"),
)
for model_file in model_files
if model_file.endswith(".pth")
and model_file.replace(".pth", ".ckpt") in model_files
}
prompt_files = os.listdir(Path(config_data_base) / "prompts")
prompts = {
prompt_file.split(".")[0]: (
config_data_base / "prompts" / prompt_file,
config_data_base / "prompts" / prompt_file.replace(".txt", ".wav"),
)
for prompt_file in prompt_files
if prompt_file.endswith(".txt")
and prompt_file.replace(".txt", ".wav") in prompt_files
}
cnhubert_base_path = str(Path(config_data_base) / "chinese-hubert-base")
bert_path = str(Path(config_data_base) / "chinese-roberta-wwm-ext-large")
return ConfigData(
bert_path=bert_path,
cnhubert_base_path=cnhubert_base_path,
models=models,
prompts=prompts,
)


PromptType = Union[str, Tuple[str, int, "np.ndarray"]]


class GPTSoVITSInferenceSimple:
config_data_base: Path
config_data: ConfigData
working_config: Tuple[str, str]
working_model: Optional[str]
working_prompt: Optional[str]

inference_worker: GPTSoVITSInference
inference_worker_lock: Lock

def __init__(
self,
config_data_base: str,
inference_worker_and_lock: Tuple[GPTSoVITSInference, Lock],
model_name: Optional[str] = None,
prompt_name: Optional[str] = None,
device: Optional[str] = None,
is_half: Optional[bool] = False,
):
self.config_data_base = Path(config_data_base)
model_files = os.listdir(str(self.config_data_base / "models"))
models = [
model_file.split(".")[0]
for model_file in model_files
if model_file.endswith(".pth")
and model_file.replace(".pth", ".ckpt") in model_files
]
prompt_files = os.listdir(str(self.config_data_base / "prompts"))
prompts = [
prompt_file.split(".")[0]
for prompt_file in prompt_files
if prompt_file.endswith(".txt")
and prompt_file.replace(".txt", ".wav") in prompt_files
]
self.config_data = ConfigData(models=models, prompts=prompts)

self.inference_worker = inference_worker_and_lock[0]
self.inference_worker_lock = inference_worker_and_lock[1]
self.working_config = (
model_name if model_name else models[0],
prompt_name if prompt_name else prompts[0],
self.config_data = read_config_data(config_data_base)
self.working_model = None
self.working_prompt = None
self.inference_worker = GPTSoVITSInference(
self.config_data.bert_path,
self.config_data.cnhubert_base_path,
device=device,
is_half=is_half,
)
self._load_model(self.working_config[0])
self._load_prompt(self.working_config[1])
self.inference_worker_lock = Lock()

def _load_model(self, model_name: str):
self.inference_worker.load_sovits(
str(self.config_data_base / "models" / f"{model_name}.pth")
self.config_data.models[model_name][0],
)
self.inference_worker.load_gpt(
str(self.config_data_base / "models" / f"{model_name}.ckpt")
self.config_data.models[model_name][1],
)

def _load_prompt(self, prompt_name: str):
with open(self.config_data_base / "prompts" / f"{prompt_name}.txt", "r") as f:
with open(self.config_data.prompts[prompt_name][0], "r") as f:
prompt_text = f.read().strip()
self.inference_worker.set_prompt_audio(
prompt_text=prompt_text,
prompt_audio_path=str(
self.config_data_base / "prompts" / f"{prompt_name}.wav"
),
prompt_audio_path=self.config_data.prompts[prompt_name][1],
)

def _load_things(self, model_name: str, prompt: PromptType):
if model_name != self.working_model:
self._load_model(model_name)
self.working_model = model_name
if isinstance(prompt, str):
if prompt != self.working_model:
self._load_prompt(prompt)
self.working_prompt = prompt
else:
self.inference_worker.set_prompt_audio(
prompt_text=prompt[0],
prompt_audio_data=prompt[2],
prompt_audio_sr=prompt[1],
)
self.working_prompt = None

def generate(
self,
model_name: str,
prompt: PromptType,
text: str,
text_language="auto",
top_k=5,
top_p=1,
temperature=1,
model_name: Optional[str] = None,
prompt_name: Optional[str] = None,
):
config = (
model_name if model_name else self.working_config[0],
prompt_name if prompt_name else self.working_config[1],
)
with self.inference_worker_lock:
if config[0] != self.working_config[0]:
self._load_model(config[0])
if config[1] != self.working_config[1]:
self._load_prompt(config[1])
self.working_config = config
self._load_things(model_name, prompt)
return self.inference_worker.get_tts_wav(
text=text,
text_language=text_language,
Expand All @@ -100,24 +130,16 @@ def generate(

def generate_stream(
self,
model_name: str,
prompt: PromptType,
text: str,
text_language="auto",
top_k=5,
top_p=1,
temperature=1,
model_name: Optional[str] = None,
prompt_name: Optional[str] = None,
):
config = (
model_name if model_name else self.working_config[0],
prompt_name if prompt_name else self.working_config[1],
)
with self.inference_worker_lock:
if config[0] != self.working_config[0]:
self._load_model(config[0])
if config[1] != self.working_config[1]:
self._load_prompt(config[1])
self.working_config = config
self._load_things(model_name, prompt)
for thing in self.inference_worker.get_tts_wav_stream(
text=text,
text_language=text_language,
Expand All @@ -128,26 +150,3 @@ def generate_stream(
yield thing

return None


if __name__ == "__main__":
from scipy.io import wavfile

inference_worker_and_lock = (
GPTSoVITSInference(
bert_path="pretrained_models/chinese-roberta-wwm-ext-large",
cnhubert_base_path="pretrained_models/chinese-hubert-base",
),
Lock(),
)

api = GPTSoVITSInferenceSimple(
config_data_base="config_data",
inference_worker_and_lock=inference_worker_and_lock,
)
for idx, (sr, data) in enumerate(
api.generate_stream("鲁迅为什么暴打周树人?这是一个问题")
):
wavfile.write(f"playground/output/output{idx}.wav", sr, data)
sr, data = api.generate("鲁迅为什么暴打周树人?这是一个问题")
wavfile.write("playground/output.wav", sr, data)

0 comments on commit 2ec4e85

Please sign in to comment.