Skip to content

Commit

Permalink
rename things
Browse files Browse the repository at this point in the history
  • Loading branch information
BeautyyuYanli committed Jun 6, 2024
1 parent fa05041 commit 7f5c29b
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 83 deletions.
2 changes: 1 addition & 1 deletion src/gpt_sovits/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from gpt_sovits.infer.inference import GPTSoVITSInference
from gpt_sovits.infer.worker import GPTSoVITSInference

__all__ = ["GPTSoVITSInference"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gpt_sovits.infer.inference import GPTSoVITSInference
from gpt_sovits.infer.worker import GPTSoVITSInference
from pydantic import BaseModel
from typing import List, Tuple, Optional
from pathlib import Path
Expand All @@ -11,7 +11,7 @@ class ConfigData(BaseModel):
prompts: List[str]


class GPTSoVITSAPI:
class GPTSoVITSInferenceSimple:
config_data_base: Path
config_data: ConfigData
working_config: Tuple[str, str]
Expand Down Expand Up @@ -141,7 +141,7 @@ def generate_stream(
Lock(),
)

api = GPTSoVITSAPI(
api = GPTSoVITSInferenceSimple(
config_data_base="config_data",
inference_worker_and_lock=inference_worker_and_lock,
)
Expand Down
65 changes: 65 additions & 0 deletions src/gpt_sovits/infer/text_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import re
from typing import List


splits = {
",",
"。",
"?",
"!",
",",
".",
"?",
"!",
"~",
":",
":",
"—",
"…",
}


def cut5(inp: str):
"""Cut one line of text into pieces."""
items = re.split(f"([{''.join(splits)}])", inp)
if items[-1] == "":
items = items[:-1]
if len(items) % 2 == 1:
items.append(".")

mergeitems: List[str] = [items[0]]
for item in items[1:]:
if item == "":
continue
if item not in splits:
mergeitems.append(item)
else:
mergeitems[-1] += item

return mergeitems


def merge_short_texts(texts: List[str], threshold: int = 6):
"""Merge short texts to longer ones. Texts are generated by cut5."""
result: List[str] = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if text:
result.append(text)
return result


def clean_and_cut_text(text: str) -> List[str]:
lines = [line.strip() for line in text.split("\n") if line.strip()]
texts = [
merged.strip()
for line in lines
for merged in merge_short_texts(cut5(line))
if not all(char in splits for char in merged.strip())
]
texts = ["." + text if len(text) < 5 else text for text in texts]
return texts
107 changes: 28 additions & 79 deletions src/gpt_sovits/infer/inference.py → src/gpt_sovits/infer/worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import re
import sys
import LangSegment
import torch
Expand All @@ -8,7 +7,7 @@
import sys
import importlib.util
from contextlib import contextmanager
from typing import Optional, Any, TypeVar, cast, List
from typing import Optional, Any, TypeVar, cast, List, Tuple
from queue import Queue
from transformers import AutoModelForMaskedLM, AutoTokenizer

Expand All @@ -19,6 +18,8 @@
from gpt_sovits.text.cleaner import clean_text
from gpt_sovits.module.mel_processing import spectrogram_torch

from gpt_sovits.infer.text_utils import splits, clean_and_cut_text


class DictToAttrRecursive(dict):
def __init__(self, input_dict):
Expand Down Expand Up @@ -72,57 +73,6 @@ def clean_text_inf(text, language):
return phones, word2ph, norm_text


splits = {
",",
"。",
"?",
"!",
",",
".",
"?",
"!",
"~",
":",
":",
"—",
"…",
}


def cut5(inp: str):
"""Cut one line of text into pieces."""
items = re.split(f"([{''.join(splits)}])", inp)
if items[-1] == "":
items = items[:-1]
if len(items) % 2 == 1:
items.append(".")

mergeitems: List[str] = [items[0]]
for item in items[1:]:
if item == "":
continue
if item not in splits:
mergeitems.append(item)
else:
mergeitems[-1] += item

return mergeitems


def merge_short_texts(texts: List[str], threshold: int = 6):
"""Merge short texts to longer ones. Texts are generated by cut5."""
result: List[str] = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if text:
result.append(text)
return result


class GPTSoVITSInference:
device: str
is_half: bool
Expand Down Expand Up @@ -321,7 +271,7 @@ def _get_spepc(self):
@property
def zero_wav(self):
return np.zeros(
int(self.sample_rate * 0.25),
int(self.sample_rate * 0.3),
dtype=self.np_dtype,
)

Expand Down Expand Up @@ -385,14 +335,14 @@ def set_prompt_audio(
self.phones1 = None
self.bert1 = None

def _get_tts_wav(
def get_tts_wav_piece(
self,
text: str,
text_language: str = "auto",
top_k=5,
top_p=1,
temperature=1,
):
) -> Tuple[int, np.ndarray]:
phones2, bert2, norm_text2 = self._get_phones_and_bert(text, text_language)
if self.prompt_text:
bert = torch.cat([self.bert1, bert2], 1)
Expand Down Expand Up @@ -434,7 +384,9 @@ def _get_tts_wav(
max_audio = np.abs(audio).max() # 简单防止16bit爆音
if max_audio > 1:
audio /= max_audio
return audio
return self.sample_rate, (
np.concatenate((audio, self.zero_wav)) * 32768
).astype(np.int16)

def produce_tts_wav(
self,
Expand All @@ -445,25 +397,16 @@ def produce_tts_wav(
top_p=1,
temperature=1,
):
lines = [line.strip() for line in text.split("\n") if line.strip()]
texts = [
merged.strip()
for line in lines
for merged in merge_short_texts(cut5(line))
if not all(char in splits for char in merged.strip())
]
texts = ["." + text if len(text) < 5 else text for text in texts]
texts = clean_and_cut_text(text)
for text in texts:
audio = self._get_tts_wav(
_, audio = self.get_tts_wav_piece(
text,
text_language,
top_k,
top_p,
temperature,
)
queue.put(
(np.concatenate((audio, self.zero_wav), 0) * 32768).astype(np.int16)
)
queue.put(audio)
queue.put(None)

def get_tts_wav_stream(
Expand Down Expand Up @@ -494,18 +437,24 @@ def get_tts_wav(
text_language="auto",
top_k=5,
top_p=1,
temperature=0.9,
temperature=1,
):
audio_opt = []
for _, audio in self.get_tts_wav_stream(
text,
text_language,
top_k,
top_p,
temperature,
):
audio_opt.append(audio)
return self.sample_rate, np.concatenate(audio_opt, 0)
texts = clean_and_cut_text(text)
audio_opt = [
self.get_tts_wav_piece(
text,
text_language,
top_k,
top_p,
temperature,
)[1]
for text in texts
]
return self.sample_rate, np.concatenate(audio_opt)


Worker = GPTSoVITSInference


if __name__ == "__main__":
Expand Down

0 comments on commit 7f5c29b

Please sign in to comment.