diff --git a/app/src/App.tsx b/app/src/App.tsx index fbe29118..31b4309a 100644 --- a/app/src/App.tsx +++ b/app/src/App.tsx @@ -82,7 +82,6 @@ function App() { console.log('Dev mode: Skipping auto-start of server (run it separately)'); setServerReady(true); // Mark as ready so UI doesn't show loading screen // Mark that server was not started by app (so we don't try to stop it on close) - // @ts-expect-error - adding property to window window.__voiceboxServerStartedByApp = false; return; } @@ -103,13 +102,11 @@ function App() { useServerStore.getState().setServerUrl(serverUrl); setServerReady(true); // Mark that we started the server (so we know to stop it on close) - // @ts-expect-error - adding property to window window.__voiceboxServerStartedByApp = true; }) .catch((error) => { console.error('Failed to auto-start server:', error); serverStartingRef.current = false; - // @ts-expect-error - adding property to window window.__voiceboxServerStartedByApp = false; }); diff --git a/app/src/components/Generation/FloatingGenerateBox.tsx b/app/src/components/Generation/FloatingGenerateBox.tsx index a8d556a6..b3436638 100644 --- a/app/src/components/Generation/FloatingGenerateBox.tsx +++ b/app/src/components/Generation/FloatingGenerateBox.tsx @@ -13,7 +13,7 @@ import { } from '@/components/ui/select'; import { Textarea } from '@/components/ui/textarea'; import { useToast } from '@/components/ui/use-toast'; -import { LANGUAGE_OPTIONS } from '@/lib/constants/languages'; +import { LANGUAGE_OPTIONS, type LanguageCode } from '@/lib/constants/languages'; import { useGenerationForm } from '@/lib/hooks/useGenerationForm'; import { useProfile, useProfiles } from '@/lib/hooks/useProfiles'; import { useAddStoryItem, useStory } from '@/lib/hooks/useStories'; @@ -112,6 +112,13 @@ export function FloatingGenerateBox({ } }, [selectedProfileId, profiles, setSelectedProfileId]); + // Sync generation form language with selected profile's language + useEffect(() => { + if (selectedProfile?.language) { + form.setValue('language', selectedProfile.language as LanguageCode); + } + }, [selectedProfile, form]); + // Auto-resize textarea based on content (only when expanded) useEffect(() => { if (!isExpanded) { diff --git a/app/src/components/StoriesTab/StoryTrackEditor.tsx b/app/src/components/StoriesTab/StoryTrackEditor.tsx index 74dbde25..bc2e0401 100644 --- a/app/src/components/StoriesTab/StoryTrackEditor.tsx +++ b/app/src/components/StoriesTab/StoryTrackEditor.tsx @@ -313,7 +313,7 @@ export function StoryTrackEditor({ storyId, items }: StoryTrackEditorProps) { } }, [isResizing, handleResizeMove, handleResizeEnd]); - const handleTimelineClick = (e: React.MouseEvent) => { + const handleTimelineClick = (e: React.MouseEvent) => { if (!tracksRef.current || draggingItem || trimmingItem) return; const rect = tracksRef.current.getBoundingClientRect(); const x = e.clientX - rect.left + tracksRef.current.scrollLeft; diff --git a/app/src/components/VoiceProfiles/AudioSampleRecording.tsx b/app/src/components/VoiceProfiles/AudioSampleRecording.tsx index 4f2db4e3..3807306f 100644 --- a/app/src/components/VoiceProfiles/AudioSampleRecording.tsx +++ b/app/src/components/VoiceProfiles/AudioSampleRecording.tsx @@ -58,6 +58,7 @@ export function AudioSampleRecording({ // Request microphone access when component mounts useEffect(() => { if (!showWaveform) return; + if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) return; let stream: MediaStream | null = null; diff --git a/app/src/hooks/useAutoUpdater.tsx b/app/src/hooks/useAutoUpdater.tsx index 8a6351f6..27683e97 100644 --- a/app/src/hooks/useAutoUpdater.tsx +++ b/app/src/hooks/useAutoUpdater.tsx @@ -73,7 +73,7 @@ export function useAutoUpdater(options: boolean | UseAutoUpdaterOptions = false) } // Empty dependency array - only run once on mount // eslint-disable-next-line react-hooks/exhaustive-deps - }, [platform.metadata.isTauricheckOnMountcheckForUpdates]); + }, [platform.metadata.isTauri, checkOnMount, checkForUpdates]); // Show toast when update is available useEffect(() => { diff --git a/backend/backends/mlx_backend.py b/backend/backends/mlx_backend.py index c4ecc090..8651538a 100644 --- a/backend/backends/mlx_backend.py +++ b/backend/backends/mlx_backend.py @@ -14,6 +14,12 @@ from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback from ..utils.tasks import get_task_manager +LANGUAGE_CODE_TO_NAME = { + "zh": "chinese", "en": "english", "ja": "japanese", "ko": "korean", + "de": "german", "fr": "french", "ru": "russian", "pt": "portuguese", + "es": "spanish", "it": "italian", +} + class MLXTTSBackend: """MLX-based TTS backend using mlx-audio.""" @@ -316,7 +322,8 @@ def _generate_sync(): # MLX generate() returns a generator yielding GenerationResult objects audio_chunks = [] sample_rate = 24000 - + lang = LANGUAGE_CODE_TO_NAME.get(language, "auto") + # Set seed if provided (MLX uses numpy random) if seed is not None: import mlx.core as mx @@ -344,23 +351,23 @@ def _generate_sync(): sig = inspect.signature(self.model.generate) if "ref_audio" in sig.parameters: # Generate with voice cloning - for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text): + for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang): audio_chunks.append(np.array(result.audio)) sample_rate = result.sample_rate else: # Fallback: generate without voice cloning - for result in self.model.generate(text): + for result in self.model.generate(text, lang_code=lang): audio_chunks.append(np.array(result.audio)) sample_rate = result.sample_rate else: # No voice prompt, generate normally - for result in self.model.generate(text): + for result in self.model.generate(text, lang_code=lang): audio_chunks.append(np.array(result.audio)) sample_rate = result.sample_rate except Exception as e: # If voice cloning fails, try without it print(f"Warning: Voice cloning failed, generating without voice prompt: {e}") - for result in self.model.generate(text): + for result in self.model.generate(text, lang_code=lang): audio_chunks.append(np.array(result.audio)) sample_rate = result.sample_rate diff --git a/backend/backends/pytorch_backend.py b/backend/backends/pytorch_backend.py index 26f38726..1adeb22d 100644 --- a/backend/backends/pytorch_backend.py +++ b/backend/backends/pytorch_backend.py @@ -15,6 +15,12 @@ from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback from ..utils.tasks import get_task_manager +LANGUAGE_CODE_TO_NAME = { + "zh": "chinese", "en": "english", "ja": "japanese", "ko": "korean", + "de": "german", "fr": "french", "ru": "russian", "pt": "portuguese", + "es": "spanish", "it": "italian", +} + class PyTorchTTSBackend: """PyTorch-based TTS backend using Qwen3-TTS.""" @@ -335,6 +341,7 @@ def _generate_sync(): wavs, sample_rate = self.model.generate_voice_clone( text=text, voice_clone_prompt=voice_prompt, + language=LANGUAGE_CODE_TO_NAME.get(language, "auto"), instruct=instruct, ) return wavs[0], sample_rate diff --git a/backend/main.py b/backend/main.py index 59fb9e18..f23ae2a6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -36,10 +36,23 @@ version=__version__, ) -# CORS middleware +# CORS middleware - restrict to known local origins by default. +# Set VOICEBOX_CORS_ORIGINS env var to a comma-separated list of origins +# to allow additional origins (e.g. for remote server mode). +_default_origins = [ + "http://localhost:5173", # Vite dev server + "http://127.0.0.1:5173", + "http://localhost:17493", + "http://127.0.0.1:17493", + "tauri://localhost", # Tauri webview (macOS) + "https://tauri.localhost", # Tauri webview (Windows/Linux) +] +_env_origins = os.environ.get("VOICEBOX_CORS_ORIGINS", "") +_cors_origins = _default_origins + [o.strip() for o in _env_origins.split(",") if o.strip()] + app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Configure appropriately for production + allow_origins=_cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -542,12 +555,6 @@ async def generate_speech( if not profile: raise HTTPException(status_code=404, detail="Profile not found") - # Create voice prompt from profile - voice_prompt = await profiles.create_voice_prompt_for_profile( - data.profile_id, - db, - ) - # Generate audio tts_model = tts.get_tts_model() # Load the requested model size if different from current (async to not block) @@ -582,7 +589,15 @@ async def download_model_background(): } ) + # Load the requested model BEFORE creating voice prompt, + # so create_voice_prompt uses the correct model size await tts_model.load_model_async(model_size) + + # Create voice prompt from profile + voice_prompt = await profiles.create_voice_prompt_for_profile( + data.profile_id, + db, + ) audio, sample_rate = await tts_model.generate( data.text, voice_prompt, diff --git a/backend/tests/test_cors.py b/backend/tests/test_cors.py new file mode 100644 index 00000000..5762c74a --- /dev/null +++ b/backend/tests/test_cors.py @@ -0,0 +1,161 @@ +""" +Tests for CORS origin restrictions. + +Validates that the CORS middleware only allows known local origins +and respects the VOICEBOX_CORS_ORIGINS environment variable. + +Uses a minimal FastAPI app that mirrors the exact CORS configuration +from backend/main.py, so tests run without heavy ML dependencies. + +Usage: + pip install httpx pytest fastapi starlette + python -m pytest backend/tests/test_cors.py -v +""" + +import os +import pytest +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from starlette.testclient import TestClient + + +def _build_app(env_origins: str = "") -> FastAPI: + """ + Build a minimal FastAPI app with the same CORS logic as backend/main.py. + + This mirrors the exact code in main.py so the test validates the real + configuration without needing torch/numpy/transformers installed. + """ + app = FastAPI() + + _default_origins = [ + "http://localhost:5173", + "http://127.0.0.1:5173", + "http://localhost:17493", + "http://127.0.0.1:17493", + "tauri://localhost", + "https://tauri.localhost", + ] + _cors_origins = _default_origins + [o.strip() for o in env_origins.split(",") if o.strip()] + + app.add_middleware( + CORSMiddleware, + allow_origins=_cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.get("/health") + async def health(): + return {"status": "ok"} + + return app + + +@pytest.fixture() +def client(): + return TestClient(_build_app()) + + +@pytest.fixture() +def client_with_custom_origins(): + return TestClient(_build_app("https://custom.example.com,https://other.example.com")) + + +def _get_with_origin(client: TestClient, origin: str) -> dict: + """Send a GET with Origin header, return response headers.""" + response = client.get("/health", headers={"Origin": origin}) + return dict(response.headers) + + +def _preflight(client: TestClient, origin: str) -> dict: + """Send CORS preflight OPTIONS request, return response headers.""" + response = client.options( + "/health", + headers={ + "Origin": origin, + "Access-Control-Request-Method": "GET", + }, + ) + return dict(response.headers) + + +class TestCORSDefaultOrigins: + """CORS should allow known local origins and block everything else.""" + + @pytest.mark.parametrize("origin", [ + "http://localhost:5173", + "http://127.0.0.1:5173", + "http://localhost:17493", + "http://127.0.0.1:17493", + "tauri://localhost", + "https://tauri.localhost", + ]) + def test_allowed_origins(self, client, origin): + headers = _get_with_origin(client, origin) + assert headers.get("access-control-allow-origin") == origin + + @pytest.mark.parametrize("origin", [ + "http://evil.com", + "http://localhost:9999", + "https://attacker.example.com", + "null", + ]) + def test_blocked_origins(self, client, origin): + headers = _get_with_origin(client, origin) + assert "access-control-allow-origin" not in headers + + def test_preflight_allowed(self, client): + headers = _preflight(client, "http://localhost:5173") + assert headers.get("access-control-allow-origin") == "http://localhost:5173" + + def test_preflight_blocked(self, client): + headers = _preflight(client, "http://evil.com") + assert "access-control-allow-origin" not in headers + + def test_credentials_header_present(self, client): + headers = _get_with_origin(client, "http://localhost:5173") + assert headers.get("access-control-allow-credentials") == "true" + + +class TestCORSCustomOrigins: + """VOICEBOX_CORS_ORIGINS env var should extend the allowlist.""" + + def test_custom_origin_allowed(self, client_with_custom_origins): + headers = _get_with_origin(client_with_custom_origins, "https://custom.example.com") + assert headers.get("access-control-allow-origin") == "https://custom.example.com" + + def test_other_custom_origin_allowed(self, client_with_custom_origins): + headers = _get_with_origin(client_with_custom_origins, "https://other.example.com") + assert headers.get("access-control-allow-origin") == "https://other.example.com" + + def test_default_origins_still_work(self, client_with_custom_origins): + headers = _get_with_origin(client_with_custom_origins, "http://localhost:5173") + assert headers.get("access-control-allow-origin") == "http://localhost:5173" + + def test_unlisted_origin_still_blocked(self, client_with_custom_origins): + headers = _get_with_origin(client_with_custom_origins, "http://evil.com") + assert "access-control-allow-origin" not in headers + + +class TestCORSEnvVarParsing: + """Edge cases for VOICEBOX_CORS_ORIGINS parsing.""" + + def test_empty_env_var(self): + app = _build_app("") + client = TestClient(app) + headers = _get_with_origin(client, "http://evil.com") + assert "access-control-allow-origin" not in headers + + def test_whitespace_trimmed(self): + app = _build_app(" https://spaced.example.com ") + client = TestClient(app) + headers = _get_with_origin(client, "https://spaced.example.com") + assert headers.get("access-control-allow-origin") == "https://spaced.example.com" + + def test_trailing_comma_ignored(self): + app = _build_app("https://one.example.com,") + client = TestClient(app) + headers = _get_with_origin(client, "https://one.example.com") + assert headers.get("access-control-allow-origin") == "https://one.example.com"