Skip to content

Commit 22e3bda

Browse files
committed
semi working tool calling
1 parent bd18c3d commit 22e3bda

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

predictionguard/src/chat.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,34 @@ def __init__(self, api_key, url):
6969
def create(
7070
self,
7171
model: str,
72-
messages: Union[str, List[Dict[str, Any]]],
72+
messages: Union[
73+
str, List[
74+
Dict[str, Any]
75+
]
76+
],
7377
input: Optional[Dict[str, Any]] = None,
7478
output: Optional[Dict[str, Any]] = None,
7579
frequency_penalty: Optional[float] = None,
76-
logit_bias: Optional[Dict[str, int]] = None,
80+
logit_bias: Optional[
81+
Dict[str, int]
82+
] = None,
7783
max_completion_tokens: Optional[int] = 100,
7884
max_tokens: Optional[int] = None,
85+
parallel_tool_calls: Optional[bool] = None,
7986
presence_penalty: Optional[float] = None,
80-
stop: Optional[Union[str, List[str]]] = None,
87+
stop: Optional[
88+
Union[
89+
str, List[str]
90+
]
91+
] = None,
8192
stream: Optional[bool] = False,
8293
temperature: Optional[float] = 1.0,
94+
tool_choice: Optional[Union[
95+
str, Dict[
96+
str, Dict[str, str]
97+
]
98+
]] = None,
99+
tools: Optional[List[Dict[str, Union[str, Dict[str, str]]]]] = None,
83100
top_p: Optional[float] = 0.99,
84101
top_k: Optional[float] = 50,
85102
) -> Dict[str, Any]:
@@ -93,10 +110,13 @@ def create(
93110
:param frequency_penalty: The frequency penalty to use.
94111
:param logit_bias: The logit bias to use.
95112
:param max_completion_tokens: The maximum amount of tokens the model should return.
113+
:param parallel_tool_calls: The parallel tool calls to use.
96114
:param presence_penalty: The presence penalty to use.
97115
:param stop: The completion stopping criteria.
98116
:param stream: Option to stream the API response
99117
:param temperature: The consistency of the model responses to the same prompt. The higher the more consistent.
118+
:param tool_choice: The tool choice to use.
119+
:param tools: Options to pass to the tool choice.
100120
:param top_p: The sampling for the model to use.
101121
:param top_k: The Top-K sampling for the model to use.
102122
:return: A dictionary containing the chat response.
@@ -121,10 +141,13 @@ def create(
121141
frequency_penalty,
122142
logit_bias,
123143
max_completion_tokens,
124-
temperature,
144+
parallel_tool_calls,
125145
presence_penalty,
126146
stop,
127147
stream,
148+
temperature,
149+
tool_choice,
150+
tools,
128151
top_p,
129152
top_k
130153
)
@@ -143,10 +166,13 @@ def _generate_chat(
143166
frequency_penalty,
144167
logit_bias,
145168
max_completion_tokens,
169+
parallel_tool_calls,
146170
presence_penalty,
147171
stop,
148172
stream,
149173
temperature,
174+
tool_choice,
175+
tools,
150176
top_p,
151177
top_k,
152178
):
@@ -276,10 +302,13 @@ def stream_generator(url, headers, payload, stream):
276302
"frequency_penalty": frequency_penalty,
277303
"logit_bias": logit_bias,
278304
"max_completion_tokens": max_completion_tokens,
305+
"parallel_tool_calls": parallel_tool_calls,
279306
"presence_penalty": presence_penalty,
280307
"stop": stop,
281308
"stream": stream,
282309
"temperature": temperature,
310+
"tool_choice": tool_choice,
311+
"tools": tools,
283312
"top_p": top_p,
284313
"top_k": top_k,
285314
}

tests/test_chat.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,21 @@ def test_chat_completions_create_vision_stream_fail():
187187
response_list.append(res)
188188

189189

190+
def test_chat_completions_create_tool_call():
191+
test_client = PredictionGuard()
192+
193+
response = test_client.chat.completions.create(
194+
model=os.environ["TEST_MODEL_NAME"],
195+
messages=[
196+
{"role": "system", "content": "You are a helpful chatbot."},
197+
{"role": "user", "content": "Tell me a joke."},
198+
],
199+
200+
)
201+
202+
assert len(response["choices"][0]["message"]["content"]) > 0
203+
204+
190205
def test_chat_completions_list_models():
191206
test_client = PredictionGuard()
192207

0 commit comments

Comments
 (0)