Skip to content

Commit 97b3e56

Browse files
authored
Merge pull request #11 from nitin4real/fix/tools
fix: tools support issue
2 parents 4aee22d + 9f9f2ad commit 97b3e56

File tree

4 files changed

+76
-11
lines changed

4 files changed

+76
-11
lines changed

realtime_agent/agent.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from agora_realtime_ai_api.rtc import Channel, ChatMessage, RtcEngine, RtcOptions
1212

1313
from .logger import setup_logger
14-
from .realtime.struct import InputAudioBufferCommitted, InputAudioBufferSpeechStarted, InputAudioBufferSpeechStopped, InputAudioTranscription, ItemCreated, ItemInputAudioTranscriptionCompleted, RateLimitsUpdated, ResponseAudioDelta, ResponseAudioDone, ResponseAudioTranscriptDelta, ResponseAudioTranscriptDone, ResponseContentPartAdded, ResponseContentPartDone, ResponseCreated, ResponseDone, ResponseOutputItemAdded, ResponseOutputItemDone, ServerVADUpdateParams, SessionUpdate, SessionUpdateParams, SessionUpdated, Voices, to_json
14+
from .realtime.struct import FunctionCallOutputItemParam, InputAudioBufferCommitted, InputAudioBufferSpeechStarted, InputAudioBufferSpeechStopped, InputAudioTranscription, ItemCreate, ItemCreated, ItemInputAudioTranscriptionCompleted, RateLimitsUpdated, ResponseAudioDelta, ResponseAudioDone, ResponseAudioTranscriptDelta, ResponseAudioTranscriptDone, ResponseContentPartAdded, ResponseContentPartDone, ResponseCreate, ResponseCreated, ResponseDone, ResponseFunctionCallArgumentsDelta, ResponseFunctionCallArgumentsDone, ResponseOutputItemAdded, ResponseOutputItemDone, ServerVADUpdateParams, SessionUpdate, SessionUpdateParams, SessionUpdated, Voices, to_json
1515
from .realtime.connection import RealtimeApiConnection
1616
from .tools import ClientToolCallResponse, ToolContext
1717
from .utils import PCMWriter
@@ -240,6 +240,21 @@ async def model_to_rtc(self) -> None:
240240
await pcm_writer.flush()
241241
raise # Re-raise the cancelled exception to properly exit the task
242242

243+
async def handle_funtion_call(self, message: ResponseFunctionCallArgumentsDone) -> None:
244+
function_call_response = await self.tools.execute_tool(message.name, message.arguments)
245+
logger.info(f"Function call response: {function_call_response}")
246+
await self.connection.send_request(
247+
ItemCreate(
248+
item = FunctionCallOutputItemParam(
249+
call_id=message.call_id,
250+
output=function_call_response.json_encoded_output
251+
)
252+
)
253+
)
254+
await self.connection.send_request(
255+
ResponseCreate()
256+
)
257+
243258
async def _process_model_messages(self) -> None:
244259
async for message in self.connection.listen():
245260
# logger.info(f"Received message {message=}")
@@ -312,5 +327,12 @@ async def _process_model_messages(self) -> None:
312327
pass
313328
case RateLimitsUpdated():
314329
pass
330+
case ResponseFunctionCallArgumentsDone():
331+
asyncio.create_task(
332+
self.handle_funtion_call(message)
333+
)
334+
case ResponseFunctionCallArgumentsDelta():
335+
pass
336+
315337
case _:
316338
logger.warning(f"Unhandled message {message=}")

realtime_agent/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from dotenv import load_dotenv
1010
from pydantic import BaseModel, Field, ValidationError
1111

12+
from realtime_agent.realtime.tools_example import AgentTools
13+
1214
from .realtime.struct import PCM_CHANNELS, PCM_SAMPLE_RATE, ServerVADUpdateParams, Voices
1315

1416
from .agent import InferenceConfig, RealtimeKitAgent
@@ -82,6 +84,7 @@ def run_agent_in_process(
8284
),
8385
inference_config=inference_config,
8486
tools=None,
87+
# tools=AgentTools() # tools example, replace with this line
8588
)
8689
)
8790

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
2+
from typing import Any
3+
from realtime_agent.tools import ToolContext
4+
5+
# Function calling Example
6+
# This is an example of how to add a new function to the agent tools.
7+
8+
class AgentTools(ToolContext):
9+
def __init__(self) -> None:
10+
super().__init__()
11+
12+
# create multiple functions here as per requirement
13+
self.register_function(
14+
name="get_avg_temp",
15+
description="Returns average temperature of a country",
16+
parameters={
17+
"type": "object",
18+
"properties": {
19+
"country": {
20+
"type": "string",
21+
"description": "Name of country",
22+
},
23+
},
24+
"required": ["country"],
25+
},
26+
fn=self._get_avg_temperature_by_country_name,
27+
)
28+
29+
async def _get_avg_temperature_by_country_name(
30+
self,
31+
country: str,
32+
) -> dict[str, Any]:
33+
try:
34+
result = "24 degree C" # Dummy data (Get the Required value here, like a DB call or API call)
35+
return {
36+
"status": "success",
37+
"message": f"Average temperature of {country} is {result}",
38+
"result": result,
39+
}
40+
except Exception as e:
41+
return {
42+
"status": "error",
43+
"message": f"Failed to get : {str(e)}",
44+
}

realtime_agent/tools.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ class LocalFunctionToolDeclaration:
2424
def model_description(self) -> dict[str, Any]:
2525
return {
2626
"type": "function",
27-
"function": {
28-
"name": self.name,
29-
"description": self.description,
30-
"parameters": self.parameters,
31-
},
27+
"name": self.name,
28+
"description": self.description,
29+
"parameters": self.parameters,
3230
}
3331

3432

@@ -43,11 +41,9 @@ class PassThroughFunctionToolDeclaration:
4341
def model_description(self) -> dict[str, Any]:
4442
return {
4543
"type": "function",
46-
"function": {
47-
"name": self.name,
48-
"description": self.description,
49-
"parameters": self.parameters,
50-
},
44+
"name": self.name,
45+
"description": self.description,
46+
"parameters": self.parameters,
5147
}
5248

5349

0 commit comments

Comments
 (0)