|
6 | 6 | import json
|
7 | 7 | import os
|
8 | 8 | from datetime import datetime
|
9 |
| -from typing import Any, Callable, Literal |
| 9 | +from typing import Annotated, Any, Callable, Literal, Union |
10 | 10 |
|
11 | 11 | import pydantic
|
12 | 12 | import websockets
|
|
52 | 52 | SessionTracingTracingConfiguration as OpenAISessionTracingConfiguration,
|
53 | 53 | SessionUpdateEvent as OpenAISessionUpdateEvent,
|
54 | 54 | )
|
55 |
| -from pydantic import TypeAdapter |
| 55 | +from pydantic import BaseModel, Field, TypeAdapter |
56 | 56 | from typing_extensions import assert_never
|
57 | 57 | from websockets.asyncio.client import ClientConnection
|
58 | 58 |
|
|
83 | 83 | RealtimeModelErrorEvent,
|
84 | 84 | RealtimeModelEvent,
|
85 | 85 | RealtimeModelExceptionEvent,
|
| 86 | + RealtimeModelInputAudioTimeoutTriggeredEvent, |
86 | 87 | RealtimeModelInputAudioTranscriptionCompletedEvent,
|
87 | 88 | RealtimeModelItemDeletedEvent,
|
88 | 89 | RealtimeModelItemUpdatedEvent,
|
@@ -128,6 +129,22 @@ async def get_api_key(key: str | Callable[[], MaybeAwaitable[str]] | None) -> st
|
128 | 129 | return os.getenv("OPENAI_API_KEY")
|
129 | 130 |
|
130 | 131 |
|
| 132 | +class _InputAudioBufferTimeoutTriggeredEvent(BaseModel): |
| 133 | + type: Literal["input_audio_buffer.timeout_triggered"] |
| 134 | + event_id: str |
| 135 | + audio_start_ms: int |
| 136 | + audio_end_ms: int |
| 137 | + item_id: str |
| 138 | + |
| 139 | +AllRealtimeServerEvents = Annotated[ |
| 140 | + Union[ |
| 141 | + OpenAIRealtimeServerEvent, |
| 142 | + _InputAudioBufferTimeoutTriggeredEvent, |
| 143 | + ], |
| 144 | + Field(discriminator="type"), |
| 145 | +] |
| 146 | + |
| 147 | + |
131 | 148 | class OpenAIRealtimeWebSocketModel(RealtimeModel):
|
132 | 149 | """A model that uses OpenAI's WebSocket API."""
|
133 | 150 |
|
@@ -462,8 +479,8 @@ async def _handle_ws_event(self, event: dict[str, Any]):
|
462 | 479 | try:
|
463 | 480 | if "previous_item_id" in event and event["previous_item_id"] is None:
|
464 | 481 | event["previous_item_id"] = "" # TODO (rm) remove
|
465 |
| - parsed: OpenAIRealtimeServerEvent = TypeAdapter( |
466 |
| - OpenAIRealtimeServerEvent |
| 482 | + parsed: AllRealtimeServerEvents = TypeAdapter( |
| 483 | + AllRealtimeServerEvents |
467 | 484 | ).validate_python(event)
|
468 | 485 | except pydantic.ValidationError as e:
|
469 | 486 | logger.error(f"Failed to validate server event: {event}", exc_info=True)
|
@@ -554,6 +571,12 @@ async def _handle_ws_event(self, event: dict[str, Any]):
|
554 | 571 | or parsed.type == "response.output_item.done"
|
555 | 572 | ):
|
556 | 573 | await self._handle_output_item(parsed.item)
|
| 574 | + elif parsed.type == "input_audio_buffer.timeout_triggered": |
| 575 | + await self._emit_event(RealtimeModelInputAudioTimeoutTriggeredEvent( |
| 576 | + item_id=parsed.item_id, |
| 577 | + audio_start_ms=parsed.audio_start_ms, |
| 578 | + audio_end_ms=parsed.audio_end_ms, |
| 579 | + )) |
557 | 580 |
|
558 | 581 | def _update_created_session(self, session: OpenAISessionObject) -> None:
|
559 | 582 | self._created_session = session
|
|
0 commit comments