16
16
17
17
import dataclasses
18
18
import io
19
- import os
20
19
import wave
21
20
from dataclasses import dataclass
22
- from pathlib import PurePosixPath
23
21
24
- import aiohttp
22
+ import httpx
25
23
from livekit import agents
26
- from livekit .agents import stt , utils
24
+ from livekit .agents import stt
27
25
from livekit .agents .utils import AudioBuffer
28
26
27
+ import openai
28
+
29
29
from .models import WhisperModels
30
- from .utils import get_base_url
31
30
32
31
33
32
@dataclass
34
33
class _STTOptions :
35
34
language : str
36
35
detect_language : bool
37
36
model : WhisperModels
38
- api_key : str
39
- endpoint : str
40
37
41
38
42
39
class STT (stt .STT ):
@@ -46,37 +43,35 @@ def __init__(
46
43
language : str = "en" ,
47
44
detect_language : bool = False ,
48
45
model : WhisperModels = "whisper-1" ,
49
- api_key : str | None = None ,
50
46
base_url : str | None = None ,
51
- http_session : aiohttp .ClientSession | None = None ,
47
+ api_key : str | None = None ,
48
+ client : openai .AsyncClient | None = None ,
52
49
):
53
50
super ().__init__ (
54
51
capabilities = stt .STTCapabilities (streaming = False , interim_results = False )
55
52
)
56
- api_key = api_key or os .environ .get ("OPENAI_API_KEY" )
57
- if not api_key :
58
- raise ValueError ("OPENAI_API_KEY must be set" )
59
-
60
53
if detect_language :
61
54
language = ""
62
55
63
- base = PurePosixPath (get_base_url (base_url ))
64
- endpoint = str (base / "audio/transcriptions" )
65
-
66
56
self ._opts = _STTOptions (
67
57
language = language ,
68
58
detect_language = detect_language ,
69
59
model = model ,
70
- api_key = api_key ,
71
- endpoint = endpoint ,
72
60
)
73
- self ._session = http_session
74
-
75
- def _ensure_session (self ) -> aiohttp .ClientSession :
76
- if not self ._session :
77
- self ._session = utils .http_context .http_session ()
78
61
79
- return self ._session
62
+ self ._client = client or openai .AsyncClient (
63
+ api_key = api_key ,
64
+ base_url = base_url ,
65
+ http_client = httpx .AsyncClient (
66
+ timeout = 5.0 ,
67
+ follow_redirects = True ,
68
+ limits = httpx .Limits (
69
+ max_connections = 1000 ,
70
+ max_keepalive_connections = 100 ,
71
+ keepalive_expiry = 120 ,
72
+ ),
73
+ ),
74
+ )
80
75
81
76
def _sanitize_options (self , * , language : str | None = None ) -> _STTOptions :
82
77
config = dataclasses .replace (self ._opts )
@@ -87,7 +82,6 @@ async def recognize(
87
82
self , buffer : AudioBuffer , * , language : str | None = None
88
83
) -> stt .SpeechEvent :
89
84
config = self ._sanitize_options (language = language )
90
-
91
85
buffer = agents .utils .merge_frames (buffer )
92
86
io_buffer = io .BytesIO ()
93
87
with wave .open (io_buffer , "wb" ) as wav :
@@ -96,29 +90,14 @@ async def recognize(
96
90
wav .setframerate (buffer .sample_rate )
97
91
wav .writeframes (buffer .data )
98
92
99
- form = aiohttp .FormData ()
100
- form .add_field ("file" , io_buffer .getvalue (), filename = "my_file.wav" )
101
- form .add_field ("model" , config .model )
102
-
103
- if config .language :
104
- form .add_field ("language" , config .language )
105
-
106
- form .add_field ("response_format" , "json" )
107
-
108
- async with self ._ensure_session ().post (
109
- self ._opts .endpoint ,
110
- headers = {"Authorization" : f"Bearer { config .api_key } " },
111
- data = form ,
112
- ) as resp :
113
- data = await resp .json ()
114
- if "text" not in data or "error" in data :
115
- raise ValueError (f"Unexpected response: { data } " )
116
-
117
- return _transcription_to_speech_event (data , config .language )
118
-
93
+ resp = await self ._client .audio .transcriptions .create (
94
+ file = ("my_file.wav" , io_buffer .getvalue (), "audio/wav" ),
95
+ model = self ._opts .model ,
96
+ language = config .language ,
97
+ response_format = "json" ,
98
+ )
119
99
120
- def _transcription_to_speech_event (transcription : dict , language ) -> stt .SpeechEvent :
121
- return stt .SpeechEvent (
122
- type = stt .SpeechEventType .FINAL_TRANSCRIPT ,
123
- alternatives = [stt .SpeechData (text = transcription ["text" ], language = language )],
124
- )
100
+ return stt .SpeechEvent (
101
+ type = stt .SpeechEventType .FINAL_TRANSCRIPT ,
102
+ alternatives = [stt .SpeechData (text = resp .text , language = language or "" )],
103
+ )
0 commit comments