diff --git a/src/gpt_sovits/infer/__init__.py b/src/gpt_sovits/infer/__init__.py index 34fec415..58b44b79 100644 --- a/src/gpt_sovits/infer/__init__.py +++ b/src/gpt_sovits/infer/__init__.py @@ -1,3 +1,3 @@ -from gpt_sovits.infer.inference import GPTSoVITSInference +from gpt_sovits.infer.worker import GPTSoVITSInference __all__ = ["GPTSoVITSInference"] diff --git a/src/gpt_sovits/infer/api.py b/src/gpt_sovits/infer/interface.py similarity index 97% rename from src/gpt_sovits/infer/api.py rename to src/gpt_sovits/infer/interface.py index 4dc86eec..d03267d4 100644 --- a/src/gpt_sovits/infer/api.py +++ b/src/gpt_sovits/infer/interface.py @@ -1,4 +1,4 @@ -from gpt_sovits.infer.inference import GPTSoVITSInference +from gpt_sovits.infer.worker import GPTSoVITSInference from pydantic import BaseModel from typing import List, Tuple, Optional from pathlib import Path @@ -11,7 +11,7 @@ class ConfigData(BaseModel): prompts: List[str] -class GPTSoVITSAPI: +class GPTSoVITSInferenceSimple: config_data_base: Path config_data: ConfigData working_config: Tuple[str, str] @@ -141,7 +141,7 @@ def generate_stream( Lock(), ) - api = GPTSoVITSAPI( + api = GPTSoVITSInferenceSimple( config_data_base="config_data", inference_worker_and_lock=inference_worker_and_lock, ) diff --git a/src/gpt_sovits/infer/text_utils.py b/src/gpt_sovits/infer/text_utils.py new file mode 100644 index 00000000..36c3113a --- /dev/null +++ b/src/gpt_sovits/infer/text_utils.py @@ -0,0 +1,65 @@ +import re +from typing import List + + +splits = { + ",", + "。", + "?", + "!", + ",", + ".", + "?", + "!", + "~", + ":", + ":", + "—", + "…", +} + + +def cut5(inp: str): + """Cut one line of text into pieces.""" + items = re.split(f"([{''.join(splits)}])", inp) + if items[-1] == "": + items = items[:-1] + if len(items) % 2 == 1: + items.append(".") + + mergeitems: List[str] = [items[0]] + for item in items[1:]: + if item == "": + continue + if item not in splits: + mergeitems.append(item) + else: + mergeitems[-1] += item + + return mergeitems + + +def merge_short_texts(texts: List[str], threshold: int = 6): + """Merge short texts to longer ones. Texts are generated by cut5.""" + result: List[str] = [] + text = "" + for ele in texts: + text += ele + if len(text) >= threshold: + result.append(text) + text = "" + if text: + result.append(text) + return result + + +def clean_and_cut_text(text: str) -> List[str]: + lines = [line.strip() for line in text.split("\n") if line.strip()] + texts = [ + merged.strip() + for line in lines + for merged in merge_short_texts(cut5(line)) + if not all(char in splits for char in merged.strip()) + ] + texts = ["." + text if len(text) < 5 else text for text in texts] + return texts diff --git a/src/gpt_sovits/infer/inference.py b/src/gpt_sovits/infer/worker.py similarity index 88% rename from src/gpt_sovits/infer/inference.py rename to src/gpt_sovits/infer/worker.py index ccc3be96..7074e136 100644 --- a/src/gpt_sovits/infer/inference.py +++ b/src/gpt_sovits/infer/worker.py @@ -1,4 +1,3 @@ -import re import sys import LangSegment import torch @@ -8,7 +7,7 @@ import sys import importlib.util from contextlib import contextmanager -from typing import Optional, Any, TypeVar, cast, List +from typing import Optional, Any, TypeVar, cast, List, Tuple from queue import Queue from transformers import AutoModelForMaskedLM, AutoTokenizer @@ -19,6 +18,8 @@ from gpt_sovits.text.cleaner import clean_text from gpt_sovits.module.mel_processing import spectrogram_torch +from gpt_sovits.infer.text_utils import splits, clean_and_cut_text + class DictToAttrRecursive(dict): def __init__(self, input_dict): @@ -72,57 +73,6 @@ def clean_text_inf(text, language): return phones, word2ph, norm_text -splits = { - ",", - "。", - "?", - "!", - ",", - ".", - "?", - "!", - "~", - ":", - ":", - "—", - "…", -} - - -def cut5(inp: str): - """Cut one line of text into pieces.""" - items = re.split(f"([{''.join(splits)}])", inp) - if items[-1] == "": - items = items[:-1] - if len(items) % 2 == 1: - items.append(".") - - mergeitems: List[str] = [items[0]] - for item in items[1:]: - if item == "": - continue - if item not in splits: - mergeitems.append(item) - else: - mergeitems[-1] += item - - return mergeitems - - -def merge_short_texts(texts: List[str], threshold: int = 6): - """Merge short texts to longer ones. Texts are generated by cut5.""" - result: List[str] = [] - text = "" - for ele in texts: - text += ele - if len(text) >= threshold: - result.append(text) - text = "" - if text: - result.append(text) - return result - - class GPTSoVITSInference: device: str is_half: bool @@ -321,7 +271,7 @@ def _get_spepc(self): @property def zero_wav(self): return np.zeros( - int(self.sample_rate * 0.25), + int(self.sample_rate * 0.3), dtype=self.np_dtype, ) @@ -385,14 +335,14 @@ def set_prompt_audio( self.phones1 = None self.bert1 = None - def _get_tts_wav( + def get_tts_wav_piece( self, text: str, text_language: str = "auto", top_k=5, top_p=1, temperature=1, - ): + ) -> Tuple[int, np.ndarray]: phones2, bert2, norm_text2 = self._get_phones_and_bert(text, text_language) if self.prompt_text: bert = torch.cat([self.bert1, bert2], 1) @@ -434,7 +384,9 @@ def _get_tts_wav( max_audio = np.abs(audio).max() # 简单防止16bit爆音 if max_audio > 1: audio /= max_audio - return audio + return self.sample_rate, ( + np.concatenate((audio, self.zero_wav)) * 32768 + ).astype(np.int16) def produce_tts_wav( self, @@ -445,25 +397,16 @@ def produce_tts_wav( top_p=1, temperature=1, ): - lines = [line.strip() for line in text.split("\n") if line.strip()] - texts = [ - merged.strip() - for line in lines - for merged in merge_short_texts(cut5(line)) - if not all(char in splits for char in merged.strip()) - ] - texts = ["." + text if len(text) < 5 else text for text in texts] + texts = clean_and_cut_text(text) for text in texts: - audio = self._get_tts_wav( + _, audio = self.get_tts_wav_piece( text, text_language, top_k, top_p, temperature, ) - queue.put( - (np.concatenate((audio, self.zero_wav), 0) * 32768).astype(np.int16) - ) + queue.put(audio) queue.put(None) def get_tts_wav_stream( @@ -494,18 +437,24 @@ def get_tts_wav( text_language="auto", top_k=5, top_p=1, - temperature=0.9, + temperature=1, ): audio_opt = [] - for _, audio in self.get_tts_wav_stream( - text, - text_language, - top_k, - top_p, - temperature, - ): - audio_opt.append(audio) - return self.sample_rate, np.concatenate(audio_opt, 0) + texts = clean_and_cut_text(text) + audio_opt = [ + self.get_tts_wav_piece( + text, + text_language, + top_k, + top_p, + temperature, + )[1] + for text in texts + ] + return self.sample_rate, np.concatenate(audio_opt) + + +Worker = GPTSoVITSInference if __name__ == "__main__":