diff --git a/src/gpt_sovits/infer/__init__.py b/src/gpt_sovits/infer/__init__.py index 58b44b79..34fec415 100644 --- a/src/gpt_sovits/infer/__init__.py +++ b/src/gpt_sovits/infer/__init__.py @@ -1,3 +1,3 @@ -from gpt_sovits.infer.worker import GPTSoVITSInference +from gpt_sovits.infer.inference import GPTSoVITSInference __all__ = ["GPTSoVITSInference"] diff --git a/src/gpt_sovits/infer/worker.py b/src/gpt_sovits/infer/inference.py similarity index 99% rename from src/gpt_sovits/infer/worker.py rename to src/gpt_sovits/infer/inference.py index 633b4531..bc08fc94 100644 --- a/src/gpt_sovits/infer/worker.py +++ b/src/gpt_sovits/infer/inference.py @@ -456,9 +456,6 @@ def get_tts_wav( return self.sample_rate, np.concatenate(audio_opt) -Worker = GPTSoVITSInference - - if __name__ == "__main__": from scipy.io import wavfile diff --git a/src/gpt_sovits/infer/inference_pool.py b/src/gpt_sovits/infer/inference_pool.py new file mode 100644 index 00000000..194d59d9 --- /dev/null +++ b/src/gpt_sovits/infer/inference_pool.py @@ -0,0 +1,143 @@ +from typing import List, Tuple, Optional +import numpy as np +import tqdm +import torch +from gpt_sovits.infer.inference import GPTSoVITSInference +from gpt_sovits.infer.text_utils import clean_and_cut_text +from concurrent.futures import Future, ProcessPoolExecutor +from multiprocessing import current_process + + +def worker_init( + bert_path: str, + cnhubert_base_path: str, + gpt_path: str, + sovits_path: str, + prompt_text: Optional[str], + prompt_language: str = "auto", + prompt_audio_path: Optional[str] = None, + prompt_audio_data: Optional[np.ndarray] = None, + prompt_audio_sr: Optional[int] = None, + device: Optional[str] = None, + is_half: bool = True, +): + global worker + if device is None: + if torch.cuda.is_available(): + cnt = torch.cuda.device_count() + num = current_process()._identity[0] + device = f"cuda:{num % cnt}" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + worker = GPTSoVITSInference( + bert_path=bert_path, + cnhubert_base_path=cnhubert_base_path, + device=device, + is_half=is_half, + ) + worker.load_gpt(gpt_path) + worker.load_sovits(sovits_path) + worker.set_prompt_audio( + prompt_text, + prompt_language, + prompt_audio_path, + prompt_audio_data, + prompt_audio_sr, + ) + + +def worker_get_tts_wav_piece( + text: str, + text_language="auto", + top_k=5, + top_p=1, + temperature=1, +): + global worker + return worker.get_tts_wav_piece( + text, + text_language, + top_k, + top_p, + temperature, + ) + + +class GPTSoVITSInferencePool: + pool: ProcessPoolExecutor + + def __init__( + self, + bert_path: str, + cnhubert_base_path: str, + gpt_path: str, + sovits_path: str, + prompt_text: Optional[str], + prompt_language: str = "auto", + prompt_audio_path: Optional[str] = None, + prompt_audio_data: Optional[np.ndarray] = None, + prompt_audio_sr: Optional[int] = None, + device: Optional[str] = None, + is_half: bool = True, + max_workers: int = 4, + ): + self.pool = ProcessPoolExecutor( + max_workers=max_workers, + initializer=worker_init, + initargs=( + bert_path, + cnhubert_base_path, + gpt_path, + sovits_path, + prompt_text, + prompt_language, + prompt_audio_path, + prompt_audio_data, + prompt_audio_sr, + device, + is_half, + ), + ) + + def get_tts_wav_stream( + self, + text: str, + text_language="auto", + top_k=5, + top_p=1, + temperature=1, + ): + tasks = clean_and_cut_text(text) + futures = [ + self.pool.submit( + worker_get_tts_wav_piece, + task, + text_language, + top_k, + top_p, + temperature, + ) + for task in tasks + ] + for future in futures: + yield future.result() + + def get_tts_wav( + self, + text: str, + text_language="auto", + top_k=5, + top_p=1, + temperature=1, + ): + audio_list: List[Tuple[int, np.ndarray]] = [] + for thing in tqdm.tqdm( + self.get_tts_wav_stream(text, text_language, top_k, top_p, temperature) + ): + audio_list.append(thing) + return audio_list[0][0], np.concatenate( + [data for _, data in audio_list], axis=0 + ) diff --git a/src/gpt_sovits/infer/interface.py b/src/gpt_sovits/infer/interface.py index d03267d4..065c86a4 100644 --- a/src/gpt_sovits/infer/interface.py +++ b/src/gpt_sovits/infer/interface.py @@ -1,4 +1,4 @@ -from gpt_sovits.infer.worker import GPTSoVITSInference +from gpt_sovits.infer.inference import GPTSoVITSInference from pydantic import BaseModel from typing import List, Tuple, Optional from pathlib import Path