diff --git a/examples/foundational/07af-interruptible-hathora.py b/examples/foundational/07af-interruptible-hathora.py new file mode 100644 index 0000000000..5a06b7bf0a --- /dev/null +++ b/examples/foundational/07af-interruptible-hathora.py @@ -0,0 +1,137 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import os + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams +from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3 +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.audio.vad.vad_analyzer import VADParams +from pipecat.frames.frames import LLMRunFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair +from pipecat.runner.types import RunnerArguments +from pipecat.runner.utils import create_transport +from pipecat.services.hathora.stt import ParakeetSTTService +from pipecat.services.hathora.tts import ChatterboxTTSService, KokoroTTSService +from pipecat.services.openai.llm import OpenAILLMService +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.daily.transport import DailyParams +from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams + +load_dotenv(override=True) + +# We store functions so objects (e.g. SileroVADAnalyzer) don't get +# instantiated. The function will be called when the desired transport gets +# selected. +transport_params = { + "daily": lambda: DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(), + ), + "webrtc": lambda: TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(), + ), +} + + +async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): + logger.info(f"Starting bot") + + # See https://models.hathora.dev/model/nvidia-parakeet-tdt-0.6b-v3 + stt = ParakeetTDTSTTService( + base_url="https://app-1c7bebb9-6977-4101-9619-833b251b86d1.app.hathora.dev/v1/transcribe", + api_key=os.getenv("HATHORA_API_KEY") + ) + + # See https://models.hathora.dev/model/hexgrad-kokoro-82m + tts = KokoroTTSService( + base_url="https://app-01312daf-6e53-4b9d-a4ad-13039f35adc4.app.hathora.dev/synthesize", + api_key=os.getenv("HATHORA_API_KEY"), + ) + + # See https://models.hathora.dev/model/resemble-ai-chatterbox + # tts = ChatterboxTTSService( + # base_url="https://app-efbc8fe2-df55-4f96-bbe3-74f6ea9d986b.app.hathora.dev/v1/generate", + # api_key=os.getenv("HATHORA_API_KEY") + # ) + + # See https://models.hathora.dev/model/qwen3-30b-a3b + llm = OpenAILLMService( + base_url="https://app-362f7ca1-6975-4e18-a605-ab202bf2c315.app.hathora.dev/v1", + api_key=os.getenv("HATHORA_API_KEY"), + model=None, + ) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = LLMContext(messages) + context_aggregator = LLMContextAggregatorPair(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([LLMRunFrame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + await task.cancel() + + runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) + + await runner.run(task) + + +async def bot(runner_args: RunnerArguments): + """Main bot entry point compatible with Pipecat Cloud.""" + transport = await create_transport(runner_args, transport_params) + await run_bot(transport, runner_args) + + +if __name__ == "__main__": + from pipecat.runner.run import main + + main() diff --git a/src/pipecat/services/hathora/__init__.py b/src/pipecat/services/hathora/__init__.py new file mode 100644 index 0000000000..11b7dbb555 --- /dev/null +++ b/src/pipecat/services/hathora/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .stt import * +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "hathora", "hathora.[stt,tts]") diff --git a/src/pipecat/services/hathora/stt.py b/src/pipecat/services/hathora/stt.py new file mode 100644 index 0000000000..ef4b209afa --- /dev/null +++ b/src/pipecat/services/hathora/stt.py @@ -0,0 +1,107 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""[Hathora-hosted](https://models.hathora.dev) speech-to-text services.""" + +import os +from typing import Optional + +import aiohttp +from loguru import logger + +from pipecat.frames.frames import ( + ErrorFrame, + TranscriptionFrame, +) +from pipecat.services.stt_service import SegmentedSTTService +from pipecat.transcriptions.language import Language +from pipecat.utils.time import time_now_iso8601 + +class ParakeetTDTSTTService(SegmentedSTTService): + """Parakeet TDT is a multilingual automatic speech recognition model + with word-level timestamps. + + This service uses the Hathora-hosted Parakeet model via the HTTP API. + + [Documentation](https://models.hathora.dev/model/nvidia-parakeet-tdt-0.6b-v3) + """ + + def __init__( + self, + *, + base_url = None, + api_key = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + **kwargs, + ): + """Initialize the Hathora-hosted Parakeet STT service. + + Args: + base_url: Base URL for the Hathora Parakeet STT API. + api_key: API key for authentication with the Hathora service; + provisiion one [here](https://models.hathora.dev/tokens). + start_time: Start time in seconds for the time window. + end_time: End time in seconds for the time window. + """ + super().__init__( + **kwargs, + ) + self._base_url = base_url + self._api_key = api_key + self._start_time = start_time + self._end_time = end_time + + def can_generate_metrics(self) -> bool: + return True + + async def run_stt(self, audio: bytes): + try: + await self.start_processing_metrics() + await self.start_ttfb_metrics() + + url = f"{self._base_url}" + + url_query_params = [] + if self._start_time is not None: + url_query_params.append(f"start_time={self._start_time}") + if self._end_time is not None: + url_query_params.append(f"end_time={self._end_time}") + url_query_params.append(f"sample_rate={self.sample_rate}") + + if len(url_query_params) > 0: + url += "?" + "&".join(url_query_params) + + api_key = self._api_key or os.getenv("HATHORA_API_KEY") + + form_data = aiohttp.FormData() + form_data.add_field("file", audio, filename="audio.wav", content_type="application/octet-stream") + + async with aiohttp.ClientSession() as session: + async with session.post( + url, + headers={"Authorization": f"Bearer {api_key}"}, + data=form_data, + ) as resp: + response = await resp.json() + + if response and "text" in response: + text = response["text"].strip() + if text: # Only yield non-empty text + await self.stop_ttfb_metrics() + await self.stop_processing_metrics() + logger.debug(f"Transcription: [{text}]") + yield TranscriptionFrame( + text, + self._user_id, + time_now_iso8601(), + Language("en"), # TODO: the parakeet hathora API doesn't accept a language but says it's multilingual + result=response, + ) + + except Exception as e: + logger.error(f"Hathora error: {e}") + yield ErrorFrame(f"Hathora error: {str(e)}") diff --git a/src/pipecat/services/hathora/tts.py b/src/pipecat/services/hathora/tts.py new file mode 100644 index 0000000000..d9ec78793e --- /dev/null +++ b/src/pipecat/services/hathora/tts.py @@ -0,0 +1,229 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""[Hathora-hosted](https://models.hathora.dev) text-to-speech services.""" + +import io +import os +import wave +from typing import Optional, Tuple + +import aiohttp +from loguru import logger + +from pipecat.frames.frames import ( + ErrorFrame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.tts_service import TTSService + +def _decode_audio_payload( + audio_bytes: bytes, + *, + fallback_sample_rate: int, + fallback_channels: int = 1, +) -> Tuple[bytes, int, int]: + """Convert a WAV/PCM payload into raw PCM samples for TTSAudioRawFrame.""" + + try: + with wave.open(io.BytesIO(audio_bytes), "rb") as wav_reader: + channels = wav_reader.getnchannels() + sample_rate = wav_reader.getframerate() + frames = wav_reader.readframes(wav_reader.getnframes()) + return frames, sample_rate, channels + except (wave.Error, EOFError): + # If the payload is already raw PCM, just pass it through. + return audio_bytes, fallback_sample_rate, fallback_channels + +class KokoroTTSService(TTSService): + """Kokoro is an open-weight TTS model with 82 million parameters. + + This service uses the Hathora-hosted Kokoro model via the HTTP API. + + [Documentation](https://models.hathora.dev/model/hexgrad-kokoro-82m) + """ + + def __init__( + self, + *, + base_url = None, + api_key = None, + voice: Optional[str] = None, + speed: Optional[float] = None, + **kwargs, + ): + """Initialize the Hathora-hosted Kokoro TTS service. + + Args: + base_url: Base URL for the Hathora Kokoro TTS API, . + api_key: API key for authentication with the Hathora service; + provisiion one [here](https://models.hathora.dev/tokens). + voice: Voice to use for synthesis (see the + [Hathora docs](https://models.hathora.dev/model/hexgrad-kokoro-82m) + for the default value; [list of voices](https://huggingface.co/hexgrad/Kokoro-82M/blob/main/VOICES.md)). + speed: Speech speed multiplier (0.5 = half speed, 2.0 = double speed, default: 1). + """ + super().__init__( + **kwargs, + ) + self._base_url = base_url + self._api_key = api_key + self._voice = voice + self._speed = speed + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str): + try: + await self.start_processing_metrics() + await self.start_ttfb_metrics() + + url = f"{self._base_url}" + + api_key = self._api_key or os.getenv("HATHORA_API_KEY") + + payload = { + "text": text + } + + if self._voice is not None: + payload["voice"] = self._voice + if self._speed is not None: + payload["speed"] = self._speed + + yield TTSStartedFrame() + + async with aiohttp.ClientSession() as session: + async with session.post( + url, + headers={"Authorization": f"Bearer {api_key}", "Accept": "application/octet-stream"}, + json=payload, + ) as resp: + audio_data = await resp.read() + + pcm_audio, sample_rate, num_channels = _decode_audio_payload( + audio_data, + fallback_sample_rate=self.sample_rate or self._init_sample_rate or 24000, + ) + + await self.stop_ttfb_metrics() + + frame = TTSAudioRawFrame( + audio=pcm_audio, + sample_rate=sample_rate, + num_channels=num_channels, + ) + + yield frame + + except Exception as e: + logger.error(f"Hathora error: {e}") + yield ErrorFrame(f"Hathora error: {str(e)}") + finally: + await self.stop_ttfb_metrics() + await self.stop_processing_metrics() + yield TTSStoppedFrame() + +class ChatterboxTTSService(TTSService): + """Chatterbox is a public text-to-speech model optimized for natural and expressive voice synthesis. + + This service uses the Hathora-hosted Chatterbox model via the HTTP API. + + [Documentation](https://models.hathora.dev/model/resemble-ai-chatterbox) + """ + + def __init__( + self, + *, + base_url = None, + api_key = None, + exaggeration: Optional[float] = None, + audio_prompt: Optional[bytes] = None, + cfg_weight: Optional[float] = None, + **kwargs, + ): + """Initialize the Hathora-hosted Chatterbox TTS service. + + Args: + base_url: Base URL for the Hathora Chatterbox TTS API. + api_key: API key for authentication with the Hathora service; + provisiion one [here](https://models.hathora.dev/tokens). + exaggeration: Controls emotional intensity (default: 0.5). + audio_prompt: Reference audio file for voice cloning. + cfg_weight: Controls adherence to reference voice (default: 0.5). + """ + + super().__init__( + **kwargs, + ) + self._base_url = base_url + self._api_key = api_key + self._exaggeration = exaggeration + self._audio_prompt = audio_prompt + self._cfg_weight = cfg_weight + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str): + try: + await self.start_ttfb_metrics() + + url = f"{self._base_url}" + + url_query_params = [] + if self._exaggeration is not None: + url_query_params.append(f"exaggeration={self._exaggeration}") + if self._cfg_weight is not None: + url_query_params.append(f"cfg_weight={self._cfg_weight}") + + if len(url_query_params) > 0: + url += "?" + "&".join(url_query_params) + + api_key = self._api_key or os.getenv("HATHORA_API_KEY") + + form_data = aiohttp.FormData() + form_data.add_field("text", text) + + if self._audio_prompt is not None: + form_data.add_field("audio_prompt", self._audio_prompt, filename="audio.wav", content_type="application/octet-stream") + + yield TTSStartedFrame() + + async with aiohttp.ClientSession() as session: + async with session.post( + url, + headers={"Authorization": f"Bearer {api_key}"}, + data=form_data, + ) as resp: + audio_data = await resp.read() + + await self.start_tts_usage_metrics(text) + + pcm_audio, sample_rate, num_channels = _decode_audio_payload( + audio_data, + fallback_sample_rate=self.sample_rate or self._init_sample_rate or 24000, + ) + + await self.stop_ttfb_metrics() + + frame = TTSAudioRawFrame( + audio=pcm_audio, + sample_rate=sample_rate, + num_channels=num_channels, + ) + + yield frame + + except Exception as e: + logger.error(f"Hathora error: {e}") + yield ErrorFrame(f"Hathora error: {str(e)}") + finally: + await self.stop_ttfb_metrics() + yield TTSStoppedFrame()