Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions examples/foundational/07af-interruptible-hathora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import os

from dotenv import load_dotenv
from loguru import logger

from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.hathora.stt import ParakeetSTTService
from pipecat.services.hathora.tts import ChatterboxTTSService, KokoroTTSService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams

load_dotenv(override=True)

# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(),
),
}


async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")

# See https://models.hathora.dev/model/nvidia-parakeet-tdt-0.6b-v3
stt = ParakeetTDTSTTService(
base_url="https://app-1c7bebb9-6977-4101-9619-833b251b86d1.app.hathora.dev/v1/transcribe",
api_key=os.getenv("HATHORA_API_KEY")
)

# See https://models.hathora.dev/model/hexgrad-kokoro-82m
tts = KokoroTTSService(
base_url="https://app-01312daf-6e53-4b9d-a4ad-13039f35adc4.app.hathora.dev/synthesize",
api_key=os.getenv("HATHORA_API_KEY"),
)

# See https://models.hathora.dev/model/resemble-ai-chatterbox
# tts = ChatterboxTTSService(
# base_url="https://app-efbc8fe2-df55-4f96-bbe3-74f6ea9d986b.app.hathora.dev/v1/generate",
# api_key=os.getenv("HATHORA_API_KEY")
# )

# See https://models.hathora.dev/model/qwen3-30b-a3b
llm = OpenAILLMService(
base_url="https://app-362f7ca1-6975-4e18-a605-ab202bf2c315.app.hathora.dev/v1",
api_key=os.getenv("HATHORA_API_KEY"),
model=None,
)

messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
},
]

context = LLMContext(messages)
context_aggregator = LLMContextAggregatorPair(context)

pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)

task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)

@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMRunFrame()])

@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()

runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)

await runner.run(task)


async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)


if __name__ == "__main__":
from pipecat.runner.run import main

main()
14 changes: 14 additions & 0 deletions src/pipecat/services/hathora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import sys

from pipecat.services import DeprecatedModuleProxy

from .stt import *
from .tts import *

sys.modules[__name__] = DeprecatedModuleProxy(globals(), "hathora", "hathora.[stt,tts]")
107 changes: 107 additions & 0 deletions src/pipecat/services/hathora/stt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""[Hathora-hosted](https://models.hathora.dev) speech-to-text services."""

import os
from typing import Optional

import aiohttp
from loguru import logger

from pipecat.frames.frames import (
ErrorFrame,
TranscriptionFrame,
)
from pipecat.services.stt_service import SegmentedSTTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601

class ParakeetTDTSTTService(SegmentedSTTService):
"""Parakeet TDT is a multilingual automatic speech recognition model
with word-level timestamps.

This service uses the Hathora-hosted Parakeet model via the HTTP API.

[Documentation](https://models.hathora.dev/model/nvidia-parakeet-tdt-0.6b-v3)
"""

def __init__(
self,
*,
base_url = None,
api_key = None,
start_time: Optional[int] = None,
end_time: Optional[int] = None,
**kwargs,
):
"""Initialize the Hathora-hosted Parakeet STT service.

Args:
base_url: Base URL for the Hathora Parakeet STT API.
api_key: API key for authentication with the Hathora service;
provisiion one [here](https://models.hathora.dev/tokens).
start_time: Start time in seconds for the time window.
end_time: End time in seconds for the time window.
"""
super().__init__(
**kwargs,
)
self._base_url = base_url
self._api_key = api_key
self._start_time = start_time
self._end_time = end_time

def can_generate_metrics(self) -> bool:
return True

async def run_stt(self, audio: bytes):
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()

url = f"{self._base_url}"

url_query_params = []
if self._start_time is not None:
url_query_params.append(f"start_time={self._start_time}")
if self._end_time is not None:
url_query_params.append(f"end_time={self._end_time}")
url_query_params.append(f"sample_rate={self.sample_rate}")

if len(url_query_params) > 0:
url += "?" + "&".join(url_query_params)

api_key = self._api_key or os.getenv("HATHORA_API_KEY")

form_data = aiohttp.FormData()
form_data.add_field("file", audio, filename="audio.wav", content_type="application/octet-stream")

async with aiohttp.ClientSession() as session:
async with session.post(
url,
headers={"Authorization": f"Bearer {api_key}"},
data=form_data,
) as resp:
response = await resp.json()

if response and "text" in response:
text = response["text"].strip()
if text: # Only yield non-empty text
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(
text,
self._user_id,
time_now_iso8601(),
Language("en"), # TODO: the parakeet hathora API doesn't accept a language but says it's multilingual
result=response,
)

except Exception as e:
logger.error(f"Hathora error: {e}")
yield ErrorFrame(f"Hathora error: {str(e)}")
Loading