Skip to content

Commit 8680cda

Browse files
authored
deepgram: add support for keywords boost/penalty (#599)
1 parent e69d25a commit 8680cda

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

.changeset/chilly-days-rhyme.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"livekit-plugins-deepgram": patch
3+
---
4+
5+
deepgram: add support for keywords boost/penalty

livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import os
2222
import wave
2323
from dataclasses import dataclass
24-
from typing import List
24+
from typing import List, Tuple
2525
from urllib.parse import urlencode
2626

2727
import aiohttp
@@ -31,6 +31,9 @@
3131
from .log import logger
3232
from .models import DeepgramLanguages, DeepgramModels
3333

34+
BASE_URL = "https://api.deepgram.com/v1/listen"
35+
BASE_URL_WS = "wss://api.deepgram.com/v1/listen"
36+
3437

3538
@dataclass
3639
class STTOptions:
@@ -45,6 +48,7 @@ class STTOptions:
4548
filler_words: bool
4649
sample_rate: int
4750
num_channels: int
51+
keywords: list[Tuple[str, float]]
4852

4953

5054
class STT(stt.STT):
@@ -60,6 +64,7 @@ def __init__(
6064
no_delay: bool = True,
6165
endpointing_ms: int = 25,
6266
filler_words: bool = False,
67+
keywords: list[Tuple[str, float]] = [],
6368
api_key: str | None = None,
6469
http_session: aiohttp.ClientSession | None = None,
6570
) -> None:
@@ -87,6 +92,7 @@ def __init__(
8792
filler_words=filler_words,
8893
sample_rate=48000,
8994
num_channels=1,
95+
keywords=keywords,
9096
)
9197
self._session = http_session
9298

@@ -106,16 +112,11 @@ async def recognize(
106112
"punctuate": config.punctuate,
107113
"detect_language": config.detect_language,
108114
"smart_format": config.smart_format,
115+
"keywords": self._opts.keywords,
109116
}
110117
if config.language:
111118
recognize_config["language"] = config.language
112119

113-
# seems like lower after encoding the parameters is needed
114-
# otherwise Deepgram returns a bad request
115-
url = (
116-
f"https://api.deepgram.com/v1/listen?{urlencode(recognize_config).lower()}"
117-
)
118-
119120
buffer = merge_frames(buffer)
120121
io_buffer = io.BytesIO()
121122
with wave.open(io_buffer, "wb") as wav:
@@ -127,7 +128,7 @@ async def recognize(
127128
data = io_buffer.getvalue()
128129

129130
async with self._ensure_session().post(
130-
url=url,
131+
url=_to_deepgram_url(recognize_config),
131132
data=data,
132133
headers={
133134
"Authorization": f"Token {self._api_key}",
@@ -204,17 +205,16 @@ async def _run(self, max_retry: int) -> None:
204205
if self._opts.endpointing_ms == 0
205206
else self._opts.endpointing_ms,
206207
"filler_words": self._opts.filler_words,
208+
"keywords": self._opts.keywords,
207209
}
208210

209211
if self._opts.language:
210212
live_config["language"] = self._opts.language
211213

212214
headers = {"Authorization": f"Token {self._api_key}"}
213-
214-
url = (
215-
f"wss://api.deepgram.com/v1/listen?{urlencode(live_config).lower()}"
215+
ws = await self._session.ws_connect(
216+
_to_deepgram_url(live_config, websocket=True), headers=headers
216217
)
217-
ws = await self._session.ws_connect(url, headers=headers)
218218
retry_count = 0 # connected successfully, reset the retry_count
219219

220220
await self._run_ws(ws)
@@ -411,3 +411,16 @@ def prerecorded_transcription_to_speech_event(
411411
for alt in dg_alts
412412
],
413413
)
414+
415+
416+
def _to_deepgram_url(opts: dict, *, websocket: bool = False) -> str:
417+
if opts.get("keywords"):
418+
# convert keywords to a list of "keyword:intensifier"
419+
opts["keywords"] = [
420+
f"{keyword}:{intensifier}" for (keyword, intensifier) in opts["keywords"]
421+
]
422+
423+
# lowercase bools
424+
opts = {k: str(v).lower() if isinstance(v, bool) else v for k, v in opts.items()}
425+
base_url = BASE_URL_WS if websocket else BASE_URL
426+
return f"{base_url}?{urlencode(opts, doseq=True)}"

0 commit comments

Comments
 (0)