diff --git a/client/src/pages/Conversation/hooks/useVoiceConversion.ts b/client/src/pages/Conversation/hooks/useVoiceConversion.ts new file mode 100644 index 00000000..51ae2176 --- /dev/null +++ b/client/src/pages/Conversation/hooks/useVoiceConversion.ts @@ -0,0 +1,306 @@ +/** + * Voice Conversion Hook + * + * Provides real-time voice conversion functionality using StreamVC-style + * architecture. Connects to the voice conversion WebSocket endpoint and + * streams audio for conversion. + * + * Usage: + * const { start, stop, status, setTargetVoice } = useVoiceConversion({ + * onConvertedAudio: (audio) => playAudio(audio), + * }); + */ + +import { useState, useCallback, useRef, useEffect } from "react"; +import Recorder from "opus-recorder"; + +export type VCStatus = + | "idle" + | "connecting" + | "awaiting_reference" + | "reference_ready" + | "converting" + | "error" + | "disconnected"; + +export interface VoiceConversionOptions { + /** Callback when converted audio is received */ + onConvertedAudio?: (audioData: ArrayBuffer) => void; + /** Callback when status changes */ + onStatusChange?: (status: VCStatus) => void; + /** Callback on error */ + onError?: (error: string) => void; + /** Server URL (defaults to current host with /api/vc path) */ + serverUrl?: string; + /** Sample rate (default: 24000) */ + sampleRate?: number; +} + +export interface VoiceConversionResult { + /** Current status */ + status: VCStatus; + /** Start voice conversion with optional target voice */ + start: (options?: { voice?: string; referenceMode?: boolean }) => Promise; + /** Stop voice conversion */ + stop: () => void; + /** Set target voice by name (requires voice embeddings on server) */ + setTargetVoice: (voiceName: string) => void; + /** Send reference audio for voice cloning */ + sendReferenceAudio: (audioData: ArrayBuffer) => void; + /** Signal end of reference audio collection */ + endReferenceCollection: () => void; + /** Whether currently recording/converting */ + isActive: boolean; + /** Available voices (fetched from server) */ + availableVoices: string[]; + /** Fetch available voices from server */ + fetchVoices: () => Promise; +} + +export const useVoiceConversion = ( + options: VoiceConversionOptions = {} +): VoiceConversionResult => { + const { + onConvertedAudio, + onStatusChange, + onError, + serverUrl, + sampleRate = 24000, + } = options; + + const [status, setStatus] = useState("idle"); + const [availableVoices, setAvailableVoices] = useState([]); + const [isActive, setIsActive] = useState(false); + + const socketRef = useRef(null); + const recorderRef = useRef(null); + const targetVoiceRef = useRef(null); + + // Update status and notify + const updateStatus = useCallback( + (newStatus: VCStatus) => { + setStatus(newStatus); + onStatusChange?.(newStatus); + }, + [onStatusChange] + ); + + // Get WebSocket URL + const getWsUrl = useCallback( + (voice?: string, referenceMode?: boolean) => { + const base = + serverUrl || + `${window.location.protocol === "https:" ? "wss:" : "ws:"}//${window.location.host}/api/vc`; + + const params = new URLSearchParams(); + if (voice) params.set("voice", voice); + if (referenceMode) params.set("reference_mode", "true"); + + const queryString = params.toString(); + return queryString ? `${base}?${queryString}` : base; + }, + [serverUrl] + ); + + // Fetch available voices from server + const fetchVoices = useCallback(async (): Promise => { + try { + const baseUrl = serverUrl?.replace(/^wss?:/, "http") || ""; + const url = baseUrl + ? `${baseUrl.replace("/api/vc", "")}/api/vc/voices` + : "/api/vc/voices"; + + const response = await fetch(url); + if (!response.ok) { + throw new Error(`Failed to fetch voices: ${response.statusText}`); + } + + const data = await response.json(); + const voices = data.voices || []; + setAvailableVoices(voices); + return voices; + } catch (error) { + console.error("Failed to fetch voices:", error); + return []; + } + }, [serverUrl]); + + // Set target voice + const setTargetVoice = useCallback((voiceName: string) => { + targetVoiceRef.current = voiceName; + }, []); + + // Send reference audio + const sendReferenceAudio = useCallback((audioData: ArrayBuffer) => { + if (socketRef.current && socketRef.current.readyState === WebSocket.OPEN) { + // Prepend message type byte (0x01 for audio) + const message = new Uint8Array(audioData.byteLength + 1); + message[0] = 0x01; + message.set(new Uint8Array(audioData), 1); + socketRef.current.send(message.buffer); + } + }, []); + + // Signal end of reference collection + const endReferenceCollection = useCallback(() => { + if (socketRef.current && socketRef.current.readyState === WebSocket.OPEN) { + // Send control message (0x03) with "end_reference" + const encoder = new TextEncoder(); + const text = encoder.encode("end_reference"); + const message = new Uint8Array(text.length + 1); + message[0] = 0x03; + message.set(text, 1); + socketRef.current.send(message.buffer); + } + }, []); + + // Start voice conversion + const start = useCallback( + async (startOptions?: { voice?: string; referenceMode?: boolean }) => { + const voice = startOptions?.voice || targetVoiceRef.current || undefined; + const referenceMode = startOptions?.referenceMode || false; + + try { + updateStatus("connecting"); + + // Create WebSocket connection + const wsUrl = getWsUrl(voice, referenceMode); + const ws = new WebSocket(wsUrl); + ws.binaryType = "arraybuffer"; + + ws.onopen = () => { + console.log("Voice conversion WebSocket connected"); + }; + + ws.onmessage = (event) => { + const data = new Uint8Array(event.data); + if (data.length === 0) return; + + const messageType = data[0]; + const payload = data.slice(1); + + switch (messageType) { + case 0x00: // Handshake + console.log("Voice conversion handshake received"); + updateStatus(referenceMode ? "awaiting_reference" : "converting"); + break; + + case 0x01: // Audio + onConvertedAudio?.(payload.buffer); + break; + + case 0x03: // Control + const controlMsg = new TextDecoder().decode(payload); + if (controlMsg === "awaiting_reference") { + updateStatus("awaiting_reference"); + } else if (controlMsg === "reference_ready") { + updateStatus("converting"); + } + break; + + case 0x05: // Error + const errorMsg = new TextDecoder().decode(payload); + onError?.(errorMsg); + updateStatus("error"); + break; + } + }; + + ws.onclose = () => { + console.log("Voice conversion WebSocket closed"); + updateStatus("disconnected"); + setIsActive(false); + }; + + ws.onerror = (error) => { + console.error("Voice conversion WebSocket error:", error); + onError?.("WebSocket connection error"); + updateStatus("error"); + }; + + socketRef.current = ws; + + // Start microphone recording + const encoderPath = new URL( + "opus-recorder/dist/encoderWorker.min.js", + import.meta.url + ).href; + + const recorder = new Recorder({ + encoderPath, + encoderSampleRate: sampleRate, + encoderFrameSize: 20, // 20ms frames + maxFramesPerPage: 2, // 40ms packets + numberOfChannels: 1, + encoderApplication: 2049, // VOIP mode + encoderComplexity: 0, // Low CPU + }); + + recorder.ondataavailable = (opusData: ArrayBuffer) => { + if ( + socketRef.current && + socketRef.current.readyState === WebSocket.OPEN && + status === "converting" + ) { + // Send audio with message type byte + const message = new Uint8Array(opusData.byteLength + 1); + message[0] = 0x01; + message.set(new Uint8Array(opusData), 1); + socketRef.current.send(message.buffer); + } + }; + + await recorder.start(); + recorderRef.current = recorder; + setIsActive(true); + + console.log("Voice conversion started"); + } catch (error) { + console.error("Failed to start voice conversion:", error); + onError?.(error instanceof Error ? error.message : "Unknown error"); + updateStatus("error"); + } + }, + [getWsUrl, onConvertedAudio, onError, sampleRate, status, updateStatus] + ); + + // Stop voice conversion + const stop = useCallback(() => { + // Stop recorder + if (recorderRef.current) { + recorderRef.current.stop(); + recorderRef.current = null; + } + + // Close WebSocket + if (socketRef.current) { + socketRef.current.close(); + socketRef.current = null; + } + + setIsActive(false); + updateStatus("idle"); + console.log("Voice conversion stopped"); + }, [updateStatus]); + + // Cleanup on unmount + useEffect(() => { + return () => { + stop(); + }; + }, [stop]); + + return { + status, + start, + stop, + setTargetVoice, + sendReferenceAudio, + endReferenceCollection, + isActive, + availableVoices, + fetchVoices, + }; +}; + +export default useVoiceConversion; diff --git a/moshi/moshi/models/voice_conversion.py b/moshi/moshi/models/voice_conversion.py new file mode 100644 index 00000000..ba3baf69 --- /dev/null +++ b/moshi/moshi/models/voice_conversion.py @@ -0,0 +1,768 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Real-time Voice Conversion Module +# StreamVC-style architecture for low-latency voice conversion +# +# Architecture based on: +# - StreamVC: Real-Time Low-Latency Voice Conversion (Google, ICASSP 2024) +# - SoundStream neural audio codec design +# - HuBERT soft speech units for content encoding + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, List +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..modules.streaming import StreamingModule, StreamingContainer +from ..modules.conv import StreamingConv1d, StreamingConvTranspose1d + + +@dataclass +class VoiceConversionState: + """State for streaming voice conversion.""" + content_buffer: torch.Tensor + f0_buffer: torch.Tensor + output_buffer: torch.Tensor + + def reset(self): + self.content_buffer.zero_() + self.f0_buffer.zero_() + self.output_buffer.zero_() + + +class CausalConvBlock(StreamingContainer): + """Causal convolutional block with residual connection. + + Based on SoundStream/StreamVC encoder architecture. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 7, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + ): + super().__init__() + self.conv = StreamingConv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + causal=True, + norm="weight_norm", + pad_mode="constant", + ) + self.activation = nn.ELU(alpha=1.0) + + # Residual connection if channels match + self.residual = ( + nn.Identity() if in_channels == out_channels and stride == 1 + else StreamingConv1d(in_channels, out_channels, kernel_size=1, stride=stride, causal=True) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = self.residual(x) + x = self.conv(x) + x = self.activation(x) + return x + residual + + +class CausalEncoderBlock(StreamingContainer): + """Encoder block with downsampling for content encoder. + + Each block: Conv -> ELU -> Conv (stride) -> ELU + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 2, + num_residual: int = 3, + ): + super().__init__() + + # Residual blocks before downsampling + layers = [] + for i in range(num_residual): + ch = in_channels if i == 0 else out_channels + layers.append(CausalConvBlock(ch, out_channels, kernel_size=7, dilation=3**i)) + + # Downsampling convolution + layers.append( + StreamingConv1d( + out_channels, + out_channels, + kernel_size=stride * 2, + stride=stride, + causal=True, + norm="weight_norm", + ) + ) + layers.append(nn.ELU(alpha=1.0)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layers(x) + + +class CausalDecoderBlock(StreamingContainer): + """Decoder block with upsampling for audio synthesis. + + Each block: ConvTranspose (stride) -> ELU -> Conv -> ELU + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 2, + num_residual: int = 3, + ): + super().__init__() + + # Upsampling transposed convolution + self.upsample = StreamingConvTranspose1d( + in_channels, + out_channels, + kernel_size=stride * 2, + stride=stride, + causal=True, + norm="weight_norm", + ) + self.activation = nn.ELU(alpha=1.0) + + # Residual blocks after upsampling + layers = [] + for i in range(num_residual): + layers.append(CausalConvBlock(out_channels, out_channels, kernel_size=7, dilation=3**i)) + + self.residual_layers = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.upsample(x) + x = self.activation(x) + x = self.residual_layers(x) + return x + + +class FiLMLayer(nn.Module): + """Feature-wise Linear Modulation layer for speaker conditioning. + + Applies affine transformation: gamma * x + beta + where gamma and beta are derived from speaker embedding. + """ + + def __init__(self, channels: int, speaker_dim: int): + super().__init__() + self.gamma_proj = nn.Linear(speaker_dim, channels) + self.beta_proj = nn.Linear(speaker_dim, channels) + + # Initialize to identity transform + nn.init.ones_(self.gamma_proj.weight.data[:, 0]) + nn.init.zeros_(self.gamma_proj.weight.data[:, 1:]) + nn.init.zeros_(self.gamma_proj.bias.data) + nn.init.zeros_(self.beta_proj.weight.data) + nn.init.zeros_(self.beta_proj.bias.data) + + def forward(self, x: torch.Tensor, speaker_emb: torch.Tensor) -> torch.Tensor: + """ + Args: + x: [B, C, T] audio features + speaker_emb: [B, speaker_dim] speaker embedding + + Returns: + Modulated features [B, C, T] + """ + gamma = self.gamma_proj(speaker_emb).unsqueeze(-1) # [B, C, 1] + beta = self.beta_proj(speaker_emb).unsqueeze(-1) # [B, C, 1] + return gamma * x + beta + + +class ContentEncoder(StreamingContainer): + """Causal content encoder that extracts linguistic features. + + Based on SoundStream encoder architecture with HuBERT soft unit targets. + Produces content embeddings that are speaker-independent. + + Architecture: + Input: 24kHz audio waveform + Output: Content embeddings at 50Hz (downsampled 480x) + + Strides: [2, 4, 5, 6] -> total 240x downsampling + At 24kHz: 240 samples = 10ms per frame -> 100Hz + With additional 2x downsample -> 50Hz + """ + + def __init__( + self, + in_channels: int = 1, + channels: int = 64, + latent_dim: int = 64, + num_residual: int = 3, + strides: List[int] = [2, 4, 5, 6], # Total: 240x + ): + super().__init__() + + self.in_channels = in_channels + self.channels = channels + self.latent_dim = latent_dim + self.strides = strides + self.hop_length = math.prod(strides) # 240 + + # Initial convolution + self.input_conv = StreamingConv1d( + in_channels, channels, kernel_size=7, causal=True, norm="weight_norm" + ) + self.input_activation = nn.ELU(alpha=1.0) + + # Encoder blocks with increasing channels + self.encoder_blocks = nn.ModuleList() + in_ch = channels + for i, stride in enumerate(strides): + out_ch = min(channels * (2 ** (i + 1)), 512) + self.encoder_blocks.append( + CausalEncoderBlock(in_ch, out_ch, stride=stride, num_residual=num_residual) + ) + in_ch = out_ch + + # Final projection to latent dimension + self.output_conv = StreamingConv1d( + in_ch, latent_dim, kernel_size=3, causal=True, norm="weight_norm" + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: [B, 1, T] input audio waveform + + Returns: + [B, latent_dim, T'] content embeddings (T' = T / hop_length) + """ + x = self.input_conv(x) + x = self.input_activation(x) + + for block in self.encoder_blocks: + x = block(x) + + x = self.output_conv(x) + return x + + +class F0Extractor(StreamingContainer): + """Causal F0 (pitch) extractor with whitening. + + Extracts fundamental frequency while preventing speaker identity leakage. + Uses a learned convolutional network instead of traditional DSP methods + for causal, low-latency operation. + + The whitening process normalizes F0 to remove speaker-specific pitch range + while preserving the pitch contour (prosody). + """ + + def __init__( + self, + in_channels: int = 1, + hidden_channels: int = 64, + out_channels: int = 1, + hop_length: int = 240, # Match content encoder + ): + super().__init__() + + self.hop_length = hop_length + + # Causal F0 estimation network + self.conv1 = StreamingConv1d( + in_channels, hidden_channels, kernel_size=7, causal=True, norm="weight_norm" + ) + self.conv2 = StreamingConv1d( + hidden_channels, hidden_channels, kernel_size=7, stride=4, causal=True, norm="weight_norm" + ) + self.conv3 = StreamingConv1d( + hidden_channels, hidden_channels, kernel_size=7, stride=4, causal=True, norm="weight_norm" + ) + self.conv4 = StreamingConv1d( + hidden_channels, hidden_channels, kernel_size=7, stride=3, causal=True, norm="weight_norm" + ) + self.conv5 = StreamingConv1d( + hidden_channels, hidden_channels, kernel_size=7, stride=5, causal=True, norm="weight_norm" + ) + self.output_conv = StreamingConv1d( + hidden_channels, out_channels, kernel_size=3, causal=True + ) + + self.activation = nn.ELU(alpha=1.0) + + # Running statistics for whitening (momentum-based) + self.register_buffer('running_mean', torch.zeros(1)) + self.register_buffer('running_std', torch.ones(1)) + self.momentum = 0.1 + + def forward(self, x: torch.Tensor, update_stats: bool = True) -> torch.Tensor: + """ + Args: + x: [B, 1, T] input audio waveform + update_stats: Whether to update running statistics for whitening + + Returns: + [B, 1, T'] whitened F0 contour (T' = T / hop_length) + """ + # Extract F0 features + h = self.activation(self.conv1(x)) + h = self.activation(self.conv2(h)) + h = self.activation(self.conv3(h)) + h = self.activation(self.conv4(h)) + h = self.activation(self.conv5(h)) + f0 = self.output_conv(h) + + # Whitening: normalize to zero mean, unit variance + if update_stats and self.training: + batch_mean = f0.mean() + batch_std = f0.std() + 1e-6 + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean + self.running_std = (1 - self.momentum) * self.running_std + self.momentum * batch_std + + # Apply whitening + f0_whitened = (f0 - self.running_mean) / (self.running_std + 1e-6) + + return f0_whitened + + +class SpeakerEncoder(nn.Module): + """Speaker encoder that extracts speaker identity embedding. + + Produces a global speaker embedding from a reference audio clip. + Can also load pre-computed speaker embeddings. + + Architecture based on d-vector / x-vector speaker verification systems. + """ + + def __init__( + self, + in_channels: int = 80, # Mel spectrogram channels + hidden_channels: int = 256, + speaker_dim: int = 256, + num_layers: int = 3, + ): + super().__init__() + + self.speaker_dim = speaker_dim + + # Frame-level feature extraction + self.frame_layers = nn.ModuleList() + in_ch = in_channels + for i in range(num_layers): + out_ch = hidden_channels + self.frame_layers.append(nn.Sequential( + nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2), + nn.BatchNorm1d(out_ch), + nn.ReLU(), + )) + in_ch = out_ch + + # Temporal pooling (attention-based) + self.attention = nn.Sequential( + nn.Linear(hidden_channels, hidden_channels), + nn.Tanh(), + nn.Linear(hidden_channels, 1), + ) + + # Speaker embedding projection + self.speaker_proj = nn.Sequential( + nn.Linear(hidden_channels, speaker_dim), + nn.BatchNorm1d(speaker_dim), + ) + + # Storage for pre-computed embeddings + self._cached_embedding: Optional[torch.Tensor] = None + + def forward(self, mel: torch.Tensor) -> torch.Tensor: + """ + Args: + mel: [B, mel_channels, T] mel spectrogram + + Returns: + [B, speaker_dim] speaker embedding + """ + # Frame-level processing + x = mel + for layer in self.frame_layers: + x = layer(x) + + # x: [B, hidden_channels, T] + x = x.transpose(1, 2) # [B, T, hidden_channels] + + # Attention-based pooling + attn_weights = F.softmax(self.attention(x), dim=1) # [B, T, 1] + x = (x * attn_weights).sum(dim=1) # [B, hidden_channels] + + # Project to speaker embedding + speaker_emb = self.speaker_proj(x) + + # L2 normalize + speaker_emb = F.normalize(speaker_emb, p=2, dim=-1) + + return speaker_emb + + def set_target_speaker(self, embedding: torch.Tensor): + """Set a pre-computed target speaker embedding.""" + self._cached_embedding = F.normalize(embedding, p=2, dim=-1) + + def get_target_speaker(self) -> Optional[torch.Tensor]: + """Get the cached target speaker embedding.""" + return self._cached_embedding + + def load_speaker_embedding(self, path: str, device: torch.device): + """Load a pre-computed speaker embedding from file.""" + embedding = torch.load(path, map_location=device) + if isinstance(embedding, dict): + embedding = embedding.get('speaker_embedding', embedding.get('embedding')) + self.set_target_speaker(embedding) + return embedding + + +class StreamingDecoder(StreamingContainer): + """Streaming decoder with FiLM conditioning for voice synthesis. + + Takes content embeddings, F0, and speaker embedding to synthesize + audio in the target voice. + + Architecture mirrors the content encoder with FiLM layers for + speaker conditioning at each resolution. + """ + + def __init__( + self, + content_dim: int = 64, + f0_dim: int = 1, + speaker_dim: int = 256, + channels: int = 512, + out_channels: int = 1, + num_residual: int = 3, + strides: List[int] = [6, 5, 4, 2], # Reverse of encoder + ): + super().__init__() + + self.strides = strides + + # Input projection combining content and F0 + self.input_proj = StreamingConv1d( + content_dim + f0_dim, channels, kernel_size=7, causal=True, norm="weight_norm" + ) + self.input_activation = nn.ELU(alpha=1.0) + + # FiLM conditioning for input + self.input_film = FiLMLayer(channels, speaker_dim) + + # Decoder blocks with FiLM conditioning + self.decoder_blocks = nn.ModuleList() + self.film_layers = nn.ModuleList() + + in_ch = channels + for i, stride in enumerate(strides): + out_ch = max(channels // (2 ** (i + 1)), 64) + self.decoder_blocks.append( + CausalDecoderBlock(in_ch, out_ch, stride=stride, num_residual=num_residual) + ) + self.film_layers.append(FiLMLayer(out_ch, speaker_dim)) + in_ch = out_ch + + # Output convolution + self.output_conv = StreamingConv1d( + in_ch, out_channels, kernel_size=7, causal=True + ) + self.output_activation = nn.Tanh() + + def forward( + self, + content: torch.Tensor, + f0: torch.Tensor, + speaker_emb: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + content: [B, content_dim, T] content embeddings + f0: [B, 1, T] whitened F0 contour + speaker_emb: [B, speaker_dim] target speaker embedding + + Returns: + [B, 1, T'] synthesized audio waveform + """ + # Concatenate content and F0 + x = torch.cat([content, f0], dim=1) + + # Input projection with FiLM + x = self.input_proj(x) + x = self.input_activation(x) + x = self.input_film(x, speaker_emb) + + # Decoder blocks with FiLM conditioning + for block, film in zip(self.decoder_blocks, self.film_layers): + x = block(x) + x = film(x, speaker_emb) + + # Output + x = self.output_conv(x) + x = self.output_activation(x) + + return x + + +class StreamVCModel(StreamingModule[VoiceConversionState]): + """Complete StreamVC-style voice conversion model. + + End-to-end model for real-time voice conversion: + Input: Source audio waveform + Output: Audio waveform in target speaker's voice + + Components: + 1. Content Encoder: Extracts speaker-independent content + 2. F0 Extractor: Extracts and whitens pitch contour + 3. Speaker Encoder: Extracts/stores target speaker identity + 4. Decoder: Synthesizes audio with speaker conditioning + + Latency: ~70-100ms depending on configuration + """ + + def __init__( + self, + sample_rate: int = 24000, + channels: int = 64, + latent_dim: int = 64, + speaker_dim: int = 256, + strides: List[int] = [2, 4, 5, 6], + num_residual: int = 3, + ): + super().__init__() + + self.sample_rate = sample_rate + self.hop_length = math.prod(strides) # 240 for default strides + self.frame_size = self.hop_length # Minimum processing size + + # Components + self.content_encoder = ContentEncoder( + in_channels=1, + channels=channels, + latent_dim=latent_dim, + num_residual=num_residual, + strides=strides, + ) + + self.f0_extractor = F0Extractor( + in_channels=1, + hidden_channels=64, + out_channels=1, + hop_length=self.hop_length, + ) + + self.speaker_encoder = SpeakerEncoder( + in_channels=80, # Mel channels + hidden_channels=256, + speaker_dim=speaker_dim, + ) + + self.decoder = StreamingDecoder( + content_dim=latent_dim, + f0_dim=1, + speaker_dim=speaker_dim, + channels=512, + out_channels=1, + num_residual=num_residual, + strides=list(reversed(strides)), + ) + + # Mel spectrogram for speaker encoder + self.register_buffer('mel_basis', self._create_mel_basis()) + + def _create_mel_basis(self) -> torch.Tensor: + """Create mel filterbank.""" + import numpy as np + n_fft = 1024 + n_mels = 80 + fmin = 0 + fmax = self.sample_rate // 2 + + # Create mel filterbank using librosa-style computation + mel_low = self._hz_to_mel(fmin) + mel_high = self._hz_to_mel(fmax) + mel_points = torch.linspace(mel_low, mel_high, n_mels + 2) + hz_points = self._mel_to_hz(mel_points) + bin_points = torch.floor((n_fft + 1) * hz_points / self.sample_rate).long() + + mel_basis = torch.zeros(n_mels, n_fft // 2 + 1) + for i in range(n_mels): + for j in range(int(bin_points[i]), int(bin_points[i + 1])): + mel_basis[i, j] = (j - bin_points[i]) / (bin_points[i + 1] - bin_points[i]) + for j in range(int(bin_points[i + 1]), int(bin_points[i + 2])): + mel_basis[i, j] = (bin_points[i + 2] - j) / (bin_points[i + 2] - bin_points[i + 1]) + + return mel_basis + + @staticmethod + def _hz_to_mel(hz: torch.Tensor) -> torch.Tensor: + return 2595 * torch.log10(1 + hz / 700) + + @staticmethod + def _mel_to_hz(mel: torch.Tensor) -> torch.Tensor: + return 700 * (10 ** (mel / 2595) - 1) + + def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor: + """Compute mel spectrogram from audio.""" + # Simple STFT-based mel computation + n_fft = 1024 + hop = 256 + + # Pad for STFT + audio = F.pad(audio, (n_fft // 2, n_fft // 2), mode='reflect') + + # STFT + window = torch.hann_window(n_fft, device=audio.device) + spec = torch.stft( + audio.squeeze(1), + n_fft=n_fft, + hop_length=hop, + win_length=n_fft, + window=window, + return_complex=True, + ) + mag = spec.abs() + + # Apply mel filterbank + mel = torch.matmul(self.mel_basis.to(audio.device), mag) + mel = torch.log(mel.clamp(min=1e-5)) + + return mel + + def _init_streaming_state(self, batch_size: int) -> VoiceConversionState: + """Initialize streaming state.""" + device = next(self.parameters()).device + return VoiceConversionState( + content_buffer=torch.zeros(batch_size, 1, self.frame_size, device=device), + f0_buffer=torch.zeros(batch_size, 1, self.frame_size, device=device), + output_buffer=torch.zeros(batch_size, 1, self.frame_size, device=device), + ) + + def set_target_speaker_from_audio(self, reference_audio: torch.Tensor): + """Set target speaker from a reference audio clip. + + Args: + reference_audio: [1, 1, T] reference audio for target speaker + """ + mel = self._compute_mel(reference_audio) + speaker_emb = self.speaker_encoder(mel) + self.speaker_encoder.set_target_speaker(speaker_emb) + + def set_target_speaker_from_embedding(self, embedding: torch.Tensor): + """Set target speaker from a pre-computed embedding. + + Args: + embedding: [1, speaker_dim] or [speaker_dim] speaker embedding + """ + if embedding.dim() == 1: + embedding = embedding.unsqueeze(0) + self.speaker_encoder.set_target_speaker(embedding) + + def load_target_speaker(self, path: str): + """Load target speaker embedding from file.""" + device = next(self.parameters()).device + self.speaker_encoder.load_speaker_embedding(path, device) + + def forward( + self, + audio: torch.Tensor, + speaker_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Convert audio to target speaker's voice. + + Args: + audio: [B, 1, T] source audio waveform + speaker_emb: [B, speaker_dim] target speaker embedding. + If None, uses cached embedding from set_target_speaker_* + + Returns: + [B, 1, T] converted audio waveform + """ + # Get target speaker embedding + if speaker_emb is None: + speaker_emb = self.speaker_encoder.get_target_speaker() + if speaker_emb is None: + raise ValueError("No target speaker set. Call set_target_speaker_* first.") + # Expand to batch size + speaker_emb = speaker_emb.expand(audio.shape[0], -1) + + # Extract content (speaker-independent) + content = self.content_encoder(audio) + + # Extract and whiten F0 (prosody without speaker identity) + f0 = self.f0_extractor(audio) + + # Synthesize in target voice + converted = self.decoder(content, f0, speaker_emb) + + # Trim to match input length (accounting for any padding) + if converted.shape[-1] > audio.shape[-1]: + converted = converted[..., :audio.shape[-1]] + elif converted.shape[-1] < audio.shape[-1]: + converted = F.pad(converted, (0, audio.shape[-1] - converted.shape[-1])) + + return converted + + def convert_streaming( + self, + audio_chunk: torch.Tensor, + speaker_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Convert a single audio chunk in streaming mode. + + Args: + audio_chunk: [B, 1, frame_size] audio chunk + speaker_emb: Optional target speaker embedding + + Returns: + [B, 1, frame_size] converted audio chunk + """ + return self.forward(audio_chunk, speaker_emb) + + +def create_voice_converter( + sample_rate: int = 24000, + device: str = "cuda", + pretrained: bool = False, +) -> StreamVCModel: + """Factory function to create a voice conversion model. + + Args: + sample_rate: Audio sample rate (default: 24000 to match Mimi) + device: Device to load model on + pretrained: Whether to load pretrained weights (if available) + + Returns: + StreamVCModel instance + """ + model = StreamVCModel( + sample_rate=sample_rate, + channels=64, + latent_dim=64, + speaker_dim=256, + strides=[2, 4, 5, 6], + num_residual=3, + ) + + model = model.to(device) + + if pretrained: + # TODO: Load pretrained weights when available + pass + + return model diff --git a/moshi/moshi/vc_server.py b/moshi/moshi/vc_server.py new file mode 100644 index 00000000..83d33bc0 --- /dev/null +++ b/moshi/moshi/vc_server.py @@ -0,0 +1,434 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Real-time Voice Conversion Server +# Provides WebSocket endpoint for streaming voice conversion +# +# Architecture based on StreamVC for low-latency (~70-100ms) voice conversion +# Uses the same WebSocket protocol as the main PersonaPlex server for compatibility + +import argparse +import asyncio +from dataclasses import dataclass +import os +from pathlib import Path +import secrets +import sys +import time +from typing import Literal, Optional + +import aiohttp +from aiohttp import web +from huggingface_hub import hf_hub_download +import numpy as np +import sphn +import torch + +from .client_utils import make_log +from .models import loaders +from .models.voice_conversion import StreamVCModel, create_voice_converter +from .utils.connection import create_ssl_context, get_lan_ip +from .utils.logging import setup_logger, ColorizedLog + + +logger = setup_logger(__name__) +DeviceString = Literal["cuda"] | Literal["cpu"] + + +def torch_auto_device(requested: Optional[DeviceString] = None) -> torch.device: + """Return a torch.device based on the requested string or availability.""" + if requested is not None: + return torch.device(requested) + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +@dataclass +class VCServerState: + """State for the voice conversion server.""" + + voice_converter: StreamVCModel + lock: asyncio.Lock + sample_rate: int + frame_size: int + device: torch.device + voice_embeddings_dir: Optional[str] + + def __init__( + self, + voice_converter: StreamVCModel, + device: torch.device, + voice_embeddings_dir: Optional[str] = None, + ): + self.voice_converter = voice_converter + self.device = device + self.voice_embeddings_dir = voice_embeddings_dir + self.sample_rate = voice_converter.sample_rate + self.frame_size = voice_converter.hop_length # 240 samples at 24kHz = 10ms + self.lock = asyncio.Lock() + + # Put model in streaming mode + self.voice_converter.eval() + + def warmup(self): + """Warm up the model with dummy data.""" + logger.info("warming up voice converter...") + for _ in range(4): + chunk = torch.zeros(1, 1, self.frame_size * 8, dtype=torch.float32, device=self.device) + with torch.no_grad(): + _ = self.voice_converter(chunk) + + if self.device.type == 'cuda': + torch.cuda.synchronize() + logger.info("voice converter warmed up") + + def load_voice_embedding(self, voice_name: str) -> bool: + """Load a voice embedding by name from the embeddings directory.""" + if self.voice_embeddings_dir is None: + return False + + # Try different file extensions + for ext in ['.pt', '.pth', '.bin']: + path = os.path.join(self.voice_embeddings_dir, f"{voice_name}{ext}") + if os.path.exists(path): + try: + self.voice_converter.load_target_speaker(path) + logger.info(f"loaded voice embedding: {path}") + return True + except Exception as e: + logger.error(f"failed to load voice embedding {path}: {e}") + + return False + + def set_voice_from_audio(self, audio: np.ndarray): + """Set target voice from reference audio.""" + audio_tensor = torch.from_numpy(audio).float().to(self.device) + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) + elif audio_tensor.dim() == 2: + audio_tensor = audio_tensor.unsqueeze(0) + + with torch.no_grad(): + self.voice_converter.set_target_speaker_from_audio(audio_tensor) + + async def handle_voice_convert(self, request): + """ + WebSocket handler for real-time voice conversion. + + Protocol: + - 0x00: Handshake (server -> client) + - 0x01: Audio data (bidirectional, Opus encoded) + - 0x02: Text message (reserved for future use) + - 0x03: Control message + - 0x04: Metadata (JSON) + - 0x05: Error + + Query parameters: + - voice: Name of target voice embedding to use + - reference_mode: If "true", first audio chunk is used as voice reference + """ + ws = web.WebSocketResponse() + await ws.prepare(request) + clog = ColorizedLog.randomize() + peer = request.remote + peer_port = request.transport.get_extra_info("peername")[1] + clog.log("info", f"Voice conversion connection from {peer}:{peer_port}") + + # Get parameters + voice_name = request.query.get("voice", None) + reference_mode = request.query.get("reference_mode", "false").lower() == "true" + + # Load target voice embedding if specified + if voice_name and not reference_mode: + if not self.load_voice_embedding(voice_name): + clog.log("warning", f"voice embedding '{voice_name}' not found, using default") + + close = False + reference_audio_buffer = [] + reference_collected = not reference_mode # Skip reference collection if not in reference mode + REFERENCE_DURATION_SAMPLES = self.sample_rate * 3 # 3 seconds of reference audio + + async def recv_loop(): + nonlocal close, reference_collected + try: + async for message in ws: + if message.type == aiohttp.WSMsgType.ERROR: + clog.log("error", f"{ws.exception()}") + break + elif message.type == aiohttp.WSMsgType.CLOSED: + break + elif message.type == aiohttp.WSMsgType.CLOSE: + break + elif message.type != aiohttp.WSMsgType.BINARY: + clog.log("error", f"unexpected message type {message.type}") + continue + + message_data = message.data + if not isinstance(message_data, bytes): + clog.log("error", f"unsupported message type {type(message_data)}") + continue + if len(message_data) == 0: + clog.log("warning", "empty message") + continue + + kind = message_data[0] + if kind == 1: # Audio + payload = message_data[1:] + opus_reader.append_bytes(payload) + elif kind == 3: # Control + control = message_data[1:].decode('utf-8') if len(message_data) > 1 else "" + if control == "end_reference": + # Signal that reference audio collection is complete + if not reference_collected and len(reference_audio_buffer) > 0: + ref_audio = np.concatenate(reference_audio_buffer) + self.set_voice_from_audio(ref_audio) + reference_collected = True + clog.log("info", f"reference audio collected: {len(ref_audio)} samples") + # Send acknowledgment + await ws.send_bytes(b"\x03reference_ready") + else: + clog.log("warning", f"unknown message kind {kind}") + finally: + close = True + clog.log("info", "recv loop closed") + + async def process_loop(): + nonlocal reference_collected + all_pcm_data = None + + while True: + if close: + return + await asyncio.sleep(0.001) + + pcm = opus_reader.read_pcm() + if pcm.shape[-1] == 0: + continue + + # In reference mode, collect audio until we have enough + if not reference_collected: + reference_audio_buffer.append(pcm.copy()) + total_samples = sum(len(buf) for buf in reference_audio_buffer) + if total_samples >= REFERENCE_DURATION_SAMPLES: + ref_audio = np.concatenate(reference_audio_buffer) + self.set_voice_from_audio(ref_audio) + reference_collected = True + clog.log("info", f"auto-collected reference audio: {len(ref_audio)} samples") + await ws.send_bytes(b"\x03reference_ready") + continue + + # Normal processing + if all_pcm_data is None: + all_pcm_data = pcm + else: + all_pcm_data = np.concatenate((all_pcm_data, pcm)) + + # Process when we have enough data + # Use larger chunks for efficiency (80ms = 1920 samples at 24kHz) + process_size = self.frame_size * 8 # ~80ms + + while all_pcm_data.shape[-1] >= process_size: + chunk = all_pcm_data[:process_size] + all_pcm_data = all_pcm_data[process_size:] + + # Convert to tensor + chunk_tensor = torch.from_numpy(chunk).float().to(self.device) + chunk_tensor = chunk_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, T] + + # Voice conversion + with torch.no_grad(): + converted = self.voice_converter(chunk_tensor) + + # Convert back to numpy + converted_pcm = converted[0, 0].cpu().numpy() + + # Send to opus writer + opus_writer.append_pcm(converted_pcm) + + async def send_loop(): + while True: + if close: + return + await asyncio.sleep(0.001) + msg = opus_writer.read_bytes() + if len(msg) > 0: + await ws.send_bytes(b"\x01" + msg) + + clog.log("info", "accepted voice conversion connection") + if voice_name: + clog.log("info", f"target voice: {voice_name}") + if reference_mode: + clog.log("info", "reference mode enabled - waiting for reference audio") + + async with self.lock: + opus_writer = sphn.OpusStreamWriter(self.sample_rate) + opus_reader = sphn.OpusStreamReader(self.sample_rate) + + # Send handshake + await ws.send_bytes(b"\x00") + clog.log("info", "sent handshake") + + if reference_mode: + # Send message indicating we're waiting for reference + await ws.send_bytes(b"\x03awaiting_reference") + + # Run processing loops + tasks = [ + asyncio.create_task(recv_loop()), + asyncio.create_task(process_loop()), + asyncio.create_task(send_loop()), + ] + + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + # Clean up + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + await ws.close() + clog.log("info", "voice conversion session closed") + + return ws + + async def handle_set_voice(self, request): + """ + HTTP endpoint to set target voice from uploaded audio. + + POST /api/vc/set_voice + Body: Raw audio bytes (WAV or raw PCM) + """ + clog = ColorizedLog.randomize() + + try: + data = await request.read() + + # Try to parse as WAV + try: + audio, sr = sphn.read_buffer(data) + if sr != self.sample_rate: + audio = sphn.resample(audio, sr, self.sample_rate) + except Exception: + # Assume raw PCM float32 + audio = np.frombuffer(data, dtype=np.float32) + + self.set_voice_from_audio(audio) + clog.log("info", f"set target voice from uploaded audio: {len(audio)} samples") + + return web.json_response({"status": "ok", "samples": len(audio)}) + except Exception as e: + clog.log("error", f"failed to set voice: {e}") + return web.json_response({"status": "error", "message": str(e)}, status=400) + + async def handle_list_voices(self, request): + """ + HTTP endpoint to list available voice embeddings. + + GET /api/vc/voices + """ + voices = [] + if self.voice_embeddings_dir and os.path.exists(self.voice_embeddings_dir): + for f in os.listdir(self.voice_embeddings_dir): + if f.endswith(('.pt', '.pth', '.bin')): + voice_name = os.path.splitext(f)[0] + voices.append(voice_name) + + return web.json_response({"voices": sorted(voices)}) + + +def main(): + parser = argparse.ArgumentParser(description="Real-time Voice Conversion Server") + parser.add_argument("--host", default="localhost", type=str) + parser.add_argument("--port", default=8999, type=int) + parser.add_argument("--device", type=str, default="cuda", + help="Device on which to run, defaults to 'cuda'.") + parser.add_argument("--voice-embeddings-dir", type=str, + help="Directory containing pre-computed voice embeddings (.pt files)") + parser.add_argument("--ssl", type=str, + help="Directory containing key.pem and cert.pem for HTTPS") + parser.add_argument("--static", type=str, + help="Path to static files to serve") + parser.add_argument("--gradio-tunnel", action='store_true', + help='Activate a gradio tunnel for remote access.') + parser.add_argument("--gradio-tunnel-token", type=str, + help='Custom token for consistent tunnel URL.') + + args = parser.parse_args() + args.device = torch_auto_device(args.device) + + # Setup tunnel if requested + setup_tunnel = None + tunnel_token = '' + if args.gradio_tunnel: + try: + from gradio import networking + except ImportError: + logger.error("Cannot find gradio. Install with: pip install gradio") + sys.exit(1) + setup_tunnel = networking.setup_tunnel + tunnel_token = args.gradio_tunnel_token or secrets.token_urlsafe(32) + + # Create voice converter + logger.info("initializing voice converter...") + voice_converter = create_voice_converter( + sample_rate=24000, + device=str(args.device), + pretrained=False, # No pretrained weights yet + ) + logger.info("voice converter initialized") + + # Create server state + state = VCServerState( + voice_converter=voice_converter, + device=args.device, + voice_embeddings_dir=args.voice_embeddings_dir, + ) + + # Warm up + state.warmup() + + # Create app + app = web.Application() + + # Voice conversion WebSocket endpoint + app.router.add_get("/api/vc", state.handle_voice_convert) + + # HTTP endpoints for voice management + app.router.add_post("/api/vc/set_voice", state.handle_set_voice) + app.router.add_get("/api/vc/voices", state.handle_list_voices) + + # Static files + if args.static and os.path.exists(args.static): + async def handle_root(_): + return web.FileResponse(os.path.join(args.static, "index.html")) + + app.router.add_get("/", handle_root) + app.router.add_static("/", path=args.static, follow_symlinks=True, name="static") + logger.info(f"serving static files from {args.static}") + + # SSL setup + protocol = "http" + ssl_context = None + if args.ssl: + ssl_context, protocol = create_ssl_context(args.ssl) + + # Log access info + host_ip = args.host if args.host not in ("0.0.0.0", "::", "localhost") else get_lan_ip() + logger.info(f"Voice Conversion Server running at {protocol}://{host_ip}:{args.port}") + logger.info(f"WebSocket endpoint: {protocol}://{host_ip}:{args.port}/api/vc") + + if setup_tunnel: + tunnel = setup_tunnel('localhost', args.port, tunnel_token, None) + logger.info(f"Tunnel available at: {tunnel}") + + # Run server + web.run_app(app, host=args.host, port=args.port, ssl_context=ssl_context) + + +if __name__ == "__main__": + with torch.no_grad(): + main()