diff --git a/pyproject.toml b/pyproject.toml index 2772ad55..1d059afe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "GPT-SoVITS-Infer" -version = "0.2.0" +version = "0.2.1" description = "Inference code for GPT-SoVITS" authors = [ {name = "Yanli",email = "mail@yanli.one"}, diff --git a/src/gpt_sovits/AR/models/t2s_model.py b/src/gpt_sovits/AR/models/t2s_model.py index 1c485f82..df859250 100644 --- a/src/gpt_sovits/AR/models/t2s_model.py +++ b/src/gpt_sovits/AR/models/t2s_model.py @@ -1,7 +1,6 @@ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # reference: https://github.com/lifeiteng/vall-e import torch -from tqdm import tqdm from gpt_sovits.AR.models.utils import make_pad_mask from gpt_sovits.AR.models.utils import ( @@ -271,7 +270,7 @@ def infer( x_len = x.shape[1] x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) stop = False - for _ in tqdm(range(1500)): + for _ in range(1500): y_emb = self.ar_audio_embedding(y) y_pos = self.ar_audio_position(y_emb) # x 和逐渐增长的 y 一起输入给模型 @@ -311,7 +310,6 @@ def infer( if prompts.shape[1] == y.shape[1]: y = torch.concat([y, torch.zeros_like(samples)], dim=1) print("bad zero prediction") - print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") break # 本次生成的 semantic_ids 和之前的 y 构成新的 y # print(samples.shape)#[1,1]#第一个1是bs @@ -390,7 +388,7 @@ def infer_panel( ) xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(x.device) - for idx in tqdm(range(1500)): + for idx in range(1500): xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer( @@ -425,7 +423,6 @@ def infer_panel( if y.shape[1] == 0: y = torch.concat([y, torch.zeros_like(samples)], dim=1) print("bad zero prediction") - print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") break ####################### update next step ################################### diff --git a/src/gpt_sovits/infer/__init__.py b/src/gpt_sovits/infer/__init__.py index 34fec415..5b4de091 100644 --- a/src/gpt_sovits/infer/__init__.py +++ b/src/gpt_sovits/infer/__init__.py @@ -1,3 +1,7 @@ from gpt_sovits.infer.inference import GPTSoVITSInference +from gpt_sovits.infer.inference_pool import GPTSoVITSInferencePool -__all__ = ["GPTSoVITSInference"] +__all__ = [ + "GPTSoVITSInference", + "GPTSoVITSInferencePool", +] diff --git a/src/gpt_sovits/infer/inference_pool.py b/src/gpt_sovits/infer/inference_pool.py index 194d59d9..e0fbe7d9 100644 --- a/src/gpt_sovits/infer/inference_pool.py +++ b/src/gpt_sovits/infer/inference_pool.py @@ -134,8 +134,16 @@ def get_tts_wav( temperature=1, ): audio_list: List[Tuple[int, np.ndarray]] = [] + total = len(clean_and_cut_text(text)) for thing in tqdm.tqdm( - self.get_tts_wav_stream(text, text_language, top_k, top_p, temperature) + self.get_tts_wav_stream( + text, + text_language, + top_k, + top_p, + temperature, + ), + total=total, ): audio_list.append(thing) return audio_list[0][0], np.concatenate( diff --git a/src/gpt_sovits/module/core_vq.py b/src/gpt_sovits/module/core_vq.py index a5e22d66..838d8308 100644 --- a/src/gpt_sovits/module/core_vq.py +++ b/src/gpt_sovits/module/core_vq.py @@ -36,7 +36,6 @@ import torch from torch import nn import torch.nn.functional as F -from tqdm import tqdm def default(val: tp.Any, d: tp.Any) -> tp.Any: @@ -75,7 +74,7 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10): means = sample_vectors(samples, num_clusters) print("kmeans start ... ") - for _ in tqdm(range(num_iters)): + for _ in range(num_iters): diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") dists = -(diffs**2).sum(dim=-1) diff --git a/src/gpt_sovits/utils.py b/src/gpt_sovits/utils.py index 7984b5a8..419fecc9 100644 --- a/src/gpt_sovits/utils.py +++ b/src/gpt_sovits/utils.py @@ -18,7 +18,6 @@ MATPLOTLIB_FLAG = False -logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) logger = logging @@ -64,14 +63,18 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False ) return model, optimizer, learning_rate, iteration + from time import time as ttime import shutil -def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path - dir=os.path.dirname(path) - name=os.path.basename(path) - tmp_path="%s.pth"%(ttime()) - torch.save(fea,tmp_path) - shutil.move(tmp_path,"%s/%s"%(dir,name)) + + +def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path + dir = os.path.dirname(path) + name = os.path.basename(path) + tmp_path = "%s.pth" % (ttime()) + torch.save(fea, tmp_path) + shutil.move(tmp_path, "%s/%s" % (dir, name)) + def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): logger.info(