Skip to content

Commit 043c794

Browse files
committed
feat: add model parameters
1 parent 1e4d353 commit 043c794

File tree

3 files changed

+94
-4
lines changed

3 files changed

+94
-4
lines changed

samples/GenerationCallWithThinking.java

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public static void streamCallWithThinking()
3333
.thinkingBudget(10000)
3434
.logprobs(true)
3535
.topLogprobs(1)
36+
.n(2)
3637
.messages(Arrays.asList(systemMsg, userMsg))
3738
.resultFormat(GenerationParam.ResultFormat.MESSAGE)
3839
.incrementalOutput(true)

samples/GenerationToolChoice.java

+84-2
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,92 @@ public static void forceCallFunctionAdd()
173173
System.out.println(JsonUtils.toJson(result));
174174
}
175175

176+
public static void parallelToolCalls()
177+
throws NoApiKeyException, ApiException, InputRequiredException {
178+
// create jsonschema generator
179+
SchemaGeneratorConfigBuilder configBuilder =
180+
new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12, OptionPreset.PLAIN_JSON);
181+
SchemaGeneratorConfig config = configBuilder.with(Option.EXTRA_OPEN_API_FORMAT_VALUES)
182+
.without(Option.FLATTENED_ENUMS_FROM_TOSTRING).build();
183+
SchemaGenerator generator = new SchemaGenerator(config);
184+
185+
// generate jsonSchema of function.
186+
ObjectNode jsonSchema = generator.generateSchema(AddFunctionTool.class);
187+
188+
// call with tools of function call, jsonSchema.toString() is jsonschema String.
189+
FunctionDefinition fd = FunctionDefinition.builder().name("add").description("add two number")
190+
.parameters(JsonUtils.parseString(jsonSchema.toString()).getAsJsonObject()).build();
191+
192+
// build system message
193+
Message systemMsg = Message.builder().role(Role.SYSTEM.getValue())
194+
.content("You are a helpful assistant. When asked a question, use tools wherever possible.")
195+
.build();
196+
197+
// user message to call function.
198+
Message userMsg =
199+
Message.builder().role(Role.USER.getValue()).content("Add 1234 and 4321, Add 2345 and 5432").build();
200+
201+
// messages to store message request and response.
202+
List<Message> messages = new ArrayList<>();
203+
messages.addAll(Arrays.asList(systemMsg, userMsg));
204+
205+
ToolFunction toolFunction =
206+
ToolFunction.builder().function(FunctionDefinition.builder().name("add").build()).build();
207+
// create generation call parameter
208+
GenerationParam param = GenerationParam.builder().model(Generation.Models.QWEN_MAX)
209+
.messages(messages).resultFormat(ResultFormat.MESSAGE).toolChoice(toolFunction)
210+
.tools(Arrays.asList(ToolFunction.builder().function(fd).build()))
211+
.parallelToolCalls(true)
212+
.build();
213+
214+
// call the Generation
215+
Generation gen = new Generation();
216+
GenerationResult result = gen.call(param);
217+
// print the result.
218+
System.out.println(JsonUtils.toJson(result));
219+
220+
// process the response
221+
for (Choice choice : result.getOutput().getChoices()) {
222+
// add the assistant message to list for next Generation call.
223+
messages.add(choice.getMessage());
224+
// check if we need call tool.
225+
if (result.getOutput().getChoices().get(0).getMessage().getToolCalls() != null) {
226+
// iterator the tool calls
227+
for (ToolCallBase toolCall : result.getOutput().getChoices().get(0).getMessage()
228+
.getToolCalls()) {
229+
// get function call.
230+
if (toolCall.getType().equals("function")) {
231+
// get function call name and argument, both String.
232+
String functionName = ((ToolCallFunction) toolCall).getFunction().getName();
233+
String functionArgument = ((ToolCallFunction) toolCall).getFunction().getArguments();
234+
if (functionName.equals("add")) {
235+
// Create the function object.
236+
AddFunctionTool addFunction =
237+
JsonUtils.fromJson(functionArgument, AddFunctionTool.class);
238+
// call function.
239+
int sum = addFunction.call();
240+
// create the tool message
241+
Message toolResultMessage = Message.builder().role("tool")
242+
.content(String.valueOf(sum)).toolCallId(toolCall.getId()).build();
243+
// add the tool message to messages list.
244+
messages.add(toolResultMessage);
245+
System.out.println(sum);
246+
}
247+
}
248+
}
249+
}
250+
}
251+
// new Generation call with messages include tool result.
252+
param.setMessages(messages);
253+
result = gen.call(param);
254+
System.out.println(JsonUtils.toJson(result));
255+
}
256+
176257
public static void main(String[] args) {
177258
try {
178-
disableToolCall();
179-
forceCallFunctionAdd();
259+
// disableToolCall();
260+
// forceCallFunctionAdd();
261+
parallelToolCalls();
180262
} catch (ApiException | NoApiKeyException | InputRequiredException e) {
181263
System.out.println(String.format("Exception %s", e.getMessage()));
182264
}

src/main/java/com/alibaba/dashscope/aigc/generation/GenerationParam.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,13 @@ public static class ResultFormat {
129129
private Integer thinkingBudget;
130130

131131
/** 返回每个输出token的对数概率 */
132-
Boolean logprobs;
132+
private Boolean logprobs;
133133

134134
/** 指定在每个token位置返回的最可能token的数量 */
135-
Integer topLogprobs;
135+
private Integer topLogprobs;
136+
137+
/** 生成响应的个数 */
138+
private Integer n;
136139

137140
@Override
138141
public JsonObject getInput() {
@@ -233,6 +236,10 @@ public Map<String, Object> getParameters() {
233236
params.put("top_logprobs", topLogprobs);
234237
}
235238

239+
if (n != null) {
240+
params.put("n", n);
241+
}
242+
236243
params.putAll(parameters);
237244
return params;
238245
}

0 commit comments

Comments
 (0)