diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9e65f520..dcb78521 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -22,10 +22,10 @@ jobs: args: "--target x86_64-apple-darwin" python-version: "3.12" backend: "pytorch" - # - platform: 'ubuntu-22.04' - # args: '' - # python-version: '3.12' - # backend: 'pytorch' + - platform: "ubuntu-22.04" + args: "" + python-version: "3.12" + backend: "pytorch" - platform: "windows-latest" args: "" python-version: "3.12" diff --git a/backend/backends/pytorch_backend.py b/backend/backends/pytorch_backend.py index 8059dc09..729053b4 100644 --- a/backend/backends/pytorch_backend.py +++ b/backend/backends/pytorch_backend.py @@ -374,6 +374,7 @@ def _generate_sync(): "small": "openai/whisper-small", "medium": "openai/whisper-medium", "large": "openai/whisper-large-v3", + "turbo": "openai/whisper-large-v3-turbo", } @@ -591,21 +592,20 @@ def _transcribe_sync(): ) inputs = inputs.to(self.device) - # Set language if provided - forced_decoder_ids = None + # Generate transcription + # If language is provided, force it; otherwise let Whisper auto-detect + generate_kwargs = {} if language: - # Support all languages from frontend: en, zh, ja, ko, de, fr, ru, pt, es, it - # Whisper supports these and many more forced_decoder_ids = self.processor.get_decoder_prompt_ids( language=language, task="transcribe", ) + generate_kwargs["forced_decoder_ids"] = forced_decoder_ids - # Generate transcription with torch.no_grad(): predicted_ids = self.model.generate( inputs["input_features"], - forced_decoder_ids=forced_decoder_ids, + **generate_kwargs, ) # Decode diff --git a/backend/main.py b/backend/main.py index 3d2ec359..04c9190f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,7 +14,6 @@ import asyncio import uvicorn import argparse -import torch import tempfile import io from pathlib import Path @@ -22,6 +21,18 @@ import asyncio import signal import os + +# Set HSA_OVERRIDE_GFX_VERSION for AMD GPUs that aren't officially listed in ROCm +# (e.g., RX 6600 is gfx1032 which maps to gfx1030 target) +# This must be set BEFORE any torch.cuda calls +if not os.environ.get("HSA_OVERRIDE_GFX_VERSION"): + os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0" + +# Suppress noisy MIOpen workspace warnings on AMD GPUs +if not os.environ.get("MIOPEN_LOG_LEVEL"): + os.environ["MIOPEN_LOG_LEVEL"] = "4" + +import torch from urllib.parse import quote @@ -1022,11 +1033,12 @@ async def transcribe_audio( # Transcribe whisper_model = transcribe.get_whisper_model() - # Check if Whisper model is downloaded (uses default size "base") + # Check if Whisper model is downloaded model_size = whisper_model.model_size - # Map model sizes to HF repo IDs (whisper-large needs -v3 suffix) + # Map model sizes to HF repo IDs (some need special suffixes) whisper_hf_repos = { "large": "openai/whisper-large-v3", + "turbo": "openai/whisper-large-v3-turbo", } model_name = whisper_hf_repos.get(model_size, f"openai/whisper-{model_size}") @@ -1490,6 +1502,13 @@ def check_chatterbox_loaded(): "model_size": "large", "check_loaded": lambda: check_whisper_loaded("large"), }, + { + "model_name": "whisper-turbo", + "display_name": "Whisper Turbo", + "hf_repo_id": "openai/whisper-large-v3-turbo", + "model_size": "turbo", + "check_loaded": lambda: check_whisper_loaded("turbo"), + }, ] # Build a mapping of model_name -> hf_repo_id so we can check if shared repos are downloading @@ -1684,6 +1703,10 @@ async def trigger_model_download(request: models.ModelDownloadRequest): "model_size": "large", "load_func": lambda: transcribe.get_whisper_model().load_model("large"), }, + "whisper-turbo": { + "model_size": "turbo", + "load_func": lambda: transcribe.get_whisper_model().load_model("turbo"), + }, } if request.model_name not in model_configs: @@ -1810,6 +1833,11 @@ async def delete_model(model_name: str): "model_size": "large", "model_type": "whisper", }, + "whisper-turbo": { + "hf_repo_id": "openai/whisper-large-v3-turbo", + "model_size": "turbo", + "model_type": "whisper", + }, } if model_name not in model_configs: @@ -2044,7 +2072,12 @@ def _get_gpu_status() -> str: """Get GPU availability status.""" backend_type = get_backend_type() if torch.cuda.is_available(): - return f"CUDA ({torch.cuda.get_device_name(0)})" + device_name = torch.cuda.get_device_name(0) + # Check if this is ROCm (AMD) or CUDA (NVIDIA) + is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None + if is_rocm: + return f"ROCm ({device_name})" + return f"CUDA ({device_name})" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "MPS (Apple Silicon)" elif backend_type == "mlx": diff --git a/tauri/src-tauri/src/audio_capture/linux.rs b/tauri/src-tauri/src/audio_capture/linux.rs index 8af26e97..3cae59e9 100644 --- a/tauri/src-tauri/src/audio_capture/linux.rs +++ b/tauri/src-tauri/src/audio_capture/linux.rs @@ -1,16 +1,312 @@ use crate::audio_capture::AudioCaptureState; +use base64::{engine::general_purpose, Engine as _}; +use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; +use cpal::{SampleFormat, StreamConfig}; +use hound::{WavSpec, WavWriter}; +use std::io::Cursor; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread; +/// Start capturing system audio on Linux using PulseAudio monitor sources. +/// +/// PulseAudio exposes "monitor" devices that mirror the output of each sink, +/// allowing us to capture whatever audio is currently playing on the system. +/// We use `cpal` with the default host (which will be PulseAudio or PipeWire +/// on modern Linux) and look for monitor input devices. pub async fn start_capture( state: &AudioCaptureState, max_duration_secs: u32, ) -> Result<(), String> { - todo!("implement Linux audio capture") + // Reset previous samples + state.reset(); + + let samples = state.samples.clone(); + let sample_rate_arc = state.sample_rate.clone(); + let channels_arc = state.channels.clone(); + let stop_tx = state.stop_tx.clone(); + let error_arc = state.error.clone(); + + // Use AtomicBool for stop signal (works across threads) + let stop_flag = Arc::new(AtomicBool::new(false)); + let stop_flag_clone = stop_flag.clone(); + + // Create tokio channel and spawn a task to bridge it to the AtomicBool + let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(1); + *stop_tx.lock().unwrap() = Some(tx); + + tokio::spawn(async move { + rx.recv().await; + stop_flag_clone.store(true, Ordering::Relaxed); + }); + + // Spawn capture on a dedicated thread + thread::spawn(move || { + let host = cpal::default_host(); + + // Try to find a monitor device for system audio capture. + // On PulseAudio/PipeWire, monitor sources have "monitor" in their name. + let device = { + let mut monitor_device = None; + + if let Ok(devices) = host.input_devices() { + for d in devices { + if let Ok(name) = d.name() { + let name_lower = name.to_lowercase(); + if name_lower.contains("monitor") { + eprintln!("Linux audio capture: Found monitor device: {}", name); + monitor_device = Some(d); + break; + } + } + } + } + + match monitor_device { + Some(d) => d, + None => { + // Fallback to default input device (microphone) + eprintln!("Linux audio capture: No monitor device found, falling back to default input"); + match host.default_input_device() { + Some(d) => d, + None => { + let error_msg = "No audio input device available".to_string(); + eprintln!("{}", error_msg); + *error_arc.lock().unwrap() = Some(error_msg); + return; + } + } + } + } + }; + + let device_name = device.name().unwrap_or_else(|_| "unknown".to_string()); + eprintln!("Linux audio capture: Using device: {}", device_name); + + // Get supported config + let config = match device.default_input_config() { + Ok(c) => c, + Err(e) => { + let error_msg = format!("Failed to get default input config: {}", e); + eprintln!("{}", error_msg); + *error_arc.lock().unwrap() = Some(error_msg); + return; + } + }; + + let sample_rate = config.sample_rate().0; + let channels = config.channels(); + let sample_format = config.sample_format(); + + eprintln!( + "Linux audio capture: Config - {}Hz, {} channels, format: {:?}", + sample_rate, channels, sample_format + ); + + *sample_rate_arc.lock().unwrap() = sample_rate; + *channels_arc.lock().unwrap() = channels; + + let stream_config = StreamConfig { + channels, + sample_rate: cpal::SampleRate(sample_rate), + buffer_size: cpal::BufferSize::Default, + }; + + let samples_clone = samples.clone(); + let error_arc_clone = error_arc.clone(); + let stop_flag_for_stream = stop_flag.clone(); + + let err_fn = { + let error_arc = error_arc.clone(); + move |err: cpal::StreamError| { + let error_msg = format!("Stream error: {}", err); + eprintln!("{}", error_msg); + *error_arc.lock().unwrap() = Some(error_msg); + } + }; + + let stream = match sample_format { + SampleFormat::F32 => { + let samples = samples_clone.clone(); + let stop = stop_flag_for_stream.clone(); + device.build_input_stream( + &stream_config, + move |data: &[f32], _: &cpal::InputCallbackInfo| { + if stop.load(Ordering::Relaxed) { + return; + } + let mut guard = samples.lock().unwrap(); + guard.extend_from_slice(data); + }, + err_fn, + None, + ) + } + SampleFormat::I16 => { + let samples = samples_clone.clone(); + let stop = stop_flag_for_stream.clone(); + device.build_input_stream( + &stream_config, + move |data: &[i16], _: &cpal::InputCallbackInfo| { + if stop.load(Ordering::Relaxed) { + return; + } + let mut guard = samples.lock().unwrap(); + for &s in data { + guard.push(s as f32 / 32768.0); + } + }, + err_fn, + None, + ) + } + SampleFormat::U16 => { + let samples = samples_clone.clone(); + let stop = stop_flag_for_stream.clone(); + device.build_input_stream( + &stream_config, + move |data: &[u16], _: &cpal::InputCallbackInfo| { + if stop.load(Ordering::Relaxed) { + return; + } + let mut guard = samples.lock().unwrap(); + for &s in data { + guard.push((s as f32 / 32768.0) - 1.0); + } + }, + err_fn, + None, + ) + } + _ => { + let error_msg = format!("Unsupported sample format: {:?}", sample_format); + eprintln!("{}", error_msg); + *error_arc_clone.lock().unwrap() = Some(error_msg); + return; + } + }; + + let stream = match stream { + Ok(s) => s, + Err(e) => { + let error_msg = format!("Failed to build input stream: {}", e); + eprintln!("{}", error_msg); + *error_arc_clone.lock().unwrap() = Some(error_msg); + return; + } + }; + + if let Err(e) = stream.play() { + let error_msg = format!("Failed to start stream: {}", e); + eprintln!("{}", error_msg); + *error_arc_clone.lock().unwrap() = Some(error_msg); + return; + } + + eprintln!("Linux audio capture: Stream started successfully"); + + // Keep thread alive until stop signal + loop { + if stop_flag.load(Ordering::Relaxed) { + break; + } + std::thread::sleep(std::time::Duration::from_millis(100)); + } + + // Stream will be dropped here, stopping capture + eprintln!("Linux audio capture: Stream stopped"); + }); + + // Spawn timeout task + let stop_tx_clone = state.stop_tx.clone(); + tokio::spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_secs(max_duration_secs as u64)).await; + let tx = stop_tx_clone.lock().unwrap().take(); + if let Some(tx) = tx { + let _ = tx.send(()).await; + } + }); + + Ok(()) } pub async fn stop_capture(state: &AudioCaptureState) -> Result { - todo!("implement Linux audio capture stop") + // Signal stop + if let Some(tx) = state.stop_tx.lock().unwrap().take() { + let _ = tx.send(()); + } + + // Wait a bit for capture to stop + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + // Check if there was an error during capture + if let Some(error) = state.error.lock().unwrap().as_ref() { + return Err(error.clone()); + } + + // Get samples + let samples = state.samples.lock().unwrap().clone(); + let sample_rate = *state.sample_rate.lock().unwrap(); + let channels = *state.channels.lock().unwrap(); + + if samples.is_empty() { + return Err( + "No audio samples captured. Make sure audio is playing on your system during recording." + .to_string(), + ); + } + + // Convert to WAV + let wav_data = samples_to_wav(&samples, sample_rate, channels)?; + + // Encode to base64 + let base64_data = general_purpose::STANDARD.encode(&wav_data); + + Ok(base64_data) } pub fn is_supported() -> bool { - false + // Check if we can find a monitor device for system audio capture + let host = cpal::default_host(); + if let Ok(devices) = host.input_devices() { + for d in devices { + if let Ok(name) = d.name() { + if name.to_lowercase().contains("monitor") { + return true; + } + } + } + } + // Even without a monitor, basic input capture is available + host.default_input_device().is_some() +} + +fn samples_to_wav(samples: &[f32], sample_rate: u32, channels: u16) -> Result, String> { + let mut buffer = Vec::new(); + let cursor = Cursor::new(&mut buffer); + + let spec = WavSpec { + channels, + sample_rate, + bits_per_sample: 16, + sample_format: hound::SampleFormat::Int, + }; + + let mut writer = + WavWriter::new(cursor, spec).map_err(|e| format!("Failed to create WAV writer: {}", e))?; + + // Convert f32 samples to i16 + for sample in samples { + let clamped = sample.clamp(-1.0, 1.0); + let i16_sample = (clamped * 32767.0) as i16; + writer + .write_sample(i16_sample) + .map_err(|e| format!("Failed to write sample: {}", e))?; + } + + writer + .finalize() + .map_err(|e| format!("Failed to finalize WAV: {}", e))?; + + Ok(buffer) } diff --git a/tauri/src-tauri/src/main.rs b/tauri/src-tauri/src/main.rs index 84a2f01a..157fee9e 100644 --- a/tauri/src-tauri/src/main.rs +++ b/tauri/src-tauri/src/main.rs @@ -792,7 +792,9 @@ pub fn run() { }); // Wait for frontend response or timeout - tokio::spawn(async move { + // Use tauri::async_runtime::spawn instead of tokio::spawn to avoid + // panics when the Tokio runtime is being dropped during app shutdown + tauri::async_runtime::spawn(async move { tokio::select! { _ = rx.recv() => { // Frontend responded, close window