21
21
import os
22
22
import wave
23
23
from dataclasses import dataclass
24
- from typing import List
24
+ from typing import List , Tuple
25
25
from urllib .parse import urlencode
26
26
27
27
import aiohttp
31
31
from .log import logger
32
32
from .models import DeepgramLanguages , DeepgramModels
33
33
34
+ BASE_URL = "https://api.deepgram.com/v1/listen"
35
+ BASE_URL_WS = "wss://api.deepgram.com/v1/listen"
36
+
34
37
35
38
@dataclass
36
39
class STTOptions :
@@ -45,6 +48,7 @@ class STTOptions:
45
48
filler_words : bool
46
49
sample_rate : int
47
50
num_channels : int
51
+ keywords : list [Tuple [str , float ]]
48
52
49
53
50
54
class STT (stt .STT ):
@@ -60,6 +64,7 @@ def __init__(
60
64
no_delay : bool = True ,
61
65
endpointing_ms : int = 25 ,
62
66
filler_words : bool = False ,
67
+ keywords : list [Tuple [str , float ]] = [],
63
68
api_key : str | None = None ,
64
69
http_session : aiohttp .ClientSession | None = None ,
65
70
) -> None :
@@ -87,6 +92,7 @@ def __init__(
87
92
filler_words = filler_words ,
88
93
sample_rate = 48000 ,
89
94
num_channels = 1 ,
95
+ keywords = keywords ,
90
96
)
91
97
self ._session = http_session
92
98
@@ -106,16 +112,11 @@ async def recognize(
106
112
"punctuate" : config .punctuate ,
107
113
"detect_language" : config .detect_language ,
108
114
"smart_format" : config .smart_format ,
115
+ "keywords" : self ._opts .keywords ,
109
116
}
110
117
if config .language :
111
118
recognize_config ["language" ] = config .language
112
119
113
- # seems like lower after encoding the parameters is needed
114
- # otherwise Deepgram returns a bad request
115
- url = (
116
- f"https://api.deepgram.com/v1/listen?{ urlencode (recognize_config ).lower ()} "
117
- )
118
-
119
120
buffer = merge_frames (buffer )
120
121
io_buffer = io .BytesIO ()
121
122
with wave .open (io_buffer , "wb" ) as wav :
@@ -127,7 +128,7 @@ async def recognize(
127
128
data = io_buffer .getvalue ()
128
129
129
130
async with self ._ensure_session ().post (
130
- url = url ,
131
+ url = _to_deepgram_url ( recognize_config ) ,
131
132
data = data ,
132
133
headers = {
133
134
"Authorization" : f"Token { self ._api_key } " ,
@@ -204,17 +205,16 @@ async def _run(self, max_retry: int) -> None:
204
205
if self ._opts .endpointing_ms == 0
205
206
else self ._opts .endpointing_ms ,
206
207
"filler_words" : self ._opts .filler_words ,
208
+ "keywords" : self ._opts .keywords ,
207
209
}
208
210
209
211
if self ._opts .language :
210
212
live_config ["language" ] = self ._opts .language
211
213
212
214
headers = {"Authorization" : f"Token { self ._api_key } " }
213
-
214
- url = (
215
- f"wss://api.deepgram.com/v1/listen?{ urlencode (live_config ).lower ()} "
215
+ ws = await self ._session .ws_connect (
216
+ _to_deepgram_url (live_config , websocket = True ), headers = headers
216
217
)
217
- ws = await self ._session .ws_connect (url , headers = headers )
218
218
retry_count = 0 # connected successfully, reset the retry_count
219
219
220
220
await self ._run_ws (ws )
@@ -411,3 +411,16 @@ def prerecorded_transcription_to_speech_event(
411
411
for alt in dg_alts
412
412
],
413
413
)
414
+
415
+
416
+ def _to_deepgram_url (opts : dict , * , websocket : bool = False ) -> str :
417
+ if opts .get ("keywords" ):
418
+ # convert keywords to a list of "keyword:intensifier"
419
+ opts ["keywords" ] = [
420
+ f"{ keyword } :{ intensifier } " for (keyword , intensifier ) in opts ["keywords" ]
421
+ ]
422
+
423
+ # lowercase bools
424
+ opts = {k : str (v ).lower () if isinstance (v , bool ) else v for k , v in opts .items ()}
425
+ base_url = BASE_URL_WS if websocket else BASE_URL
426
+ return f"{ base_url } ?{ urlencode (opts , doseq = True )} "
0 commit comments