From 64a00936ba2804bc5ca84745d13d3e15d0e0c4f9 Mon Sep 17 00:00:00 2001 From: Yanli Date: Sat, 8 Jun 2024 02:21:54 +0800 Subject: [PATCH] enhance text split --- pyproject.toml | 2 +- src/gpt_sovits/infer/inference.py | 4 +-- src/gpt_sovits/infer/text_utils.py | 55 +++++++++++++++++++++++------- 3 files changed, 46 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 85d79dbe..3734fe7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "GPT-SoVITS-Infer" -version = "0.2.2" +version = "0.2.3" description = "Inference code for GPT-SoVITS" authors = [ {name = "Yanli",email = "mail@yanli.one"}, diff --git a/src/gpt_sovits/infer/inference.py b/src/gpt_sovits/infer/inference.py index bc08fc94..f2e7e080 100644 --- a/src/gpt_sovits/infer/inference.py +++ b/src/gpt_sovits/infer/inference.py @@ -18,7 +18,7 @@ from gpt_sovits.text.cleaner import clean_text from gpt_sovits.module.mel_processing import spectrogram_torch -from gpt_sovits.infer.text_utils import splits, clean_and_cut_text +from gpt_sovits.infer.text_utils import clean_and_cut_text, full_splits class DictToAttrRecursive(dict): @@ -291,7 +291,7 @@ def set_prompt_audio( ): if prompt_text: prompt_text = prompt_text.strip("\n") - if prompt_text[-1] not in splits: + if prompt_text[-1] not in full_splits: prompt_text += "." self.prompt_text = prompt_text diff --git a/src/gpt_sovits/infer/text_utils.py b/src/gpt_sovits/infer/text_utils.py index db862068..4a635c1a 100644 --- a/src/gpt_sovits/infer/text_utils.py +++ b/src/gpt_sovits/infer/text_utils.py @@ -1,31 +1,60 @@ import re -from typing import List +from typing import List, Set - -splits = { - ",", +tier1_splits = { "。", "?", "!", - ",", ".", "?", "!", - "~", +} + +tier2_splits = { + ",", + ",", ":", ":", "—", "…", + "~", + "、", + ";", + ";", + "(", + "(", + ")", + ")", + "《", + "》", + "“", + "”", + "‘", + "’", + '"', + "'", + "【", + "】", + "[", + "]", + "「", + "」", + "『", + "』", + "<", + ">", } +full_splits = tier1_splits | tier2_splits -def cut5(inp: str): + +def cut5(inp: str, splits: Set[str], append_dot: str): """Cut one line of text into pieces.""" - items = re.split(f"([{''.join(splits)}])", inp) + items = re.split(f"([{''.join(re.escape(x) for x in splits)}])", inp) if items[-1] == "": items = items[:-1] if len(items) % 2 == 1: - items.append(".") + items.append(append_dot) mergeitems: List[str] = [items[0]] for item in items[1:]: @@ -55,11 +84,13 @@ def merge_short_texts(texts: List[str], threshold: int = 32): def clean_and_cut_text(text: str) -> List[str]: lines = [line.strip() for line in text.split("\n") if line.strip()] + sents = [ + sent for line in lines for sent in cut5(line, tier1_splits, ".") if sent.strip() + ] texts = [ merged.strip() - for line in lines - for merged in merge_short_texts(cut5(line)) - if not all(char in splits for char in merged.strip()) + for sent in sents + for merged in merge_short_texts(cut5(sent, tier2_splits, "")) ] texts = [("." + text) if len(text) < 5 else text for text in texts] return texts