diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..98dbeb4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +*.py[cod] +.venv/ diff --git a/README.md b/README.md index 20a84fb..8bf18e1 100644 --- a/README.md +++ b/README.md @@ -1 +1,151 @@ -# repka \ No newline at end of file +# Neural Realtime Voice Changer (MVP) + +MVP-проект на Python/PyTorch для изменения голоса в реальном времени с обучением на ваших аудио-сэмплах. + +## Что умеет + +- обучать модель конвертации голоса на папках с WAV/FLAC/OGG; +- строить профиль целевого голоса (`.pt`) по набору сэмплов; +- менять голос в реальном времени с микрофона; +- конвертировать файл офлайн (для проверки качества). + +## Ограничения MVP + +- это не коммерческий прод-уровень, а рабочий прототип; +- качество сильно зависит от датасета (чистота записи, длительность, разнообразие); +- для real-time нужен стабильный аудио-драйвер и желательно GPU. + +## Установка + +```bash +python3 -m venv .venv +source .venv/bin/activate +pip install -U pip +pip install -r requirements.txt +``` + +## AMD RX 470 8GB (Polaris): важные настройки + +Для RX 470 поддержка в новых ROCm-сборках часто нестабильна, поэтому добавлен флаг: + +- `--amd-gfx-version 8.0.3` (устанавливает `HSA_OVERRIDE_GFX_VERSION=8.0.3`) + +Дополнительно для старых AMD лучше отключать mixed precision: + +- `--amp-mode off` + +Если у вас уже настроен ROCm/PyTorch ROCm, запускайте скрипты с этими параметрами. +Если GPU не подхватился, временно используйте `--device cpu`. + +## Подготовка датасета + +Структура: + +```text +data/ + speaker_1/ + a.wav + b.wav + speaker_2/ + c.wav + d.wav + ... +``` + +Рекомендации: + +- минимум 2 спикера для нормальной межголосовой конвертации; +- от 5-15 минут чистой речи на каждого; +- одинаковая частота дискретизации не обязательна (скрипт ресемплит в 16k). + +## Обучение + +```bash +python3 scripts/train_voice_converter.py \ + --data-dir data \ + --output-dir artifacts \ + --epochs 40 \ + --batch-size 12 \ + --device auto \ + --amp-mode auto +``` + +Чекпоинты: `artifacts/checkpoints/latest.pt`, `artifacts/checkpoints/epoch_XXX.pt`. + +Для RX 470 (более безопасный старт): + +```bash +python3 scripts/train_voice_converter.py \ + --data-dir data \ + --output-dir artifacts_rx470 \ + --epochs 40 \ + --batch-size 4 \ + --device cuda \ + --amd-gfx-version 8.0.3 \ + --amp-mode off +``` + +## Дообучение на новых сэмплах + +Добавьте новые папки в `data/` и запустите: + +```bash +python3 scripts/train_voice_converter.py \ + --data-dir data \ + --output-dir artifacts_finetune \ + --resume-checkpoint artifacts/checkpoints/latest.pt \ + --epochs 15 \ + --device auto \ + --amp-mode auto +``` + +## Создание профиля целевого голоса + +Можно передать папку или один файл: + +```bash +python3 scripts/build_voice_profile.py \ + --checkpoint artifacts/checkpoints/latest.pt \ + --samples samples/target_voice \ + --output-profile artifacts/profiles/target_profile.pt \ + --device auto \ + --amd-gfx-version 8.0.3 +``` + +## Запуск в реальном времени (микрофон -> динамики) + +Сначала можно посмотреть устройства: + +```bash +python3 scripts/realtime_voice_changer.py --list-devices +``` + +Запуск: + +```bash +python3 scripts/realtime_voice_changer.py \ + --checkpoint artifacts/checkpoints/latest.pt \ + --profile artifacts/profiles/target_profile.pt \ + --block-size 1024 \ + --device auto \ + --amd-gfx-version 8.0.3 +``` + +Чем меньше `--block-size`, тем ниже задержка, но выше риск артефактов. + +## Офлайн конвертация файла + +```bash +python3 scripts/convert_file.py \ + --checkpoint artifacts/checkpoints/latest.pt \ + --profile artifacts/profiles/target_profile.pt \ + --input demo/input.wav \ + --output demo/output.wav \ + --device auto \ + --amd-gfx-version 8.0.3 +``` + +## Важные замечания + +- Используйте эту систему только с согласия владельца голоса. +- Для стрима в Discord/OBS обычно удобнее использовать виртуальный аудио-кабель. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..408693f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torch +torchaudio +numpy +tqdm +sounddevice diff --git a/scripts/build_voice_profile.py b/scripts/build_voice_profile.py new file mode 100644 index 0000000..49b6785 --- /dev/null +++ b/scripts/build_voice_profile.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +"""Build target voice profile embedding from sample audios.""" + +from __future__ import annotations + +import argparse +from pathlib import Path +import sys + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Create speaker profile from voice samples.") + parser.add_argument( + "--checkpoint", + required=True, + help="Path to trained checkpoint (*.pt)", + ) + parser.add_argument( + "--samples", + required=True, + help="Directory or file with target speaker samples", + ) + parser.add_argument( + "--output-profile", + default="artifacts/profiles/target_profile.pt", + help="Path to output profile .pt file", + ) + parser.add_argument("--segment-frames", type=int, default=96) + parser.add_argument("--max-segments-per-file", type=int, default=8) + parser.add_argument( + "--device", + default="auto", + help='Device: "auto", "cpu", "cuda", "cuda:0"...', + ) + parser.add_argument( + "--amd-gfx-version", + default=None, + help="Sets HSA_OVERRIDE_GFX_VERSION before torch import (RX 470: 8.0.3).", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + from voice_changer.runtime import configure_amd_runtime + + applied_gfx_override = configure_amd_runtime(args.amd_gfx_version) + if applied_gfx_override: + print(f"HSA_OVERRIDE_GFX_VERSION={applied_gfx_override}") + + from voice_changer.inference import ( + build_voice_profile_embedding, + collect_audio_files, + load_inference_bundle, + save_voice_profile, + ) + + bundle = load_inference_bundle(args.checkpoint, device=args.device) + sample_files = collect_audio_files(args.samples) + embedding = build_voice_profile_embedding( + bundle=bundle, + sample_paths=sample_files, + segment_frames=args.segment_frames, + max_segments_per_file=args.max_segments_per_file, + ) + save_voice_profile( + output_path=args.output_profile, + embedding=embedding, + source_files=sample_files, + sample_rate=bundle.processor.sample_rate, + checkpoint_path=str(Path(args.checkpoint).resolve()), + ) + print(f"Saved profile: {Path(args.output_profile).resolve()}") + print(f"Used files: {len(sample_files)}") + + +if __name__ == "__main__": + main() diff --git a/scripts/convert_file.py b/scripts/convert_file.py new file mode 100644 index 0000000..d68648a --- /dev/null +++ b/scripts/convert_file.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +"""Convert a WAV/FLAC/OGG file with trained voice changer.""" + +from __future__ import annotations + +import argparse +from pathlib import Path +import sys + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Offline file conversion with target profile.") + parser.add_argument("--checkpoint", required=True) + parser.add_argument("--profile", required=True) + parser.add_argument("--input", required=True, help="Input audio file") + parser.add_argument("--output", required=True, help="Output audio file") + parser.add_argument("--chunk-frames", type=int, default=128) + parser.add_argument( + "--device", + default="auto", + help='Torch device: "auto", "cpu", "cuda"...', + ) + parser.add_argument( + "--amd-gfx-version", + default=None, + help="Sets HSA_OVERRIDE_GFX_VERSION before torch import (RX 470: 8.0.3).", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + from voice_changer.runtime import configure_amd_runtime + + applied_gfx_override = configure_amd_runtime(args.amd_gfx_version) + if applied_gfx_override: + print(f"HSA_OVERRIDE_GFX_VERSION={applied_gfx_override}") + + from voice_changer.convert import convert_file + from voice_changer.inference import load_inference_bundle, load_voice_profile + + bundle = load_inference_bundle(args.checkpoint, device=args.device) + target_embedding = load_voice_profile(args.profile, device=bundle.device) + convert_file( + bundle=bundle, + input_path=args.input, + output_path=args.output, + target_embedding=target_embedding, + chunk_frames=args.chunk_frames, + ) + print(f"Converted audio saved to: {Path(args.output).resolve()}") + + +if __name__ == "__main__": + main() diff --git a/scripts/realtime_voice_changer.py b/scripts/realtime_voice_changer.py new file mode 100644 index 0000000..e7cf9df --- /dev/null +++ b/scripts/realtime_voice_changer.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +"""Run real-time neural voice conversion from microphone.""" + +from __future__ import annotations + +import argparse +from pathlib import Path +import sys + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +def _parse_device_arg(value: str | None) -> int | str | None: + if value is None: + return None + stripped = value.strip() + if stripped == "": + return None + if stripped.isdigit(): + return int(stripped) + return stripped + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Realtime neural voice changer.") + parser.add_argument( + "--checkpoint", + help="Path to trained checkpoint (*.pt)", + ) + parser.add_argument( + "--profile", + help="Path to voice profile (*.pt)", + ) + parser.add_argument("--block-size", type=int, default=1024) + parser.add_argument("--input-device", default=None) + parser.add_argument("--output-device", default=None) + parser.add_argument( + "--latency", + default="low", + help='sounddevice latency: "low", "high", or float seconds', + ) + parser.add_argument( + "--device", + default="auto", + help='Torch device: "auto", "cpu", "cuda"...', + ) + parser.add_argument( + "--amd-gfx-version", + default=None, + help="Sets HSA_OVERRIDE_GFX_VERSION before torch import (RX 470: 8.0.3).", + ) + parser.add_argument( + "--list-devices", + action="store_true", + help="List available audio devices and exit", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + from voice_changer.runtime import configure_amd_runtime + + applied_gfx_override = configure_amd_runtime(args.amd_gfx_version) + if applied_gfx_override: + print(f"HSA_OVERRIDE_GFX_VERSION={applied_gfx_override}") + + from voice_changer.inference import load_inference_bundle, load_voice_profile + from voice_changer.realtime import RealtimeVoiceChanger, list_audio_devices + + if args.list_devices: + list_audio_devices() + return + + if not args.checkpoint or not args.profile: + raise ValueError("Arguments --checkpoint and --profile are required unless --list-devices is used.") + + bundle = load_inference_bundle(args.checkpoint, device=args.device) + target_embedding = load_voice_profile(args.profile, device=bundle.device) + + changer = RealtimeVoiceChanger( + bundle=bundle, + target_embedding=target_embedding, + block_size=args.block_size, + input_device=_parse_device_arg(args.input_device), + output_device=_parse_device_arg(args.output_device), + latency=args.latency, + ) + changer.run() + + +if __name__ == "__main__": + main() diff --git a/scripts/train_voice_converter.py b/scripts/train_voice_converter.py new file mode 100644 index 0000000..1c0bb33 --- /dev/null +++ b/scripts/train_voice_converter.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +"""Train neural real-time voice conversion model.""" + +from __future__ import annotations + +import argparse +from pathlib import Path +import sys + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Train neural voice changer on multi-speaker WAV dataset." + ) + parser.add_argument("--data-dir", required=True, help="Path to data//*.wav") + parser.add_argument("--output-dir", default="artifacts", help="Directory for checkpoints") + parser.add_argument("--epochs", type=int, default=40) + parser.add_argument("--batch-size", type=int, default=12) + parser.add_argument("--segment-frames", type=int, default=64) + parser.add_argument("--learning-rate", type=float, default=3e-4) + parser.add_argument("--weight-decay", type=float, default=1e-4) + parser.add_argument("--embedding-dim", type=int, default=128) + parser.add_argument("--speaker-hidden-dim", type=int, default=256) + parser.add_argument("--converter-hidden-dim", type=int, default=384) + parser.add_argument("--converter-layers", type=int, default=2) + parser.add_argument("--num-workers", type=int, default=0) + parser.add_argument("--sample-rate", type=int, default=16000) + parser.add_argument("--n-fft", type=int, default=512) + parser.add_argument("--hop-length", type=int, default=128) + parser.add_argument("--win-length", type=int, default=512) + parser.add_argument("--identity-loss-weight", type=float, default=1.0) + parser.add_argument("--cycle-loss-weight", type=float, default=0.75) + parser.add_argument("--speaker-loss-weight", type=float, default=0.4) + parser.add_argument("--transfer-speaker-loss-weight", type=float, default=0.8) + parser.add_argument("--gradient-clip", type=float, default=1.0) + parser.add_argument("--save-every", type=int, default=5) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--cache-features", action="store_true") + parser.add_argument( + "--resume-checkpoint", + default=None, + help="Path to checkpoint for fine-tuning", + ) + parser.add_argument( + "--device", + default="auto", + help='Device: "auto", "cpu", "cuda", "cuda:0"...', + ) + parser.add_argument( + "--amp-mode", + default="auto", + choices=["auto", "on", "off"], + help='Automatic mixed precision mode (recommended for RX 470: "off").', + ) + parser.add_argument( + "--amd-gfx-version", + default=None, + help="Sets HSA_OVERRIDE_GFX_VERSION before torch import (RX 470: 8.0.3).", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + from voice_changer.runtime import configure_amd_runtime + + applied_gfx_override = configure_amd_runtime(args.amd_gfx_version) + if applied_gfx_override: + print(f"HSA_OVERRIDE_GFX_VERSION={applied_gfx_override}") + + from voice_changer.train import TrainConfig, train_voice_changer + + config = TrainConfig( + data_dir=args.data_dir, + output_dir=args.output_dir, + epochs=args.epochs, + batch_size=args.batch_size, + segment_frames=args.segment_frames, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + embedding_dim=args.embedding_dim, + speaker_hidden_dim=args.speaker_hidden_dim, + converter_hidden_dim=args.converter_hidden_dim, + converter_layers=args.converter_layers, + num_workers=args.num_workers, + sample_rate=args.sample_rate, + n_fft=args.n_fft, + hop_length=args.hop_length, + win_length=args.win_length, + identity_loss_weight=args.identity_loss_weight, + cycle_loss_weight=args.cycle_loss_weight, + speaker_loss_weight=args.speaker_loss_weight, + transfer_speaker_loss_weight=args.transfer_speaker_loss_weight, + gradient_clip=args.gradient_clip, + save_every=args.save_every, + seed=args.seed, + cache_features=args.cache_features, + resume_checkpoint=args.resume_checkpoint, + device=args.device, + amp_mode=args.amp_mode, + ) + latest_checkpoint = train_voice_changer(config) + print(f"Training complete. Latest checkpoint: {latest_checkpoint}") + print(f"Output dir: {Path(args.output_dir).resolve()}") + + +if __name__ == "__main__": + main() diff --git a/voice_changer/__init__.py b/voice_changer/__init__.py new file mode 100644 index 0000000..3e6e7a7 --- /dev/null +++ b/voice_changer/__init__.py @@ -0,0 +1,6 @@ +"""Lightweight neural voice changer package.""" + +from .audio import SpectrogramProcessor +from .models import SpeakerEncoder, VoiceConverter + +__all__ = ["SpectrogramProcessor", "SpeakerEncoder", "VoiceConverter"] diff --git a/voice_changer/audio.py b/voice_changer/audio.py new file mode 100644 index 0000000..7bb10c5 --- /dev/null +++ b/voice_changer/audio.py @@ -0,0 +1,135 @@ +"""Audio loading and STFT helpers for voice conversion.""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass +import random +from typing import Any + +import torch +import torch.nn.functional as F +import torchaudio + + +@dataclass(frozen=True) +class SpectrogramConfig: + sample_rate: int = 16_000 + n_fft: int = 512 + hop_length: int = 128 + win_length: int = 512 + + @property + def freq_bins(self) -> int: + return (self.n_fft // 2) + 1 + + +class SpectrogramProcessor: + """Converts waveform chunks to log-magnitude spectra and back.""" + + def __init__(self, config: SpectrogramConfig | None = None) -> None: + self.config = config or SpectrogramConfig() + + @property + def sample_rate(self) -> int: + return self.config.sample_rate + + @property + def freq_bins(self) -> int: + return self.config.freq_bins + + def to_dict(self) -> dict[str, Any]: + return asdict(self.config) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "SpectrogramProcessor": + return cls(SpectrogramConfig(**data)) + + def _window(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + return torch.hann_window( + self.config.win_length, + device=device, + dtype=dtype, + ) + + def waveform_to_log_magnitude(self, waveform: torch.Tensor) -> torch.Tensor: + """Returns [frames, freq_bins] log(1 + magnitude).""" + log_magnitude, _ = self.waveform_to_log_magnitude_and_phase(waveform) + return log_magnitude + + def waveform_to_log_magnitude_and_phase( + self, + waveform: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Returns ([frames, freq_bins] log-magnitude, [freq_bins, frames] unit phase).""" + if waveform.dim() != 1: + raise ValueError("Expected 1D waveform tensor.") + + window = self._window(waveform.device, waveform.dtype) + spec = torch.stft( + waveform, + n_fft=self.config.n_fft, + hop_length=self.config.hop_length, + win_length=self.config.win_length, + window=window, + center=True, + return_complex=True, + ) + magnitude = spec.abs().clamp_min(1e-8) + phase = spec / magnitude + log_magnitude = torch.log1p(magnitude).transpose(0, 1) + return log_magnitude, phase + + def log_magnitude_to_waveform( + self, + log_magnitude: torch.Tensor, + phase: torch.Tensor, + length: int, + ) -> torch.Tensor: + """Reconstructs waveform with source phase.""" + if log_magnitude.dim() != 2: + raise ValueError("Expected [frames, freq_bins] log-magnitude tensor.") + magnitude = torch.expm1(log_magnitude).transpose(0, 1).clamp_min(0.0) + complex_spec = magnitude * phase + window = self._window(complex_spec.device, magnitude.dtype) + waveform = torch.istft( + complex_spec, + n_fft=self.config.n_fft, + hop_length=self.config.hop_length, + win_length=self.config.win_length, + window=window, + center=True, + length=length, + ) + return waveform + + def sample_segment( + self, + features: torch.Tensor, + segment_frames: int, + ) -> torch.Tensor: + """Samples or pads a random feature segment to fixed frame count.""" + if features.dim() != 2: + raise ValueError("Expected [frames, freq_bins] features tensor.") + + total_frames = features.size(0) + if total_frames >= segment_frames: + start = random.randint(0, total_frames - segment_frames) + return features[start : start + segment_frames] + + pad_amount = segment_frames - total_frames + padded = F.pad(features, (0, 0, 0, pad_amount), mode="replicate") + return padded + + +def load_audio_file(path: str, sample_rate: int) -> torch.Tensor: + """Loads audio file, converts to mono float waveform at target sample rate.""" + waveform, loaded_sample_rate = torchaudio.load(path) + if waveform.size(0) > 1: + waveform = waveform.mean(dim=0, keepdim=True) + if loaded_sample_rate != sample_rate: + waveform = torchaudio.functional.resample( + waveform, + orig_freq=loaded_sample_rate, + new_freq=sample_rate, + ) + return waveform.squeeze(0).contiguous() diff --git a/voice_changer/convert.py b/voice_changer/convert.py new file mode 100644 index 0000000..8fd38ed --- /dev/null +++ b/voice_changer/convert.py @@ -0,0 +1,68 @@ +"""Offline conversion utilities for audio files.""" + +from __future__ import annotations + +from pathlib import Path + +import torch +import torchaudio + +from .audio import load_audio_file +from .inference import InferenceBundle + + +@torch.inference_mode() +def convert_waveform( + bundle: InferenceBundle, + waveform: torch.Tensor, + target_embedding: torch.Tensor, + chunk_frames: int = 128, +) -> torch.Tensor: + """Converts a mono waveform to target speaker timbre.""" + if waveform.dim() != 1: + raise ValueError("Expected mono waveform tensor.") + if target_embedding.dim() == 1: + target_embedding = target_embedding.unsqueeze(0) + + log_magnitude, phase = bundle.processor.waveform_to_log_magnitude_and_phase(waveform) + hidden_state: torch.Tensor | None = None + outputs: list[torch.Tensor] = [] + + for start in range(0, log_magnitude.size(0), chunk_frames): + segment = log_magnitude[start : start + chunk_frames] + segment_in = segment.unsqueeze(0).to(bundle.device) + converted, hidden_state = bundle.converter( + source_features=segment_in, + target_embedding=target_embedding.to(bundle.device), + hidden_state=hidden_state, + ) + if hidden_state is not None: + hidden_state = hidden_state.detach() + outputs.append(converted.squeeze(0).cpu()) + + converted_log_magnitude = torch.cat(outputs, dim=0) + converted_waveform = bundle.processor.log_magnitude_to_waveform( + log_magnitude=converted_log_magnitude, + phase=phase, + length=waveform.numel(), + ) + return converted_waveform.clamp(min=-1.0, max=1.0) + + +def convert_file( + bundle: InferenceBundle, + input_path: str, + output_path: str, + target_embedding: torch.Tensor, + chunk_frames: int = 128, +) -> None: + waveform = load_audio_file(input_path, sample_rate=bundle.processor.sample_rate) + converted = convert_waveform( + bundle=bundle, + waveform=waveform, + target_embedding=target_embedding, + chunk_frames=chunk_frames, + ) + output = converted.unsqueeze(0).cpu() + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + torchaudio.save(output_path, output, sample_rate=bundle.processor.sample_rate) diff --git a/voice_changer/data.py b/voice_changer/data.py new file mode 100644 index 0000000..d09edaf --- /dev/null +++ b/voice_changer/data.py @@ -0,0 +1,136 @@ +"""Dataset utilities for multi-speaker training.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +import random +from typing import Iterable + +import torch +from torch.utils.data import Dataset + +from .audio import SpectrogramProcessor, load_audio_file + +AUDIO_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a"} + + +@dataclass(frozen=True) +class SampleRecord: + path: str + speaker_idx: int + + +class MultiSpeakerSpectrogramDataset(Dataset[tuple[torch.Tensor, torch.Tensor, int]]): + """Returns source segment, reference segment and source speaker label.""" + + def __init__( + self, + data_root: str, + processor: SpectrogramProcessor, + segment_frames: int, + cache_features: bool = False, + ) -> None: + self.data_root = Path(data_root) + self.processor = processor + self.segment_frames = segment_frames + self.cache_features = cache_features + self._feature_cache: dict[str, torch.Tensor] = {} + + ( + self.entries, + self.speaker_to_idx, + self.idx_to_speaker, + self.indices_by_speaker, + ) = self._scan_dataset(self.data_root) + + if not self.entries: + raise ValueError( + "Dataset is empty. Expected files in data//*.wav." + ) + + @staticmethod + def _audio_files(folder: Path) -> Iterable[Path]: + for path in folder.rglob("*"): + if path.is_file() and path.suffix.lower() in AUDIO_EXTENSIONS: + yield path + + def _scan_dataset( + self, + root: Path, + ) -> tuple[ + list[SampleRecord], + dict[str, int], + dict[int, str], + dict[int, list[int]], + ]: + speaker_dirs = sorted([path for path in root.iterdir() if path.is_dir()]) + if not speaker_dirs: + raise ValueError( + f"No speaker folders found in '{root}'. " + "Use data//*.wav layout." + ) + + speaker_to_idx = {directory.name: idx for idx, directory in enumerate(speaker_dirs)} + idx_to_speaker = {idx: name for name, idx in speaker_to_idx.items()} + entries: list[SampleRecord] = [] + indices_by_speaker: dict[int, list[int]] = { + speaker_idx: [] for speaker_idx in idx_to_speaker + } + + for directory in speaker_dirs: + speaker_idx = speaker_to_idx[directory.name] + for file_path in sorted(self._audio_files(directory)): + record = SampleRecord(path=str(file_path), speaker_idx=speaker_idx) + indices_by_speaker[speaker_idx].append(len(entries)) + entries.append(record) + + return entries, speaker_to_idx, idx_to_speaker, indices_by_speaker + + @property + def num_speakers(self) -> int: + return len(self.speaker_to_idx) + + def __len__(self) -> int: + return len(self.entries) + + def _load_features(self, path: str) -> torch.Tensor: + if self.cache_features and path in self._feature_cache: + return self._feature_cache[path] + + waveform = load_audio_file(path, self.processor.sample_rate) + features = self.processor.waveform_to_log_magnitude(waveform).cpu() + if self.cache_features: + self._feature_cache[path] = features + return features + + def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, int]: + for _ in range(5): + sample = self.entries[index] + try: + source_features = self._load_features(sample.path) + source_segment = self.processor.sample_segment( + source_features, + self.segment_frames, + ) + + ref_index = random.choice(self.indices_by_speaker[sample.speaker_idx]) + ref_features = self._load_features(self.entries[ref_index].path) + ref_segment = self.processor.sample_segment( + ref_features, + self.segment_frames, + ) + return source_segment, ref_segment, sample.speaker_idx + except Exception: + index = random.randint(0, len(self.entries) - 1) + + raise RuntimeError("Failed to load valid audio sample after multiple retries.") + + +def collate_voice_batch( + batch: list[tuple[torch.Tensor, torch.Tensor, int]], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + source = torch.stack([item[0] for item in batch], dim=0) + reference = torch.stack([item[1] for item in batch], dim=0) + labels = torch.tensor([item[2] for item in batch], dtype=torch.long) + return source, reference, labels diff --git a/voice_changer/inference.py b/voice_changer/inference.py new file mode 100644 index 0000000..a81a506 --- /dev/null +++ b/voice_changer/inference.py @@ -0,0 +1,182 @@ +"""Checkpoint loading and speaker-profile utilities.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Iterable + +import torch +import torch.nn.functional as F + +from .audio import SpectrogramConfig, SpectrogramProcessor, load_audio_file +from .models import SpeakerEncoder, VoiceConverter +from .runtime import resolve_device + +AUDIO_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a"} + + +@dataclass +class InferenceBundle: + processor: SpectrogramProcessor + speaker_encoder: SpeakerEncoder + converter: VoiceConverter + speaker_to_idx: dict[str, int] + device: torch.device + checkpoint_payload: dict[str, Any] + +def _read_model_config(payload: dict[str, Any]) -> dict[str, int]: + model_config = payload.get("model_config", {}) + freq_bins = model_config.get("freq_bins") + embedding_dim = model_config.get("embedding_dim") + speaker_hidden_dim = model_config.get("speaker_hidden_dim", 256) + converter_hidden_dim = model_config.get("converter_hidden_dim", 384) + converter_layers = model_config.get("converter_layers", 2) + + if freq_bins is None: + freq_bins = payload["speaker_encoder"]["conv_stack.0.weight"].shape[1] + if embedding_dim is None: + embedding_dim = payload["speaker_encoder"]["embedding_proj.weight"].shape[0] + + return { + "freq_bins": int(freq_bins), + "embedding_dim": int(embedding_dim), + "speaker_hidden_dim": int(speaker_hidden_dim), + "converter_hidden_dim": int(converter_hidden_dim), + "converter_layers": int(converter_layers), + } + + +def load_inference_bundle( + checkpoint_path: str, + device: str = "auto", +) -> InferenceBundle: + target_device = resolve_device(device) + payload: dict[str, Any] = torch.load(checkpoint_path, map_location=target_device) + + audio_config = payload.get("audio_config", {}) + processor = SpectrogramProcessor(SpectrogramConfig(**audio_config)) + + model_config = _read_model_config(payload) + + speaker_to_idx = payload.get("speaker_to_idx", {}) + classifier_weight = payload["speaker_encoder"]["classifier.weight"] + num_speakers = int(classifier_weight.shape[0]) + if not speaker_to_idx: + speaker_to_idx = {f"speaker_{idx}": idx for idx in range(num_speakers)} + + speaker_encoder = SpeakerEncoder( + freq_bins=model_config["freq_bins"], + num_speakers=num_speakers, + embedding_dim=model_config["embedding_dim"], + hidden_dim=model_config["speaker_hidden_dim"], + ).to(target_device) + converter = VoiceConverter( + freq_bins=model_config["freq_bins"], + embedding_dim=model_config["embedding_dim"], + hidden_dim=model_config["converter_hidden_dim"], + num_layers=model_config["converter_layers"], + ).to(target_device) + + speaker_encoder.load_state_dict(payload["speaker_encoder"], strict=True) + converter.load_state_dict(payload["converter"], strict=True) + speaker_encoder.eval() + converter.eval() + + return InferenceBundle( + processor=processor, + speaker_encoder=speaker_encoder, + converter=converter, + speaker_to_idx=speaker_to_idx, + device=target_device, + checkpoint_payload=payload, + ) + + +def collect_audio_files(input_path: str) -> list[str]: + path = Path(input_path) + if path.is_file(): + return [str(path)] + if not path.exists(): + raise FileNotFoundError(f"Path does not exist: {input_path}") + + files = [ + str(file_path) + for file_path in sorted(path.rglob("*")) + if file_path.is_file() and file_path.suffix.lower() in AUDIO_EXTENSIONS + ] + if not files: + raise ValueError(f"No audio files found in: {input_path}") + return files + + +def _split_segments( + features: torch.Tensor, + segment_frames: int, + max_segments: int, +) -> Iterable[torch.Tensor]: + total_frames = features.size(0) + if total_frames <= segment_frames: + yield F.pad( + features, + (0, 0, 0, max(0, segment_frames - total_frames)), + mode="replicate", + ) + return + + step = max(1, (total_frames - segment_frames) // max(1, max_segments - 1)) + starts = list(range(0, total_frames - segment_frames + 1, step))[:max_segments] + for start in starts: + yield features[start : start + segment_frames] + + +@torch.inference_mode() +def build_voice_profile_embedding( + bundle: InferenceBundle, + sample_paths: list[str], + segment_frames: int = 96, + max_segments_per_file: int = 8, +) -> torch.Tensor: + embeddings: list[torch.Tensor] = [] + + for path in sample_paths: + waveform = load_audio_file(path, bundle.processor.sample_rate) + features = bundle.processor.waveform_to_log_magnitude(waveform).to(bundle.device) + for segment in _split_segments(features, segment_frames, max_segments_per_file): + segment = segment.unsqueeze(0) + embedding, _ = bundle.speaker_encoder(segment) + embeddings.append(embedding.squeeze(0)) + + if not embeddings: + raise ValueError("Could not extract speaker embedding from provided samples.") + + mean_embedding = torch.stack(embeddings, dim=0).mean(dim=0) + mean_embedding = F.normalize(mean_embedding, dim=-1) + return mean_embedding + + +def save_voice_profile( + output_path: str, + embedding: torch.Tensor, + source_files: list[str], + sample_rate: int, + checkpoint_path: str, +) -> None: + payload = { + "embedding": embedding.detach().cpu(), + "source_files": source_files, + "sample_rate": sample_rate, + "checkpoint_path": checkpoint_path, + "created_at_utc": datetime.now(timezone.utc).isoformat(), + } + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + torch.save(payload, output_path) + + +def load_voice_profile(profile_path: str, device: torch.device) -> torch.Tensor: + payload = torch.load(profile_path, map_location=device) + embedding = payload["embedding"] + if embedding.dim() == 1: + embedding = embedding.unsqueeze(0) + return F.normalize(embedding.to(device), dim=-1) diff --git a/voice_changer/models.py b/voice_changer/models.py new file mode 100644 index 0000000..aae1971 --- /dev/null +++ b/voice_changer/models.py @@ -0,0 +1,104 @@ +"""Neural network modules used by the voice changer.""" + +from __future__ import annotations + +import torch +from torch import nn +import torch.nn.functional as F + + +class SpeakerEncoder(nn.Module): + """Encodes speaker identity from spectral frames.""" + + def __init__( + self, + freq_bins: int, + num_speakers: int, + embedding_dim: int = 128, + hidden_dim: int = 256, + ) -> None: + super().__init__() + self.conv_stack = nn.Sequential( + nn.Conv1d(freq_bins, hidden_dim, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.BatchNorm1d(hidden_dim), + nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.BatchNorm1d(hidden_dim), + nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.BatchNorm1d(hidden_dim), + ) + self.embedding_proj = nn.Linear(hidden_dim * 2, embedding_dim) + self.classifier = nn.Linear(embedding_dim, num_speakers) + + def forward(self, features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + features: [batch, frames, freq_bins] + Returns: + embedding: [batch, embedding_dim] + logits: [batch, num_speakers] + """ + x = features.transpose(1, 2) # [B, F, T] + x = self.conv_stack(x) + pooled_mean = x.mean(dim=-1) + pooled_std = x.std(dim=-1).clamp_min(1e-6) + pooled = torch.cat([pooled_mean, pooled_std], dim=-1) + embedding = F.normalize(self.embedding_proj(pooled), dim=-1) + logits = self.classifier(embedding) + return embedding, logits + + +class VoiceConverter(nn.Module): + """Sequence model that shifts spectral timbre toward target embedding.""" + + def __init__( + self, + freq_bins: int, + embedding_dim: int = 128, + hidden_dim: int = 384, + num_layers: int = 2, + max_log_magnitude: float = 12.0, + ) -> None: + super().__init__() + self.max_log_magnitude = max_log_magnitude + self.input_proj = nn.Linear(freq_bins + embedding_dim, hidden_dim) + self.gru = nn.GRU( + input_size=hidden_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + ) + self.output_proj = nn.Linear(hidden_dim, freq_bins) + + def forward( + self, + source_features: torch.Tensor, + target_embedding: torch.Tensor, + hidden_state: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + source_features: [batch, frames, freq_bins] + target_embedding: [batch, embedding_dim] + hidden_state: optional GRU state for streaming + """ + if source_features.dim() != 3: + raise ValueError("source_features must be [batch, frames, freq_bins].") + if target_embedding.dim() != 2: + raise ValueError("target_embedding must be [batch, embedding_dim].") + if source_features.size(0) != target_embedding.size(0): + raise ValueError("Batch size mismatch between source and target embedding.") + + condition = target_embedding.unsqueeze(1).expand(-1, source_features.size(1), -1) + model_input = torch.cat([source_features, condition], dim=-1) + model_input = torch.tanh(self.input_proj(model_input)) + + if hidden_state is not None and hidden_state.size(1) != source_features.size(0): + hidden_state = None + + hidden_output, next_hidden_state = self.gru(model_input, hidden_state) + delta = self.output_proj(hidden_output) + converted = torch.clamp(source_features + delta, min=0.0, max=self.max_log_magnitude) + return converted, next_hidden_state diff --git a/voice_changer/realtime.py b/voice_changer/realtime.py new file mode 100644 index 0000000..8c560aa --- /dev/null +++ b/voice_changer/realtime.py @@ -0,0 +1,133 @@ +"""Real-time microphone voice conversion.""" + +from __future__ import annotations + +import threading +from typing import Any + +import numpy as np +import torch + +from .inference import InferenceBundle + +try: + import sounddevice as sd +except ImportError as error: # pragma: no cover - import guard for optional runtime dep + sd = None + _SOUNDDEVICE_IMPORT_ERROR = error +else: + _SOUNDDEVICE_IMPORT_ERROR = None + + +class RealtimeVoiceChanger: + """Processes microphone audio in chunks and outputs converted voice.""" + + def __init__( + self, + bundle: InferenceBundle, + target_embedding: torch.Tensor, + block_size: int = 1024, + input_device: int | str | None = None, + output_device: int | str | None = None, + latency: str | float = "low", + ) -> None: + if sd is None: + raise ImportError( + "sounddevice is required for realtime mode. " + "Install dependencies from requirements.txt." + ) from _SOUNDDEVICE_IMPORT_ERROR + + if block_size <= 0: + raise ValueError("block_size must be positive.") + + self.bundle = bundle + self.block_size = block_size + self.input_device = input_device + self.output_device = output_device + self.latency = latency + + self.sample_rate = bundle.processor.sample_rate + self.target_embedding = target_embedding.to(bundle.device) + self.hidden_state: torch.Tensor | None = None + self._lock = threading.Lock() + + def reset_state(self) -> None: + with self._lock: + self.hidden_state = None + + @torch.inference_mode() + def process_block(self, audio_block: np.ndarray) -> np.ndarray: + if audio_block.ndim != 1: + raise ValueError("Expected mono block as 1D array.") + + if audio_block.dtype != np.float32: + audio_block = audio_block.astype(np.float32) + + waveform = torch.from_numpy(audio_block) + log_magnitude, phase = self.bundle.processor.waveform_to_log_magnitude_and_phase(waveform) + source = log_magnitude.unsqueeze(0).to(self.bundle.device) + + with self._lock: + converted, self.hidden_state = self.bundle.converter( + source_features=source, + target_embedding=self.target_embedding, + hidden_state=self.hidden_state, + ) + if self.hidden_state is not None: + self.hidden_state = self.hidden_state.detach() + + converted = converted.squeeze(0).cpu() + reconstructed = self.bundle.processor.log_magnitude_to_waveform( + log_magnitude=converted, + phase=phase, + length=audio_block.shape[0], + ) + reconstructed = reconstructed.clamp(min=-1.0, max=1.0) + return reconstructed.numpy().astype(np.float32) + + def _callback( + self, + indata: np.ndarray, + outdata: np.ndarray, + frames: int, + _time_info: Any, + status: Any, + ) -> None: + if status: + print(f"[sounddevice] {status}") + + input_mono = indata[:, 0].copy() + converted = self.process_block(input_mono) + + if converted.shape[0] < frames: + converted = np.pad(converted, (0, frames - converted.shape[0]), mode="constant") + elif converted.shape[0] > frames: + converted = converted[:frames] + + outdata[:, 0] = converted + + def run(self) -> None: + stream = sd.Stream( + samplerate=self.sample_rate, + blocksize=self.block_size, + dtype="float32", + channels=1, + callback=self._callback, + device=(self.input_device, self.output_device), + latency=self.latency, + ) + with stream: + print( + "Realtime voice changer running. " + "Press Ctrl+C to stop, or speak into microphone." + ) + while True: + sd.sleep(250) + + +def list_audio_devices() -> None: + if sd is None: + raise ImportError( + "sounddevice is required for listing devices." + ) from _SOUNDDEVICE_IMPORT_ERROR + print(sd.query_devices()) diff --git a/voice_changer/runtime.py b/voice_changer/runtime.py new file mode 100644 index 0000000..78f21ce --- /dev/null +++ b/voice_changer/runtime.py @@ -0,0 +1,77 @@ +"""Runtime helpers for backend selection and AMD compatibility.""" + +from __future__ import annotations + +from dataclasses import dataclass +import os + + +def configure_amd_runtime(amd_gfx_version: str | None) -> str | None: + """ + Sets HSA_OVERRIDE_GFX_VERSION for older AMD GPUs (e.g. RX 470 => 8.0.3). + Must be called before importing torch. + """ + if amd_gfx_version is None: + return os.environ.get("HSA_OVERRIDE_GFX_VERSION") + os.environ["HSA_OVERRIDE_GFX_VERSION"] = amd_gfx_version + return amd_gfx_version + + +def resolve_device(device_name: str = "auto") -> "torch.device": + import torch + + if device_name == "auto": + if torch.cuda.is_available(): + return torch.device("cuda") + mps_backend = getattr(torch.backends, "mps", None) + if mps_backend is not None and mps_backend.is_available(): + return torch.device("mps") + return torch.device("cpu") + return torch.device(device_name) + + +def backend_name_for(device: "torch.device") -> str: + import torch + + if device.type == "cuda" and getattr(torch.version, "hip", None): + return "rocm" + return device.type + + +def should_enable_amp(device: "torch.device", amp_mode: str) -> bool: + """ + Returns whether autocast/GradScaler should be used. + For ROCm we disable AMP by default for better compatibility. + """ + if amp_mode not in {"auto", "on", "off"}: + raise ValueError(f"amp_mode must be one of auto/on/off, got: {amp_mode}") + + if amp_mode == "off": + return False + + if amp_mode == "on": + return device.type == "cuda" + + if device.type != "cuda": + return False + + import torch + + is_rocm = bool(getattr(torch.version, "hip", None)) + if is_rocm: + return False + return True + + +@dataclass(frozen=True) +class RuntimeInfo: + device: "torch.device" + backend: str + amp_enabled: bool + + +def describe_runtime(device_name: str = "auto", amp_mode: str = "auto") -> RuntimeInfo: + device = resolve_device(device_name) + backend = backend_name_for(device) + amp_enabled = should_enable_amp(device, amp_mode=amp_mode) + return RuntimeInfo(device=device, backend=backend, amp_enabled=amp_enabled) diff --git a/voice_changer/train.py b/voice_changer/train.py new file mode 100644 index 0000000..6d0d3cb --- /dev/null +++ b/voice_changer/train.py @@ -0,0 +1,322 @@ +"""Training loop for the neural voice changer.""" + +from __future__ import annotations + +import contextlib +from dataclasses import asdict, dataclass +from pathlib import Path +import random +from typing import Any + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.utils.data import DataLoader +from tqdm import tqdm + +from .audio import SpectrogramConfig, SpectrogramProcessor +from .data import MultiSpeakerSpectrogramDataset, collate_voice_batch +from .models import SpeakerEncoder, VoiceConverter +from .runtime import describe_runtime + + +@dataclass +class TrainConfig: + data_dir: str + output_dir: str + epochs: int = 40 + batch_size: int = 12 + segment_frames: int = 64 + learning_rate: float = 3e-4 + weight_decay: float = 1e-4 + embedding_dim: int = 128 + speaker_hidden_dim: int = 256 + converter_hidden_dim: int = 384 + converter_layers: int = 2 + num_workers: int = 0 + sample_rate: int = 16_000 + n_fft: int = 512 + hop_length: int = 128 + win_length: int = 512 + identity_loss_weight: float = 1.0 + cycle_loss_weight: float = 0.75 + speaker_loss_weight: float = 0.4 + transfer_speaker_loss_weight: float = 0.8 + gradient_clip: float = 1.0 + save_every: int = 5 + seed: int = 42 + cache_features: bool = False + resume_checkpoint: str | None = None + device: str = "auto" + amp_mode: str = "auto" + +def _set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def _load_forgiving(module: nn.Module, state_dict: dict[str, torch.Tensor]) -> None: + current_state = module.state_dict() + filtered_state = {} + for key, value in state_dict.items(): + if key in current_state and current_state[key].shape == value.shape: + filtered_state[key] = value + module.load_state_dict(filtered_state, strict=False) + + +def _save_checkpoint( + path: str, + epoch: int, + speaker_encoder: SpeakerEncoder, + converter: VoiceConverter, + optimizer: torch.optim.Optimizer, + scheduler: CosineAnnealingLR, + train_config: TrainConfig, + processor: SpectrogramProcessor, + speaker_to_idx: dict[str, int], +) -> None: + payload: dict[str, Any] = { + "epoch": epoch, + "speaker_encoder": speaker_encoder.state_dict(), + "converter": converter.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "speaker_to_idx": speaker_to_idx, + "audio_config": processor.to_dict(), + "model_config": { + "embedding_dim": train_config.embedding_dim, + "speaker_hidden_dim": train_config.speaker_hidden_dim, + "converter_hidden_dim": train_config.converter_hidden_dim, + "converter_layers": train_config.converter_layers, + "freq_bins": processor.freq_bins, + }, + "train_config": asdict(train_config), + } + Path(path).parent.mkdir(parents=True, exist_ok=True) + torch.save(payload, path) + + +def train_voice_changer(config: TrainConfig) -> Path: + """Trains the model and returns path to latest checkpoint.""" + _set_seed(config.seed) + runtime = describe_runtime(device_name=config.device, amp_mode=config.amp_mode) + device = runtime.device + amp_enabled = runtime.amp_enabled and device.type == "cuda" + print( + f"Using device: {device} " + f"(backend={runtime.backend}, amp={'on' if amp_enabled else 'off'})" + ) + if runtime.backend == "rocm" and config.amp_mode == "auto": + print( + "ROCm backend detected: AMP disabled by default for compatibility. " + "You can force it with --amp-mode on if your stack is stable." + ) + + processor = SpectrogramProcessor( + SpectrogramConfig( + sample_rate=config.sample_rate, + n_fft=config.n_fft, + hop_length=config.hop_length, + win_length=config.win_length, + ) + ) + + dataset = MultiSpeakerSpectrogramDataset( + data_root=config.data_dir, + processor=processor, + segment_frames=config.segment_frames, + cache_features=config.cache_features, + ) + + if dataset.num_speakers < 2: + print( + "Warning: detected only one speaker in dataset. " + "Model can train, but cross-speaker conversion quality will be limited." + ) + + dataloader = DataLoader( + dataset, + batch_size=config.batch_size, + shuffle=True, + num_workers=config.num_workers, + drop_last=True, + collate_fn=collate_voice_batch, + pin_memory=device.type == "cuda", + ) + + if len(dataloader) == 0: + raise ValueError( + "Dataloader is empty. Increase dataset size or decrease batch size." + ) + + speaker_encoder = SpeakerEncoder( + freq_bins=processor.freq_bins, + num_speakers=dataset.num_speakers, + embedding_dim=config.embedding_dim, + hidden_dim=config.speaker_hidden_dim, + ).to(device) + converter = VoiceConverter( + freq_bins=processor.freq_bins, + embedding_dim=config.embedding_dim, + hidden_dim=config.converter_hidden_dim, + num_layers=config.converter_layers, + ).to(device) + + optimizer = AdamW( + list(speaker_encoder.parameters()) + list(converter.parameters()), + lr=config.learning_rate, + weight_decay=config.weight_decay, + ) + scheduler = CosineAnnealingLR( + optimizer, + T_max=max(1, config.epochs * len(dataloader)), + ) + + start_epoch = 1 + if config.resume_checkpoint: + checkpoint = torch.load(config.resume_checkpoint, map_location=device) + if "speaker_encoder" in checkpoint: + _load_forgiving(speaker_encoder, checkpoint["speaker_encoder"]) + if "converter" in checkpoint: + _load_forgiving(converter, checkpoint["converter"]) + if "optimizer" in checkpoint: + try: + optimizer.load_state_dict(checkpoint["optimizer"]) + except Exception: + print("Warning: optimizer state not restored, continuing with new optimizer.") + if "scheduler" in checkpoint: + try: + scheduler.load_state_dict(checkpoint["scheduler"]) + except Exception: + print("Warning: scheduler state not restored, continuing with new scheduler.") + start_epoch = int(checkpoint.get("epoch", 0)) + 1 + print(f"Resumed from checkpoint: {config.resume_checkpoint} (epoch {start_epoch})") + + scaler = torch.amp.GradScaler(enabled=amp_enabled) + + output_dir = Path(config.output_dir) + checkpoint_dir = output_dir / "checkpoints" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + for epoch in range(start_epoch, config.epochs + 1): + speaker_encoder.train() + converter.train() + + running = { + "total": 0.0, + "identity": 0.0, + "cycle": 0.0, + "speaker": 0.0, + "transfer_speaker": 0.0, + } + + progress = tqdm( + dataloader, + desc=f"Epoch {epoch}/{config.epochs}", + dynamic_ncols=True, + ) + for source, reference, labels in progress: + source = source.to(device) + reference = reference.to(device) + labels = labels.to(device) + + permutation = torch.randperm(source.size(0), device=device) + target_reference = reference[permutation] + target_labels = labels[permutation] + + optimizer.zero_grad(set_to_none=True) + + if device.type == "cuda": + autocast_context = torch.autocast( + device_type="cuda", + enabled=amp_enabled, + ) + else: + autocast_context = contextlib.nullcontext() + + with autocast_context: + source_embedding, source_logits = speaker_encoder(reference) + target_embedding, _ = speaker_encoder(target_reference) + + identity_output, _ = converter(source, source_embedding) + transfer_output, _ = converter(source, target_embedding) + cycle_output, _ = converter(transfer_output, source_embedding.detach()) + _, transfer_logits = speaker_encoder(transfer_output) + + identity_loss = F.l1_loss(identity_output, source) + cycle_loss = F.l1_loss(cycle_output, source) + speaker_loss = F.cross_entropy(source_logits, labels) + transfer_speaker_loss = F.cross_entropy(transfer_logits, target_labels) + total_loss = ( + config.identity_loss_weight * identity_loss + + config.cycle_loss_weight * cycle_loss + + config.speaker_loss_weight * speaker_loss + + config.transfer_speaker_loss_weight * transfer_speaker_loss + ) + + scaler.scale(total_loss).backward() + scaler.unscale_(optimizer) + nn.utils.clip_grad_norm_( + list(speaker_encoder.parameters()) + list(converter.parameters()), + max_norm=config.gradient_clip, + ) + scaler.step(optimizer) + scaler.update() + scheduler.step() + + running["total"] += float(total_loss.item()) + running["identity"] += float(identity_loss.item()) + running["cycle"] += float(cycle_loss.item()) + running["speaker"] += float(speaker_loss.item()) + running["transfer_speaker"] += float(transfer_speaker_loss.item()) + step = max(1, progress.n + 1) + progress.set_postfix( + loss=f"{running['total'] / step:.4f}", + id=f"{running['identity'] / step:.4f}", + cyc=f"{running['cycle'] / step:.4f}", + ) + + num_steps = len(dataloader) + print( + f"Epoch {epoch}: " + f"total={running['total'] / num_steps:.4f}, " + f"id={running['identity'] / num_steps:.4f}, " + f"cycle={running['cycle'] / num_steps:.4f}, " + f"spk={running['speaker'] / num_steps:.4f}, " + f"spk_xfer={running['transfer_speaker'] / num_steps:.4f}" + ) + + latest_path = checkpoint_dir / "latest.pt" + _save_checkpoint( + path=str(latest_path), + epoch=epoch, + speaker_encoder=speaker_encoder, + converter=converter, + optimizer=optimizer, + scheduler=scheduler, + train_config=config, + processor=processor, + speaker_to_idx=dataset.speaker_to_idx, + ) + + if epoch % config.save_every == 0 or epoch == config.epochs: + epoch_path = checkpoint_dir / f"epoch_{epoch:03d}.pt" + _save_checkpoint( + path=str(epoch_path), + epoch=epoch, + speaker_encoder=speaker_encoder, + converter=converter, + optimizer=optimizer, + scheduler=scheduler, + train_config=config, + processor=processor, + speaker_to_idx=dataset.speaker_to_idx, + ) + + return checkpoint_dir / "latest.pt"