|
11 | 11 | from agora_realtime_ai_api.rtc import Channel, ChatMessage, RtcEngine, RtcOptions
|
12 | 12 |
|
13 | 13 | from .logger import setup_logger
|
14 |
| -from .realtime.struct import InputAudioBufferCommitted, InputAudioBufferSpeechStarted, InputAudioBufferSpeechStopped, InputAudioTranscription, ItemCreated, ItemInputAudioTranscriptionCompleted, RateLimitsUpdated, ResponseAudioDelta, ResponseAudioDone, ResponseAudioTranscriptDelta, ResponseAudioTranscriptDone, ResponseContentPartAdded, ResponseContentPartDone, ResponseCreated, ResponseDone, ResponseOutputItemAdded, ResponseOutputItemDone, ServerVADUpdateParams, SessionUpdate, SessionUpdateParams, SessionUpdated, Voices, to_json |
| 14 | +from .realtime.struct import FunctionCallOutputItemParam, InputAudioBufferCommitted, InputAudioBufferSpeechStarted, InputAudioBufferSpeechStopped, InputAudioTranscription, ItemCreate, ItemCreated, ItemInputAudioTranscriptionCompleted, RateLimitsUpdated, ResponseAudioDelta, ResponseAudioDone, ResponseAudioTranscriptDelta, ResponseAudioTranscriptDone, ResponseContentPartAdded, ResponseContentPartDone, ResponseCreate, ResponseCreated, ResponseDone, ResponseFunctionCallArgumentsDelta, ResponseFunctionCallArgumentsDone, ResponseOutputItemAdded, ResponseOutputItemDone, ServerVADUpdateParams, SessionUpdate, SessionUpdateParams, SessionUpdated, Voices, to_json |
15 | 15 | from .realtime.connection import RealtimeApiConnection
|
16 | 16 | from .tools import ClientToolCallResponse, ToolContext
|
17 | 17 | from .utils import PCMWriter
|
@@ -240,6 +240,21 @@ async def model_to_rtc(self) -> None:
|
240 | 240 | await pcm_writer.flush()
|
241 | 241 | raise # Re-raise the cancelled exception to properly exit the task
|
242 | 242 |
|
| 243 | + async def handle_funtion_call(self, message: ResponseFunctionCallArgumentsDone) -> None: |
| 244 | + function_call_response = await self.tools.execute_tool(message.name, message.arguments) |
| 245 | + logger.info(f"Function call response: {function_call_response}") |
| 246 | + await self.connection.send_request( |
| 247 | + ItemCreate( |
| 248 | + item = FunctionCallOutputItemParam( |
| 249 | + call_id=message.call_id, |
| 250 | + output=function_call_response.json_encoded_output |
| 251 | + ) |
| 252 | + ) |
| 253 | + ) |
| 254 | + await self.connection.send_request( |
| 255 | + ResponseCreate() |
| 256 | + ) |
| 257 | + |
243 | 258 | async def _process_model_messages(self) -> None:
|
244 | 259 | async for message in self.connection.listen():
|
245 | 260 | # logger.info(f"Received message {message=}")
|
@@ -312,5 +327,12 @@ async def _process_model_messages(self) -> None:
|
312 | 327 | pass
|
313 | 328 | case RateLimitsUpdated():
|
314 | 329 | pass
|
| 330 | + case ResponseFunctionCallArgumentsDone(): |
| 331 | + asyncio.create_task( |
| 332 | + self.handle_funtion_call(message) |
| 333 | + ) |
| 334 | + case ResponseFunctionCallArgumentsDelta(): |
| 335 | + pass |
| 336 | + |
315 | 337 | case _:
|
316 | 338 | logger.warning(f"Unhandled message {message=}")
|
0 commit comments