Skip to content

Commit 623f3dd

Browse files
GLM-4.5 tool calling: Complete implementation with schema-aware parsing
- Added GLM-4.5 chat format with native XML template support - Implemented schema-aware type conversion for tool arguments - Added JSON unpacking logic for single-argument cases - Enhanced server integration with tools_schema support - Added comprehensive test coverage and documentation - Fixed template polyfill conflicts and type conversion issues
1 parent 45fac2a commit 623f3dd

File tree

7 files changed

+298
-37
lines changed

7 files changed

+298
-37
lines changed

.github/chatmodes/custom_models.chatmode.md

Whitespace-only changes.

common/chat.cpp

Lines changed: 171 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,63 +1330,197 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
13301330
}
13311331

13321332
static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) {
1333+
LOG_INF("%s: initializing GLM-4.5 chat params\n", __func__);
13331334
common_chat_params data;
1334-
data.prompt = apply(tmpl, inputs);
1335+
1336+
// Configure template inputs
1337+
minja::chat_template_inputs tmpl_inputs;
1338+
tmpl_inputs.messages = inputs.messages;
1339+
tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools;
1340+
tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
1341+
tmpl_inputs.extra_context = inputs.extra_context;
1342+
tmpl_inputs.now = inputs.now; // Use the consistent timestamp from params
1343+
1344+
// Configure template options to disable polyfills and enforce native XML format
1345+
minja::chat_template_options opts;
1346+
opts.apply_polyfills = false; // Hard disable all polyfills
1347+
1348+
// The prompt is generated here
1349+
data.prompt = tmpl.apply(tmpl_inputs, opts);
13351350
data.format = COMMON_CHAT_FORMAT_GLM_4_5;
1351+
1352+
data.preserved_tokens = {
1353+
"<|system|>", "<|assistant|>", "<|observation|>",
1354+
"<tool_call>", "</tool_call>", "<arg_key>", "</arg_key>",
1355+
"<arg_value>", "</arg_value>", "<think>", "</think>",
1356+
"<tool_response>", "</tool_response>",
1357+
};
1358+
1359+
// Store tools schema for type-aware parsing later
1360+
data.tools_schema = inputs.tools;
1361+
1362+
LOG_INF("%s: GLM-4.5 native XML format enforced\n", __func__);
13361363
return data;
13371364
}
13381365

13391366
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
1340-
builder.consume_spaces();
1341-
builder.try_parse_reasoning("<think>", "</think>");
1342-
if (!builder.syntax().parse_tool_calls) {
1343-
builder.add_content(builder.consume_rest());
1344-
return;
1345-
}
1346-
1347-
// GLM 4.5 uses format: <tool_call>function_name\n<arg_key>key</arg_key>\n<arg_value>value</arg_value>\n</tool_call>
1348-
static const common_regex tool_call_start("<tool_call>([^\n<]+)");
1349-
static const common_regex arg_key_regex("<arg_key>([^<]+)</arg_key>");
1350-
static const common_regex arg_value_regex("<arg_value>([\\s\\S]*?)</arg_value>");
1351-
static const common_regex tool_call_end("</tool_call>");
1367+
1368+
auto get_expected_type = [&](const std::string& tool_name, const std::string& param_name) -> std::string {
1369+
// Access tools schema from builder syntax
1370+
const auto& tools_schema = builder.syntax().tools_schema;
1371+
if (tools_schema.is_array()) {
1372+
for (const auto& tool : tools_schema) {
1373+
if (tool.contains("function") && tool["function"]["name"] == tool_name) {
1374+
auto params = tool["function"]["parameters"];
1375+
if (params.contains("properties") && params["properties"].contains(param_name)) {
1376+
return params["properties"][param_name].value("type", "string");
1377+
}
1378+
}
1379+
}
1380+
}
1381+
return "string"; // Default fallback
1382+
};
13521383

