diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index b6394cc4c..e6108e52d 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -300,7 +300,8 @@ def _prepare_generation_outputs( end_idx_int = end_idx.int() inputs_ids_lst = [ - inputs_ids[idx].narrow(0, start_idx, int(i)) for idx, i in enumerate(end_idx_int) + inputs_ids[idx].narrow(0, start_idx, int(i)) + for idx, i in enumerate(end_idx_int) ] if infer_text: inputs_ids_lst = [i.narrow(1, 0, 1).squeeze_(1) for i in inputs_ids_lst] @@ -309,7 +310,8 @@ def _prepare_generation_outputs( if len(hiddens) > 0: hiddens_lst = torch.stack(hiddens, 1) hiddens_lst = [ - hiddens_lst[idx].narrow(0, 0, int(i)) for idx, i in enumerate(end_idx_int) + hiddens_lst[idx].narrow(0, 0, int(i)) + for idx, i in enumerate(end_idx_int) ] return self.GenerationOutputs( @@ -341,7 +343,7 @@ def generate( manual_seed: Optional[int] = None, context=Context(), ): - + self.logger.debug("start generate") attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = [] @@ -353,7 +355,9 @@ def generate( ) finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool() - self.logger.debug(f"set start_idx: {start_idx}, end_idx and finish with all zeros, len {inputs_ids.shape[0]}") + self.logger.debug( + f"set start_idx: {start_idx}, end_idx and finish with all zeros, len {inputs_ids.shape[0]}" + ) old_temperature = temperature @@ -364,7 +368,9 @@ def generate( .view(-1, 1) ) - self.logger.debug(f"expand temperature from shape {old_temperature.shape} to {temperature.shape}") + self.logger.debug( + f"expand temperature from shape {old_temperature.shape} to {temperature.shape}" + ) attention_mask_cache = torch.ones( ( @@ -374,7 +380,9 @@ def generate( dtype=torch.bool, device=inputs_ids.device, ) - self.logger.debug(f"init attention_mask_cache with shape {attention_mask_cache.shape}") + self.logger.debug( + f"init attention_mask_cache with shape {attention_mask_cache.shape}" + ) if attention_mask is not None: attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_( attention_mask @@ -391,7 +399,9 @@ def generate( device=inputs_ids.device, ) inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids) - self.logger.debug(f"expand inputs_ids buf from shape {inputs_ids.shape} to {inputs_ids_buf.shape}") + self.logger.debug( + f"expand inputs_ids buf from shape {inputs_ids.shape} to {inputs_ids_buf.shape}" + ) del inputs_ids inputs_ids = inputs_ids_buf.narrow(1, 0, progress) @@ -605,7 +615,7 @@ def generate( yield result del inputs_ids return - + self.logger.debug("start output") del idx_next diff --git a/ChatTTS/utils/gpu.py b/ChatTTS/utils/gpu.py index a9d106ef0..bc9578800 100644 --- a/ChatTTS/utils/gpu.py +++ b/ChatTTS/utils/gpu.py @@ -53,6 +53,7 @@ def select_device(min_memory=2047, experimental=False): if experimental: logger.get_logger().warning("experimental: using DML.") import torch_directml + device = torch_directml.device(torch_directml.default_device()) else: logger.get_logger().info("found DML, but use CPU.") diff --git a/examples/web/funcs.py b/examples/web/funcs.py index bbbe2158d..8143925a1 100644 --- a/examples/web/funcs.py +++ b/examples/web/funcs.py @@ -25,8 +25,8 @@ has_interrupted = False is_in_generate = False -enable_cache=True -experimental=False +enable_cache = True +experimental = False seed_min = 1 seed_max = 4294967295 @@ -64,12 +64,14 @@ def on_audio_seed_change(audio_seed_input): rand_spk = chat.sample_random_speaker() return rand_spk + def set_params(en_cache, exp): global enable_cache, experimental enable_cache = en_cache experimental = exp + def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool: global enable_cache, experimental @@ -78,7 +80,11 @@ def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool: else: logger.info("local model path: %s", cust_path) ret = chat.load( - "custom", custom_path=cust_path, coef=coef, enable_cache=enable_cache, experimental=experimental + "custom", + custom_path=cust_path, + coef=coef, + enable_cache=enable_cache, + experimental=experimental, ) global custom_path custom_path = cust_path