From 7acd4f8c9c69c6a21916374477a09cb1a83696e3 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Wed, 9 Oct 2024 03:07:21 +0000 Subject: [PATCH 01/14] feat(xgrammar): trying to replace outlines with xgrammar --- python/sglang/srt/constrained/bnf_cache.py | 84 +++++++++++++++ python/sglang/srt/managers/schedule_batch.py | 100 +++++++++++++++++- python/sglang/srt/managers/scheduler.py | 27 +++++ .../sglang/srt/model_executor/model_runner.py | 2 + .../srt/sampling/sampling_batch_info.py | 29 ++++- 5 files changed, 238 insertions(+), 4 deletions(-) create mode 100644 python/sglang/srt/constrained/bnf_cache.py diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/bnf_cache.py new file mode 100644 index 00000000000..99ccfcd93fe --- /dev/null +++ b/python/sglang/srt/constrained/bnf_cache.py @@ -0,0 +1,84 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Cache for the compressed finite state machine.""" + +from transformers import AutoTokenizer +from xgrammar import BuiltinGrammar, GrammarStateMatcher + +from sglang.srt.constrained.base_tool_cache import BaseToolCache + + +class BNFCache(BaseToolCache): + def __init__( + self, + tokenizer_path, + tokenizer_args_dict, + skip_tokenizer_init=False, + enable=True, + ): + super().__init__(enable=enable) + if skip_tokenizer_init: + return + + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, **tokenizer_args_dict + ) + + def init_value(self, key): + key_type, key_string = key + + if key_type == "json": + grammar = BuiltinGrammar.json_schema(key_string) + elif key_type == "regex": + assert False, "Not supported by xgrammar yet" + else: + raise ValueError(f"Invalid key_type: {key_type}") + + return grammar + + def query(self, key): + grammar = super().query(key) + return GrammarStateMatcher(grammar, self.tokenizer) + + +# class BNFCache(BaseToolCache): +# def __init__( +# self, +# tokenizer_path, +# tokenizer_args_dict, +# enable=True, +# skip_tokenizer_init=False, +# constrained_json_whitespace_pattern=None, +# ): +# super().__init__(enable=enable) + +# if skip_tokenizer_init: +# return + +# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) +# self.tokenizer = tokenizer +# self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern + +# def init_value(self, key): +# key_type, key_string = key +# if key_type == "json": +# grammar = BuiltinGrammar.json_schema(key_string) +# elif key_type == "regex": +# assert False, "Not supported yet" +# else: +# raise ValueError(f"Invalid key_type: {key_type}") + +# return GrammarStateMatcher(grammar, self.tokenizer) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f4d41feeffd..ab6c8e275ec 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -32,6 +32,7 @@ from typing import List, Optional, Tuple, Union import torch +from xgrammar import GrammarStateMatcher from sglang.global_config import global_config from sglang.srt.constrained import RegexGuide @@ -231,6 +232,9 @@ def __init__( self.regex_fsm_state: int = 0 self.jump_forward_map: JumpForwardMap = None + self.regex_bnf: Optional[GrammarStateMatcher] = None + self.allow_jump_forward: bool = False + # whether request reached finished condition def finished(self) -> bool: return self.finished_reason is not None @@ -387,6 +391,66 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): return True + def jump_forward_and_retokenize_bnf(self, jump_forward_str): + if self.origin_input_text is None: + # Recovering text can only use unpadded ids + self.origin_input_text = self.tokenizer.decode( + self.origin_input_ids_unpadded + ) + + all_text = self.origin_input_text + self.decoded_text + jump_forward_str + all_ids = self.tokenizer.encode(all_text) + if not all_ids: + logger.warning("Encoded all_text resulted in empty all_ids") + return False + + prompt_tokens = len(self.origin_input_ids_unpadded) + if prompt_tokens > len(all_ids): + logger.warning("prompt_tokens is larger than encoded all_ids") + return False + + if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]: + # TODO(lsyin): fix token fusion + logger.warning( + "Token fusion between input and output, try to avoid this by removing the space at the end of the input." + ) + return False + + old_output_ids = self.output_ids + self.output_ids = all_ids[prompt_tokens:] + self.decoded_text = self.decoded_text + jump_forward_str + self.surr_offset = prompt_tokens + self.read_offset = len(all_ids) + + # NOTE: A trick to reduce the surrouding tokens decoding overhead + for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET): + surr_text_ = self.tokenizer.decode( + all_ids[self.read_offset - i : self.read_offset] + ) + if not surr_text_.endswith("�"): + self.surr_offset = self.read_offset - i + break + + k = 0 + for i, old_id in enumerate(old_output_ids): + if old_id == self.output_ids[i]: + k = i + 1 + else: + break + + for i in range(k, len(self.output_ids)): + assert self.regex_bnf is not None, "regex_bnf is None" + self.regex_bnf.accept_token(self.output_ids[i]) + + if self.return_logprob: + # For fast-forward part's logprobs + self.output_token_logprobs = self.output_token_logprobs[:k] + self.output_top_logprobs = self.output_top_logprobs[:k] + self.logprob_start_len = prompt_tokens + k + self.last_update_decode_tokens = len(self.output_ids) - k + + return True + def __repr__(self): return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, " @@ -430,7 +494,8 @@ class ScheduleBatch: def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): return_logprob = any(req.return_logprob for req in reqs) has_stream = any(req.stream for req in reqs) - has_regex = any(req.regex_fsm for req in reqs) + # has_regex = any(req.regex_fsm for req in reqs) + has_regex = any(req.regex_bnf for req in reqs) return cls( reqs=reqs, @@ -695,6 +760,33 @@ def check_for_jump_forward(self, pad_input_ids_func): req.output_ids = cur_output_ids continue + print(f"Jump forward: {jump_forward_str}") + + # The decode status has diverged from detokenizer_manager + req.vid += 1 + + # insert the old request into tree_cache + self.tree_cache.cache_finished_req(req, cur_all_ids) + + # re-applying image padding + if req.image_inputs is not None: + req.origin_input_ids = pad_input_ids_func( + req.origin_input_ids_unpadded, req.image_inputs + ) + + jump_forward_reqs.append(req) + filter_indices.remove(i) + + if req.allow_jump_forward and req.regex_bnf is not None: + jump_forward_str = req.regex_bnf.find_jump_forward_string() + if len(jump_forward_str) > 1: + cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1] + cur_output_ids = req.output_ids + if not req.jump_forward_and_retokenize_bnf(jump_forward_str): + # Failed to jump forward, revert + req.output_ids = cur_output_ids + continue + # The decode status has diverged from detokenizer_manager req.vid += 1 @@ -762,7 +854,8 @@ def filter_batch(self, unfinished_indices: List[int]): self.top_logprobs_nums = None self.has_stream = any(req.stream for req in self.reqs) - self.has_regex = any(req.regex_fsm for req in self.reqs) + # self.has_regex = any(req.regex_fsm for req in self.reqs) + self.has_regex = any(req.regex_bnf for req in self.reqs) self.sampling_info.filter_batch(unfinished_indices, new_indices) @@ -807,6 +900,9 @@ def get_model_worker_batch(self): req.regex_fsm_state for req in self.reqs ] + # TODO(dark): remove the above and use the below + self.sampling_info.regex_bnfs = [req.regex_bnf for req in self.reqs] + return ModelWorkerBatch( forward_mode=self.forward_mode, input_ids=self.input_ids, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f0e78f45b08..1af80a3d237 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -28,6 +28,7 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.constrained.bnf_cache import BNFCache from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer @@ -214,6 +215,14 @@ def __init__( skip_tokenizer_init=server_args.skip_tokenizer_init, constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, ) + self.regex_bnf_cache = BNFCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + skip_tokenizer_init=server_args.skip_tokenizer_init, + ) self.jump_forward_cache = JumpForwardCache() # Init new token estimation @@ -310,18 +319,27 @@ def handle_generate_request( req.sampling_params.json_schema is not None or req.sampling_params.regex is not None ): + # TODO(dark): replace FSM cache with BNF cache if req.sampling_params.json_schema is not None: req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( ("json", req.sampling_params.json_schema) ) + req.regex_bnf = self.regex_bnf_cache.query( + ("json", req.sampling_params.json_schema) + ) elif req.sampling_params.regex is not None: req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( ("regex", req.sampling_params.regex) ) + req.regex_bnf = self.regex_bnf_cache.query( + ("regex", req.sampling_params.regex) + ) if not self.disable_regex_jump_forward: req.jump_forward_map = self.jump_forward_cache.query( computed_regex_string ) + req.allow_jump_forward = True + # TODO(dark): add custom bnf jump forward map here # Truncate prompts that are too long if len(req.origin_input_ids) >= self.max_req_input_len: @@ -680,11 +698,15 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): req.output_ids.append(next_token_ids[i]) req.check_finished() + # TODO(dark): replace FSM cache with BNF cache if req.regex_fsm is not None: req.regex_fsm_state = req.regex_fsm.get_next_state( req.regex_fsm_state, next_token_ids[i] ) + if req.regex_bnf is not None: + assert req.regex_bnf.accept_token(next_token_ids[i]) + if req.finished(): self.tree_cache.cache_finished_req(req) elif req not in batch.decoding_reqs: @@ -751,11 +773,15 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): req.output_ids.append(next_token_id) req.check_finished() + # TODO(dark): replace FSM cache with BNF cache if req.regex_fsm is not None: req.regex_fsm_state = req.regex_fsm.get_next_state( req.regex_fsm_state, next_token_id ) + if req.regex_bnf is not None: + assert req.regex_bnf.accept_token(next_token_id) + if req.finished(): self.tree_cache.cache_finished_req(req) @@ -959,6 +985,7 @@ def flush_cache(self): self.tree_cache.reset() self.tree_cache_metrics = {"total": 0, "hit": 0} self.regex_fsm_cache.reset() + self.regex_bnf_cache.reset() self.req_to_token_pool.clear() self.token_to_kv_pool.clear() torch.cuda.empty_cache() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 83273bc4374..f3390ae46cb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -522,6 +522,7 @@ def sample( ) -> torch.Tensor: # Put CPU-heavy tasks here. They will be overlapped with the forward pass. sampling_info = forward_batch.sampling_info + sampling_info.update_regex_vocab_mask_bnf() sampling_info.update_regex_vocab_mask() sampling_info.update_penalties() logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info) @@ -550,6 +551,7 @@ def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchIn # Apply regex vocab_mask if sampling_info.vocab_mask is not None: logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf")) + # TODO(dark): add some custom cuda kernel for bnf bitmask here return logits diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index de781acb3fc..569465ce5e2 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -1,9 +1,10 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional import torch +from xgrammar import GrammarStateMatcher import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.constrained import RegexGuide @@ -26,12 +27,15 @@ class SamplingBatchInfo: # Bias Tensors vocab_size: int logit_bias: torch.Tensor = None - vocab_mask: torch.Tensor = None + vocab_mask: Optional[torch.Tensor] = None # FSM states regex_fsms: List[RegexGuide] = None regex_fsm_states: List[int] = None + # TODO(dark): remove the above and use the regex_bnf instead + regex_bnfs: Optional[List[Optional[GrammarStateMatcher]]] = None + # Penalizer penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None linear_penalties: torch.Tensor = None @@ -128,6 +132,27 @@ def update_regex_vocab_mask(self): regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens ] = 0 + # TODO(dark): rename this to update_regex_vocab_mask after removing the old one + def update_regex_vocab_mask_bnf(self): + # Reset the vocab mask + self.vocab_mask = None + + if self.regex_bnfs and any(regex_bnf for regex_bnf in self.regex_bnfs): + # If has regex, then we need to update the vocab mask + self.vocab_mask = torch.zeros( + len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda" + ) + for i, regex_bnf in enumerate(self.regex_bnfs): + if regex_bnf is not None: + # Note that this bitmask is a bitset, not bool + bitmask = regex_bnf.find_next_token_bitmask() + # Mask the tokens that are not allowed + self.vocab_mask[i][ + regex_bnf.get_rejected_tokens_from_bitmask( + bitmask, self.vocab_size + ) + ] = 1 + def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices) From 3a05b1adb9746e0ce4918df2012656e66e080038 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Wed, 9 Oct 2024 15:45:27 +0000 Subject: [PATCH 02/14] fix(xgrammar): rollback when retokenization happens --- python/sglang/srt/managers/schedule_batch.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ab6c8e275ec..939d720cf21 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -392,6 +392,8 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): return True def jump_forward_and_retokenize_bnf(self, jump_forward_str): + assert self.regex_bnf is not None, "should be a regex request" + if self.origin_input_text is None: # Recovering text can only use unpadded ids self.origin_input_text = self.tokenizer.decode( @@ -438,8 +440,10 @@ def jump_forward_and_retokenize_bnf(self, jump_forward_str): else: break + # rollback to the last token that is the same + self.regex_bnf.rollback(len(old_output_ids) - k) + for i in range(k, len(self.output_ids)): - assert self.regex_bnf is not None, "regex_bnf is None" self.regex_bnf.accept_token(self.output_ids[i]) if self.return_logprob: From f18823558c31ed67ff28aa3092dff1578638b849 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Fri, 11 Oct 2024 03:16:32 +0000 Subject: [PATCH 03/14] minor(xgrammar): remove all dependencies on outlines --- python/sglang/srt/constrained/fsm_cache.py | 88 -------- python/sglang/srt/constrained/jump_forward.py | 203 ------------------ python/sglang/srt/managers/schedule_batch.py | 132 ------------ python/sglang/srt/managers/scheduler.py | 40 +--- .../sglang/srt/model_executor/model_runner.py | 1 - .../srt/sampling/sampling_batch_info.py | 21 -- 6 files changed, 2 insertions(+), 483 deletions(-) delete mode 100644 python/sglang/srt/constrained/fsm_cache.py delete mode 100644 python/sglang/srt/constrained/jump_forward.py diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py deleted file mode 100644 index ab025f26e9d..00000000000 --- a/python/sglang/srt/constrained/fsm_cache.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -"""Cache for the compressed finite state machine.""" -import logging - -from interegular import InvalidSyntax, parse_pattern -from outlines.fsm.json_schema import build_regex_from_schema -from transformers import AutoTokenizer - -from sglang.srt.constrained import RegexGuide, TransformerTokenizer -from sglang.srt.constrained.base_tool_cache import BaseToolCache - -logger = logging.getLogger(__name__) - - -class FSMCache(BaseToolCache): - def __init__( - self, - tokenizer_path, - tokenizer_args_dict, - enable=True, - skip_tokenizer_init=False, - constrained_json_whitespace_pattern=None, - ): - super().__init__(enable=enable) - - if ( - skip_tokenizer_init - or tokenizer_path.endswith(".json") - or tokenizer_path.endswith(".model") - ): - # Do not support TiktokenTokenizer or SentencePieceTokenizer - return - - tokenizer_args_dict.setdefault("padding_side", "left") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) - try: - self.outlines_tokenizer = TransformerTokenizer(tokenizer) - except AttributeError: - # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) - origin_pad_token_id = tokenizer.pad_token_id - - def fset(self, value): - self._value = value - - type(tokenizer).pad_token_id = property( - fget=type(tokenizer).pad_token_id.fget, fset=fset - ) - self.outlines_tokenizer = TransformerTokenizer(tokenizer) - self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id - self.outlines_tokenizer.pad_token_id = origin_pad_token_id - self.outlines_tokenizer.pad_token = ( - self.outlines_tokenizer.tokenizer.pad_token - ) - self.outlines_tokenizer.vocabulary = ( - self.outlines_tokenizer.tokenizer.get_vocab() - ) - self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern - - def init_value(self, key): - key_type, key_string = key - if key_type == "json": - regex = build_regex_from_schema( - key_string, whitespace_pattern=self.constrained_json_whitespace_pattern - ) - elif key_type == "regex": - regex = key_string - else: - raise ValueError(f"Invalid key_type: {key_type}") - try: - parse_pattern(regex) - except InvalidSyntax as e: - logger.warning(f"skip invalid regex guide: {regex=}, {e=}") - return None, regex - return RegexGuide(regex, self.outlines_tokenizer), regex diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/jump_forward.py deleted file mode 100644 index 1ebc8b21718..00000000000 --- a/python/sglang/srt/constrained/jump_forward.py +++ /dev/null @@ -1,203 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -""" -Faster constrained decoding. -Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/ -""" - -import dataclasses -import logging -from collections import defaultdict - -import interegular -import outlines.caching -from interegular import InvalidSyntax - -from sglang.srt.constrained import ( - FSMInfo, - disk_cache, - make_byte_level_fsm, - make_deterministic_fsm, -) -from sglang.srt.constrained.base_tool_cache import BaseToolCache - -IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" - -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass -class JumpEdge: - symbol: str = None - symbol_next_state: int = None - byte: int = None - byte_next_state: int = None - - -class JumpForwardMap: - def __init__(self, regex_string): - @disk_cache() - def _init_state_to_jump_forward(regex_string): - try: - regex_pattern = interegular.parse_pattern(regex_string) - except InvalidSyntax as e: - logger.warning(f"skip invalid regex: {regex_string}, {e=}") - self.state_to_jump_forward = None - return - - byte_fsm = make_byte_level_fsm( - regex_pattern.to_fsm().reduce(), keep_utf8=True - ) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - - fsm_info: FSMInfo = regex_fsm.fsm_info - - symbol_to_id = fsm_info.alphabet_symbol_mapping - id_to_symbol = {} - for symbol, id_ in symbol_to_id.items(): - id_to_symbol.setdefault(id_, []).append(symbol) - - transitions = fsm_info.transitions - - outgoings_ct = defaultdict(int) - # NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally - for s in fsm_info.finals: - outgoings_ct[s] = 1 - - state_to_jump_forward = {} - for (state, id_), next_state in transitions.items(): - if id_ == fsm_info.alphabet_anything_value: - # Arbitrarily symbol cannot be recognized as jump forward - continue - - symbols = id_to_symbol[id_] - for c in symbols: - if len(c) > 1: - # Skip byte level transitions like c = "5E" - continue - - outgoings_ct[state] += 1 - if outgoings_ct[state] > 1: - if state in state_to_jump_forward: - del state_to_jump_forward[state] - break - - state_to_jump_forward[state] = JumpEdge( - symbol=c, - symbol_next_state=next_state, - ) - - # Process the byte level jump forward - outgoings_ct = defaultdict(int) - for s in fsm_info.finals: - outgoings_ct[s] = 1 - - for (state, id_), next_state in transitions.items(): - if id_ == fsm_info.alphabet_anything_value: - continue - symbols = id_to_symbol[id_] - for c in symbols: - byte_ = None - if len(c) == 1 and ord(c) < 0x80: - # ASCII character - byte_ = ord(c) - elif len(c) > 1: - # FIXME: This logic is due to the leading \x00 - # https://github.com/outlines-dev/outlines/pull/930 - byte_ = int(symbols[0][1:], 16) - - if byte_ is not None: - outgoings_ct[state] += 1 - if outgoings_ct[state] > 1: - if state in state_to_jump_forward: - del state_to_jump_forward[state] - break - e = state_to_jump_forward.get(state, JumpEdge()) - e.byte = byte_ - e.byte_next_state = next_state - state_to_jump_forward[state] = e - - return state_to_jump_forward - - self.state_to_jump_forward = _init_state_to_jump_forward(regex_string) - - def jump_forward_symbol(self, state): - jump_forward_str = "" - next_state = state - while state in self.state_to_jump_forward: - e = self.state_to_jump_forward[state] - if e.symbol is None: - break - jump_forward_str += e.symbol - next_state = e.symbol_next_state - state = next_state - - return jump_forward_str, next_state - - def jump_forward_byte(self, state): - if state not in self.state_to_jump_forward: - return None - - jump_forward_bytes = [] - next_state = None - while state in self.state_to_jump_forward: - e = self.state_to_jump_forward[state] - assert e.byte is not None and e.byte_next_state is not None - jump_forward_bytes.append((e.byte, e.byte_next_state)) - next_state = e.byte_next_state - state = next_state - - return jump_forward_bytes - - def is_jump_forward_symbol_state(self, state): - return ( - state in self.state_to_jump_forward - and self.state_to_jump_forward[state].symbol is not None - ) - - -class JumpForwardCache(BaseToolCache): - def __init__(self): - super().__init__() - - def init_value(self, regex): - forward_map = JumpForwardMap(regex) - if forward_map.state_to_jump_forward: - return forward_map - else: - return None - - -def test_main(regex_string): - jump_forward_map = JumpForwardMap(regex_string) - for state, e in jump_forward_map.state_to_jump_forward.items(): - if e.symbol is not None: - jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state) - print(f"{state} -> {next_state}", jump_forward_str) - bytes_ = jump_forward_map.jump_forward_byte(state) - print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_]) - - -if __name__ == "__main__": - import outlines - - outlines.caching.clear_cache() - test_main(r"The google's DNS sever address is " + IP_REGEX) - test_main(r"霍格沃茨特快列车|霍比特人比尔博") - # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ... - # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ... - - test_main(r"[-+]?[0-9]+[ ]*") diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 939d720cf21..9cc036bba19 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -35,8 +35,6 @@ from xgrammar import GrammarStateMatcher from sglang.global_config import global_config -from sglang.srt.constrained import RegexGuide -from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -228,10 +226,6 @@ def __init__( self.embedding = None # Constrained decoding - self.regex_fsm: RegexGuide = None - self.regex_fsm_state: int = 0 - self.jump_forward_map: JumpForwardMap = None - self.regex_bnf: Optional[GrammarStateMatcher] = None self.allow_jump_forward: bool = False @@ -334,63 +328,6 @@ def check_finished(self): self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) return - def jump_forward_and_retokenize(self, jump_forward_str, next_state): - if self.origin_input_text is None: - # Recovering text can only use unpadded ids - self.origin_input_text = self.tokenizer.decode( - self.origin_input_ids_unpadded - ) - - all_text = self.origin_input_text + self.decoded_text + jump_forward_str - all_ids = self.tokenizer.encode(all_text) - if not all_ids: - logger.warning("Encoded all_text resulted in empty all_ids") - return False - - prompt_tokens = len(self.origin_input_ids_unpadded) - if prompt_tokens > len(all_ids): - logger.warning("prompt_tokens is larger than encoded all_ids") - return False - - if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]: - # TODO(lsyin): fix token fusion - logger.warning( - "Token fusion between input and output, try to avoid this by removing the space at the end of the input." - ) - return False - - old_output_ids = self.output_ids - self.output_ids = all_ids[prompt_tokens:] - self.decoded_text = self.decoded_text + jump_forward_str - self.surr_offset = prompt_tokens - self.read_offset = len(all_ids) - - # NOTE: A trick to reduce the surrouding tokens decoding overhead - for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET): - surr_text_ = self.tokenizer.decode( - all_ids[self.read_offset - i : self.read_offset] - ) - if not surr_text_.endswith("�"): - self.surr_offset = self.read_offset - i - break - - self.regex_fsm_state = next_state - - if self.return_logprob: - # For fast-forward part's logprobs - k = 0 - for i, old_id in enumerate(old_output_ids): - if old_id == self.output_ids[i]: - k = k + 1 - else: - break - self.output_token_logprobs = self.output_token_logprobs[:k] - self.output_top_logprobs = self.output_top_logprobs[:k] - self.logprob_start_len = prompt_tokens + k - self.last_update_decode_tokens = len(self.output_ids) - k - - return True - def jump_forward_and_retokenize_bnf(self, jump_forward_str): assert self.regex_bnf is not None, "should be a regex request" @@ -498,7 +435,6 @@ class ScheduleBatch: def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): return_logprob = any(req.return_logprob for req in reqs) has_stream = any(req.stream for req in reqs) - # has_regex = any(req.regex_fsm for req in reqs) has_regex = any(req.regex_bnf for req in reqs) return cls( @@ -720,67 +656,6 @@ def check_for_jump_forward(self, pad_input_ids_func): filter_indices = [i for i in range(len(self.reqs))] for i, req in enumerate(self.reqs): - if req.jump_forward_map is not None: - jump_forward_bytes = req.jump_forward_map.jump_forward_byte( - req.regex_fsm_state - ) - if jump_forward_bytes is not None and len(jump_forward_bytes) > 1: - suffix_bytes = [] - continuation_range = range(0x80, 0xC0) - cur_state = req.regex_fsm_state - while ( - len(jump_forward_bytes) - and jump_forward_bytes[0][0] in continuation_range - ): - # continuation bytes - byte_edge = jump_forward_bytes.pop(0) - suffix_bytes.append(byte_edge[0]) - cur_state = byte_edge[1] - - suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes] - suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens) - - # Current ids, for cache and revert - cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1] - cur_output_ids = req.output_ids - - req.output_ids.extend(suffix_ids) - decode_res, new_text = req.get_next_inc_detokenization() - if not decode_res: - req.output_ids = cur_output_ids - continue - - ( - jump_forward_str, - next_state, - ) = req.jump_forward_map.jump_forward_symbol(cur_state) - - # Make the incrementally decoded text part of jump_forward_str - # so that the UTF-8 will not corrupt - jump_forward_str = new_text + jump_forward_str - if not req.jump_forward_and_retokenize( - jump_forward_str, next_state - ): - req.output_ids = cur_output_ids - continue - - print(f"Jump forward: {jump_forward_str}") - - # The decode status has diverged from detokenizer_manager - req.vid += 1 - - # insert the old request into tree_cache - self.tree_cache.cache_finished_req(req, cur_all_ids) - - # re-applying image padding - if req.image_inputs is not None: - req.origin_input_ids = pad_input_ids_func( - req.origin_input_ids_unpadded, req.image_inputs - ) - - jump_forward_reqs.append(req) - filter_indices.remove(i) - if req.allow_jump_forward and req.regex_bnf is not None: jump_forward_str = req.regex_bnf.find_jump_forward_string() if len(jump_forward_str) > 1: @@ -858,7 +733,6 @@ def filter_batch(self, unfinished_indices: List[int]): self.top_logprobs_nums = None self.has_stream = any(req.stream for req in self.reqs) - # self.has_regex = any(req.regex_fsm for req in self.reqs) self.has_regex = any(req.regex_bnf for req in self.reqs) self.sampling_info.filter_batch(unfinished_indices, new_indices) @@ -899,12 +773,6 @@ def get_model_worker_batch(self): lora_paths = [req.lora_path for req in self.reqs] if self.has_regex: - self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] - self.sampling_info.regex_fsm_states = [ - req.regex_fsm_state for req in self.reqs - ] - - # TODO(dark): remove the above and use the below self.sampling_info.regex_bnfs = [req.regex_bnf for req in self.reqs] return ModelWorkerBatch( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1af80a3d237..cde093df668 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -29,8 +29,6 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.bnf_cache import BNFCache -from sglang.srt.constrained.fsm_cache import FSMCache -from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( @@ -204,17 +202,8 @@ def __init__( self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) - # Init the FSM cache for constrained generation + # Init the BNF cache for constrained generation if not server_args.skip_tokenizer_init: - self.regex_fsm_cache = FSMCache( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - skip_tokenizer_init=server_args.skip_tokenizer_init, - constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, - ) self.regex_bnf_cache = BNFCache( server_args.tokenizer_path, { @@ -223,7 +212,6 @@ def __init__( }, skip_tokenizer_init=server_args.skip_tokenizer_init, ) - self.jump_forward_cache = JumpForwardCache() # Init new token estimation assert ( @@ -314,32 +302,21 @@ def handle_generate_request( # By default, only return the logprobs for output tokens req.logprob_start_len = len(recv_req.input_ids) - 1 - # Init regex FSM + # Init regex BNF if ( req.sampling_params.json_schema is not None or req.sampling_params.regex is not None ): - # TODO(dark): replace FSM cache with BNF cache if req.sampling_params.json_schema is not None: - req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( - ("json", req.sampling_params.json_schema) - ) req.regex_bnf = self.regex_bnf_cache.query( ("json", req.sampling_params.json_schema) ) elif req.sampling_params.regex is not None: - req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( - ("regex", req.sampling_params.regex) - ) req.regex_bnf = self.regex_bnf_cache.query( ("regex", req.sampling_params.regex) ) if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - computed_regex_string - ) req.allow_jump_forward = True - # TODO(dark): add custom bnf jump forward map here # Truncate prompts that are too long if len(req.origin_input_ids) >= self.max_req_input_len: @@ -698,12 +675,6 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): req.output_ids.append(next_token_ids[i]) req.check_finished() - # TODO(dark): replace FSM cache with BNF cache - if req.regex_fsm is not None: - req.regex_fsm_state = req.regex_fsm.get_next_state( - req.regex_fsm_state, next_token_ids[i] - ) - if req.regex_bnf is not None: assert req.regex_bnf.accept_token(next_token_ids[i]) @@ -773,12 +744,6 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): req.output_ids.append(next_token_id) req.check_finished() - # TODO(dark): replace FSM cache with BNF cache - if req.regex_fsm is not None: - req.regex_fsm_state = req.regex_fsm.get_next_state( - req.regex_fsm_state, next_token_id - ) - if req.regex_bnf is not None: assert req.regex_bnf.accept_token(next_token_id) @@ -984,7 +949,6 @@ def flush_cache(self): ): self.tree_cache.reset() self.tree_cache_metrics = {"total": 0, "hit": 0} - self.regex_fsm_cache.reset() self.regex_bnf_cache.reset() self.req_to_token_pool.clear() self.token_to_kv_pool.clear() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f3390ae46cb..43fefd8f835 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -523,7 +523,6 @@ def sample( # Put CPU-heavy tasks here. They will be overlapped with the forward pass. sampling_info = forward_batch.sampling_info sampling_info.update_regex_vocab_mask_bnf() - sampling_info.update_regex_vocab_mask() sampling_info.update_penalties() logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 569465ce5e2..96365795956 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -30,10 +30,6 @@ class SamplingBatchInfo: vocab_mask: Optional[torch.Tensor] = None # FSM states - regex_fsms: List[RegexGuide] = None - regex_fsm_states: List[int] = None - - # TODO(dark): remove the above and use the regex_bnf instead regex_bnfs: Optional[List[Optional[GrammarStateMatcher]]] = None # Penalizer @@ -115,23 +111,6 @@ def update_penalties(self): ) self.linear_penalties = penalizer.apply(self.linear_penalties) - def update_regex_vocab_mask(self): - has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms) - - # Reset the vocab mask - self.vocab_mask = None - - if has_regex: - self.vocab_mask = torch.zeros( - len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda" - ) - for i, regex_fsm in enumerate(self.regex_fsms): - if regex_fsm is not None: - self.vocab_mask[i].fill_(1) - self.vocab_mask[i][ - regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens - ] = 0 - # TODO(dark): rename this to update_regex_vocab_mask after removing the old one def update_regex_vocab_mask_bnf(self): # Reset the vocab mask From 0adf268be446eb2609168193ab4064d82211320c Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Fri, 11 Oct 2024 04:24:49 +0000 Subject: [PATCH 04/14] test(xgrammar): customize some testcases for xgrammar | yet buggy... --- test/srt/test_json_constrained.py | 112 ++++++++++++++++++++---------- 1 file changed, 74 insertions(+), 38 deletions(-) diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index d3abc70a44f..8340118d48d 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -17,33 +17,25 @@ class TestJSONConstrained(unittest.TestCase): def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - cls.json_schema = json.dumps( - { - "type": "object", - "properties": { - "name": {"type": "string", "pattern": "^[\\w]+$"}, - "population": {"type": "integer"}, - }, - "required": ["name", "population"], - } - ) cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) @classmethod def tearDownClass(cls): kill_child_process(cls.process.pid) - def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): - response = requests.post( + def run_decode( + self, prompt, json_schema, return_logprob=False, top_logprobs_num=0, n=1 + ): + return requests.post( self.base_url + "/generate", json={ - "text": "The capital of France is", + "text": prompt, "sampling_params": { "temperature": 0 if n == 1 else 0.5, "max_new_tokens": 128, "n": n, "stop_token_ids": [119690], - "json_schema": self.json_schema, + "json_schema": json_schema, }, "stream": False, "return_logprob": return_logprob, @@ -51,6 +43,20 @@ def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): "logprob_start_len": 0, }, ) + + def test_json_generate_simple(self): + prompt = "The capital of France is" + json_schema = json.dumps( + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + } + ) + response = self.run_decode(prompt, json_schema) print(json.dumps(response.json())) print("=" * 100) try: @@ -60,34 +66,64 @@ def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): assert isinstance(js_obj["name"], str) assert isinstance(js_obj["population"], int) - def test_json_generate(self): - self.run_decode() - - def test_json_openai(self): - client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1") - - response = client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "Introduce the capital of France."}, - ], - temperature=0, - max_tokens=128, - response_format={ - "type": "json_schema", - "json_schema": {"name": "foo", "schema": json.loads(self.json_schema)}, - }, - ) - text = response.choices[0].message.content - + def test_json_generate_complex(self): + # TODO(dark): Fix this test. $definitions is not supported by xgrammar yet. + prompt = "Please create a character named Komeiji Satori:" + json_schema = """{ + "title": "Character", + "type": "object", + "properties": { + "name": { + "title": "Name", + "type": "string" + }, + "age": { + "title": "Age", + "type": "integer" + }, + "armor": {"$ref": "#/$definitions/Armor"}, + "weapon": {"$ref": "#/$definitions/Weapon"}, + "strength": { + "title": "Strength", + "type": "integer" + } + }, + "required": ["name", "age", "armor", "weapon", "strength"], + "definitions": { + "Armor": { + "title": "Armor", + "description": "An enumeration.", + "enum": ["skirt", "leather", "chainmail", "plate"], + "type": "string" + }, + "Weapon": { + "title": "Weapon", + "description": "An enumeration.", + "enum": ["third eye", "sword", "axe", "mace", "spear", "bow", "crossbow"], + "type": "string" + } + } +}""" + response = self.run_decode(prompt, json_schema) + print(json.dumps(response.json())) + print("=" * 100) try: - js_obj = json.loads(text) + js_obj = json.loads(response.json()["text"]) except (TypeError, json.decoder.JSONDecodeError): - print("JSONDecodeError", text) raise assert isinstance(js_obj["name"], str) - assert isinstance(js_obj["population"], int) + assert isinstance(js_obj["age"], int) + assert js_obj["armor"] in ["skirt", "leather", "chainmail", "plate"] + assert js_obj["weapon"] in [ + "third eye", + "sword", + "axe", + "mace", + "spear", + "bow", + "crossbow", + ] + assert isinstance(js_obj["strength"], int) if __name__ == "__main__": From 66ce4c6635b57f65b3a201f29384b3d583910d65 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Sat, 12 Oct 2024 07:35:12 +0000 Subject: [PATCH 05/14] minor(xgrammar): fix the testcase --- test/srt/test_json_constrained.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 8340118d48d..3cb9f3be684 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -59,15 +59,11 @@ def test_json_generate_simple(self): response = self.run_decode(prompt, json_schema) print(json.dumps(response.json())) print("=" * 100) - try: - js_obj = json.loads(response.json()["text"]) - except (TypeError, json.decoder.JSONDecodeError): - raise + js_obj = json.loads(response.json()["text"]) assert isinstance(js_obj["name"], str) assert isinstance(js_obj["population"], int) def test_json_generate_complex(self): - # TODO(dark): Fix this test. $definitions is not supported by xgrammar yet. prompt = "Please create a character named Komeiji Satori:" json_schema = """{ "title": "Character", @@ -81,15 +77,15 @@ def test_json_generate_complex(self): "title": "Age", "type": "integer" }, - "armor": {"$ref": "#/$definitions/Armor"}, - "weapon": {"$ref": "#/$definitions/Weapon"}, + "armor": {"$ref": "#/$defs/Armor"}, + "weapon": {"$ref": "#/$defs/Weapon"}, "strength": { "title": "Strength", "type": "integer" } }, "required": ["name", "age", "armor", "weapon", "strength"], - "definitions": { + "$defs": { "Armor": { "title": "Armor", "description": "An enumeration.", @@ -107,10 +103,7 @@ def test_json_generate_complex(self): response = self.run_decode(prompt, json_schema) print(json.dumps(response.json())) print("=" * 100) - try: - js_obj = json.loads(response.json()["text"]) - except (TypeError, json.decoder.JSONDecodeError): - raise + js_obj = json.loads(response.json()["text"]) assert isinstance(js_obj["name"], str) assert isinstance(js_obj["age"], int) assert js_obj["armor"] in ["skirt", "leather", "chainmail", "plate"] From 44906b79363243f988186a2113bfc61c9351be76 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Sat, 12 Oct 2024 13:56:16 +0000 Subject: [PATCH 06/14] fix(xgrammar): fix the rollback of jump_forward --- python/sglang/srt/constrained/bnf_cache.py | 8 ++++++-- python/sglang/srt/managers/schedule_batch.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/bnf_cache.py index 99ccfcd93fe..dabf0eff9b1 100644 --- a/python/sglang/srt/constrained/bnf_cache.py +++ b/python/sglang/srt/constrained/bnf_cache.py @@ -20,6 +20,8 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache +MAX_ROLLBACK_STEPS = 10 + class BNFCache(BaseToolCache): def __init__( @@ -41,7 +43,7 @@ def init_value(self, key): key_type, key_string = key if key_type == "json": - grammar = BuiltinGrammar.json_schema(key_string) + grammar = BuiltinGrammar.json_schema(key_string, indent=None) elif key_type == "regex": assert False, "Not supported by xgrammar yet" else: @@ -51,7 +53,9 @@ def init_value(self, key): def query(self, key): grammar = super().query(key) - return GrammarStateMatcher(grammar, self.tokenizer) + return GrammarStateMatcher( + grammar, self.tokenizer, max_rollback_steps=MAX_ROLLBACK_STEPS + ) # class BNFCache(BaseToolCache): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 9cc036bba19..6adfb9a208c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -330,6 +330,7 @@ def check_finished(self): def jump_forward_and_retokenize_bnf(self, jump_forward_str): assert self.regex_bnf is not None, "should be a regex request" + assert self.tokenizer is not None, "should have a tokenizer" if self.origin_input_text is None: # Recovering text can only use unpadded ids @@ -378,10 +379,11 @@ def jump_forward_and_retokenize_bnf(self, jump_forward_str): break # rollback to the last token that is the same - self.regex_bnf.rollback(len(old_output_ids) - k) + if k < len(old_output_ids): + self.regex_bnf.rollback(len(old_output_ids) - k) for i in range(k, len(self.output_ids)): - self.regex_bnf.accept_token(self.output_ids[i]) + assert self.regex_bnf.accept_token(self.output_ids[i]) if self.return_logprob: # For fast-forward part's logprobs @@ -661,6 +663,12 @@ def check_for_jump_forward(self, pad_input_ids_func): if len(jump_forward_str) > 1: cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1] cur_output_ids = req.output_ids + decode_res, new_text = req.get_next_inc_detokenization() + if not decode_res: + req.output_ids = cur_output_ids + continue + + jump_forward_str = new_text + jump_forward_str if not req.jump_forward_and_retokenize_bnf(jump_forward_str): # Failed to jump forward, revert req.output_ids = cur_output_ids From 5582355f381c04f49407ddc7049192357689da6a Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Sun, 13 Oct 2024 04:05:08 +0000 Subject: [PATCH 07/14] minor(xgrammar): fix some merge errors --- .../srt/sampling/sampling_batch_info.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 9b7b384c448..417618ca6f8 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -126,22 +126,25 @@ def update_penalties(self): ) self.linear_penalties = penalizer.apply(self.linear_penalties) - def update_regex_vocab_mask(self): - has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms) - + def update_regex_vocab_mask_bnf(self): # Reset the vocab mask self.vocab_mask = None - if has_regex: + if self.regex_bnfs and any(regex_bnf for regex_bnf in self.regex_bnfs): + # If has regex, then we need to update the vocab mask self.vocab_mask = torch.zeros( - len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda" + len(self.temperatures), self.vocab_size, dtype=torch.bool, device=self.device ) - for i, regex_fsm in enumerate(self.regex_fsms): - if regex_fsm is not None: - self.vocab_mask[i].fill_(1) + for i, regex_bnf in enumerate(self.regex_bnfs): + if regex_bnf is not None: + # Note that this bitmask is a bitset, not bool + bitmask = regex_bnf.find_next_token_bitmask() + # Mask the tokens that are not allowed self.vocab_mask[i][ - regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens - ] = 0 + regex_bnf.get_rejected_tokens_from_bitmask( + bitmask, self.vocab_size + ) + ] = 1 def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): if self.penalizer_orchestrator: From 435525fba2ee50ed9622d45a079c374c17265478 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Sun, 13 Oct 2024 04:48:27 +0000 Subject: [PATCH 08/14] minor(xgrammar): fix some merge errors in testcases --- test/srt/test_json_constrained.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index e87085a3f01..318037af914 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -91,7 +91,7 @@ def test_json_generate_complex(self): "Armor": { "title": "Armor", "description": "An enumeration.", - "enum": ["skirt", "leather", "chainmail", "plate"], + "enum": ["leather", "chainmail", "plate"], "type": "string" }, "Weapon": { @@ -107,8 +107,10 @@ def test_json_generate_complex(self): print("=" * 100) js_obj = json.loads(response.json()["text"]) assert isinstance(js_obj["name"], str) - assert isinstance(js_obj["population"], int) - + assert isinstance(js_obj["age"], int) + assert js_obj["armor"] in ["leather", "chainmail", "plate"] + assert js_obj["weapon"] in ["third eye", "sword", "axe", "mace", "spear", "bow", "crossbow"] + assert isinstance(js_obj["strength"], int) if __name__ == "__main__": unittest.main() From 0e05a194a1b2f0538ac4a92f4686c9cae512e5e4 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Sun, 13 Oct 2024 04:49:33 +0000 Subject: [PATCH 09/14] minor(xgrammar): format the code --- python/sglang/srt/sampling/sampling_batch_info.py | 5 ++++- test/srt/test_json_constrained.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 417618ca6f8..ff68c14c1c7 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -133,7 +133,10 @@ def update_regex_vocab_mask_bnf(self): if self.regex_bnfs and any(regex_bnf for regex_bnf in self.regex_bnfs): # If has regex, then we need to update the vocab mask self.vocab_mask = torch.zeros( - len(self.temperatures), self.vocab_size, dtype=torch.bool, device=self.device + len(self.temperatures), + self.vocab_size, + dtype=torch.bool, + device=self.device, ) for i, regex_bnf in enumerate(self.regex_bnfs): if regex_bnf is not None: diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 318037af914..1d085e5d4e4 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -109,8 +109,17 @@ def test_json_generate_complex(self): assert isinstance(js_obj["name"], str) assert isinstance(js_obj["age"], int) assert js_obj["armor"] in ["leather", "chainmail", "plate"] - assert js_obj["weapon"] in ["third eye", "sword", "axe", "mace", "spear", "bow", "crossbow"] + assert js_obj["weapon"] in [ + "third eye", + "sword", + "axe", + "mace", + "spear", + "bow", + "crossbow", + ] assert isinstance(js_obj["strength"], int) + if __name__ == "__main__": unittest.main() From 0bda5ead786655201e334702b9e3ad1f89b39f7a Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Mon, 14 Oct 2024 09:37:45 +0000 Subject: [PATCH 10/14] feat(xgrammar): adapt to a newer version of xgrammar --- python/sglang/srt/constrained/bnf_cache.py | 63 +++++-------------- python/sglang/srt/managers/schedule_batch.py | 4 +- .../srt/sampling/sampling_batch_info.py | 4 +- 3 files changed, 20 insertions(+), 51 deletions(-) diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/bnf_cache.py index dabf0eff9b1..605758a765e 100644 --- a/python/sglang/srt/constrained/bnf_cache.py +++ b/python/sglang/srt/constrained/bnf_cache.py @@ -16,14 +16,16 @@ """Cache for the compressed finite state machine.""" from transformers import AutoTokenizer -from xgrammar import BuiltinGrammar, GrammarStateMatcher - -from sglang.srt.constrained.base_tool_cache import BaseToolCache +from xgrammar import ( + GrammarMatcher, + GrammarMatcherInitContext, + GrammarMatcherInitContextCache, +) MAX_ROLLBACK_STEPS = 10 -class BNFCache(BaseToolCache): +class BNFCache: def __init__( self, tokenizer_path, @@ -31,58 +33,25 @@ def __init__( skip_tokenizer_init=False, enable=True, ): - super().__init__(enable=enable) - if skip_tokenizer_init: - return + # TODO(dark): determine how to handle with `skip_tokenizer_init` self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, **tokenizer_args_dict ) + self.grammar_cache = GrammarMatcherInitContextCache( + tokenizer_or_vocab=self.tokenizer + ) - def init_value(self, key): + def get_context(self, key) -> GrammarMatcherInitContext: key_type, key_string = key if key_type == "json": - grammar = BuiltinGrammar.json_schema(key_string, indent=None) + return self.grammar_cache.get_init_context_for_json_schema(key_string) elif key_type == "regex": - assert False, "Not supported by xgrammar yet" + raise ValueError(f"regex hasn't been supported by xgrammar yet") else: raise ValueError(f"Invalid key_type: {key_type}") - return grammar - - def query(self, key): - grammar = super().query(key) - return GrammarStateMatcher( - grammar, self.tokenizer, max_rollback_steps=MAX_ROLLBACK_STEPS - ) - - -# class BNFCache(BaseToolCache): -# def __init__( -# self, -# tokenizer_path, -# tokenizer_args_dict, -# enable=True, -# skip_tokenizer_init=False, -# constrained_json_whitespace_pattern=None, -# ): -# super().__init__(enable=enable) - -# if skip_tokenizer_init: -# return - -# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) -# self.tokenizer = tokenizer -# self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern - -# def init_value(self, key): -# key_type, key_string = key -# if key_type == "json": -# grammar = BuiltinGrammar.json_schema(key_string) -# elif key_type == "regex": -# assert False, "Not supported yet" -# else: -# raise ValueError(f"Invalid key_type: {key_type}") - -# return GrammarStateMatcher(grammar, self.tokenizer) + def query(self, key) -> GrammarMatcher: + ctx = self.get_context(key) + return GrammarMatcher(ctx, max_rollback_steps=MAX_ROLLBACK_STEPS) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 02692af3fc4..c15073239f5 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -32,7 +32,7 @@ from typing import List, Optional, Tuple, Union import torch -from xgrammar import GrammarStateMatcher +from xgrammar import GrammarMatcher from sglang.global_config import global_config from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache @@ -226,7 +226,7 @@ def __init__( self.embedding = None # Constrained decoding - self.regex_bnf: Optional[GrammarStateMatcher] = None + self.regex_bnf: Optional[GrammarMatcher] = None self.allow_jump_forward: bool = False # whether request reached finished condition diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index ff68c14c1c7..1024a091c4a 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, List, Optional import torch -from xgrammar import GrammarStateMatcher +from xgrammar import GrammarMatcher import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.constrained import RegexGuide @@ -30,7 +30,7 @@ class SamplingBatchInfo: vocab_mask: Optional[torch.Tensor] = None # FSM states - regex_bnfs: Optional[List[Optional[GrammarStateMatcher]]] = None + regex_bnfs: Optional[List[Optional[GrammarMatcher]]] = None # Penalizer penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None From 8cf411baa9872a48bfe0d2d1b2d2c7d370ed5650 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Mon, 14 Oct 2024 09:40:46 +0000 Subject: [PATCH 11/14] minor(xgrammar): disable bnf cache reset temporarily(need support from xgrammar) --- python/sglang/srt/managers/scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fe07b09f908..029bdd856f7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -986,7 +986,8 @@ def flush_cache(self): ): self.tree_cache.reset() self.tree_cache_metrics = {"total": 0, "hit": 0} - self.regex_bnf_cache.reset() + # TODO(dark): How to reset bnf cache? + # self.regex_bnf_cache.reset() self.req_to_token_pool.clear() self.token_to_kv_pool.clear() torch.cuda.empty_cache() From 262f56771d2b0d0923aeefdb9076d93a7a72ff6a Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Tue, 15 Oct 2024 05:29:41 +0000 Subject: [PATCH 12/14] minor(xgrammar): adapt to newer api of xgrammar --- python/sglang/srt/constrained/bnf_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/bnf_cache.py index 605758a765e..03078be5fc3 100644 --- a/python/sglang/srt/constrained/bnf_cache.py +++ b/python/sglang/srt/constrained/bnf_cache.py @@ -54,4 +54,4 @@ def get_context(self, key) -> GrammarMatcherInitContext: def query(self, key) -> GrammarMatcher: ctx = self.get_context(key) - return GrammarMatcher(ctx, max_rollback_steps=MAX_ROLLBACK_STEPS) + return GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_STEPS) From dba81d3842e3b3262da277644abf17a905caa3c2 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Wed, 16 Oct 2024 06:32:22 +0000 Subject: [PATCH 13/14] fix(xgrammar): pass model.vocab_size to xgrammar to generate correct state_matcher --- python/sglang/srt/constrained/bnf_cache.py | 12 +++++------- python/sglang/srt/managers/scheduler.py | 4 ++-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/bnf_cache.py index 03078be5fc3..fd8015559b1 100644 --- a/python/sglang/srt/constrained/bnf_cache.py +++ b/python/sglang/srt/constrained/bnf_cache.py @@ -42,16 +42,14 @@ def __init__( tokenizer_or_vocab=self.tokenizer ) - def get_context(self, key) -> GrammarMatcherInitContext: - key_type, key_string = key - + def get_context(self, key_type, key_str) -> GrammarMatcherInitContext: if key_type == "json": - return self.grammar_cache.get_init_context_for_json_schema(key_string) + return self.grammar_cache.get_init_context_for_json_schema(key_str) elif key_type == "regex": raise ValueError(f"regex hasn't been supported by xgrammar yet") else: raise ValueError(f"Invalid key_type: {key_type}") - def query(self, key) -> GrammarMatcher: - ctx = self.get_context(key) - return GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_STEPS) + def query(self, key_type: str, key_str : str, vocab_size : int) -> GrammarMatcher: + ctx = self.get_context(key_type, key_str) + return GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_STEPS, mask_vocab_size=vocab_size) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 029bdd856f7..921345e02de 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -353,11 +353,11 @@ def handle_generate_request( ): if req.sampling_params.json_schema is not None: req.regex_bnf = self.regex_bnf_cache.query( - ("json", req.sampling_params.json_schema) + "json", req.sampling_params.json_schema, self.model_config.vocab_size ) elif req.sampling_params.regex is not None: req.regex_bnf = self.regex_bnf_cache.query( - ("regex", req.sampling_params.regex) + "regex", req.sampling_params.regex, self.model_config.vocab_size ) if not self.disable_regex_jump_forward: req.allow_jump_forward = True From 6ad4ccecc21d82a24999949dccf4d5ed4aa9699f Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Wed, 16 Oct 2024 06:37:35 +0000 Subject: [PATCH 14/14] minor(xgrammar): run pre-commit to format the code --- python/sglang/srt/constrained/bnf_cache.py | 6 ++++-- python/sglang/srt/managers/scheduler.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/bnf_cache.py index fd8015559b1..61cb99df857 100644 --- a/python/sglang/srt/constrained/bnf_cache.py +++ b/python/sglang/srt/constrained/bnf_cache.py @@ -50,6 +50,8 @@ def get_context(self, key_type, key_str) -> GrammarMatcherInitContext: else: raise ValueError(f"Invalid key_type: {key_type}") - def query(self, key_type: str, key_str : str, vocab_size : int) -> GrammarMatcher: + def query(self, key_type: str, key_str: str, vocab_size: int) -> GrammarMatcher: ctx = self.get_context(key_type, key_str) - return GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_STEPS, mask_vocab_size=vocab_size) + return GrammarMatcher( + ctx, max_rollback_tokens=MAX_ROLLBACK_STEPS, mask_vocab_size=vocab_size + ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4517e5e1b4a..a72fb1737f7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -356,11 +356,13 @@ def handle_generate_request( ): if req.sampling_params.json_schema is not None: req.regex_bnf = self.regex_bnf_cache.query( - "json", req.sampling_params.json_schema, self.model_config.vocab_size + "json", + req.sampling_params.json_schema, + self.model_config.vocab_size, ) elif req.sampling_params.regex is not None: req.regex_bnf = self.regex_bnf_cache.query( - "regex", req.sampling_params.regex, self.model_config.vocab_size + "regex", req.sampling_params.regex, self.model_config.vocab_size ) if not self.disable_regex_jump_forward: req.allow_jump_forward = True