From 49fbabf699be9d9b8c29c648d7e234c07818bc7d Mon Sep 17 00:00:00 2001 From: Om Gupta Date: Thu, 19 Dec 2024 17:51:05 +0530 Subject: [PATCH 1/3] feat: fixes for friday --- ai01/providers/openai/audio_track.py | 3 +- .../openai/realtime/realtime_model.py | 50 +++++++++++-------- ai01/utils/socket.py | 2 +- poetry.lock | 28 ++++++----- pyproject.toml | 2 +- 5 files changed, 47 insertions(+), 38 deletions(-) diff --git a/ai01/providers/openai/audio_track.py b/ai01/providers/openai/audio_track.py index 56befb2..d6a3ee8 100644 --- a/ai01/providers/openai/audio_track.py +++ b/ai01/providers/openai/audio_track.py @@ -57,7 +57,7 @@ def __init__(self, options=AudioTrackOptions()): def __repr__(self) -> str: return f" sample_rate={self.sample_rate} channels={self.channels} sample_width={self.sample_width}>" - def enqueue_audio(self, base64_audio: str): + def enqueue_audio(self, id: str, base64_audio: str): """Process and add audio data directly to the AudioFifo""" if self.readyState != "live": return @@ -125,6 +125,7 @@ async def recv(self) -> AudioFrame: else: # Update frame properties frame.sample_rate = self.sample_rate + frame.time_base = fractions.Fraction(1, self.sample_rate) # Set frame PTS diff --git a/ai01/providers/openai/realtime/realtime_model.py b/ai01/providers/openai/realtime/realtime_model.py index 0fbe923..2d894c8 100644 --- a/ai01/providers/openai/realtime/realtime_model.py +++ b/ai01/providers/openai/realtime/realtime_model.py @@ -123,7 +123,7 @@ def __init__(self, agent: Agent, options: RealTimeModelOptions): # Logger for RealTimeModel. self._logger = logger.getChild(f"RealTimeModel-{self._opts.model}") - # Conversation is the Conversations which being are happening with the RealTimeModel. + # Conversation are all the Remote Tracks who are talking to the RealTimeModel. self._conversation: Conversation = Conversation(id = str(uuid.uuid4())) # Main Task is the Audio Append the RealTimeModel. @@ -156,7 +156,7 @@ async def connect(self): self._logger.info("Connected to OpenAI RealTime Model") - self._main_tsk = asyncio.create_task(self._main(), name="RealTimeModel-Main") + self._main_tsk = asyncio.create_task(self._main(), name="RealTimeModel-Loop") except _exceptions.RealtimeModelNotConnectedError: raise @@ -240,6 +240,9 @@ async def _handle_message(self, message: Union[str, bytes]): event: _api.ServerEventType = data.get("type", "unknown") + self._logger.info(f"Event: {event}") + self._logger.info(f"Data: {data}") + if event == "session.created": self._handle_session_created(data) elif event == "error": @@ -276,8 +279,8 @@ async def _handle_message(self, message: Union[str, bytes]): # self._handle_response_content_part_added(data) elif event == "response.audio.delta": self._handle_response_audio_delta(data) - # elif event == "response.audio.done": - # self._handle_response_audio_done(data) + elif event == "response.audio.done": + self._handle_response_audio_done(data) # elif event == "response.text.done": # self._handle_response_text_done(data) # elif event == "response.audio_transcript.done": @@ -288,21 +291,21 @@ async def _handle_message(self, message: Union[str, bytes]): # self._handle_response_output_item_done(data) # elif event == "response.done": # self._handle_response_done(data) - - self._logger.info(f"Unhandled Event: {event}") + else: + self._logger.error(f"Unhandled Event: {event}") def _handle_response_output_item_done(self, data: dict): """ Response Output Item Done is the Event Handler for the Response Output Item Done Event. """ - self._logger.info("Response Output Item Done") + self._logger.info("Response Output Item Done", data) def _handle_response_content_part_done(self, data: dict): """ Response Content Part Done is the Event Handler for the Response Content Part Done Event. """ - self._logger.info("Response Content Part Done") + self._logger.info("Response Content Part Done", data) def _handle_conversation_item_truncated(self, data: dict): """ @@ -314,19 +317,19 @@ def _handle_conversation_item_deleted(self, data: dict): """ Conversation Item Deleted is the Event Handler for the Conversation Item Deleted Event. """ - self._logger.info("Conversation Item Deleted") + self._logger.info("Conversation Item Deleted", data) def _handle_conversation_item_created(self, data: dict): """ Conversation Item Created is the Event Handler for the Conversation Item Created Event. """ - self._logger.info("Conversation Item Created") + self._logger.info("Conversation Item Created", data) def _handle_session_created(self, data: dict): """ Session Created is the Event Handler for the Session Created Event. """ - self._logger.info("Session Created") + self._logger.info("Session Created", data) def _handle_error(self, data: dict): """ @@ -338,7 +341,7 @@ def _handle_input_audio_buffer_speech_started(self, data: dict): """ Speech Started is the Event Handler for the Speech Started Event. """ - self._logger.info("Speech Started") + self._logger.info("Speech Started", data) if self.agent.audio_track: self.agent.audio_track.flush_audio() @@ -349,13 +352,13 @@ def _handle_input_audio_buffer_speech_stopped(self, data: dict): """ Speech Stopped is the Event Handler for the Speech Stopped Event. """ - self._logger.info("Speech Stopped") + self._logger.info("Speech Stopped", data) def _handle_input_audio_buffer_speech_committed(self, data: dict): """ Speech Committed is the Event Handler for the Speech Committed Event. """ - self._logger.info("Speech Committed") + self._logger.info("Speech Committed", data) def _handle_conversation_item_input_audio_transcription_completed(self, data: dict): """ @@ -373,37 +376,40 @@ def _handle_response_done(self, data: dict): """ Response Done is the Event Handler for the Response Done Event. """ - self._logger.info("Response Done") + self._logger.info("Response Done", data) def _handle_response_created(self, data: dict): """ Response Created is the Event Handler for the Response Created Event. """ - self._logger.info("Response Created") + self._logger.info("Response Created", data) def _handle_response_output_item_added(self, data: dict): """ Response Output Item Added is the Event Handler for the Response Output Item Added Event. """ - self._logger.info("Response Output Item Added") + self._logger.info("Response Output Item Added", data) def _handle_response_content_part_added(self, data: dict): """ Response Content Part Added is the Event Handler for the Response Content Part Added Event. """ - self._logger.info("Response Content Part Added") + self._logger.info("Response Content Part Added", data) - def _handle_response_audio_delta(self, data: dict): + def _handle_response_audio_delta(self, response_audio_delta: _api.ServerEvent.ResponseAudioDelta): """ Response Audio Delta is the Event Handler for the Response Audio Delta Event. """ self._logger.info("Response Audio Delta") - base64_audio = data.get("delta") + base64_audio = response_audio_delta['delta'] + + item_id = str(response_audio_delta.get("item_id")) if base64_audio and self.agent.audio_track: self.agent.emit(AgentsEvents.Speaking) - self.agent.audio_track.enqueue_audio(base64_audio=base64_audio) + + self.agent.audio_track.enqueue_audio(id=item_id,base64_audio=base64_audio) def _handle_response_audio_transcript_delta(self, data: dict): """ @@ -415,7 +421,7 @@ def _handle_response_audio_done(self, data: dict): """ Response Audio Done is the Event Handler for the Response Audio Done Event. """ - self._logger.info("Response Audio Done") + self._logger.info("Response Audio Done", data) def _handle_response_text_done(self, data: dict): """ diff --git a/ai01/utils/socket.py b/ai01/utils/socket.py index a4b63e2..7ecfdee 100644 --- a/ai01/utils/socket.py +++ b/ai01/utils/socket.py @@ -6,8 +6,8 @@ import websockets logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class SocketClient: """ diff --git a/poetry.lock b/poetry.lock index 469aaae..65a7f8e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -577,22 +577,24 @@ name = "huddle01" version = "1.0.5" description = "Python Real-Time-SDK for Huddle01 dRTC Network" optional = false -python-versions = "<4.0,>=3.12" -files = [ - {file = "huddle01-1.0.5-py3-none-any.whl", hash = "sha256:fd18266ca55d1f2fb89265760afb1828fd4065065c1cb3efdfc54a4309318a75"}, - {file = "huddle01-1.0.5.tar.gz", hash = "sha256:0b0abeefa4614569a401c3410159ab46e9f3b03f020e84c3f7c5c28fb75e586e"}, -] +python-versions = "^3.12" +files = [] +develop = true [package.dependencies] -aiohttp = ">=3.10.10,<4.0.0" -asyncio = ">=3.4.3,<4.0.0" -protobuf = ">=5.28.2,<6.0.0" +aiohttp = "^3.10.10" +asyncio = "^3.4.3" +protobuf = "^5.28.2" pydantic = ">=1.10.17,<3.0.0" pyee = "11.1.0" -pymediasoup = ">=1.0.1,<2.0.0" -python-dotenv = ">=1.0.1,<2.0.0" -ruff = ">=0.6.9,<0.7.0" -websockets = ">=13.1,<14.0" +pymediasoup = "^1.0.1" +python-dotenv = "^1.0.1" +ruff = "^0.6.9" +websockets = "^13.1" + +[package.source] +type = "directory" +url = "../shinigami/packages/huddle01-python" [[package]] name = "idna" @@ -1324,4 +1326,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "c554680895ecb9b7d3b51e0bdcacd1ce2953113e5c55a72dcba04e26b3de753b" +content-hash = "8f0f5301a5e75d97c258e6f58242759c442a48cfb824bde519e6f38d777f68f0" diff --git a/pyproject.toml b/pyproject.toml index db48e7e..b1d8317 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ pyee = "11.1.0" av = "11.0.0" numpy = "^2.1.3" websockets = "^13.1" -huddle01 = "1.0.5" +huddle01 = {path = "../shinigami/packages/huddle01-python", develop = true} uuid = "^1.30" [build-system] From 0fb61316f236d182989b02268a19558d44f5fcd3 Mon Sep 17 00:00:00 2001 From: Om Gupta Date: Mon, 23 Dec 2024 23:43:19 +0530 Subject: [PATCH 2/3] add: openai truncate and some other major improvements --- ai01/agent/__init__.py | 5 +- ai01/agent/_api.py | 16 + ai01/agent/_models.py | 14 - ai01/agent/agent.py | 33 +- ai01/providers/openai/audio_track.py | 155 ++++--- ai01/providers/openai/realtime/__init__.py | 5 +- ai01/providers/openai/realtime/_api.py | 173 +++++++- .../providers/openai/realtime/conversation.py | 25 +- .../openai/realtime/realtime_model.py | 382 ++++++++++-------- ai01/rtc/__init__.py | 3 +- ai01/rtc/utils.py | 41 ++ ai01/utils/emitter.py | 212 +++++++++- ai01/utils/socket.py | 2 + example/chatbot/main.py | 34 +- 14 files changed, 762 insertions(+), 338 deletions(-) create mode 100644 ai01/agent/_api.py delete mode 100644 ai01/agent/_models.py create mode 100644 ai01/rtc/utils.py diff --git a/ai01/agent/__init__.py b/ai01/agent/__init__.py index 51c7232..450a1e3 100644 --- a/ai01/agent/__init__.py +++ b/ai01/agent/__init__.py @@ -1,10 +1,11 @@ -from ._models import AgentsEvents +from ._api import AgentEventTypes, AgentState from .agent import Agent, AgentOptions __all__ = [ "Agent", "AgentOptions", - "AgentsEvents", + "AgentEventTypes", + "AgentState", ] # Cleanup docs of unexported modules diff --git a/ai01/agent/_api.py b/ai01/agent/_api.py new file mode 100644 index 0000000..870ac60 --- /dev/null +++ b/ai01/agent/_api.py @@ -0,0 +1,16 @@ +from typing import Literal + +AgentEventTypes = Literal[ + "listening", + "speaking", + "idle", + "connected", + "disconnected", + "error", +] + +AgentState = Literal[ + "speaking", + "listening", + "idle" +] \ No newline at end of file diff --git a/ai01/agent/_models.py b/ai01/agent/_models.py deleted file mode 100644 index bbc2e56..0000000 --- a/ai01/agent/_models.py +++ /dev/null @@ -1,14 +0,0 @@ -# AgentsEvents is the Enum for the different types of Events emitted by the Agent. -class AgentsEvents(str): - Connected: str = "Connected" - Disconnected: str = "Disconnected" - Speaking: str = "Speaking" - Listening: str = "Listening" - Thinking: str = "Thinking" - - - - - - - diff --git a/ai01/agent/agent.py b/ai01/agent/agent.py index e062753..e7c7afb 100644 --- a/ai01/agent/agent.py +++ b/ai01/agent/agent.py @@ -1,15 +1,17 @@ import logging +from dataclasses import dataclass from typing import Optional -from pydantic import BaseModel +from ai01 import RTC, RTCOptions +from ai01.providers.openai.audio_track import AudioTrack +from ai01.utils.emitter import EnhancedEventEmitter -from ..providers.openai.audio_track import AudioTrack -from ..rtc import RTC, RTCOptions -from ..utils.emitter import EnhancedEventEmitter +from . import _api from ._exceptions import RoomNotConnectedError, RoomNotCreatedError -class AgentOptions(BaseModel): +@dataclass +class AgentOptions: """ " Every Agent is created with a set of options that define the configuration for the Agent. @@ -32,9 +34,8 @@ class Config: arbitrary_types_allowed = True - logger = logging.getLogger("Agent") -class Agent(EnhancedEventEmitter): +class Agent(EnhancedEventEmitter[_api.AgentEventTypes]): """ Agents is defined as the higher level user which is its own entity and has exposed APIs to interact with different Models and Outer World using dRTC. @@ -46,6 +47,9 @@ class Agent(EnhancedEventEmitter): def __init__(self, options: AgentOptions): super(Agent, self).__init__() + + # State of the Agent. + self._state: _api.AgentState = 'idle' # Options is the configuration for the Agent. self.options = options @@ -78,6 +82,19 @@ def room(self): raise RoomNotCreatedError() return self.__rtc.room + + def _update_state(self, state: _api.AgentState): + """ + Update the State of the Agent. + """ + self._state = state + + if self._state == 'listening': + self.emit('listening') + elif self._state == 'speaking': + self.emit('speaking') + elif self._state == 'idle': + self.emit('idle') async def join(self): """ @@ -119,3 +136,5 @@ async def connect(self): raise RoomNotCreatedError() await room.connect() + + self.emit('connected') diff --git a/ai01/providers/openai/audio_track.py b/ai01/providers/openai/audio_track.py index d6a3ee8..f45e9b6 100644 --- a/ai01/providers/openai/audio_track.py +++ b/ai01/providers/openai/audio_track.py @@ -1,93 +1,94 @@ +from __future__ import annotations + import asyncio -import base64 -import fractions import logging import threading +from contextlib import contextmanager +from dataclasses import dataclass import numpy as np from aiortc.mediastreams import MediaStreamError, MediaStreamTrack from av import AudioFrame from av.audio.fifo import AudioFifo -from pydantic import BaseModel - -logger = logging.getLogger(__name__) - - -class AudioTrackOptions(BaseModel): - """Audio Track Options""" - sample_rate: int = 24000 - """ - Sample Rate is the number of samples of audio carried per second, measured in Hz, Default is 24000. - """ +from ai01 import rtc - channels: int = 1 - """ - Channels is the number of audio channels, Default is 1, which is mono. - """ - - sample_width: int = 2 - """ - Sample Width is the number of bytes per sample, Default is 2, which is 16 bits. - """ +logger = logging.getLogger(__name__) +# Constants +AUDIO_PTIME = 0.020 # 20ms +DEFAULT_SAMPLE_RATE = 24000 +DEFAULT_CHANNELS = 1 +DEFAULT_SAMPLE_WIDTH = 2 + +@dataclass +class AudioTrackOptions: + sample_rate: int = DEFAULT_SAMPLE_RATE + channels: int = DEFAULT_CHANNELS + sample_width: int = DEFAULT_SAMPLE_WIDTH + +class AudioFIFOManager: + def __init__(self): + self.fifo = AudioFifo() + self.lock = threading.Lock() + + @contextmanager + def fifo_operation(self): + with self.lock: + yield self.fifo + + def flush(self): + with self.lock: + self.fifo = AudioFifo() class AudioTrack(MediaStreamTrack): kind = "audio" def __init__(self, options=AudioTrackOptions()): - print("AudioTrack __init__") super().__init__() - - # Audio configuration self.sample_rate = options.sample_rate self.channels = options.channels - self.sample_width = options.sample_width # 2 bytes per sample (16 bits) + self.sample_width = options.sample_width # 2 bytes per sample (16-bit PCM) self._start = None self._timestamp = 0 + self.frame_samples = rtc.get_frame_size(self.sample_rate, AUDIO_PTIME) - self.AUDIO_PTIME = 0.020 # 20ms audio packetization - self.frame_samples = int(self.AUDIO_PTIME * self.sample_rate) + self._pushed_duration = 0.0 + self._total_played_time = None - # Audio FIFO buffer - self.audio_fifo = AudioFifo() - self.fifo_lock = threading.Lock() + self.fifo_manager = AudioFIFOManager() def __repr__(self) -> str: - return f" sample_rate={self.sample_rate} channels={self.channels} sample_width={self.sample_width}>" - - def enqueue_audio(self, id: str, base64_audio: str): - """Process and add audio data directly to the AudioFifo""" - if self.readyState != "live": - return - + return f" sample_rate={self.sample_rate} channels={self.channels} sample_width={self.sample_width}>" + + @property + def audio_samples(self) -> int: + """ + Audio Samples Returns the number of audio samples that have been played. + """ + if self._total_played_time is not None: + return int(self._total_played_time * self.sample_rate) + queued_duration = self.fifo_manager.fifo.samples / self.sample_rate + + return int((self._pushed_duration - queued_duration) * self.sample_rate) + + def enqueue_audio(self, content_index:int, audio: AudioFrame): try: - audio_bytes = base64.b64decode(base64_audio) - audio_array = np.frombuffer(audio_bytes, dtype=np.int16) - audio_array = audio_array.reshape(self.channels, -1) - - frame = AudioFrame.from_ndarray( - audio_array, - format="s16", - layout="mono" if self.channels == 1 else "stereo", - ) - frame.sample_rate = self.sample_rate - frame.time_base = fractions.Fraction(1, self.sample_rate) - - with self.fifo_lock: - self.audio_fifo.write(frame) - + if self.readyState != "live": + return MediaStreamError("AudioTrack is not live") + + with self.fifo_manager.fifo_operation() as fifo: + fifo.write(audio) + self._pushed_duration += audio.samples / self.sample_rate except Exception as e: logger.error(f"Error in enqueue_audio: {e}", exc_info=True) def flush_audio(self): """Flush the audio FIFO buffer""" - with self.fifo_lock: - self.audio_fifo = AudioFifo() + self.fifo_manager.flush() async def recv(self) -> AudioFrame: - """Receive the next audio frame""" if self.readyState != "live": raise MediaStreamError @@ -95,49 +96,35 @@ async def recv(self) -> AudioFrame: self._start = asyncio.get_event_loop().time() self._timestamp = 0 - samples = self.frame_samples - self._timestamp += samples + self._timestamp += self.frame_samples target_time = self._start + (self._timestamp / self.sample_rate) current_time = asyncio.get_event_loop().time() + wait = target_time - current_time - if wait > 0: await asyncio.sleep(wait) try: - # Read frames from the FIFO buffer - with self.fifo_lock: - frame = self.audio_fifo.read(samples) + with self.fifo_manager.fifo_operation() as fifo: + frame = fifo.read(self.frame_samples) if frame is None: - # If no data is available, generate silence - frame = AudioFrame( - format="s16", - layout="mono" if self.channels == 1 else "stereo", - samples=samples, - ) - for p in frame.planes: - p.update(np.zeros(samples, dtype=np.int16).tobytes()) - - frame.sample_rate = self.sample_rate - frame.time_base = fractions.Fraction(1, self.sample_rate) - else: - # Update frame properties - frame.sample_rate = self.sample_rate + silence_buffer = np.zeros(self.frame_samples, dtype=np.int16).tobytes() - frame.time_base = fractions.Fraction(1, self.sample_rate) + frame = rtc.convert_to_audio_frame( + silence_buffer, + self.sample_rate, + self.channels, + len(silence_buffer) // 2 + ) - # Set frame PTS frame.pts = self._timestamp + + self._total_played_time = self._timestamp / self.sample_rate return frame except Exception as e: logger.error(f"Error in recv: {e}", exc_info=True) - raise MediaStreamError("Error processing audio frame") - - def stop(self) -> None: - """Stop the track""" - if self.readyState == "live": - super().stop() + raise MediaStreamError("Error processing audio frame") \ No newline at end of file diff --git a/ai01/providers/openai/realtime/__init__.py b/ai01/providers/openai/realtime/__init__.py index 5a0bd55..34525d8 100644 --- a/ai01/providers/openai/realtime/__init__.py +++ b/ai01/providers/openai/realtime/__init__.py @@ -3,14 +3,16 @@ AudioFormat, ClientEvent, ClientEventType, + EventTypes, Modality, + RealTimeModelOptions, RealTimeModels, ServerEvent, ServerEventType, ToolChoice, Voice, ) -from .realtime_model import RealTimeModel, RealTimeModelOptions +from .realtime_model import RealTimeModel __all__ = [ "api", @@ -25,6 +27,7 @@ "ToolChoice", "ClientEventType", "ServerEventType", + "EventTypes", ] diff --git a/ai01/providers/openai/realtime/_api.py b/ai01/providers/openai/realtime/_api.py index ce02794..9166e3f 100644 --- a/ai01/providers/openai/realtime/_api.py +++ b/ai01/providers/openai/realtime/_api.py @@ -1,12 +1,94 @@ from __future__ import annotations -from typing import Literal, Union +import asyncio +from dataclasses import dataclass +from typing import AsyncIterable, Literal, Optional, Union +import av +from pydantic import BaseModel from typing_extensions import NotRequired, TypedDict SAMPLE_RATE = 24000 NUM_CHANNELS = 1 +class RealTimeModelOptions(BaseModel): + """ + RealTimeModelOptions is the configuration for the RealTimeModel. + """ + + oai_api_key: str + """ + OpenAI API Key is the API Key for the OpenAI API. + """ + + model: RealTimeModels = "gpt-4o-realtime-preview" + """ + Model is the RealTimeModel to be used, defaults to gpt-4o-realtime-preview. + """ + + instructions: str = "" + """ + Instructions is the Initial Prompt given to the Model. + """ + + modalities: list[Modality] = ["text", "audio"] + """ + Modalities is the list of things to be used by the Model. + """ + + voice: Voice = "alloy" + """ + Voice is the of audio voices which will be generated and returned by the Model. + """ + + temperature: float = 0.8 + """ + Temperature is the randomness of the Model, to select the next token. + """ + + input_audio_format: AudioFormat = "pcm16" + """ + Input Audio Format is the format of the input audio, which is given to the Model. + """ + + output_audio_format: AudioFormat = "pcm16" + """ + Output Audio Format is the format of the audio, which is returned by the Model. + """ + + max_response_output_tokens: int | Literal["inf"] = 4096 + """ + Max Response Output Tokens is the maximum number of tokens in the response, defaults to 4096. + """ + + base_url: str = "wss://api.openai.com/v1/realtime" + """ + Base URL is the URL of the RealTime API, defaults to the OpenAI RealTime API. + """ + + tool_choice: ToolChoice = "auto" + """ + Tools are different other APIs which the Model can access, defaults to auto. + """ + + server_vad_opts: ServerVad = { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 1000, + "silence_duration_ms": 500, + } + """ + Server VAD which means Voice Activity Detection is the configuration for the VAD, to detect the voice activity. + """ + + loop: Optional[asyncio.AbstractEventLoop] = None + """ + Loop is the Event Loop to be used for the RealTimeModel, defaults to the current Event Loop. + """ + + class Config: + arbitrary_types_allowed = True + class FunctionToolChoice(TypedDict): type: Literal["function"] @@ -61,7 +143,6 @@ class FunctionToolChoice(TypedDict): ResponseStatus is used to specify the status of the response. """ - class TextContent(TypedDict): type: Literal["text"] text: str @@ -84,10 +165,9 @@ class InputAudioContent(TypedDict): Content = Union[InputTextContent, TextContent, AudioContent, InputAudioContent] - class ContentPart(TypedDict): type: Literal["text", "audio"] - audio: NotRequired[str] # b64 + audio: NotRequired[str] transcript: NotRequired[str] @@ -628,3 +708,88 @@ class RateLimitsUpdated: "response.function_call_arguments.done", "rate_limits.updated", ] + +EventTypes = Literal[ + "start_session", + "input_speech_started", + "input_speech_stopped", + "input_speech_committed", + "input_speech_transcription_completed", + "input_speech_transcription_failed", + "response_created", + "response_output_added", + "response_content_added", + "response_content_done", + "response_output_done", + "response_done", + "function_calls_collected", + "function_calls_finished", + "metrics_collected", + "error", +] + + +# Realtime Mode Internal Model Types +@dataclass +class RealtimeResponse: + id: str + """id of the message""" + status: ResponseStatus + """status of the response""" + status_details: ResponseStatusDetails | None + """details of the status (only with "incomplete, cancelled and failed")""" + output: list[RealtimeOutput] + """list of outputs""" + usage: Usage | None + """usage of the response""" + done_fut: asyncio.Future[None] + """future that will be set when the response is completed""" + created_timestamp: float + """timestamp when the response was created""" + first_token_timestamp: float | None = None + """timestamp when the first token was received""" + +@dataclass +class RealtimeOutput: + response_id: str + """id of the response""" + item_id: str + """id of the item""" + output_index: int + """index of the output""" + role: Role + """role of the message""" + type: Literal["message", "function_call"] + """type of the output""" + content: list[RealtimeContent] + """list of content""" + done_fut: asyncio.Future[None] + """future that will be set when the output is completed""" + +@dataclass +class RealtimeToolCall: + name: str + """name of the function""" + arguments: str + """accumulated arguments""" + tool_call_id: str + """id of the tool call""" + +@dataclass +class RealtimeContent: + response_id: str + """id of the response""" + item_id: str + """id of the item""" + output_index: int + """index of the output""" + content_index: int + """index of the content""" + text: str + """accumulated text content""" + audio: list[av.AudioFrame] + """accumulated audio content""" + tool_calls: list[RealtimeToolCall] + """pending tool calls""" + content_type: Modality + """type of the content""" \ No newline at end of file diff --git a/ai01/providers/openai/realtime/conversation.py b/ai01/providers/openai/realtime/conversation.py index a2a1916..a0637b7 100644 --- a/ai01/providers/openai/realtime/conversation.py +++ b/ai01/providers/openai/realtime/conversation.py @@ -50,18 +50,20 @@ def logger(self): def active(self): return self._active - def add_track(self, id: str, track: MediaStreamTrack): + def add_track(self, track: MediaStreamTrack): """ Add a Track to the Conversation, which streamlines conversation into one Audio Stream. - which can be later retrieved using the `recv_frame` method and feeded to the Model. + which can be later retrieved using the `recv` method and feeded to the Model. """ if track.kind != "audio": raise _exceptions.RealtimeModelTrackInvalidError() - if self._track_fut.get(id): + track_id = track.id + + if self._track_fut.get(track_id): raise _exceptions.RealtimeModelError("Track is already started.") - async def handle_audio_frame(): + async def handle_track(): try: while self._active and track.readyState != "ended": frame = await track.recv() @@ -72,12 +74,15 @@ async def handle_audio_frame(): continue self.audio_resampler.resample(frame) + + if task := self._track_fut.get(track_id): + task.cancel() + del self._track_fut[track_id] + except Exception as e: self.logger.error(f"Error in handling audio frame: {e}") - self._started_fut = asyncio.create_task(handle_audio_frame(), name=f"Conversation-{id}") - - self._track_fut[id] = self._started_fut + self._track_fut[track_id] = asyncio.create_task(handle_track(), name=f"Conversation-{id}") def stop(self): """ @@ -85,6 +90,12 @@ def stop(self): """ self._active = False + for task in self._track_fut.values(): + if not task.done(): + task.cancel() + + self._track_fut.clear() + self.audio_resampler.clear() def recv(self): diff --git a/ai01/providers/openai/realtime/realtime_model.py b/ai01/providers/openai/realtime/realtime_model.py index 2d894c8..aed3e06 100644 --- a/ai01/providers/openai/realtime/realtime_model.py +++ b/ai01/providers/openai/realtime/realtime_model.py @@ -2,15 +2,17 @@ import base64 import json import logging +import time import uuid -from typing import Literal, Optional, Union +from typing import Dict, Literal, Optional, Union -from pydantic import BaseModel +from aiortc.mediastreams import MediaStreamTrack -from ai01.agent import Agent, AgentsEvents +from ai01.agent import Agent +from ai01.rtc.utils import convert_to_audio_frame +from ai01.utils.emitter import EnhancedEventEmitter from ai01.utils.socket import SocketClient -from ....utils.emitter import EnhancedEventEmitter from . import _api, _exceptions from .conversation import Conversation @@ -18,87 +20,8 @@ logger = logging.getLogger(__name__) -class RealTimeModelOptions(BaseModel): - """ - RealTimeModelOptions is the configuration for the RealTimeModel. - """ - - oai_api_key: str - """ - OpenAI API Key is the API Key for the OpenAI API. - """ - - model: _api.RealTimeModels = "gpt-4o-realtime-preview" - """ - Model is the RealTimeModel to be used, defaults to gpt-4o-realtime-preview. - """ - - instructions: str = "" - """ - Instructions is the Initial Prompt given to the Model. - """ - - modalities: list[_api.Modality] = ["text", "audio"] - """ - Modalities is the list of things to be used by the Model. - """ - - voice: _api.Voice = "alloy" - """ - Voice is the of audio voices which will be generated and returned by the Model. - """ - - temperature: float = 0.8 - """ - Temperature is the randomness of the Model, to select the next token. - """ - - input_audio_format: _api.AudioFormat = "pcm16" - """ - Input Audio Format is the format of the input audio, which is given to the Model. - """ - - output_audio_format: _api.AudioFormat = "pcm16" - """ - Output Audio Format is the format of the audio, which is returned by the Model. - """ - - max_response_output_tokens: int | Literal["inf"] = 4096 - """ - Max Response Output Tokens is the maximum number of tokens in the response, defaults to 4096. - """ - - base_url: str = "wss://api.openai.com/v1/realtime" - """ - Base URL is the URL of the RealTime API, defaults to the OpenAI RealTime API. - """ - - tool_choice: _api.ToolChoice = "auto" - """ - Tools are different other APIs which the Model can access, defaults to auto. - """ - - server_vad_opts: _api.ServerVad = { - "type": "server_vad", - "threshold": 0.5, - "prefix_padding_ms": 1000, - "silence_duration_ms": 500, - } - """ - Server VAD which means Voice Activity Detection is the configuration for the VAD, to detect the voice activity. - """ - - loop: Optional[asyncio.AbstractEventLoop] = None - """ - Loop is the Event Loop to be used for the RealTimeModel, defaults to the current Event Loop. - """ - - class Config: - arbitrary_types_allowed = True - - -class RealTimeModel(EnhancedEventEmitter): - def __init__(self, agent: Agent, options: RealTimeModelOptions): +class RealTimeModel(EnhancedEventEmitter[_api.EventTypes]): + def __init__(self, agent: Agent, options: _api.RealTimeModelOptions): # Agent is the instance which is interacting with the RealTimeModel. self.agent = agent @@ -123,6 +46,9 @@ def __init__(self, agent: Agent, options: RealTimeModelOptions): # Logger for RealTimeModel. self._logger = logger.getChild(f"RealTimeModel-{self._opts.model}") + # Pending Responses which the Server will keep on generating. + self._pending_responses : Dict[str, _api.RealtimeResponse] = {} + # Conversation are all the Remote Tracks who are talking to the RealTimeModel. self._conversation: Conversation = Conversation(id = str(uuid.uuid4())) @@ -134,10 +60,12 @@ def __str__(self): def __repr__(self): return f"RealTimeModel: {self._opts.model}" - - @property - def conversation(self): - return self._conversation + + def add_track(self, track: MediaStreamTrack): + """ + Add A Track, which needs to Communicate with the RealTimeModel. + """ + self._conversation.add_track(track) async def connect(self): """ @@ -152,7 +80,7 @@ async def connect(self): asyncio.create_task(self._socket_listen(), name="Socket-Listen") - await self._session_create() + await self._session_update() self._logger.info("Connected to OpenAI RealTime Model") @@ -165,12 +93,36 @@ async def connect(self): self._logger.error(f"Error connecting to RealTime API: {e}") raise _exceptions.RealtimeModelSocketError() - async def _session_create(self): + async def close(self): + """ + Close the RealTimeModel. + """ + self._conversation.stop() + + if self._main_tsk: + self._main_tsk.cancel() + + self._logger.info("Closed RealTimeModel") + + async def truncate(self, *, item_id: str, content_index: int, audio_end_ms: int): + """ + Truncate the Conversation Item, which tells the models about how much of the conversation to consider. + """ + truncate_message: _api.ClientEvent.ConversationItemTruncate = { + "item_id": item_id, + "content_index": content_index, + "audio_end_ms": audio_end_ms, + "type": "conversation.item.truncate", + } + + await self.socket.send(truncate_message) + + async def _session_update(self): """ - Session Updated is the Event Handler for the Session Update Event. + Updates the session on the OpenAI RealTime API. """ try: - self._logger.info("Send Session Updated ") + self._logger.info("Send Session Updated") if not self.socket.connected: raise _exceptions.RealtimeModelNotConnectedError() @@ -218,41 +170,51 @@ async def _send_audio_append(self, audio_byte: bytes): } await self.socket.send(payload) - - async def _socket_listen(self): - """ - Listen to the WebSocket - """ - try: - if not self.socket.connected: - raise _exceptions.RealtimeModelNotConnectedError() - - async for message in self.socket.ws: - await self._handle_message(message) - except Exception as e: - logger.error(f"Error listening to WebSocket: {e}") - - raise _exceptions.RealtimeModelSocketError() - async def _handle_message(self, message: Union[str, bytes]): data = json.loads(message) event: _api.ServerEventType = data.get("type", "unknown") - self._logger.info(f"Event: {event}") - self._logger.info(f"Data: {data}") - + # Session Events if event == "session.created": self._handle_session_created(data) - elif event == "error": - self._handle_error(data) + + # Response Events + elif event == "response.created": + self._handle_response_created(data) + elif event == "response.output_item.added": + self._handle_response_output_item_added(data) + elif event == "response.content_part.added": + self._handle_response_content_part_added(data) + elif event == "response.audio.delta": + self._handle_response_audio_delta(data) + elif event == "response.audio.done": + self._handle_response_audio_done(data) + elif event == "response.text.done": + self._handle_response_text_done(data) + elif event == "response.audio_transcript.done": + self._handle_response_audio_transcript_done(data) + elif event == "response.content_part.done": + self._handle_response_content_part_done(data) + elif event == "response.output_item.done": + self._handle_response_output_item_done(data) + elif event == "response.done": + self._handle_response_done(data) + + # Input Audio Buffer Events, Append Events elif event == "input_audio_buffer.speech_started": self._handle_input_audio_buffer_speech_started(data) elif event == "input_audio_buffer.speech_stopped": self._handle_input_audio_buffer_speech_stopped(data) elif event == "response.audio_transcript.delta": self._handle_response_audio_transcript_delta(data) + + # Conversation Updated Events + elif event == "conversation.item.created": + self._handle_conversation_item_created(data) + elif event == "conversation.item.truncated": + self._handle_conversation_item_truncated(data) # elif event == "input_audio_buffer.committed": # self._handle_input_audio_buffer_speech_committed(data) # elif ( @@ -265,33 +227,12 @@ async def _handle_message(self, message: Union[str, bytes]): # self._handle_conversation_item_input_audio_transcription_failed( # data # ) - # elif event == "conversation.item.created": - # self._handle_conversation_item_created(data) # elif event == "conversation.item.deleted": # self._handle_conversation_item_deleted(data) - # elif event == "conversation.item.truncated": - # self._handle_conversation_item_truncated(data) - # elif event == "response.created": - # self._handle_response_created(data) - # elif event == "response.output_item.added": - # self._handle_response_output_item_added(data) - # elif event == "response.content_part.added": - # self._handle_response_content_part_added(data) - elif event == "response.audio.delta": - self._handle_response_audio_delta(data) - elif event == "response.audio.done": - self._handle_response_audio_done(data) - # elif event == "response.text.done": - # self._handle_response_text_done(data) - # elif event == "response.audio_transcript.done": - # self._handle_response_audio_transcript_done(data) - # elif event == "response.content_part.done": - # self._handle_response_content_part_done(data) - # elif event == "response.output_item.done": - # self._handle_response_output_item_done(data) - # elif event == "response.done": - # self._handle_response_done(data) - + + + elif event == "error": + self._handle_error(data) else: self._logger.error(f"Unhandled Event: {event}") @@ -319,11 +260,11 @@ def _handle_conversation_item_deleted(self, data: dict): """ self._logger.info("Conversation Item Deleted", data) - def _handle_conversation_item_created(self, data: dict): + def _handle_conversation_item_created(self, data: _api.ServerEvent.ResponseCreated): """ Conversation Item Created is the Event Handler for the Conversation Item Created Event. """ - self._logger.info("Conversation Item Created", data) + self._logger.warning("IMPLEMENT!!! Conversation Item Created", data) def _handle_session_created(self, data: dict): """ @@ -337,16 +278,24 @@ def _handle_error(self, data: dict): """ self._logger.error(f"Error: {data}") - def _handle_input_audio_buffer_speech_started(self, data: dict): + def _handle_input_audio_buffer_speech_started(self, speech_started: _api.ServerEvent.InputAudioBufferSpeechStarted): """ Speech Started is the Event Handler for the Speech Started Event. """ - self._logger.info("Speech Started", data) + self._logger.info("Speech Started") + + self.agent._update_state("listening") - if self.agent.audio_track: - self.agent.audio_track.flush_audio() + if self.agent.audio_track is not None and self.agent.audio_track.readyState == "live": + audio_end_ms = int(self.agent.audio_track.audio_samples / (_api.SAMPLE_RATE * 1000)) - self.agent.emit(AgentsEvents.Listening) + return asyncio.create_task( + self.truncate( + item_id=speech_started['item_id'], + content_index=0, + audio_end_ms=audio_end_ms, + ), name="Truncate-Task" + ) def _handle_input_audio_buffer_speech_stopped(self, data: dict): """ @@ -378,40 +327,121 @@ def _handle_response_done(self, data: dict): """ self._logger.info("Response Done", data) - def _handle_response_created(self, data: dict): + def _handle_response_created(self, reponse_created: _api.ServerEvent.ResponseCreated): """ Response Created is the Event Handler for the Response Created Event. """ - self._logger.info("Response Created", data) + self._logger.info("✅ Response Created") + + response = reponse_created['response'] + + status_details = response.get("status_details") + usage = response.get("usage") + + new_response = _api.RealtimeResponse( + id = response['id'], + done_fut=asyncio.Future(), + output=[], + status=response['status'], + status_details=status_details, + usage=usage, + created_timestamp=time.time(), + ) - def _handle_response_output_item_added(self, data: dict): + self._pending_responses[response['id']] = new_response + + self.emit('response_created', new_response) + + + def _handle_response_output_item_added(self, output_item: _api.ServerEvent.ResponseOutputItemAdded): """ Response Output Item Added is the Event Handler for the Response Output Item Added Event. """ - self._logger.info("Response Output Item Added", data) + self._logger.info("✅ Response Output Item Added") + + response_id = output_item['response_id'] + response = self._pending_responses[response_id] + done_fut = asyncio.Future() + + item_data = output_item['item'] + + item_type: Literal['message', 'function_call'] = item_data['type'] # type: ignore + + item_rol: _api.Role = item_data.get('role', "assistane") + + new_output = _api.RealtimeOutput( + response_id=response_id, + item_id=item_data['id'], + output_index=output_item['output_index'], + type=item_type, + role=item_rol, + done_fut=done_fut, + content=[] + ) + + response.output.append(new_output) - def _handle_response_content_part_added(self, data: dict): + self.emit('response_output_added', new_output) + + def _handle_response_content_part_added(self, content_added: _api.ServerEvent.ResponseContentPartAdded): """ Response Content Part Added is the Event Handler for the Response Content Part Added Event. """ - self._logger.info("Response Content Part Added", data) + self._logger.info("✅ Response Content Part Added") + response_id = content_added['response_id'] + response = self._pending_responses[response_id] + + output_index = content_added['output_index'] + + output = response.output[output_index] + + content_type = content_added['part']['type'] + + new_content = _api.RealtimeContent( + response_id=response_id, + item_id=output.item_id, + output_index=output_index, + content_index=content_added['content_index'], + text="", + tool_calls=[], + content_type=content_type, + audio=[] + ) + + output.content.append(new_content) + + response.first_token_timestamp = time.time() + + self.emit('response_content_added', new_content) def _handle_response_audio_delta(self, response_audio_delta: _api.ServerEvent.ResponseAudioDelta): """ Response Audio Delta is the Event Handler for the Response Audio Delta Event. """ - self._logger.info("Response Audio Delta") + self._logger.info("✅ Response Audio Delta") - base64_audio = response_audio_delta['delta'] + response = self._pending_responses[response_audio_delta["response_id"]] + output = response.output[response_audio_delta["output_index"]] + content = output.content[response_audio_delta["content_index"]] - item_id = str(response_audio_delta.get("item_id")) + data = base64.b64decode(response_audio_delta["delta"]) - if base64_audio and self.agent.audio_track: - self.agent.emit(AgentsEvents.Speaking) + audio = convert_to_audio_frame( + data=data, + sample_rate=_api.SAMPLE_RATE, + num_channels=_api.NUM_CHANNELS, + samples_per_channel=len(data) // 2, + ) + + content.audio.append(audio) - self.agent.audio_track.enqueue_audio(id=item_id,base64_audio=base64_audio) + if track := self.agent.audio_track: + track.enqueue_audio( + content_index=content.content_index, + audio=audio, + ) - def _handle_response_audio_transcript_delta(self, data: dict): + def _handle_response_audio_transcript_delta(self, response_audio_delta: _api.ServerEvent.ResponseAudioTranscriptDelta): """ Response Audio Transcript Delta is the Event Handler for the Response Audio Transcript Delta Event. """ @@ -421,7 +451,7 @@ def _handle_response_audio_done(self, data: dict): """ Response Audio Done is the Event Handler for the Response Audio Done Event. """ - self._logger.info("Response Audio Done", data) + self._logger.info("✅ Response Audio Done", data) def _handle_response_text_done(self, data: dict): """ @@ -435,26 +465,38 @@ def _handle_response_audio_transcript_done(self, data: dict): """ self._logger.info("Response Audio Transcript Done") + async def _socket_listen(self): + """ + Listen to the WebSocket + """ + try: + if not self.socket.connected: + raise _exceptions.RealtimeModelNotConnectedError() + + async for message in self.socket.ws: + await self._handle_message(message) + except Exception as e: + logger.error(f"Error listening to WebSocket: {e}") + + raise _exceptions.RealtimeModelSocketError() + async def _main(self): + """ + Runs the Main Loop for the RealTimeModel. + """ if not self.socket.connected: raise _exceptions.RealtimeModelNotConnectedError() try: async def handle_audio_chunk(): - while True: - if not self.conversation.active: - await asyncio.sleep(0.01) - continue + while self._conversation.active: - audio_chunk = self.conversation.recv() - - if audio_chunk is None: - await asyncio.sleep(0.01) + if audio_chunk := self._conversation.recv(): + await self._send_audio_append(audio_chunk) continue - await self._send_audio_append(audio_chunk) + await asyncio.sleep(0.01) self._main_tsk = asyncio.create_task(handle_audio_chunk(), name="RealTimeModel-AudioAppend") except Exception as e: self._logger.error(f"Error in Main Loop: {e}") - diff --git a/ai01/rtc/__init__.py b/ai01/rtc/__init__.py index 83658fa..0d9c37c 100644 --- a/ai01/rtc/__init__.py +++ b/ai01/rtc/__init__.py @@ -4,8 +4,9 @@ from .audio_resampler import AudioFrame, AudioResampler from .rtc import RTC, HuddleClientOptions, RTCOptions +from .utils import convert_to_audio_frame, get_frame_size -__all__ = ["RTC", "RTCOptions", "AudioResampler", "AudioFrame", "HuddleClientOptions", "Role", "RoomEvents", "RoomEventsData", "ProduceOptions"] +__all__ = ["RTC", "RTCOptions", "AudioResampler", "AudioFrame", "HuddleClientOptions", "Role", "RoomEvents", "RoomEventsData", "ProduceOptions", "convert_to_audio_frame", "get_frame_size"] # Cleanup docs of unexported modules diff --git a/ai01/rtc/utils.py b/ai01/rtc/utils.py new file mode 100644 index 0000000..51f3d45 --- /dev/null +++ b/ai01/rtc/utils.py @@ -0,0 +1,41 @@ +import fractions +import logging +from typing import Union + +import numpy as np +from av import AudioFrame + +AUDIO_PTIME = 0.020 # 20ms + +logger = logging.getLogger(__name__) + +def get_frame_size(sample_rate: int, ptime: float) -> int: + """ + Frame size in samples, which is the number of samples in a frame. + """ + return int(ptime * sample_rate) + +def convert_to_audio_frame( + data: Union[bytes, bytearray, memoryview], + sample_rate: int, + num_channels: int, + samples_per_channel: int, +) -> AudioFrame: + audio_array = np.frombuffer(data, dtype=np.int16) + + expected_length = num_channels * samples_per_channel + if len(audio_array) != expected_length: + raise ValueError(f"Data length mismatch: got {len(audio_array)}, expected {expected_length}") + + if samples_per_channel != get_frame_size(sample_rate, AUDIO_PTIME): + logger.warning("Unexpected frame duration") + + audio_array = audio_array.reshape(num_channels, samples_per_channel) + frame = AudioFrame.from_ndarray( + audio_array, + format="s16", + layout="mono" if num_channels == 1 else "stereo" + ) + frame.sample_rate = sample_rate + frame.time_base = fractions.Fraction(1, sample_rate) + return frame \ No newline at end of file diff --git a/ai01/utils/emitter.py b/ai01/utils/emitter.py index 1d1f1e9..d90fe6c 100644 --- a/ai01/utils/emitter.py +++ b/ai01/utils/emitter.py @@ -1,18 +1,194 @@ -from pyee import AsyncIOEventEmitter - - -class EnhancedEventEmitter(AsyncIOEventEmitter): - def __init__(self, loop=None): - super(EnhancedEventEmitter, self).__init__(loop=loop) - - async def emit_for_results(self, event, *args, **kwargs): - results = [] - for f in list(self._events[event].values()): - try: - result = await f(*args, **kwargs) - except Exception as exc: - self.emit("error", exc) - else: - if result: - results.append(result) - return results +import inspect +import logging +from typing import Callable, Dict, Generic, Optional, Set, TypeVar + +T_contra = TypeVar("T_contra", contravariant=True) + +logger = logging.getLogger("ai01") + +class EnhancedEventEmitter(Generic[T_contra]): + def __init__(self) -> None: + """ + Initialize a new instance of EventEmitter. + """ + self._events: Dict[T_contra, Set[Callable]] = dict() + + def emit(self, event: T_contra, *args) -> None: + """ + Trigger all callbacks associated with the given event. + + Args: + event (T): The event to emit. + *args: Positional arguments to pass to the callbacks. + + Example: + Basic usage of emit: + + ```python + emitter = EventEmitter[str]() + + def greet(name): + print(f"Hello, {name}!") + + emitter.on('greet', greet) + emitter.emit('greet', 'Alice') # Output: Hello, Alice! + ``` + """ + if event in self._events: + callables = self._events[event].copy() + for callback in callables: + try: + sig = inspect.signature(callback) + params = sig.parameters.values() + + has_varargs = any(p.kind == p.VAR_POSITIONAL for p in params) + if has_varargs: + callback(*args) + else: + positional_params = [ + p + for p in params + if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) + ] + num_params = len(positional_params) + num_args = min(len(args), num_params) + callback_args = args[:num_args] + + callback(*callback_args) + except TypeError: + raise + except Exception: + logger.exception(f"failed to emit event {event}") + + def once(self, event: T_contra, callback: Optional[Callable] = None) -> Callable: + """ + Register a callback to be called only once when the event is emitted. + + If a callback is provided, it registers the callback directly. + If no callback is provided, it returns a decorator for use with function definitions. + + Args: + event (T): The event to listen for. + callback (Callable, optional): The callback to register. Defaults to None. + + Returns: + Callable: The registered callback or a decorator if callback is None. + + Example: + Using once with a direct callback: + + ```python + emitter = EventEmitter[str]() + + def greet_once(name): + print(f"Hello once, {name}!") + + emitter.once('greet', greet_once) + emitter.emit('greet', 'Bob') # Output: Hello once, Bob! + emitter.emit('greet', 'Bob') # No output, callback was removed after first call + ``` + + Using once as a decorator: + + ```python + emitter = EventEmitter[str]() + + @emitter.once('greet') + def greet_once(name): + print(f"Hello once, {name}!") + + emitter.emit('greet', 'Bob') # Output: Hello once, Bob! + emitter.emit('greet', 'Bob') # No output + ``` + """ + if callback is not None: + + def once_callback(*args, **kwargs): + self.off(event, once_callback) + callback(*args, **kwargs) + + return self.on(event, once_callback) + else: + + def decorator(callback: Callable) -> Callable: + self.once(event, callback) + return callback + + return decorator + + def on(self, event: T_contra, callback: Optional[Callable] = None) -> Callable: + """ + Register a callback to be called whenever the event is emitted. + + If a callback is provided, it registers the callback directly. + If no callback is provided, it returns a decorator for use with function definitions. + + Args: + event (T): The event to listen for. + callback (Callable, optional): The callback to register. Defaults to None. + + Returns: + Callable: The registered callback or a decorator if callback is None. + + Example: + Using on with a direct callback: + + ```python + emitter = EventEmitter[str]() + + def greet(name): + print(f"Hello, {name}!") + + emitter.on('greet', greet) + emitter.emit('greet', 'Charlie') # Output: Hello, Charlie! + ``` + + Using on as a decorator: + + ```python + emitter = EventEmitter[str]() + + @emitter.on('greet') + def greet(name): + print(f"Hello, {name}!") + + emitter.emit('greet', 'Charlie') # Output: Hello, Charlie! + ``` + """ + if callback is not None: + if event not in self._events: + self._events[event] = set() + self._events[event].add(callback) + return callback + else: + + def decorator(callback: Callable) -> Callable: + self.on(event, callback) + return callback + + return decorator + + def off(self, event: T_contra, callback: Callable) -> None: + """ + Unregister a callback from an event. + + Args: + event (T): The event to stop listening to. + callback (Callable): The callback to remove. + + Example: + Removing a callback: + + ```python + emitter = EventEmitter[str]() + + def greet(name): + print(f"Hello, {name}!") + + emitter.on('greet', greet) + emitter.off('greet', greet) + emitter.emit('greet', 'Dave') # No output, callback was removed + ``` + """ + if event in self._events: + self._events[event].remove(callback) diff --git a/ai01/utils/socket.py b/ai01/utils/socket.py index 7ecfdee..b47c09a 100644 --- a/ai01/utils/socket.py +++ b/ai01/utils/socket.py @@ -77,6 +77,8 @@ async def send(self, message: Any): await self.__ws.send(dump_data) + return + except Exception as e: self._logger.error(f"Error sending message: {e}") raise diff --git a/example/chatbot/main.py b/example/chatbot/main.py index d8cac94..a7b0de1 100644 --- a/example/chatbot/main.py +++ b/example/chatbot/main.py @@ -4,7 +4,7 @@ from dotenv import load_dotenv -from ai01.agent import Agent, AgentOptions, AgentsEvents +from ai01.agent import Agent, AgentOptions from ai01.providers.openai import AudioTrack from ai01.providers.openai.realtime import RealTimeModel, RealTimeModelOptions from ai01.rtc import ( @@ -97,14 +97,9 @@ def on_room_joined(): def on_remote_consumer_added(data: RoomEventsData.NewConsumerAdded): logger.info(f"Remote Consumer Added: {data}") - if data['kind'] == 'audio': - track = data['consumer'].track - - if track is None: - logger.error("Consumer Track is None, This should never happen.") - return - - llm.conversation.add_track(data['consumer_id'], track) + if track := data['consumer'].track: + if track.kind == 'audio': + llm.add_track(track) # @room.on(RoomEvents.ConsumerClosed) # def on_remote_consumer_closed(data: RoomEventsData.ConsumerClosed): @@ -119,27 +114,6 @@ def on_remote_consumer_added(data: RoomEventsData.NewConsumerAdded): # logger.info(f"Remote Consumer Resumed: {data['consumer_id']}") - # # Agent Events - @agent.on(AgentsEvents.Connected) - def on_agent_connected(): - logger.info("Agent Connected") - - @agent.on(AgentsEvents.Disconnected) - def on_agent_disconnected(): - logger.info("Agent Disconnected") - - @agent.on(AgentsEvents.Speaking) - def on_agent_speaking(): - logger.info("Agent Speaking") - - @agent.on(AgentsEvents.Listening) - def on_agent_listening(): - logger.info("Agent Listening") - - @agent.on(AgentsEvents.Thinking) - def on_agent_thinking(): - logger.info("Agent Thinking") - # Connect to the LLM to the Room await llm.connect() From 6d019af2ecc0ff70d87645eec8bf7d8e43ef5745 Mon Sep 17 00:00:00 2001 From: Om Gupta Date: Wed, 25 Dec 2024 16:02:48 +0530 Subject: [PATCH 3/3] fix: logging --- ai01/agent/agent.py | 12 +++---- ai01/providers/openai/audio_track.py | 4 +-- .../providers/openai/realtime/conversation.py | 8 +++-- .../openai/realtime/realtime_model.py | 31 ++++++++++--------- ai01/rtc/rtc.py | 7 +---- ai01/rtc/utils.py | 4 +-- ai01/utils/__init__.py | 14 +++++++++ ai01/utils/emitter.py | 4 +-- ai01/utils/log.py | 4 +++ ai01/utils/socket.py | 12 ++----- example/chatbot/main.py | 9 ++---- 11 files changed, 57 insertions(+), 52 deletions(-) create mode 100644 ai01/utils/__init__.py create mode 100644 ai01/utils/log.py diff --git a/ai01/agent/agent.py b/ai01/agent/agent.py index e7c7afb..dc7c3f5 100644 --- a/ai01/agent/agent.py +++ b/ai01/agent/agent.py @@ -1,9 +1,9 @@ -import logging from dataclasses import dataclass from typing import Optional -from ai01 import RTC, RTCOptions from ai01.providers.openai.audio_track import AudioTrack +from ai01.rtc import RTC, RTCOptions +from ai01.utils import logger from ai01.utils.emitter import EnhancedEventEmitter from . import _api @@ -33,8 +33,6 @@ class AgentOptions: class Config: arbitrary_types_allowed = True - -logger = logging.getLogger("Agent") class Agent(EnhancedEventEmitter[_api.AgentEventTypes]): """ Agents is defined as the higher level user which is its own entity and has exposed APIs to @@ -115,8 +113,6 @@ def on_room_join(): print("Room successfully joined!") ``` """ - self.logger.info("Joining Agent to the dRTC Network") - room = await self.__rtc.join() if not room: @@ -128,8 +124,6 @@ async def connect(self): """ Connects the Agent to the Room, This is only available after the Agent is joined to the dRTC Network. """ - self.logger.info("Connecting Agent to the Room") - room = self.__rtc.room if not room: @@ -137,4 +131,6 @@ async def connect(self): await room.connect() + self.logger.info("🔔 Agent Connected to the Huddle01 Room") + self.emit('connected') diff --git a/ai01/providers/openai/audio_track.py b/ai01/providers/openai/audio_track.py index f45e9b6..a04802c 100644 --- a/ai01/providers/openai/audio_track.py +++ b/ai01/providers/openai/audio_track.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import logging import threading from contextlib import contextmanager from dataclasses import dataclass @@ -12,8 +11,7 @@ from av.audio.fifo import AudioFifo from ai01 import rtc - -logger = logging.getLogger(__name__) +from ai01.utils import logger # Constants AUDIO_PTIME = 0.020 # 20ms diff --git a/ai01/providers/openai/realtime/conversation.py b/ai01/providers/openai/realtime/conversation.py index a0637b7..7bd2643 100644 --- a/ai01/providers/openai/realtime/conversation.py +++ b/ai01/providers/openai/realtime/conversation.py @@ -1,13 +1,13 @@ import asyncio -import logging from typing import Dict from aiortc.mediastreams import MediaStreamTrack +from ai01.utils import logger + from ....rtc.audio_resampler import AudioResampler from . import _exceptions -logger = logging.getLogger(__name__) class Conversation: def __init__(self, id: str): @@ -67,6 +67,10 @@ async def handle_track(): try: while self._active and track.readyState != "ended": frame = await track.recv() + + frame_data = frame.to_ndarray() + + logger.info(f"Received Frame: {frame_data}") frame.pts = None diff --git a/ai01/providers/openai/realtime/realtime_model.py b/ai01/providers/openai/realtime/realtime_model.py index aed3e06..eaa0d9b 100644 --- a/ai01/providers/openai/realtime/realtime_model.py +++ b/ai01/providers/openai/realtime/realtime_model.py @@ -1,7 +1,6 @@ import asyncio import base64 import json -import logging import time import uuid from typing import Dict, Literal, Optional, Union @@ -10,15 +9,13 @@ from ai01.agent import Agent from ai01.rtc.utils import convert_to_audio_frame +from ai01.utils import logger from ai01.utils.emitter import EnhancedEventEmitter from ai01.utils.socket import SocketClient from . import _api, _exceptions from .conversation import Conversation -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - class RealTimeModel(EnhancedEventEmitter[_api.EventTypes]): def __init__(self, agent: Agent, options: _api.RealTimeModelOptions): @@ -44,7 +41,7 @@ def __init__(self, agent: Agent, options: _api.RealTimeModelOptions): self.turn_detection = options.server_vad_opts # Logger for RealTimeModel. - self._logger = logger.getChild(f"RealTimeModel-{self._opts.model}") + self._logger = logger.getChild("OpenAI RealtimeModel") # Pending Responses which the Server will keep on generating. self._pending_responses : Dict[str, _api.RealtimeResponse] = {} @@ -72,17 +69,13 @@ async def connect(self): Connects the RealTimeModel to the RealTime API. """ try: - self._logger.info( - f"Connecting to OpenAI RealTime Model at {self._opts.base_url}" - ) - await self.socket.connect() asyncio.create_task(self._socket_listen(), name="Socket-Listen") - await self._session_update() + self._logger.info("✅ Connected to OpenAI RealTime Model") - self._logger.info("Connected to OpenAI RealTime Model") + await self._session_update() self._main_tsk = asyncio.create_task(self._main(), name="RealTimeModel-Loop") @@ -122,8 +115,6 @@ async def _session_update(self): Updates the session on the OpenAI RealTime API. """ try: - self._logger.info("Send Session Updated") - if not self.socket.connected: raise _exceptions.RealtimeModelNotConnectedError() @@ -150,6 +141,8 @@ async def _session_update(self): await self.socket.send(payload) + self._logger.info("✅ Sent Initial Session Updated") + except Exception as e: self._logger.error(f"Error Sending Session Update Event: {e}") raise @@ -169,6 +162,8 @@ async def _send_audio_append(self, audio_byte: bytes): "audio": pcm_base64, } + self._logger.info("🔊 Sending Audio Append") + await self.socket.send(payload) async def _handle_message(self, message: Union[str, bytes]): @@ -179,6 +174,8 @@ async def _handle_message(self, message: Union[str, bytes]): # Session Events if event == "session.created": self._handle_session_created(data) + if event == 'session.updated': + self.__handle_session_updated(data) # Response Events elif event == "response.created": @@ -270,7 +267,13 @@ def _handle_session_created(self, data: dict): """ Session Created is the Event Handler for the Session Created Event. """ - self._logger.info("Session Created", data) + self._logger.debug("🔔 Session Created", data) + + def __handle_session_updated(self, session_updated: _api.ServerEvent.SessionUpdated): + """ + Session Updated is the Event Handler for the Session Updated Event. + """ + self._logger.debug("🔔 Session Updated", session_updated) def _handle_error(self, data: dict): """ diff --git a/ai01/rtc/rtc.py b/ai01/rtc/rtc.py index 52f3645..bd380f3 100644 --- a/ai01/rtc/rtc.py +++ b/ai01/rtc/rtc.py @@ -1,5 +1,4 @@ import json -import logging from huddle01 import ( AccessToken, @@ -12,9 +11,7 @@ from huddle01.local_peer import ProduceOptions from pydantic import BaseModel -logging.basicConfig(level=logging.INFO) - -logger = logging.getLogger("RTC") +from ai01.utils import logger class RTCOptions(BaseModel): @@ -137,8 +134,6 @@ def on_room_join(): - `metadata`: Optional metadata for the Room (must be JSON serializable). - `role`: The role of the local user in the Room (e.g., "host", "guest"). """ - self._logger.info("Join Huddle01 dRTC Network") - accessTokenData = AccessTokenData( room_id=self._options.room_id, api_key=self._options.api_key, diff --git a/ai01/rtc/utils.py b/ai01/rtc/utils.py index 51f3d45..a45ce62 100644 --- a/ai01/rtc/utils.py +++ b/ai01/rtc/utils.py @@ -1,13 +1,13 @@ import fractions -import logging from typing import Union import numpy as np from av import AudioFrame +from ai01.utils import logger + AUDIO_PTIME = 0.020 # 20ms -logger = logging.getLogger(__name__) def get_frame_size(sample_rate: int, ptime: float) -> int: """ diff --git a/ai01/utils/__init__.py b/ai01/utils/__init__.py new file mode 100644 index 0000000..227c9d6 --- /dev/null +++ b/ai01/utils/__init__.py @@ -0,0 +1,14 @@ +from .log import logger + +__all__ = [ + "logger" +] + +# Cleanup docs of unexported modules +_module = dir() +NOT_IN_ALL = [m for m in _module if m not in __all__] + +__pdoc__ = {} + +for n in NOT_IN_ALL: + __pdoc__[n] = False diff --git a/ai01/utils/emitter.py b/ai01/utils/emitter.py index d90fe6c..3d637d8 100644 --- a/ai01/utils/emitter.py +++ b/ai01/utils/emitter.py @@ -1,10 +1,10 @@ import inspect -import logging from typing import Callable, Dict, Generic, Optional, Set, TypeVar +from ai01.utils import logger + T_contra = TypeVar("T_contra", contravariant=True) -logger = logging.getLogger("ai01") class EnhancedEventEmitter(Generic[T_contra]): def __init__(self) -> None: diff --git a/ai01/utils/log.py b/ai01/utils/log.py new file mode 100644 index 0000000..514b9e6 --- /dev/null +++ b/ai01/utils/log.py @@ -0,0 +1,4 @@ +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("ai01") \ No newline at end of file diff --git a/ai01/utils/socket.py b/ai01/utils/socket.py index b47c09a..5aa1498 100644 --- a/ai01/utils/socket.py +++ b/ai01/utils/socket.py @@ -1,13 +1,11 @@ import asyncio import json -import logging from typing import Any, Dict, Optional import websockets -logging.basicConfig(level=logging.INFO) +from ai01.utils import logger -logger = logging.getLogger(__name__) class SocketClient: """ @@ -55,12 +53,8 @@ async def connect(self): """ Connect to the WebSocket server. """ - try: - self._logger.info(f"Attempting to connect to WebSocket at {self.url}") - - self.__ws = await websockets.connect(self.url, extra_headers=self.headers) - - self._logger.info("WebSocket connection established") + try: + self.__ws = await websockets.connect(self.url, extra_headers=self.headers) except Exception as e: self._logger.error(f"Error connecting to WebSocket: {e}") raise diff --git a/example/chatbot/main.py b/example/chatbot/main.py index a7b0de1..1d7482f 100644 --- a/example/chatbot/main.py +++ b/example/chatbot/main.py @@ -1,5 +1,4 @@ import asyncio -import logging import os from dotenv import load_dotenv @@ -15,14 +14,13 @@ RoomEventsData, RTCOptions, ) +from ai01.utils import logger from .prompt import bot_prompt load_dotenv() -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("Chatbot") async def main(): @@ -71,7 +69,7 @@ async def main(): # Room Events @room.on(RoomEvents.RoomJoined) def on_room_joined(): - logger.info("Room Joined") + logger.info("Chatbot Joined the Huddle01 Room") # @room.on(RoomEvents.NewPeerJoined) # def on_new_remote_peer(data: RoomEventsData.NewPeerJoined): @@ -95,10 +93,9 @@ def on_room_joined(): @room.on(RoomEvents.NewConsumerAdded) def on_remote_consumer_added(data: RoomEventsData.NewConsumerAdded): - logger.info(f"Remote Consumer Added: {data}") - if track := data['consumer'].track: if track.kind == 'audio': + logger.info(f"✅ New Audio Consumer Added: {data['consumer_id']}") llm.add_track(track) # @room.on(RoomEvents.ConsumerClosed)