Skip to content

Commit bcaaab6

Browse files
added preserved tokens and <|observation|> handling
Improved overall stability of code
1 parent 45fac2a commit bcaaab6

File tree

1 file changed

+150
-36
lines changed

1 file changed

+150
-36
lines changed

common/chat.cpp

Lines changed: 150 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,60 +1331,174 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
13311331

13321332
static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) {
13331333
common_chat_params data;
1334+
13341335
data.prompt = apply(tmpl, inputs);
13351336
data.format = COMMON_CHAT_FORMAT_GLM_4_5;
1337+
1338+
data.preserved_tokens = {
1339+
"<|system|>", "<|assistant|>", "<|observation|>",
1340+
"<tool_call>", "</tool_call>", "<arg_key>", "</arg_key>",
1341+
"<arg_value>", "</arg_value>", "<think>", "</think>",
1342+
"<tool_response>", "</tool_response>",
1343+
};
1344+
1345+
// Todo @bruce XML grammer for GLM codes
1346+
// Step 3: Grammar for tools (if any)
1347+
// if (inputs.tools.is_array() && !inputs.tools.empty()) {
1348+
// data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1349+
1350+
// data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1351+
// // Define flexible arg pair pattern
1352+
// auto arg_pair = builder.add_rule("arg_pair",
1353+
// std::string("\"<arg_key>\" [a-zA-Z_][a-zA-Z0-9_]* \"</arg_key>\" space ") +
1354+
// "\"<arg_value>\" [^<]+ \"</arg_value>\" space"
1355+
// );
1356+
1357+
// // Zero or more arg pairs
1358+
// auto arg_section = builder.add_rule("arg_section", "(" + arg_pair + ")*");
1359+
1360+
// std::vector<std::string> tool_rules;
1361+
// foreach_function(inputs.tools, [&](const json & tool) {
1362+
// const auto & function = tool.at("function");
1363+
// std::string name = function.at("name");
1364+
1365+
// tool_rules.push_back(
1366+
// builder.add_rule(name + "-call",
1367+
// "\"<tool_call>\" \"" + name + "\" space " +
1368+
// arg_section + " space " +
1369+
// "\"</tool_call>\""
1370+
// )
1371+
// );
1372+
// });
1373+
1374+
// auto any_tool_call = builder.add_rule("any_tool_call",
1375+
// string_join(tool_rules, " | ")
1376+
// );
1377+
1378+
// auto multiple_tool_calls = builder.add_rule("multiple_tool_calls",
1379+
// "(" + any_tool_call + " space)+"
1380+
// );
1381+
1382+
// builder.add_rule("root", multiple_tool_calls);
1383+
// });
1384+
1385+
// data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"});
1386+
// }
1387+
13361388
return data;
13371389
}
13381390

