@@ -173,10 +173,92 @@ public static void forceCallFunctionAdd()
173
173
System .out .println (JsonUtils .toJson (result ));
174
174
}
175
175
176
+ public static void parallelToolCalls ()
177
+ throws NoApiKeyException , ApiException , InputRequiredException {
178
+ // create jsonschema generator
179
+ SchemaGeneratorConfigBuilder configBuilder =
180
+ new SchemaGeneratorConfigBuilder (SchemaVersion .DRAFT_2020_12 , OptionPreset .PLAIN_JSON );
181
+ SchemaGeneratorConfig config = configBuilder .with (Option .EXTRA_OPEN_API_FORMAT_VALUES )
182
+ .without (Option .FLATTENED_ENUMS_FROM_TOSTRING ).build ();
183
+ SchemaGenerator generator = new SchemaGenerator (config );
184
+
185
+ // generate jsonSchema of function.
186
+ ObjectNode jsonSchema = generator .generateSchema (AddFunctionTool .class );
187
+
188
+ // call with tools of function call, jsonSchema.toString() is jsonschema String.
189
+ FunctionDefinition fd = FunctionDefinition .builder ().name ("add" ).description ("add two number" )
190
+ .parameters (JsonUtils .parseString (jsonSchema .toString ()).getAsJsonObject ()).build ();
191
+
192
+ // build system message
193
+ Message systemMsg = Message .builder ().role (Role .SYSTEM .getValue ())
194
+ .content ("You are a helpful assistant. When asked a question, use tools wherever possible." )
195
+ .build ();
196
+
197
+ // user message to call function.
198
+ Message userMsg =
199
+ Message .builder ().role (Role .USER .getValue ()).content ("Add 1234 and 4321, Add 2345 and 5432" ).build ();
200
+
201
+ // messages to store message request and response.
202
+ List <Message > messages = new ArrayList <>();
203
+ messages .addAll (Arrays .asList (systemMsg , userMsg ));
204
+
205
+ ToolFunction toolFunction =
206
+ ToolFunction .builder ().function (FunctionDefinition .builder ().name ("add" ).build ()).build ();
207
+ // create generation call parameter
208
+ GenerationParam param = GenerationParam .builder ().model (Generation .Models .QWEN_MAX )
209
+ .messages (messages ).resultFormat (ResultFormat .MESSAGE ).toolChoice (toolFunction )
210
+ .tools (Arrays .asList (ToolFunction .builder ().function (fd ).build ()))
211
+ .parallelToolCalls (true )
212
+ .build ();
213
+
214
+ // call the Generation
215
+ Generation gen = new Generation ();
216
+ GenerationResult result = gen .call (param );
217
+ // print the result.
218
+ System .out .println (JsonUtils .toJson (result ));
219
+
220
+ // process the response
221
+ for (Choice choice : result .getOutput ().getChoices ()) {
222
+ // add the assistant message to list for next Generation call.
223
+ messages .add (choice .getMessage ());
224
+ // check if we need call tool.
225
+ if (result .getOutput ().getChoices ().get (0 ).getMessage ().getToolCalls () != null ) {
226
+ // iterator the tool calls
227
+ for (ToolCallBase toolCall : result .getOutput ().getChoices ().get (0 ).getMessage ()
228
+ .getToolCalls ()) {
229
+ // get function call.
230
+ if (toolCall .getType ().equals ("function" )) {
231
+ // get function call name and argument, both String.
232
+ String functionName = ((ToolCallFunction ) toolCall ).getFunction ().getName ();
233
+ String functionArgument = ((ToolCallFunction ) toolCall ).getFunction ().getArguments ();
234
+ if (functionName .equals ("add" )) {
235
+ // Create the function object.
236
+ AddFunctionTool addFunction =
237
+ JsonUtils .fromJson (functionArgument , AddFunctionTool .class );
238
+ // call function.
239
+ int sum = addFunction .call ();
240
+ // create the tool message
241
+ Message toolResultMessage = Message .builder ().role ("tool" )
242
+ .content (String .valueOf (sum )).toolCallId (toolCall .getId ()).build ();
243
+ // add the tool message to messages list.
244
+ messages .add (toolResultMessage );
245
+ System .out .println (sum );
246
+ }
247
+ }
248
+ }
249
+ }
250
+ }
251
+ // new Generation call with messages include tool result.
252
+ param .setMessages (messages );
253
+ result = gen .call (param );
254
+ System .out .println (JsonUtils .toJson (result ));
255
+ }
256
+
176
257
public static void main (String [] args ) {
177
258
try {
178
- disableToolCall ();
179
- forceCallFunctionAdd ();
259
+ // disableToolCall();
260
+ // forceCallFunctionAdd();
261
+ parallelToolCalls ();
180
262
} catch (ApiException | NoApiKeyException | InputRequiredException e ) {
181
263
System .out .println (String .format ("Exception %s" , e .getMessage ()));
182
264
}
0 commit comments