Skip to content

Commit c6b1b4e

Browse files
authored
[Grammar] Integrate with XGrammar (#635)
This PR integrates with XGrammar: https://github.com/mlc-ai/xgrammar. Prior to this PR, grammar is supported by the grammar portion of MLC-LLM compiled into the model WASM. That portion is now a standalone project XGrammar. Therefore, this PR adds `mlc-ai/web-xgrammar` as part of the dependency and remove `src/grammar.ts`. We update `llm_chat.ts` accordingly for xgrammar's APIs. In addition, besides `json_schema`, we now also support requests with EBNF-formatted strings by using the following in the chat completion request. See `examples/json-schema`'s `ebnfGrammarExample()` for a full example. ```typescript response_format: { type: "grammar", grammar: jsonGrammarStr, } as webllm.ResponseFormat, ``` We also add the following performance info: - Add `grammar_init_ms` and `grammar_per_token_ms` to `CompletionUsage.extra` when using grammar - Add `time_to_first_token_s` (TTFT) and `time_per_output_token_s` (TPOT), `e2e_latency_s` to `CompletionUsage.extra` We also add `ignore_eos` to `Completion` and `ChatCompletion` requests, which can be useful for benchmarking purposes.
1 parent 6504047 commit c6b1b4e

13 files changed

+467
-301
lines changed

examples/json-mode/src/json_mode.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ async function main() {
1212
const initProgressCallback = (report: webllm.InitProgressReport) => {
1313
setLabel("init-label", report.text);
1414
};
15-
const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
15+
// Pick any one of these models to start trying -- most models in WebLLM support grammar
16+
const selectedModel = "Llama-3.2-3B-Instruct-q4f16_1-MLC";
17+
// const selectedModel = "Qwen2.5-1.5B-Instruct-q4f16_1-MLC";
18+
// const selectedModel = "Phi-3.5-mini-instruct-q4f16_1-MLC";
1619
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
1720
selectedModel,
1821
{ initProgressCallback: initProgressCallback },

examples/json-schema/src/json_schema.ts

+71-6
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,14 @@ async function simpleStructuredTextExample() {
3737
const initProgressCallback = (report: webllm.InitProgressReport) => {
3838
setLabel("init-label", report.text);
3939
};
40+
41+
// Pick any one of these models to start trying -- most models in WebLLM support grammar
42+
// const selectedModel = "Llama-3.2-3B-Instruct-q4f16_1-MLC";
43+
// const selectedModel = "Qwen2.5-1.5B-Instruct-q4f16_1-MLC";
44+
const selectedModel = "Phi-3.5-mini-instruct-q4f16_1-MLC";
4045
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
41-
"Llama-3.1-8B-Instruct-q4f16_1-MLC",
42-
{ initProgressCallback: initProgressCallback },
46+
selectedModel,
47+
{ initProgressCallback: initProgressCallback, logLevel: "INFO" },
4348
);
4449

4550
// Note that you'd need to prompt the model to answer in JSON either in
@@ -106,9 +111,14 @@ async function harryPotterExample() {
106111
setLabel("init-label", report.text);
107112
};
108113

114+
// Pick any one of these models to start trying -- most models in WebLLM support grammar
115+
const selectedModel = "Llama-3.2-3B-Instruct-q4f16_1-MLC";
116+
// const selectedModel = "Qwen2.5-1.5B-Instruct-q4f16_1-MLC";
117+
// const selectedModel = "Phi-3.5-mini-instruct-q4f16_1-MLC";
118+
109119
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
110-
"Llama-3.1-8B-Instruct-q4f16_1-MLC",
111-
{ initProgressCallback: initProgressCallback },
120+
selectedModel,
121+
{ initProgressCallback: initProgressCallback, logLevel: "INFO" },
112122
);
113123

114124
// Note that you'd need to prompt the model to answer in JSON either in
@@ -134,6 +144,7 @@ async function harryPotterExample() {
134144
console.log(reply);
135145
console.log("Output:\n" + (await engine.getMessage()));
136146
console.log(reply.usage);
147+
console.log(reply.usage!.extra);
137148
}
138149

139150
async function functionCallingExample() {
@@ -214,10 +225,64 @@ async function functionCallingExample() {
214225
console.log(reply.usage);
215226
}
216227

228+
async function ebnfGrammarExample() {
229+
// You can directly define an EBNFGrammar string with ResponseFormat.grammar
230+
const jsonGrammarStr = String.raw`
231+
root ::= basic_array | basic_object
232+
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
233+
basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"?
234+
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
235+
basic_string ::= (([\"] basic_string_1 [\"]))
236+
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
237+
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
238+
basic_boolean ::= "true" | "false"
239+
basic_null ::= "null"
240+
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
241+
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
242+
ws ::= [ \n\t]*
243+
`;
244+
245+
const initProgressCallback = (report: webllm.InitProgressReport) => {
246+
setLabel("init-label", report.text);
247+
};
248+
249+
// Pick any one of these models to start trying -- most models in WebLLM support grammar
250+
const selectedModel = "Llama-3.2-3B-Instruct-q4f16_1-MLC";
251+
// const selectedModel = "Qwen2.5-1.5B-Instruct-q4f16_1-MLC";
252+
// const selectedModel = "Phi-3.5-mini-instruct-q4f16_1-MLC";
253+
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
254+
selectedModel,
255+
{ initProgressCallback: initProgressCallback, logLevel: "INFO" },
256+
);
257+
258+
// Note that you'd need to prompt the model to answer in JSON either in
259+
// user's message or the system prompt
260+
const request: webllm.ChatCompletionRequest = {
261+
stream: false, // works with streaming, logprobs, top_logprobs as well
262+
messages: [
263+
{
264+
role: "user",
265+
content: "Introduce yourself in JSON",
266+
},
267+
],
268+
max_tokens: 128,
269+
response_format: {
270+
type: "grammar",
271+
grammar: jsonGrammarStr,
272+
} as webllm.ResponseFormat,
273+
};
274+
275+
const reply0 = await engine.chatCompletion(request);
276+
console.log(reply0);
277+
console.log("Output:\n" + (await engine.getMessage()));
278+
console.log(reply0.usage);
279+
}
280+
217281
async function main() {
218282
// await simpleStructuredTextExample();
219-
// await harryPotterExample();
220-
await functionCallingExample();
283+
await harryPotterExample();
284+
// await functionCallingExample();
285+
// await ebnfGrammarExample();
221286
}
222287

223288
main();

package-lock.json

+28
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"ts-jest": "^29.1.2",
5252
"tslib": "^2.3.1",
5353
"@mlc-ai/web-runtime": "0.18.0-dev2",
54+
"@mlc-ai/web-xgrammar": "../xgrammar/web",
5455
"typescript": "^4.9.5"
5556
},
5657
"dependencies": {

src/config.ts

+1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ export interface MLCEngineConfig {
127127
export interface GenerationConfig {
128128
// Only used in MLC
129129
repetition_penalty?: number;
130+
ignore_eos?: boolean;
130131
// Shared by MLC and OpenAI APIs
131132
top_p?: number | null;
132133
temperature?: number | null;

src/conversation.ts

+6-4
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,12 @@ export class Conversation {
257257
}
258258

259259
getStopStr(): string[] {
260-
if (this.config.stop_str.length > 0) {
261-
return this.config.stop_str;
262-
}
263-
return [this.config.seps[this.config.seps.length - 1]];
260+
// TODO(Charlie): Is this needed?
261+
// if (this.config.stop_str.length > 0) {
262+
// return this.config.stop_str;
263+
// }
264+
// return [this.config.seps[this.config.seps.length - 1]];
265+
return this.config.stop_str;
264266
}
265267

266268
getStopTokens() {

src/engine.ts

+63-9
Original file line numberDiff line numberDiff line change
@@ -465,20 +465,23 @@ export class MLCEngine implements MLCEngineInterface {
465465
pipeline: LLMChatPipeline,
466466
chatConfig: ChatConfig,
467467
genConfig: GenerationConfig,
468+
timeReceived: number,
468469
): AsyncGenerator<ChatCompletionChunk, void, void>;
469470
asyncGenerate(
470471
request: CompletionCreateParamsStreaming,
471472
model: string,
472473
pipeline: LLMChatPipeline,
473474
chatConfig: ChatConfig,
474475
genConfig: GenerationConfig,
476+
timeReceived: number,
475477
): AsyncGenerator<Completion, void, void>;
476478
async *asyncGenerate(
477479
request: ChatCompletionRequestStreaming | CompletionCreateParamsStreaming,
478480
model: string,
479481
pipeline: LLMChatPipeline,
480482
chatConfig: ChatConfig,
481483
genConfig: GenerationConfig,
484+
timeReceived: number,
482485
): AsyncGenerator<ChatCompletionChunk | Completion, void, void> {
483486
// Since it is an async generator, we need to do fine-grained try-catch to ensure lock is
484487
// released only when errors occur. Then release at the very end when no error occurs.
@@ -678,18 +681,39 @@ export class MLCEngine implements MLCEngineInterface {
678681

679682
// 4. Usage chunk
680683
if (request.stream_options?.include_usage) {
684+
const usedGrammar =
685+
"response_format" in request &&
686+
(request.response_format?.type === "grammar" ||
687+
request.response_format?.type === "json_object");
681688
const completion_tokens = pipeline.getCurRoundDecodingTotalTokens();
682689
const prompt_tokens = pipeline.getCurRoundPrefillTotalTokens();
683690
const prefill_tokens_per_s = pipeline.getCurRoundPrefillTokensPerSec();
684691
const decode_tokens_per_s = pipeline.getCurRoundDecodingTokensPerSec();
692+
const grammar_init_s = pipeline.getCurRoundGrammarInitTotalTime();
693+
const prefill_time = pipeline.getCurRoundPrefillTotalTime();
694+
const decode_time = pipeline.getCurRoundDecodingTotalTime();
695+
const grammar_per_token_s =
696+
pipeline.getCurRoundGrammarPerTokenTotalTime();
697+
const defaultExtra = {
698+
e2e_latency_s: (Date.now() - timeReceived) / 1000,
699+
prefill_tokens_per_s: prefill_tokens_per_s,
700+
decode_tokens_per_s: decode_tokens_per_s,
701+
time_to_first_token_s: prefill_time,
702+
time_per_output_token_s: decode_time / completion_tokens,
703+
};
685704
const usage: CompletionUsage = {
686705
completion_tokens: completion_tokens,
687706
prompt_tokens: prompt_tokens,
688707
total_tokens: completion_tokens + prompt_tokens,
689-
extra: {
690-
prefill_tokens_per_s: prefill_tokens_per_s,
691-
decode_tokens_per_s: decode_tokens_per_s,
692-
},
708+
extra: usedGrammar
709+
? {
710+
...defaultExtra,
711+
...{
712+
grammar_init_s: grammar_init_s,
713+
grammar_per_token_s: grammar_per_token_s / completion_tokens,
714+
},
715+
}
716+
: defaultExtra,
693717
};
694718
if (isChatCompletion) {
695719
const usageChunk: ChatCompletionChunk = {
@@ -745,6 +769,7 @@ export class MLCEngine implements MLCEngineInterface {
745769
async chatCompletion(
746770
request: ChatCompletionRequest,
747771
): Promise<AsyncIterable<ChatCompletionChunk> | ChatCompletion> {
772+
const timeReceived = Date.now();
748773
// 0. Check model loaded and preprocess inputs
749774
const [selectedModelId, selectedPipeline, selectedChatConfig] =
750775
this.getLLMStates("ChatCompletionRequest", request.model);
@@ -766,6 +791,7 @@ export class MLCEngine implements MLCEngineInterface {
766791
logprobs: request.logprobs,
767792
top_logprobs: request.top_logprobs,
768793
response_format: request.response_format,
794+
ignore_eos: request.ignore_eos,
769795
};
770796

771797
// 0.5 Block wait until this pipeline finishes all previous requests
@@ -780,6 +806,7 @@ export class MLCEngine implements MLCEngineInterface {
780806
selectedPipeline,
781807
selectedChatConfig,
782808
genConfig,
809+
timeReceived,
783810
);
784811
}
785812

@@ -796,6 +823,8 @@ export class MLCEngine implements MLCEngineInterface {
796823
let prompt_tokens = 0;
797824
let prefill_time = 0;
798825
let decode_time = 0;
826+
let grammar_init_s = 0;
827+
let grammar_per_token_s = 0;
799828
for (let i = 0; i < n; i++) {
800829
let outputMessage: string;
801830
if (this.interruptSignal) {
@@ -852,8 +881,21 @@ export class MLCEngine implements MLCEngineInterface {
852881
prompt_tokens += selectedPipeline.getCurRoundPrefillTotalTokens();
853882
prefill_time += selectedPipeline.getCurRoundPrefillTotalTime();
854883
decode_time += selectedPipeline.getCurRoundDecodingTotalTime();
884+
grammar_init_s += selectedPipeline.getCurRoundGrammarInitTotalTime();
885+
grammar_per_token_s +=
886+
selectedPipeline.getCurRoundGrammarPerTokenTotalTime();
855887
}
856-
888+
const usedGrammar =
889+
"response_format" in request &&
890+
(request.response_format?.type === "grammar" ||
891+
request.response_format?.type === "json_object");
892+
const defaultExtra = {
893+
e2e_latency_s: (Date.now() - timeReceived) / 1000,
894+
prefill_tokens_per_s: prompt_tokens / prefill_time,
895+
decode_tokens_per_s: completion_tokens / decode_time,
896+
time_to_first_token_s: prefill_time,
897+
time_per_output_token_s: decode_time / completion_tokens,
898+
};
857899
const response: ChatCompletion = {
858900
id: crypto.randomUUID(),
859901
choices: choices,
@@ -864,10 +906,15 @@ export class MLCEngine implements MLCEngineInterface {
864906
completion_tokens: completion_tokens,
865907
prompt_tokens: prompt_tokens,
866908
total_tokens: completion_tokens + prompt_tokens,
867-
extra: {
868-
prefill_tokens_per_s: prompt_tokens / prefill_time,
869-
decode_tokens_per_s: completion_tokens / decode_time,
870-
},
909+
extra: usedGrammar
910+
? {
911+
...defaultExtra,
912+
...{
913+
grammar_init_s: grammar_init_s,
914+
grammar_per_token_s: grammar_per_token_s / completion_tokens,
915+
},
916+
}
917+
: defaultExtra,
871918
} as CompletionUsage,
872919
};
873920

@@ -901,6 +948,8 @@ export class MLCEngine implements MLCEngineInterface {
901948
async completion(
902949
request: CompletionCreateParams,
903950
): Promise<AsyncIterable<Completion> | Completion> {
951+
const timeReceived = Date.now();
952+
904953
// 0. Check model loaded and preprocess inputs
905954
const [selectedModelId, selectedPipeline, selectedChatConfig] =
906955
this.getLLMStates("CompletionCreateParams", request.model);
@@ -915,6 +964,7 @@ export class MLCEngine implements MLCEngineInterface {
915964
logit_bias: request.logit_bias,
916965
logprobs: request.logprobs,
917966
top_logprobs: request.top_logprobs,
967+
ignore_eos: request.ignore_eos,
918968
};
919969

920970
// 0.5 Block wait until this pipeline finishes all previous requests
@@ -929,6 +979,7 @@ export class MLCEngine implements MLCEngineInterface {
929979
selectedPipeline,
930980
selectedChatConfig,
931981
genConfig,
982+
timeReceived,
932983
);
933984
}
934985

@@ -989,8 +1040,11 @@ export class MLCEngine implements MLCEngineInterface {
9891040
prompt_tokens: prompt_tokens,
9901041
total_tokens: completion_tokens + prompt_tokens,
9911042
extra: {
1043+
e2e_latency_s: (Date.now() - timeReceived) / 1000,
9921044
prefill_tokens_per_s: prompt_tokens / prefill_time,
9931045
decode_tokens_per_s: completion_tokens / decode_time,
1046+
time_to_first_token_s: prefill_time,
1047+
time_per_output_token_s: decode_time / completion_tokens,
9941048
},
9951049
} as CompletionUsage,
9961050
};

0 commit comments

Comments
 (0)