Skip to content

Commit

Permalink
clean output
Browse files Browse the repository at this point in the history
  • Loading branch information
BeautyyuYanli committed Jun 6, 2024
1 parent 6cf64fb commit a34331e
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "GPT-SoVITS-Infer"
version = "0.2.0"
version = "0.2.1"
description = "Inference code for GPT-SoVITS"
authors = [
{name = "Yanli",email = "[email protected]"},
Expand Down
7 changes: 2 additions & 5 deletions src/gpt_sovits/AR/models/t2s_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import torch
from tqdm import tqdm

from gpt_sovits.AR.models.utils import make_pad_mask
from gpt_sovits.AR.models.utils import (
Expand Down Expand Up @@ -271,7 +270,7 @@ def infer(
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
for _ in tqdm(range(1500)):
for _ in range(1500):
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position(y_emb)
# x 和逐渐增长的 y 一起输入给模型
Expand Down Expand Up @@ -311,7 +310,6 @@ def infer(
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
Expand Down Expand Up @@ -390,7 +388,7 @@ def infer_panel(
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(x.device)

for idx in tqdm(range(1500)):
for idx in range(1500):

xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(
Expand Down Expand Up @@ -425,7 +423,6 @@ def infer_panel(
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break

####################### update next step ###################################
Expand Down
6 changes: 5 additions & 1 deletion src/gpt_sovits/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from gpt_sovits.infer.inference import GPTSoVITSInference
from gpt_sovits.infer.inference_pool import GPTSoVITSInferencePool

__all__ = ["GPTSoVITSInference"]
__all__ = [
"GPTSoVITSInference",
"GPTSoVITSInferencePool",
]
10 changes: 9 additions & 1 deletion src/gpt_sovits/infer/inference_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,16 @@ def get_tts_wav(
temperature=1,
):
audio_list: List[Tuple[int, np.ndarray]] = []
total = len(clean_and_cut_text(text))
for thing in tqdm.tqdm(
self.get_tts_wav_stream(text, text_language, top_k, top_p, temperature)
self.get_tts_wav_stream(
text,
text_language,
top_k,
top_p,
temperature,
),
total=total,
):
audio_list.append(thing)
return audio_list[0][0], np.concatenate(
Expand Down
3 changes: 1 addition & 2 deletions src/gpt_sovits/module/core_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm


def default(val: tp.Any, d: tp.Any) -> tp.Any:
Expand Down Expand Up @@ -75,7 +74,7 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10):
means = sample_vectors(samples, num_clusters)

print("kmeans start ... ")
for _ in tqdm(range(num_iters)):
for _ in range(num_iters):
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
dists = -(diffs**2).sum(dim=-1)

Expand Down
17 changes: 10 additions & 7 deletions src/gpt_sovits/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

MATPLOTLIB_FLAG = False

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging


Expand Down Expand Up @@ -64,14 +63,18 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
)
return model, optimizer, learning_rate, iteration


from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))


def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s.pth" % (ttime())
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))


def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
Expand Down

0 comments on commit a34331e

Please sign in to comment.