Skip to content

Commit

Permalink
enh: Extending api to support chirp_2 model, and support passing loca…
Browse files Browse the repository at this point in the history
…tion (#1089)
  • Loading branch information
brightsparc authored Nov 17, 2024
1 parent 238d018 commit 383f102
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 4 deletions.
5 changes: 5 additions & 0 deletions .changeset/chatty-grapes-scream.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-google": minor
---

Add support for google STT chirp_2 model.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
# Speech to Text v2

SpeechModels = Literal[
"long", "short", "telephony", "medical_dictation", "medical_conversation", "chirp"
"long",
"short",
"telephony",
"medical_dictation",
"medical_conversation",
"chirp",
"chirp_2",
]

SpeechLanguages = Literal[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
utils,
)

from google.api_core.client_options import ClientOptions
from google.api_core.exceptions import Aborted, DeadlineExceeded, GoogleAPICallError
from google.auth import default as gauth_default
from google.auth.exceptions import DefaultCredentialsError
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
punctuate: bool = True,
spoken_punctuation: bool = True,
model: SpeechModels = "long",
location: str = "global",
credentials_info: dict | None = None,
credentials_file: str | None = None,
keywords: List[tuple[str, float]] | None = None,
Expand All @@ -97,6 +99,7 @@ def __init__(
)

self._client: SpeechAsyncClient | None = None
self._location = location
self._credentials_info = credentials_info
self._credentials_file = credentials_file

Expand Down Expand Up @@ -132,9 +135,16 @@ def _ensure_client(self) -> SpeechAsyncClient:
self._client = SpeechAsyncClient.from_service_account_file(
self._credentials_file
)
else:
elif self._location == "global":
self._client = SpeechAsyncClient()

else:
# Add support for passing a specific location that matches recognizer
# see: https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages
self._client = SpeechAsyncClient(
client_options=ClientOptions(
api_endpoint=f"{self._location}-speech.googleapis.com"
)
)
assert self._client is not None
return self._client

Expand All @@ -150,7 +160,7 @@ def _recognizer(self) -> str:
from google.auth import default as ga_default

_, project_id = ga_default()
return f"projects/{project_id}/locations/global/recognizers/_"
return f"projects/{project_id}/locations/{self._location}/recognizers/_"

def _sanitize_options(self, *, language: str | None = None) -> STTOptions:
config = dataclasses.replace(self._config)
Expand Down
1 change: 1 addition & 0 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[pytest]
asyncio_mode = auto
asyncio_default_fixture_loop_scope = "function"
log_cli = true
log_cli_level = DEBUG
12 changes: 12 additions & 0 deletions tests/test_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def read_mp3_file(filename: str) -> rtc.AudioFrame:
RECOGNIZE_STT = [
deepgram.STT(),
google.STT(),
google.STT(
languages=["en-AU"],
model="chirp_2",
spoken_punctuation=False,
location="us-central1",
),
openai.STT(),
fal.WizperSTT(),
]
Expand All @@ -63,6 +69,12 @@ async def test_recognize(stt: agents.stt.STT):
assemblyai.STT(),
deepgram.STT(),
google.STT(),
google.STT(
languages=["en-AU"],
model="chirp_2",
spoken_punctuation=False,
location="us-central1",
),
agents.stt.StreamAdapter(stt=openai.STT(), vad=STREAM_VAD),
azure.STT(),
]
Expand Down

0 comments on commit 383f102

Please sign in to comment.