Skip to content

Commit

Permalink
init api
Browse files Browse the repository at this point in the history
  • Loading branch information
BeautyyuYanli committed Jun 5, 2024
1 parent c57a782 commit fa05041
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 4 deletions.
153 changes: 153 additions & 0 deletions src/gpt_sovits/infer/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from gpt_sovits.infer.inference import GPTSoVITSInference
from pydantic import BaseModel
from typing import List, Tuple, Optional
from pathlib import Path
from threading import Lock
import os, sys


class ConfigData(BaseModel):
models: List[str]
prompts: List[str]


class GPTSoVITSAPI:
config_data_base: Path
config_data: ConfigData
working_config: Tuple[str, str]

inference_worker: GPTSoVITSInference
inference_worker_lock: Lock

def __init__(
self,
config_data_base: str,
inference_worker_and_lock: Tuple[GPTSoVITSInference, Lock],
model_name: Optional[str] = None,
prompt_name: Optional[str] = None,
):
self.config_data_base = Path(config_data_base)
model_files = os.listdir(str(self.config_data_base / "models"))
models = [
model_file.split(".")[0]
for model_file in model_files
if model_file.endswith(".pth")
and model_file.replace(".pth", ".ckpt") in model_files
]
prompt_files = os.listdir(str(self.config_data_base / "prompts"))
prompts = [
prompt_file.split(".")[0]
for prompt_file in prompt_files
if prompt_file.endswith(".txt")
and prompt_file.replace(".txt", ".wav") in prompt_files
]
self.config_data = ConfigData(models=models, prompts=prompts)

self.inference_worker = inference_worker_and_lock[0]
self.inference_worker_lock = inference_worker_and_lock[1]
self.working_config = (
model_name if model_name else models[0],
prompt_name if prompt_name else prompts[0],
)
self._load_model(self.working_config[0])
self._load_prompt(self.working_config[1])

def _load_model(self, model_name: str):
self.inference_worker.load_sovits(
str(self.config_data_base / "models" / f"{model_name}.pth")
)
self.inference_worker.load_gpt(
str(self.config_data_base / "models" / f"{model_name}.ckpt")
)

def _load_prompt(self, prompt_name: str):
with open(self.config_data_base / "prompts" / f"{prompt_name}.txt", "r") as f:
prompt_text = f.read().strip()
self.inference_worker.set_prompt_audio(
prompt_text=prompt_text,
prompt_audio_path=str(
self.config_data_base / "prompts" / f"{prompt_name}.wav"
),
)

def generate(
self,
text: str,
text_language="auto",
top_k=5,
top_p=1,
temperature=1,
model_name: Optional[str] = None,
prompt_name: Optional[str] = None,
):
config = (
model_name if model_name else self.working_config[0],
prompt_name if prompt_name else self.working_config[1],
)
with self.inference_worker_lock:
if config[0] != self.working_config[0]:
self._load_model(config[0])
if config[1] != self.working_config[1]:
self._load_prompt(config[1])
self.working_config = config
return self.inference_worker.get_tts_wav(
text=text,
text_language=text_language,
top_k=top_k,
top_p=top_p,
temperature=temperature,
)

def generate_stream(
self,
text: str,
text_language="auto",
top_k=5,
top_p=1,
temperature=1,
model_name: Optional[str] = None,
prompt_name: Optional[str] = None,
):
config = (
model_name if model_name else self.working_config[0],
prompt_name if prompt_name else self.working_config[1],
)
with self.inference_worker_lock:
if config[0] != self.working_config[0]:
self._load_model(config[0])
if config[1] != self.working_config[1]:
self._load_prompt(config[1])
self.working_config = config
for thing in self.inference_worker.get_tts_wav_stream(
text=text,
text_language=text_language,
top_k=top_k,
top_p=top_p,
temperature=temperature,
):
yield thing

return None


if __name__ == "__main__":
from scipy.io import wavfile

inference_worker_and_lock = (
GPTSoVITSInference(
bert_path="pretrained_models/chinese-roberta-wwm-ext-large",
cnhubert_base_path="pretrained_models/chinese-hubert-base",
),
Lock(),
)

api = GPTSoVITSAPI(
config_data_base="config_data",
inference_worker_and_lock=inference_worker_and_lock,
)
for idx, (sr, data) in enumerate(
api.generate_stream("鲁迅为什么暴打周树人?这是一个问题")
):
wavfile.write(f"playground/output/output{idx}.wav", sr, data)
sr, data = api.generate("鲁迅为什么暴打周树人?这是一个问题")
wavfile.write("playground/output.wav", sr, data)
10 changes: 6 additions & 4 deletions src/gpt_sovits/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import librosa
import sys
import importlib.util
from contextlib import contextmanager
from typing import Optional, Any, TypeVar, cast, List
from queue import Queue
from transformers import AutoModelForMaskedLM, AutoTokenizer
Expand Down Expand Up @@ -47,7 +48,8 @@ def __delattr__(self, item):
raise AttributeError(f"Attribute {item} not found")


def _load_sovits(sovits_path: str):
@contextmanager
def _tmp_sys_path():
package_name = "gpt_sovits"
spec = importlib.util.find_spec(package_name)
if spec is not None:
Expand All @@ -60,9 +62,8 @@ def _load_sovits(sovits_path: str):
raise ModuleNotFoundError(f"Package {package_name} not found.")

sys.path.append(tmp_path)
dict_s2 = torch.load(sovits_path, map_location="cpu")
yield
sys.path.remove(tmp_path)
return dict_s2


def clean_text_inf(text, language):
Expand Down Expand Up @@ -167,7 +168,8 @@ def __init__(
self.ssl_model = self._prepare_torch(cnhubert.get_model())

def load_sovits(self, sovits_path: str):
dict_s2 = _load_sovits(sovits_path)
with _tmp_sys_path():
dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
Expand Down

0 comments on commit fa05041

Please sign in to comment.