Skip to content

Commit 1dace6a

Browse files
authored
openai: use openai client for stt (#583)
1 parent bf9334b commit 1dace6a

File tree

6 files changed

+41
-64
lines changed

6 files changed

+41
-64
lines changed

Diff for: .changeset/thin-apricots-end.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"livekit-plugins-openai": minor
3+
---
4+
5+
openai: use openai client for stt

Diff for: livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from .log import logger
3131
from .models import ChatModels
32-
from .utils import AsyncAzureADTokenProvider, get_base_url
32+
from .utils import AsyncAzureADTokenProvider
3333

3434

3535
@dataclass
@@ -49,7 +49,7 @@ def __init__(
4949
self._opts = LLMOptions(model=model)
5050
self._client = client or openai.AsyncClient(
5151
api_key=api_key,
52-
base_url=get_base_url(base_url),
52+
base_url=base_url,
5353
http_client=httpx.AsyncClient(
5454
timeout=5.0,
5555
follow_redirects=True,

Diff for: livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
"gpt-4-turbo",
1111
"gpt-4-turbo-2024-04-09",
1212
"gpt-4-turbo-preview",
13-
"gpt-4-0125-preview" "gpt-4-1106-preview",
13+
"gpt-4-0125-preview",
14+
"gpt-4-1106-preview",
1415
"gpt-4-vision-preview",
1516
"gpt-4-1106-vision-preview",
1617
"gpt-4",
@@ -26,7 +27,6 @@
2627
"gpt-3.5-turbo-1106",
2728
"gpt-3.5-turbo-16k-0613",
2829
]
29-
3030
EmbeddingModels = Literal[
3131
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"
3232
]

Diff for: livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py

+29-50
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,24 @@
1616

1717
import dataclasses
1818
import io
19-
import os
2019
import wave
2120
from dataclasses import dataclass
22-
from pathlib import PurePosixPath
2321

24-
import aiohttp
22+
import httpx
2523
from livekit import agents
26-
from livekit.agents import stt, utils
24+
from livekit.agents import stt
2725
from livekit.agents.utils import AudioBuffer
2826

27+
import openai
28+
2929
from .models import WhisperModels
30-
from .utils import get_base_url
3130

3231

3332
@dataclass
3433
class _STTOptions:
3534
language: str
3635
detect_language: bool
3736
model: WhisperModels
38-
api_key: str
39-
endpoint: str
4037

4138

4239
class STT(stt.STT):
@@ -46,37 +43,35 @@ def __init__(
4643
language: str = "en",
4744
detect_language: bool = False,
4845
model: WhisperModels = "whisper-1",
49-
api_key: str | None = None,
5046
base_url: str | None = None,
51-
http_session: aiohttp.ClientSession | None = None,
47+
api_key: str | None = None,
48+
client: openai.AsyncClient | None = None,
5249
):
5350
super().__init__(
5451
capabilities=stt.STTCapabilities(streaming=False, interim_results=False)
5552
)
56-
api_key = api_key or os.environ.get("OPENAI_API_KEY")
57-
if not api_key:
58-
raise ValueError("OPENAI_API_KEY must be set")
59-
6053
if detect_language:
6154
language = ""
6255

63-
base = PurePosixPath(get_base_url(base_url))
64-
endpoint = str(base / "audio/transcriptions")
65-
6656
self._opts = _STTOptions(
6757
language=language,
6858
detect_language=detect_language,
6959
model=model,
70-
api_key=api_key,
71-
endpoint=endpoint,
7260
)
73-
self._session = http_session
74-
75-
def _ensure_session(self) -> aiohttp.ClientSession:
76-
if not self._session:
77-
self._session = utils.http_context.http_session()
7861

79-
return self._session
62+
self._client = client or openai.AsyncClient(
63+
api_key=api_key,
64+
base_url=base_url,
65+
http_client=httpx.AsyncClient(
66+
timeout=5.0,
67+
follow_redirects=True,
68+
limits=httpx.Limits(
69+
max_connections=1000,
70+
max_keepalive_connections=100,
71+
keepalive_expiry=120,
72+
),
73+
),
74+
)
8075

8176
def _sanitize_options(self, *, language: str | None = None) -> _STTOptions:
8277
config = dataclasses.replace(self._opts)
@@ -87,7 +82,6 @@ async def recognize(
8782
self, buffer: AudioBuffer, *, language: str | None = None
8883
) -> stt.SpeechEvent:
8984
config = self._sanitize_options(language=language)
90-
9185
buffer = agents.utils.merge_frames(buffer)
9286
io_buffer = io.BytesIO()
9387
with wave.open(io_buffer, "wb") as wav:
@@ -96,29 +90,14 @@ async def recognize(
9690
wav.setframerate(buffer.sample_rate)
9791
wav.writeframes(buffer.data)
9892

99-
form = aiohttp.FormData()
100-
form.add_field("file", io_buffer.getvalue(), filename="my_file.wav")
101-
form.add_field("model", config.model)
102-
103-
if config.language:
104-
form.add_field("language", config.language)
105-
106-
form.add_field("response_format", "json")
107-
108-
async with self._ensure_session().post(
109-
self._opts.endpoint,
110-
headers={"Authorization": f"Bearer {config.api_key}"},
111-
data=form,
112-
) as resp:
113-
data = await resp.json()
114-
if "text" not in data or "error" in data:
115-
raise ValueError(f"Unexpected response: {data}")
116-
117-
return _transcription_to_speech_event(data, config.language)
118-
93+
resp = await self._client.audio.transcriptions.create(
94+
file=("my_file.wav", io_buffer.getvalue(), "audio/wav"),
95+
model=self._opts.model,
96+
language=config.language,
97+
response_format="json",
98+
)
11999

120-
def _transcription_to_speech_event(transcription: dict, language) -> stt.SpeechEvent:
121-
return stt.SpeechEvent(
122-
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
123-
alternatives=[stt.SpeechData(text=transcription["text"], language=language)],
124-
)
100+
return stt.SpeechEvent(
101+
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
102+
alternatives=[stt.SpeechData(text=resp.text, language=language or "")],
103+
)

Diff for: livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from .log import logger
2626
from .models import TTSModels, TTSVoices
27-
from .utils import AsyncAzureADTokenProvider, get_base_url
27+
from .utils import AsyncAzureADTokenProvider
2828

2929
OPENAI_TTS_SAMPLE_RATE = 24000
3030
OPENAI_TTS_CHANNELS = 1
@@ -58,7 +58,7 @@ def __init__(
5858

5959
self._client = client or openai.AsyncClient(
6060
api_key=api_key,
61-
base_url=get_base_url(base_url),
61+
base_url=base_url,
6262
http_client=httpx.AsyncClient(
6363
timeout=5.0,
6464
follow_redirects=True,
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
import os
2-
from typing import Awaitable, Callable, Optional, Union
1+
from typing import Awaitable, Callable, Union
32

43
AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]]
5-
6-
7-
def get_base_url(base_url: Optional[str]) -> str:
8-
if not base_url:
9-
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
10-
return base_url

0 commit comments

Comments
 (0)