diff --git a/fireredtts2/codec/model.py b/fireredtts2/codec/model.py index 1bf9476..a20b6e0 100644 --- a/fireredtts2/codec/model.py +++ b/fireredtts2/codec/model.py @@ -211,6 +211,9 @@ def __init__(self, codec: RedCodec): def from_pretrained(cls, conf_path: str, ckpt_path: str) -> "RedCodecInfer": with open(conf_path, "r") as f: codec = RedCodec.from_config(conf_path) + # support cpu + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ckpt = torch.load(ckpt_path, map_location=device)["generator"] ckpt = torch.load(ckpt_path)["generator"] codec.load_state_dict(ckpt) return cls(codec) diff --git a/fireredtts2/fireredtts2.py b/fireredtts2/fireredtts2.py index 4e35c38..ed7c943 100644 --- a/fireredtts2/fireredtts2.py +++ b/fireredtts2/fireredtts2.py @@ -58,6 +58,9 @@ def __init__(self, pretrained_dir, gen_type, device): def load_prompt_audio(self, audio_path) -> torch.Tensor: audio, audio_sr = torchaudio.load(audio_path) + # Convert to mono if multi-channel + if audio.shape[0] > 1: + audio = torch.mean(audio, dim=0, keepdim=True) audio16k = torchaudio.functional.resample(audio, audio_sr, 16000) return audio16k