1353-
while (auto res = builder.try_find_regex(tool_call_start)) {
1354-
// Move to the start of the tool call and consume it
1355-
builder.move_to(res->groups[0].begin);
1356-
builder.consume_regex(tool_call_start);
1384+
auto handle_tool_call_end = [&] (common_chat_msg_parser & builder, auto end_pos) {
1385+
builder.move_to(end_pos);
1386+
builder.consume_literal("</tool_call>");
13571387

1358-
std::string function_name = builder.str(res->groups[1]);
1359-
json arguments = json::object();
1388+
size_t obs_pos = builder.input().find("<|observation|>", builder.pos());
1389+
if (obs_pos != std::string::npos) {
1390+
if (obs_pos > builder.pos()) {
1391+
std::string content = builder.input().substr(builder.pos(), obs_pos - builder.pos());
1392+
builder.add_content(content);
1393+
}
1394+
1395+
builder.move_to(obs_pos);
1396+
builder.consume_literal("<|observation|>");
1397+
} else {
1398+
std::string remaining = builder.consume_rest();
1399+
if (!remaining.empty()) builder.add_content(remaining);
1400+
}
1401+
};
1402+
1403+
builder.consume_spaces();
1404+
1405+
builder.try_parse_reasoning("<think>", "</think>");
1406+
1407+
size_t curr_pos = builder.pos();
1408+
while (builder.input().find("<tool_call>", builder.pos()) != std::string::npos) {
1409+
size_t tool_call_start = builder.input().find("<tool_call>", builder.pos());
1410+
if (tool_call_start > builder.pos()) {
1411+
std::string content = builder.input().substr(builder.pos(), tool_call_start - builder.pos());
1412+
builder.add_content(content);
1413+
}
13601414

1415+
size_t tool_call_end = builder.input().find("</tool_call>", tool_call_start);
1416+
if (tool_call_end == std::string::npos) return;
1417+
1418+
builder.move_to(tool_call_start);
1419+
builder.consume_literal("<tool_call>");
13611420
builder.consume_spaces();
1362-
1363-
// Parse all arg_key/arg_value pairs
1364-
while (auto key_res = builder.try_consume_regex(arg_key_regex)) {
1365-
std::string key = builder.str(key_res->groups[1]);
1366-
builder.consume_spaces();
1421+
1422+
size_t arg_key_start = builder.input().find("<arg_key>", tool_call_start);
1423+
if (arg_key_start == std::string::npos || arg_key_start > tool_call_end) {
1424+
std::string function_content = builder.input().substr(builder.pos(), tool_call_end - builder.pos());
1425+
std::string function_name = string_strip(function_content);
1426+
1427+
if (!builder.add_tool_call(function_name, "", "{}")) {
1428+
LOG_INF("%s: failed to add tool call\n", __func__);
1429+
}
1430+
1431+
handle_tool_call_end(builder, tool_call_end);
1432+
1433+
} else {
1434+
std::string function_content = builder.input().substr(builder.pos(), arg_key_start - builder.pos());
1435+
std::string function_name = string_strip(function_content);
1436+
1437+
json args_json = json::object();
1438+
builder.move_to(arg_key_start);
13671439

1368-
if (auto value_res = builder.try_consume_regex(arg_value_regex)) {
1369-
std::string value = builder.str(value_res->groups[1]);
1370-
arguments[key] = value;
1440+
while (builder.pos() < tool_call_end && builder.input().substr(builder.pos()).find("<arg_key>") == 0) {
1441+
if (!builder.try_consume_literal("<arg_key>")) break;
1442+
1443+
auto key_close = builder.try_find_literal("</arg_key>");
1444+
if (!key_close || key_close->groups[0].end > tool_call_end) {
1445+
throw common_chat_msg_partial_exception("incomplete tool call");
1446+
return;
1447+
}
1448+
1449+
std::string key = string_strip(key_close->prelude);
1450+
1451+
builder.consume_spaces();
1452+
1453+
if (!builder.try_consume_literal("<arg_value>")) {
1454+
throw common_chat_msg_partial_exception("incomplete tool call");
1455+
return;
1456+
}
1457+
1458+
auto value_close = builder.try_find_literal("</arg_value>");
1459+
if (!value_close || value_close->groups[0].end > tool_call_end) {
1460+
throw common_chat_msg_partial_exception("incomplete tool call");
1461+
return;
1462+
}
1463+
1464+
std::string value = string_strip(value_close->prelude);
1465+
1466+
// Schema-aware type conversion
1467+
std::string expected_type = get_expected_type(function_name, key);
1468+
json parsed_value;
1469+
1470+
if (expected_type == "array" || expected_type == "object") {
1471+
try {
1472+
parsed_value = json::parse(value);
1473+
} catch (...) {
1474+
parsed_value = value;
1475+
}
1476+
} else {
1477+
// For all other types, store as string and let the unpacking logic handle it
1478+
parsed_value = value;
1479+
}
1480+
1481+
args_json[key] = parsed_value;
13711482
builder.consume_spaces();
1483+
}
1484+
1485+
if (args_json.size() == 1) {
1486+
const auto key = args_json.begin().key();
1487+
auto& value = args_json.begin().value();
1488+
1489+
if (value.is_string()) {
1490+
try {
1491+
json unpacked_json = json::parse(value.get<std::string>());
1492+
if (unpacked_json.is_object()) {
1493+
args_json = unpacked_json;
1494+
}
1495+
} catch (const std::exception&) {
1496+
// Not a valid JSON string, proceed as normal
1497+
}
1498+
}
1499+
}
1500+
1501+
if (!builder.add_tool_call(function_name, "", args_json.dump())) {
1502+
LOG_INF("%s: failed to add tool call with arguments\n", __func__);
13721503
} else {
1373-
throw common_chat_msg_partial_exception("Expected <arg_value> after <arg_key>");
1504+
LOG_INF("%s: successfully added tool call with arguments\n", __func__);
13741505
}
1506+
1507+
handle_tool_call_end(builder, tool_call_end);
13751508
}
1376-
1377-
// Consume closing tag
1378-
builder.consume_regex(tool_call_end);
1379-
builder.consume_spaces();
1380-
1381-
// Add the parsed tool call
1382-
if (!builder.add_tool_call(function_name, "", arguments.dump())) {
1383-
throw common_chat_msg_partial_exception("Failed to add GLM tool call");
1509+
1510+
if (curr_pos == builder.pos()) {
1511+
// No progress made, avoid infinite loop
1512+
LOG_INF("%s: no progress in parsing, stopping to avoid infinite loop\n", __func__);
1513+
break;
13841514
}
1515+
curr_pos = builder.pos();
13851516
}
13861517

1387-
builder.add_content(builder.consume_rest());
1518+
if (builder.pos() < builder.input().size()) {
1519+
builder.add_content(builder.consume_rest());
1520+
}
13881521
}
13891522

1523+
13901524
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
13911525
LOG_DBG("%s\n", __func__);
13921526
common_chat_params data;

common/chat.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#pragma once
44

55
#include "common.h"
6+
#include <nlohmann/json.hpp>
67
#include <functional>
78
#include <chrono>
89
#include <string>
@@ -142,6 +143,7 @@ struct common_chat_params {
142143
std::vector<common_grammar_trigger> grammar_triggers;
143144
std::vector<std::string> preserved_tokens;
144145
std::vector<std::string> additional_stops;
146+
nlohmann::ordered_json tools_schema = nlohmann::ordered_json(); // Schema for tools to pass to parser
145147
};
146148

147149
struct common_chat_syntax {
@@ -151,6 +153,7 @@ struct common_chat_syntax {
151153
bool reasoning_in_content = false;
152154
bool thinking_forced_open = false;
153155
bool parse_tool_calls = true;
156+
nlohmann::ordered_json tools_schema = nlohmann::ordered_json(); // Schema for tools to enable type-aware parsing
154157
};
155158

156159
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

models/templates/glm_4_5.jinja

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
[gMASK]<sop>
2+
{%- if tools -%}
3+
<|system|>
4+
# Tools
5+
6+
You may call one or more functions to assist with the user query.
7+
8+
You are provided with function signatures within <tools></tools> XML tags:
9+
<tools>
10+
{% for tool in tools %}
11+
{{ tool | tojson }}
12+
{% endfor %}
13+
</tools>
14+
15+
For each function call, output the function name and arguments within the following XML format:
16+
<tool_call>{function-name}
17+
<arg_key>{arg-key-1}</arg_key>
18+
<arg_value>{arg-value-1}</arg_value>
19+
<arg_key>{arg-key-2}</arg_key>
20+
<arg_value>{arg-value-2}</arg_value>
21+
...
22+
</tool_call>{%- endif -%}
23+
{%- macro visible_text(content) -%}
24+
{%- if content is string -%}
25+
{{- content }}
26+
{%- elif content is iterable and content is not mapping -%}
27+
{%- for item in content -%}
28+
{%- if item is mapping and item.type == 'text' -%}
29+
{{- item.text }}
30+
{%- elif item is string -%}
31+
{{- item }}
32+
{%- endif -%}
33+
{%- endfor -%}
34+
{%- else -%}
35+
{{- content }}
36+
{%- endif -%}
37+
{%- endmacro -%}
38+
{%- set ns = namespace(last_user_index=-1) %}
39+
{%- for m in messages %}
40+
{%- if m.role == 'user' %}
41+
{% set ns.last_user_index = loop.index0 -%}
42+
{%- endif %}
43+
{%- endfor %}
44+
{% for m in messages %}
45+
{%- if m.role == 'user' -%}<|user|>
46+
{%- set user_content = visible_text(m.content) -%}
47+
{{ user_content }}
48+
{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not user_content.endswith("/nothink")) else '' -}}
49+
{%- elif m.role == 'assistant' -%}
50+
<|assistant|>
51+
{%- set reasoning_content = '' %}
52+
{%- set content = visible_text(m.content) %}
53+
{%- if m.reasoning_content is string %}
54+
{%- set reasoning_content = m.reasoning_content %}
55+
{%- else %}
56+
{%- if '</think>' in content %}
57+
{%- set parts = content.split('</think>') -%}
58+
{%- set before_first_close = parts | first -%}
59+
{%- set inner_parts = before_first_close.rstrip('\n').split('<think>') -%}
60+
{%- set extracted_reasoning = inner_parts | last -%}
61+
{%- set reasoning_content = extracted_reasoning.lstrip('\n') -%}
62+
{%- set after_last_close = parts | last -%}
63+
{%- set content = after_last_close.lstrip('\n') -%}
64+
{%- endif %}
65+
{%- endif %}
66+
{%- if loop.index0 > ns.last_user_index and reasoning_content -%}
67+
{{ '\n<think>' + reasoning_content.strip() + '</think>'}}
68+
{%- else -%}
69+
{{ '\n<think></think>' }}
70+
{%- endif -%}
71+
{%- if content.strip() -%}
72+
{{ '\n' + content.strip() }}
73+
{%- endif -%}
74+
{% if m.tool_calls %}
75+
{% for tc in m.tool_calls %}
76+
{%- if tc.function %}
77+
{%- set tc_obj = tc.function %}
78+
{%- else %}
79+
{%- set tc_obj = tc %}
80+
{%- endif %}
81+
{{ '\n<tool_call>' + tc_obj.name }}
82+
{%- if tc_obj.arguments is mapping -%}
83+
{%- for k, v in tc_obj.arguments.items() -%}
84+
85+
<arg_key>{{ k }}</arg_key>
86+
<arg_value>{{ v | tojson if v is not string else v }}</arg_value>
87+
{%- endfor -%}
88+
{%- else -%}
89+
{#- Arguments came as string - this shouldn't happen with polyfills disabled -#}
90+
{#- Output as single argument for debugging -#}
91+
92+
<arg_key>raw_arguments</arg_key>
93+
<arg_value>{{ tc_obj.arguments }}</arg_value>
94+
{%- endif -%}
95+
</tool_call>{% endfor %}
96+
{% endif %}
97+
{%- elif m.role == 'tool' -%}
98+
{%- if m.content is string -%}
99+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
100+
{{- '<|observation|>' }}
101+
{%- endif %}
102+
{{- '\n<tool_response>\n' }}
103+
{{- m.content }}
104+
{{- '\n</tool_response>' }}
105+
{%- else -%}
106+
<|observation|>{% for tr in m.content %}
107+
108+
<tool_response>
109+
{{ tr.output if tr.output is defined else tr }}
110+
</tool_response>{% endfor -%}
111+
{% endif -%}
112+
{%- elif m.role == 'system' -%}
113+
<|system|>
114+
{{ visible_text(m.content) }}
115+
{%- endif -%}
116+
{%- endfor -%}
117+
{%- if add_generation_prompt -%}
118+
<|assistant|>{{- '\n<think></think>' if (enable_thinking is defined and not enable_thinking) else '' -}}
119+
{%- endif -%}

tests/test-chat.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
294294
common_chat_syntax syntax;
295295
syntax.format = data.params.format;
296296
syntax.reasoning_format = reasoning_format;
297+
syntax.tools_schema = data.params.tools_schema;
297298
const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, syntax);
298299
assert_msg_equals(test_message, msg);
299300
}

tools/server/server.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ struct server_task {
387387
params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (params_base.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
388388
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
389389
params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
390+
params.oaicompat_chat_syntax.tools_schema = json_value(data, "tools_schema", nlohmann::ordered_json());
390391
}
391392

392393
{

0 commit comments

Comments
 (0)