diff --git a/Makefile b/Makefile index 2b3705974a..35345b8c1f 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,6 @@ test_xtts: test_aux: ## run aux tests. coverage run -m pytest -x -v --durations=0 tests/aux_tests - ./run_bash_tests.sh test_zoo: ## run zoo tests. coverage run -m pytest -x -v --durations=0 tests/zoo_tests/test_models.py diff --git a/TTS/api.py b/TTS/api.py index 7720530823..fdc26502ff 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -273,11 +273,11 @@ def _check_arguments( def tts( self, text: str, - speaker: str = None, - language: str = None, - speaker_wav: str = None, - emotion: str = None, - speed: float = None, + speaker: Optional[str] = None, + language: Optional[str] = None, + speaker_wav: Optional[str] = None, + emotion: Optional[str] = None, + speed: Optional[float] = None, split_sentences: bool = True, **kwargs, ): @@ -322,10 +322,10 @@ def tts( def tts_to_file( self, text: str, - speaker: str = None, - language: str = None, - speaker_wav: str = None, - emotion: str = None, + speaker: Optional[str] = None, + language: Optional[str] = None, + speaker_wav: Optional[str] = None, + emotion: Optional[str] = None, speed: float = 1.0, pipe_out=None, file_path: str = "output.wav", @@ -418,9 +418,9 @@ def voice_conversion_to_file( def tts_with_vc( self, text: str, - language: str = None, - speaker_wav: str = None, - speaker: str = None, + language: Optional[str] = None, + speaker_wav: Optional[str] = None, + speaker: Optional[str] = None, split_sentences: bool = True, ): """Convert text to speech with voice conversion. @@ -460,10 +460,10 @@ def tts_with_vc( def tts_with_vc_to_file( self, text: str, - language: str = None, - speaker_wav: str = None, + language: Optional[str] = None, + speaker_wav: Optional[str] = None, file_path: str = "output.wav", - speaker: str = None, + speaker: Optional[str] = None, split_sentences: bool = True, pipe_out=None, ) -> str: diff --git a/TTS/server/server.py b/TTS/server/server.py index 6a4642f9a2..cb4ed4d9b2 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -8,7 +8,6 @@ import logging import os import sys -from pathlib import Path from threading import Lock from typing import Union from urllib.parse import parse_qs @@ -19,10 +18,9 @@ msg = "Server requires requires flask, use `pip install coqui-tts[server]`" raise ImportError(msg) from e -from TTS.config import load_config +from TTS.api import TTS from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.utils.manage import ModelManager -from TTS.utils.synthesizer import Synthesizer logger = logging.getLogger(__name__) setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) @@ -60,6 +58,7 @@ def create_argparser() -> argparse.ArgumentParser: parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) parser.add_argument("--port", type=int, default=5002, help="port to listen on.") + parser.add_argument("--device", type=str, help="Device to run model on.", default="cpu") parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, default=False, help="true to use CUDA.") parser.add_argument( "--debug", action=argparse.BooleanOptionalAction, default=False, help="true to enable Flask debug mode." @@ -73,8 +72,7 @@ def create_argparser() -> argparse.ArgumentParser: # parse the args args = create_argparser().parse_args() -path = Path(__file__).parent / "../.models.json" -manager = ModelManager(path) +manager = ModelManager(models_file=TTS.get_models_file_path()) # update in-use models to the specified released models. model_path = None @@ -86,51 +84,27 @@ def create_argparser() -> argparse.ArgumentParser: # CASE1: list pre-trained TTS models if args.list_models: manager.list_models() - sys.exit() - -# CASE2: load pre-trained model paths -if args.model_name is not None and not args.model_path: - model_path, config_path, model_item = manager.download_model(args.model_name) - args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name - -if args.vocoder_name is not None and not args.vocoder_path: - vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) - -# CASE3: set custom model paths -if args.model_path is not None: - model_path = args.model_path - config_path = args.config_path - speakers_file_path = args.speakers_file_path - -if args.vocoder_path is not None: - vocoder_path = args.vocoder_path - vocoder_config_path = args.vocoder_config_path - -# load models -synthesizer = Synthesizer( - tts_checkpoint=model_path, - tts_config_path=config_path, - tts_speakers_file=speakers_file_path, - tts_languages_file=None, - vocoder_checkpoint=vocoder_path, - vocoder_config=vocoder_config_path, - encoder_checkpoint="", - encoder_config="", - use_cuda=args.use_cuda, -) - -use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and ( - synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None -) -speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None) - -use_multi_language = hasattr(synthesizer.tts_model, "num_languages") and ( - synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None -) -language_manager = getattr(synthesizer.tts_model, "language_manager", None) + sys.exit(0) + +device = args.device +if args.use_cuda: + device = "cuda" + +# CASE2: load models +model_name = args.model_name if args.model_path is None else None +api = TTS( + model_name=model_name, + model_path=args.model_path, + config_path=args.config_path, + vocoder_name=args.vocoder_name, + vocoder_path=args.vocoder_path, + vocoder_config_path=args.vocoder_config_path, + speakers_file_path=args.speakers_file_path, + # language_ids_file_path=args.language_ids_file_path, +).to(device) # TODO: set this from SpeakerManager -use_gst = synthesizer.tts_config.get("use_gst", False) +use_gst = api.synthesizer.tts_config.get("use_gst", False) app = Flask(__name__) @@ -158,27 +132,18 @@ def index(): return render_template( "index.html", show_details=args.show_details, - use_multi_speaker=use_multi_speaker, - use_multi_language=use_multi_language, - speaker_ids=speaker_manager.name_to_id if speaker_manager is not None else None, - language_ids=language_manager.name_to_id if language_manager is not None else None, + use_multi_speaker=api.is_multi_speaker, + use_multi_language=api.is_multi_lingual, + speaker_ids=api.speakers, + language_ids=api.languages, use_gst=use_gst, ) @app.route("/details") def details(): - if args.config_path is not None and os.path.isfile(args.config_path): - model_config = load_config(args.config_path) - elif args.model_name is not None: - model_config = load_config(config_path) - - if args.vocoder_config_path is not None and os.path.isfile(args.vocoder_config_path): - vocoder_config = load_config(args.vocoder_config_path) - elif args.vocoder_name is not None: - vocoder_config = load_config(vocoder_config_path) - else: - vocoder_config = None + model_config = api.synthesizer.tts_config + vocoder_config = api.synthesizer.vocoder_config or None return render_template( "details.html", @@ -196,17 +161,23 @@ def details(): def tts(): with lock: text = request.headers.get("text") or request.values.get("text", "") - speaker_idx = request.headers.get("speaker-id") or request.values.get("speaker_id", "") - language_idx = request.headers.get("language-id") or request.values.get("language_id", "") + speaker_idx = ( + request.headers.get("speaker-id") or request.values.get("speaker_id", "") if api.is_multi_speaker else None + ) + language_idx = ( + request.headers.get("language-id") or request.values.get("language_id", "") + if api.is_multi_lingual + else None + ) style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "") style_wav = style_wav_uri_to_dict(style_wav) logger.info("Model input: %s", text) logger.info("Speaker idx: %s", speaker_idx) logger.info("Language idx: %s", language_idx) - wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav) + wavs = api.tts(text, speaker=speaker_idx, language=language_idx, style_wav=style_wav) out = io.BytesIO() - synthesizer.save_wav(wavs, out) + api.synthesizer.save_wav(wavs, out) return send_file(out, mimetype="audio/wav") @@ -248,9 +219,9 @@ def mary_tts_api_process(): else: text = request.args.get("INPUT_TEXT", "") logger.info("Model input: %s", text) - wavs = synthesizer.tts(text) + wavs = api.tts(text) out = io.BytesIO() - synthesizer.save_wav(wavs, out) + api.synthesizer.save_wav(wavs, out) return send_file(out, mimetype="audio/wav") diff --git a/docs/source/server.md b/docs/source/server.md index 3fa211d0d7..69bdace27b 100644 --- a/docs/source/server.md +++ b/docs/source/server.md @@ -4,8 +4,7 @@ You can boot up a demo 🐸TTS server to run an inference with your models (make sure to install the additional dependencies with `pip install coqui-tts[server]`). -Note that the server is not optimized for performance and does not support all -Coqui models yet. +Note that the server is not optimized for performance. The demo server provides pretty much the same interface as the CLI command. @@ -15,7 +14,8 @@ tts-server --list_models # list the available models. ``` Run a TTS model, from the release models list, with its default vocoder. -If the model you choose is a multi-speaker TTS model, you can select different speakers on the Web interface and synthesize +If the model you choose is a multi-speaker or multilingual TTS model, you can +select different speakers and languages on the Web interface and synthesize speech. ```bash diff --git a/pyproject.toml b/pyproject.toml index ba28618d0a..20f15cde9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,7 +173,6 @@ exclude = [ "/.readthedocs.yml", "/Makefile", "/dockerfiles", - "/run_bash_tests.sh", "/scripts", "/tests", ] diff --git a/run_bash_tests.sh b/run_bash_tests.sh deleted file mode 100755 index 5f6cd43f68..0000000000 --- a/run_bash_tests.sh +++ /dev/null @@ -1,6 +0,0 @@ -set -e -TF_CPP_MIN_LOG_LEVEL=3 - -# runtime bash based tests -# TODO: move these to python -./tests/bash_tests/test_demo_server.sh diff --git a/tests/aux_tests/test_server.py b/tests/aux_tests/test_server.py new file mode 100644 index 0000000000..1b691f9596 --- /dev/null +++ b/tests/aux_tests/test_server.py @@ -0,0 +1,47 @@ +import os +import signal +import socket +import subprocess +import time +import wave + +import pytest +import requests + +PORT = 5003 + + +def wait_for_server(host, port, timeout=30): + start_time = time.time() + while time.time() - start_time < timeout: + try: + with socket.create_connection((host, port), timeout=2): + return True + except (OSError, ConnectionRefusedError): + time.sleep(1) + raise TimeoutError(f"Server at {host}:{port} did not start within {timeout} seconds.") + + +@pytest.fixture(scope="module", autouse=True) +def start_flask_server(): + server_process = subprocess.Popen( + ["python", "-m", "TTS.server.server", "--port", str(PORT)], + ) + wait_for_server("localhost", PORT) + yield + os.kill(server_process.pid, signal.SIGTERM) + server_process.wait() + + +def test_flask_server(tmp_path): + url = f"http://localhost:{PORT}/api/tts?text=synthesis%20schmynthesis" + response = requests.get(url) + assert response.status_code == 200, f"Request failed with status code {response.status_code}" + + wav_path = tmp_path / "output.wav" + with wav_path.open("wb") as f: + f.write(response.content) + + with wave.open(str(wav_path), "rb") as wav_file: + num_frames = wav_file.getnframes() + assert num_frames > 0, "WAV file contains no frames." diff --git a/tests/bash_tests/test_demo_server.sh b/tests/bash_tests/test_demo_server.sh deleted file mode 100755 index ebd0bc8b89..0000000000 --- a/tests/bash_tests/test_demo_server.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash -set -xe - -python -m TTS.server.server & -SERVER_PID=$! - -echo 'Waiting for server...' -sleep 30 - -curl -o /tmp/audio.wav "http://localhost:5002/api/tts?text=synthesis%20schmynthesis" -python -c 'import sys; import wave; print(wave.open(sys.argv[1]).getnframes())' /tmp/audio.wav - -kill $SERVER_PID - -rm /tmp/audio.wav