Skip to content

Commit

Permalink
inference pool
Browse files Browse the repository at this point in the history
  • Loading branch information
BeautyyuYanli committed Jun 6, 2024
1 parent 844878c commit 402fc88
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/gpt_sovits/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from gpt_sovits.infer.worker import GPTSoVITSInference
from gpt_sovits.infer.inference import GPTSoVITSInference

__all__ = ["GPTSoVITSInference"]
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,6 @@ def get_tts_wav(
return self.sample_rate, np.concatenate(audio_opt)


Worker = GPTSoVITSInference


if __name__ == "__main__":

from scipy.io import wavfile
Expand Down
143 changes: 143 additions & 0 deletions src/gpt_sovits/infer/inference_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from typing import List, Tuple, Optional
import numpy as np
import tqdm
import torch
from gpt_sovits.infer.inference import GPTSoVITSInference
from gpt_sovits.infer.text_utils import clean_and_cut_text
from concurrent.futures import Future, ProcessPoolExecutor
from multiprocessing import current_process


def worker_init(
bert_path: str,
cnhubert_base_path: str,
gpt_path: str,
sovits_path: str,
prompt_text: Optional[str],
prompt_language: str = "auto",
prompt_audio_path: Optional[str] = None,
prompt_audio_data: Optional[np.ndarray] = None,
prompt_audio_sr: Optional[int] = None,
device: Optional[str] = None,
is_half: bool = True,
):
global worker
if device is None:
if torch.cuda.is_available():
cnt = torch.cuda.device_count()
num = current_process()._identity[0]
device = f"cuda:{num % cnt}"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"

worker = GPTSoVITSInference(
bert_path=bert_path,
cnhubert_base_path=cnhubert_base_path,
device=device,
is_half=is_half,
)
worker.load_gpt(gpt_path)
worker.load_sovits(sovits_path)
worker.set_prompt_audio(
prompt_text,
prompt_language,
prompt_audio_path,
prompt_audio_data,
prompt_audio_sr,
)


def worker_get_tts_wav_piece(
text: str,
text_language="auto",
top_k=5,
top_p=1,
temperature=1,
):
global worker
return worker.get_tts_wav_piece(
text,
text_language,
top_k,
top_p,
temperature,
)


class GPTSoVITSInferencePool:
pool: ProcessPoolExecutor

def __init__(
self,
bert_path: str,
cnhubert_base_path: str,
gpt_path: str,
sovits_path: str,
prompt_text: Optional[str],
prompt_language: str = "auto",
prompt_audio_path: Optional[str] = None,
prompt_audio_data: Optional[np.ndarray] = None,
prompt_audio_sr: Optional[int] = None,
device: Optional[str] = None,
is_half: bool = True,
max_workers: int = 4,
):
self.pool = ProcessPoolExecutor(
max_workers=max_workers,
initializer=worker_init,
initargs=(
bert_path,
cnhubert_base_path,
gpt_path,
sovits_path,
prompt_text,
prompt_language,
prompt_audio_path,
prompt_audio_data,
prompt_audio_sr,
device,
is_half,
),
)

def get_tts_wav_stream(
self,
text: str,
text_language="auto",
top_k=5,
top_p=1,
temperature=1,
):
tasks = clean_and_cut_text(text)
futures = [
self.pool.submit(
worker_get_tts_wav_piece,
task,
text_language,
top_k,
top_p,
temperature,
)
for task in tasks
]
for future in futures:
yield future.result()

def get_tts_wav(
self,
text: str,
text_language="auto",
top_k=5,
top_p=1,
temperature=1,
):
audio_list: List[Tuple[int, np.ndarray]] = []
for thing in tqdm.tqdm(
self.get_tts_wav_stream(text, text_language, top_k, top_p, temperature)
):
audio_list.append(thing)
return audio_list[0][0], np.concatenate(
[data for _, data in audio_list], axis=0
)
2 changes: 1 addition & 1 deletion src/gpt_sovits/infer/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gpt_sovits.infer.worker import GPTSoVITSInference
from gpt_sovits.infer.inference import GPTSoVITSInference
from pydantic import BaseModel
from typing import List, Tuple, Optional
from pathlib import Path
Expand Down

0 comments on commit 402fc88

Please sign in to comment.