13391391
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
1392+
1393+
auto handle_tool_call_end = [&] (common_chat_msg_parser & builder, auto end_pos) {
1394+
builder.move_to(end_pos);
1395+
builder.consume_literal("</tool_call>");
1396+
1397+
size_t obs_pos = builder.input().find("<|observation|>", builder.pos());
1398+
if (obs_pos != std::string::npos) {
1399+
if (obs_pos > builder.pos()) {
1400+
std::string content = builder.input().substr(builder.pos(), obs_pos - builder.pos());
1401+
builder.add_content(content);
1402+
}
1403+
1404+
builder.move_to(obs_pos);
1405+
builder.consume_literal("<|observation|>");
1406+
} else {
1407+
std::string remaining = builder.consume_rest();
1408+
if (!remaining.empty()) builder.add_content(remaining);
1409+
}
1410+
};
1411+
13401412
builder.consume_spaces();
1413+
13411414
builder.try_parse_reasoning("<think>", "</think>");
1342-
if (!builder.syntax().parse_tool_calls) {
1343-
builder.add_content(builder.consume_rest());
1344-
return;
1345-
}
1346-
1347-
// GLM 4.5 uses format: <tool_call>function_name\n<arg_key>key</arg_key>\n<arg_value>value</arg_value>\n</tool_call>
1348-
static const common_regex tool_call_start("<tool_call>([^\n<]+)");
1349-
static const common_regex arg_key_regex("<arg_key>([^<]+)</arg_key>");
1350-
static const common_regex arg_value_regex("<arg_value>([\\s\\S]*?)</arg_value>");
1351-
static const common_regex tool_call_end("</tool_call>");
1352-
1353-
while (auto res = builder.try_find_regex(tool_call_start)) {
1354-
// Move to the start of the tool call and consume it
1355-
builder.move_to(res->groups[0].begin);
1356-
builder.consume_regex(tool_call_start);
1357-
1358-
std::string function_name = builder.str(res->groups[1]);
1359-
json arguments = json::object();
1415+
1416+
size_t curr_pos = builder.pos();
1417+
while (builder.input().find("<tool_call>", builder.pos()) != std::string::npos) {
1418+
size_t tool_call_start = builder.input().find("<tool_call>", builder.pos());
1419+
if (tool_call_start > builder.pos()) {
1420+
std::string content = builder.input().substr(builder.pos(), tool_call_start - builder.pos());
1421+
builder.add_content(content);
1422+
}
13601423

1424+
size_t tool_call_end = builder.input().find("</tool_call>", tool_call_start);
1425+
if (tool_call_end == std::string::npos) return;
1426+
1427+
builder.move_to(tool_call_start);
1428+
builder.consume_literal("<tool_call>");
13611429
builder.consume_spaces();
1362-
1363-
// Parse all arg_key/arg_value pairs
1364-
while (auto key_res = builder.try_consume_regex(arg_key_regex)) {
1365-
std::string key = builder.str(key_res->groups[1]);
1366-
builder.consume_spaces();
1430+
1431+
size_t arg_key_start = builder.input().find("<arg_key>", tool_call_start);
1432+
if (arg_key_start == std::string::npos) {
1433+
std::string function_content = builder.input().substr(builder.pos(), tool_call_end - builder.pos());
1434+
std::string function_name = string_strip(function_content);
13671435

1368-
if (auto value_res = builder.try_consume_regex(arg_value_regex)) {
1369-
std::string value = builder.str(value_res->groups[1]);
1370-
arguments[key] = value;
1436+
if (!builder.add_tool_call(function_name, "", "{}")) {
1437+
LOG_INF("%s: failed to add tool call\n", __func__);
1438+
}
1439+
1440+
handle_tool_call_end(builder, tool_call_end);
1441+
1442+
} else {
1443+
std::string function_content = builder.input().substr(builder.pos(), arg_key_start - builder.pos());
1444+
std::string function_name = string_strip(function_content);
1445+
1446+
json args_json = json::object();
1447+
builder.move_to(arg_key_start);
1448+
1449+
while (builder.pos() < tool_call_end) {
1450+
if (!builder.try_consume_literal("<arg_key>")) {
1451+
builder.consume_spaces();
1452+
if (!builder.try_consume_literal("<arg_key>")) {
1453+
break;
1454+
}
1455+
}
1456+
1457+
auto key_close = builder.try_find_literal("</arg_key>");
1458+
if (!key_close || key_close->groups[0].end > tool_call_end) {
1459+
throw common_chat_msg_partial_exception("incomplete tool call");
1460+
return;
1461+
}
1462+
1463+
std::string key = string_strip(key_close->prelude);
1464+
1465+
builder.consume_spaces();
1466+
1467+
if (!builder.try_consume_literal("<arg_value>")) {
1468+
throw common_chat_msg_partial_exception("incomplete tool call");
1469+
return;
1470+
}
1471+
1472+
auto value_close = builder.try_find_literal("</arg_value>");
1473+
if (!value_close || value_close->groups[0].end > tool_call_end) {
1474+
throw common_chat_msg_partial_exception("incomplete tool call");
1475+
return;
1476+
}
1477+
1478+
std::string value = string_strip(value_close->prelude);
13711479
builder.consume_spaces();
1480+
args_json[key] = value;
1481+
}
1482+
1483+
if (!builder.add_tool_call(function_name, "", args_json.dump())) {
1484+
LOG_INF("%s: failed to add tool call with arguments\n", __func__);
13721485
} else {
1373-
throw common_chat_msg_partial_exception("Expected <arg_value> after <arg_key>");
1486+
LOG_INF("%s: successfully added tool call with arguments\n", __func__);
13741487
}
1488+
1489+
handle_tool_call_end(builder, tool_call_end);
13751490
}
1376-
1377-
// Consume closing tag
1378-
builder.consume_regex(tool_call_end);
1379-
builder.consume_spaces();
1380-
1381-
// Add the parsed tool call
1382-
if (!builder.add_tool_call(function_name, "", arguments.dump())) {
1383-
throw common_chat_msg_partial_exception("Failed to add GLM tool call");
1491+
1492+
if (curr_pos == builder.pos()) {
1493+
// No progress made, avoid infinite loop
1494+
LOG_INF("%s: no progress in parsing, stopping to avoid infinite loop\n", __func__);
1495+
break;
13841496
}
13851497
}
13861498

1387-
builder.add_content(builder.consume_rest());
1499+
if (builder.pos() < builder.input().size()) {
1500+
builder.add_content(builder.consume_rest());
1501+
}
13881502
}
13891503

13901504
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {

0 commit comments

Comments
 (0)