Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def _prepare_generation_outputs(
end_idx_int = end_idx.int()

inputs_ids_lst = [
inputs_ids[idx].narrow(0, start_idx, int(i)) for idx, i in enumerate(end_idx_int)
inputs_ids[idx].narrow(0, start_idx, int(i))
for idx, i in enumerate(end_idx_int)
]
if infer_text:
inputs_ids_lst = [i.narrow(1, 0, 1).squeeze_(1) for i in inputs_ids_lst]
Expand All @@ -309,7 +310,8 @@ def _prepare_generation_outputs(
if len(hiddens) > 0:
hiddens_lst = torch.stack(hiddens, 1)
hiddens_lst = [
hiddens_lst[idx].narrow(0, 0, int(i)) for idx, i in enumerate(end_idx_int)
hiddens_lst[idx].narrow(0, 0, int(i))
for idx, i in enumerate(end_idx_int)
]

return self.GenerationOutputs(
Expand Down Expand Up @@ -341,7 +343,7 @@ def generate(
manual_seed: Optional[int] = None,
context=Context(),
):

self.logger.debug("start generate")

attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = []
Expand All @@ -353,7 +355,9 @@ def generate(
)
finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()

self.logger.debug(f"set start_idx: {start_idx}, end_idx and finish with all zeros, len {inputs_ids.shape[0]}")
self.logger.debug(
f"set start_idx: {start_idx}, end_idx and finish with all zeros, len {inputs_ids.shape[0]}"
)

old_temperature = temperature

Expand All @@ -364,7 +368,9 @@ def generate(
.view(-1, 1)
)

self.logger.debug(f"expand temperature from shape {old_temperature.shape} to {temperature.shape}")
self.logger.debug(
f"expand temperature from shape {old_temperature.shape} to {temperature.shape}"
)

attention_mask_cache = torch.ones(
(
Expand All @@ -374,7 +380,9 @@ def generate(
dtype=torch.bool,
device=inputs_ids.device,
)
self.logger.debug(f"init attention_mask_cache with shape {attention_mask_cache.shape}")
self.logger.debug(
f"init attention_mask_cache with shape {attention_mask_cache.shape}"
)
if attention_mask is not None:
attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
attention_mask
Expand All @@ -391,7 +399,9 @@ def generate(
device=inputs_ids.device,
)
inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids)
self.logger.debug(f"expand inputs_ids buf from shape {inputs_ids.shape} to {inputs_ids_buf.shape}")
self.logger.debug(
f"expand inputs_ids buf from shape {inputs_ids.shape} to {inputs_ids_buf.shape}"
)
del inputs_ids
inputs_ids = inputs_ids_buf.narrow(1, 0, progress)

Expand Down Expand Up @@ -605,7 +615,7 @@ def generate(
yield result
del inputs_ids
return

self.logger.debug("start output")

del idx_next
Expand Down
1 change: 1 addition & 0 deletions ChatTTS/utils/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def select_device(min_memory=2047, experimental=False):
if experimental:
logger.get_logger().warning("experimental: using DML.")
import torch_directml

device = torch_directml.device(torch_directml.default_device())
else:
logger.get_logger().info("found DML, but use CPU.")
Expand Down
12 changes: 9 additions & 3 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
has_interrupted = False
is_in_generate = False

enable_cache=True
experimental=False
enable_cache = True
experimental = False

seed_min = 1
seed_max = 4294967295
Expand Down Expand Up @@ -64,12 +64,14 @@ def on_audio_seed_change(audio_seed_input):
rand_spk = chat.sample_random_speaker()
return rand_spk


def set_params(en_cache, exp):
global enable_cache, experimental

enable_cache = en_cache
experimental = exp


def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool:
global enable_cache, experimental

Expand All @@ -78,7 +80,11 @@ def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool:
else:
logger.info("local model path: %s", cust_path)
ret = chat.load(
"custom", custom_path=cust_path, coef=coef, enable_cache=enable_cache, experimental=experimental
"custom",
custom_path=cust_path,
coef=coef,
enable_cache=enable_cache,
experimental=experimental,
)
global custom_path
custom_path = cust_path
Expand Down