diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/bnf_cache.py new file mode 100644 index 00000000000..61cb99df857 --- /dev/null +++ b/python/sglang/srt/constrained/bnf_cache.py @@ -0,0 +1,57 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Cache for the compressed finite state machine.""" + +from transformers import AutoTokenizer +from xgrammar import ( + GrammarMatcher, + GrammarMatcherInitContext, + GrammarMatcherInitContextCache, +) + +MAX_ROLLBACK_STEPS = 10 + + +class BNFCache: + def __init__( + self, + tokenizer_path, + tokenizer_args_dict, + skip_tokenizer_init=False, + enable=True, + ): + # TODO(dark): determine how to handle with `skip_tokenizer_init` + + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, **tokenizer_args_dict + ) + self.grammar_cache = GrammarMatcherInitContextCache( + tokenizer_or_vocab=self.tokenizer + ) + + def get_context(self, key_type, key_str) -> GrammarMatcherInitContext: + if key_type == "json": + return self.grammar_cache.get_init_context_for_json_schema(key_str) + elif key_type == "regex": + raise ValueError(f"regex hasn't been supported by xgrammar yet") + else: + raise ValueError(f"Invalid key_type: {key_type}") + + def query(self, key_type: str, key_str: str, vocab_size: int) -> GrammarMatcher: + ctx = self.get_context(key_type, key_str) + return GrammarMatcher( + ctx, max_rollback_tokens=MAX_ROLLBACK_STEPS, mask_vocab_size=vocab_size + ) diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py deleted file mode 100644 index ab025f26e9d..00000000000 --- a/python/sglang/srt/constrained/fsm_cache.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -"""Cache for the compressed finite state machine.""" -import logging - -from interegular import InvalidSyntax, parse_pattern -from outlines.fsm.json_schema import build_regex_from_schema -from transformers import AutoTokenizer - -from sglang.srt.constrained import RegexGuide, TransformerTokenizer -from sglang.srt.constrained.base_tool_cache import BaseToolCache - -logger = logging.getLogger(__name__) - - -class FSMCache(BaseToolCache): - def __init__( - self, - tokenizer_path, - tokenizer_args_dict, - enable=True, - skip_tokenizer_init=False, - constrained_json_whitespace_pattern=None, - ): - super().__init__(enable=enable) - - if ( - skip_tokenizer_init - or tokenizer_path.endswith(".json") - or tokenizer_path.endswith(".model") - ): - # Do not support TiktokenTokenizer or SentencePieceTokenizer - return - - tokenizer_args_dict.setdefault("padding_side", "left") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) - try: - self.outlines_tokenizer = TransformerTokenizer(tokenizer) - except AttributeError: - # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) - origin_pad_token_id = tokenizer.pad_token_id - - def fset(self, value): - self._value = value - - type(tokenizer).pad_token_id = property( - fget=type(tokenizer).pad_token_id.fget, fset=fset - ) - self.outlines_tokenizer = TransformerTokenizer(tokenizer) - self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id - self.outlines_tokenizer.pad_token_id = origin_pad_token_id - self.outlines_tokenizer.pad_token = ( - self.outlines_tokenizer.tokenizer.pad_token - ) - self.outlines_tokenizer.vocabulary = ( - self.outlines_tokenizer.tokenizer.get_vocab() - ) - self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern - - def init_value(self, key): - key_type, key_string = key - if key_type == "json": - regex = build_regex_from_schema( - key_string, whitespace_pattern=self.constrained_json_whitespace_pattern - ) - elif key_type == "regex": - regex = key_string - else: - raise ValueError(f"Invalid key_type: {key_type}") - try: - parse_pattern(regex) - except InvalidSyntax as e: - logger.warning(f"skip invalid regex guide: {regex=}, {e=}") - return None, regex - return RegexGuide(regex, self.outlines_tokenizer), regex diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/jump_forward.py deleted file mode 100644 index 1ebc8b21718..00000000000 --- a/python/sglang/srt/constrained/jump_forward.py +++ /dev/null @@ -1,203 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -""" -Faster constrained decoding. -Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/ -""" - -import dataclasses -import logging -from collections import defaultdict - -import interegular -import outlines.caching -from interegular import InvalidSyntax - -from sglang.srt.constrained import ( - FSMInfo, - disk_cache, - make_byte_level_fsm, - make_deterministic_fsm, -) -from sglang.srt.constrained.base_tool_cache import BaseToolCache - -IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" - -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass -class JumpEdge: - symbol: str = None - symbol_next_state: int = None - byte: int = None - byte_next_state: int = None - - -class JumpForwardMap: - def __init__(self, regex_string): - @disk_cache() - def _init_state_to_jump_forward(regex_string): - try: - regex_pattern = interegular.parse_pattern(regex_string) - except InvalidSyntax as e: - logger.warning(f"skip invalid regex: {regex_string}, {e=}") - self.state_to_jump_forward = None - return - - byte_fsm = make_byte_level_fsm( - regex_pattern.to_fsm().reduce(), keep_utf8=True - ) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - - fsm_info: FSMInfo = regex_fsm.fsm_info - - symbol_to_id = fsm_info.alphabet_symbol_mapping - id_to_symbol = {} - for symbol, id_ in symbol_to_id.items(): - id_to_symbol.setdefault(id_, []).append(symbol) - - transitions = fsm_info.transitions - - outgoings_ct = defaultdict(int) - # NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally - for s in fsm_info.finals: - outgoings_ct[s] = 1 - - state_to_jump_forward = {} - for (state, id_), next_state in transitions.items(): - if id_ == fsm_info.alphabet_anything_value: - # Arbitrarily symbol cannot be recognized as jump forward - continue - - symbols = id_to_symbol[id_] - for c in symbols: - if len(c) > 1: - # Skip byte level transitions like c = "5E" - continue - - outgoings_ct[state] += 1 - if outgoings_ct[state] > 1: - if state in state_to_jump_forward: - del state_to_jump_forward[state] - break - - state_to_jump_forward[state] = JumpEdge( - symbol=c, - symbol_next_state=next_state, - ) - - # Process the byte level jump forward - outgoings_ct = defaultdict(int) - for s in fsm_info.finals: - outgoings_ct[s] = 1 - - for (state, id_), next_state in transitions.items(): - if id_ == fsm_info.alphabet_anything_value: - continue - symbols = id_to_symbol[id_] - for c in symbols: - byte_ = None - if len(c) == 1 and ord(c) < 0x80: - # ASCII character - byte_ = ord(c) - elif len(c) > 1: - # FIXME: This logic is due to the leading \x00 - # https://github.com/outlines-dev/outlines/pull/930 - byte_ = int(symbols[0][1:], 16) - - if byte_ is not None: - outgoings_ct[state] += 1 - if outgoings_ct[state] > 1: - if state in state_to_jump_forward: - del state_to_jump_forward[state] - break - e = state_to_jump_forward.get(state, JumpEdge()) - e.byte = byte_ - e.byte_next_state = next_state - state_to_jump_forward[state] = e - - return state_to_jump_forward - - self.state_to_jump_forward = _init_state_to_jump_forward(regex_string) - - def jump_forward_symbol(self, state): - jump_forward_str = "" - next_state = state - while state in self.state_to_jump_forward: - e = self.state_to_jump_forward[state] - if e.symbol is None: - break - jump_forward_str += e.symbol - next_state = e.symbol_next_state - state = next_state - - return jump_forward_str, next_state - - def jump_forward_byte(self, state): - if state not in self.state_to_jump_forward: - return None - - jump_forward_bytes = [] - next_state = None - while state in self.state_to_jump_forward: - e = self.state_to_jump_forward[state] - assert e.byte is not None and e.byte_next_state is not None - jump_forward_bytes.append((e.byte, e.byte_next_state)) - next_state = e.byte_next_state - state = next_state - - return jump_forward_bytes - - def is_jump_forward_symbol_state(self, state): - return ( - state in self.state_to_jump_forward - and self.state_to_jump_forward[state].symbol is not None - ) - - -class JumpForwardCache(BaseToolCache): - def __init__(self): - super().__init__() - - def init_value(self, regex): - forward_map = JumpForwardMap(regex) - if forward_map.state_to_jump_forward: - return forward_map - else: - return None - - -def test_main(regex_string): - jump_forward_map = JumpForwardMap(regex_string) - for state, e in jump_forward_map.state_to_jump_forward.items(): - if e.symbol is not None: - jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state) - print(f"{state} -> {next_state}", jump_forward_str) - bytes_ = jump_forward_map.jump_forward_byte(state) - print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_]) - - -if __name__ == "__main__": - import outlines - - outlines.caching.clear_cache() - test_main(r"The google's DNS sever address is " + IP_REGEX) - test_main(r"霍格沃茨特快列车|霍比特人比尔博") - # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ... - # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ... - - test_main(r"[-+]?[0-9]+[ ]*") diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 8d7ccd3547e..5e202332047 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -32,10 +32,9 @@ from typing import List, Optional, Tuple, Union import torch +from xgrammar import GrammarMatcher from sglang.global_config import global_config -from sglang.srt.constrained import RegexGuide -from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -228,9 +227,8 @@ def __init__( self.embedding = None # Constrained decoding - self.regex_fsm: RegexGuide = None - self.regex_fsm_state: int = 0 - self.jump_forward_map: JumpForwardMap = None + self.regex_bnf: Optional[GrammarMatcher] = None + self.allow_jump_forward: bool = False # whether request reached finished condition def finished(self) -> bool: @@ -331,7 +329,10 @@ def check_finished(self): self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) return - def jump_forward_and_retokenize(self, jump_forward_str, next_state): + def jump_forward_and_retokenize_bnf(self, jump_forward_str): + assert self.regex_bnf is not None, "should be a regex request" + assert self.tokenizer is not None, "should have a tokenizer" + if self.origin_input_text is None: # Recovering text can only use unpadded ids self.origin_input_text = self.tokenizer.decode( @@ -371,16 +372,22 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): self.surr_offset = self.read_offset - i break - self.regex_fsm_state = next_state + k = 0 + for i, old_id in enumerate(old_output_ids): + if old_id == self.output_ids[i]: + k = i + 1 + else: + break + + # rollback to the last token that is the same + if k < len(old_output_ids): + self.regex_bnf.rollback(len(old_output_ids) - k) + + for i in range(k, len(self.output_ids)): + assert self.regex_bnf.accept_token(self.output_ids[i]) if self.return_logprob: # For fast-forward part's logprobs - k = 0 - for i, old_id in enumerate(old_output_ids): - if old_id == self.output_ids[i]: - k = k + 1 - else: - break self.output_token_logprobs = self.output_token_logprobs[:k] self.output_top_logprobs = self.output_top_logprobs[:k] self.logprob_start_len = prompt_tokens + k @@ -440,7 +447,7 @@ class ScheduleBatch: def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): return_logprob = any(req.return_logprob for req in reqs) has_stream = any(req.stream for req in reqs) - has_regex = any(req.regex_fsm for req in reqs) + has_regex = any(req.regex_bnf for req in reqs) return cls( reqs=reqs, @@ -667,47 +674,19 @@ def check_for_jump_forward(self, pad_input_ids_func): keep_indices = set(i for i in range(len(self.reqs))) for i, req in enumerate(self.reqs): - if req.jump_forward_map is not None: - jump_forward_bytes = req.jump_forward_map.jump_forward_byte( - req.regex_fsm_state - ) - if jump_forward_bytes is not None and len(jump_forward_bytes) > 1: - suffix_bytes = [] - continuation_range = range(0x80, 0xC0) - cur_state = req.regex_fsm_state - while ( - len(jump_forward_bytes) - and jump_forward_bytes[0][0] in continuation_range - ): - # continuation bytes - byte_edge = jump_forward_bytes.pop(0) - suffix_bytes.append(byte_edge[0]) - cur_state = byte_edge[1] - - suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes] - suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens) - - # Current ids, for cache and revert + if req.allow_jump_forward and req.regex_bnf is not None: + jump_forward_str = req.regex_bnf.find_jump_forward_string() + if len(jump_forward_str) > 1: cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1] cur_output_ids = req.output_ids - - req.output_ids.extend(suffix_ids) decode_res, new_text = req.get_next_inc_detokenization() if not decode_res: req.output_ids = cur_output_ids continue - ( - jump_forward_str, - next_state, - ) = req.jump_forward_map.jump_forward_symbol(cur_state) - - # Make the incrementally decoded text part of jump_forward_str - # so that the UTF-8 will not corrupt jump_forward_str = new_text + jump_forward_str - if not req.jump_forward_and_retokenize( - jump_forward_str, next_state - ): + if not req.jump_forward_and_retokenize_bnf(jump_forward_str): + # Failed to jump forward, revert req.output_ids = cur_output_ids continue @@ -782,7 +761,7 @@ def filter_batch( self.top_logprobs_nums = None self.has_stream = any(req.stream for req in self.reqs) - self.has_regex = any(req.regex_fsm for req in self.reqs) + self.has_regex = any(req.regex_bnf for req in self.reqs) self.sampling_info.filter_batch(keep_indices, new_indices) @@ -824,12 +803,9 @@ def get_model_worker_batch(self): lora_paths = [req.lora_path for req in self.reqs] if self.has_regex: - self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] - self.sampling_info.regex_fsm_states = [ - req.regex_fsm_state for req in self.reqs - ] + self.sampling_info.regex_bnfs = [req.regex_bnf for req in self.reqs] else: - self.sampling_info.regex_fsms = None + self.sampling_info.regex_bnfs = None global bid bid += 1 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e3fea477788..a72fb1737f7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -28,8 +28,7 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.constrained.fsm_cache import FSMCache -from sglang.srt.constrained.jump_forward import JumpForwardCache +from sglang.srt.constrained.bnf_cache import BNFCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( @@ -207,18 +206,16 @@ def __init__( self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) - # Init the FSM cache for constrained generation + # Init the BNF cache for constrained generation if not server_args.skip_tokenizer_init: - self.regex_fsm_cache = FSMCache( + self.regex_bnf_cache = BNFCache( server_args.tokenizer_path, { "tokenizer_mode": server_args.tokenizer_mode, "trust_remote_code": server_args.trust_remote_code, }, skip_tokenizer_init=server_args.skip_tokenizer_init, - constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, ) - self.jump_forward_cache = JumpForwardCache() # Init new token estimation assert ( @@ -352,23 +349,23 @@ def handle_generate_request( # By default, only return the logprobs for output tokens req.logprob_start_len = len(recv_req.input_ids) - 1 - # Init regex FSM + # Init regex BNF if ( req.sampling_params.json_schema is not None or req.sampling_params.regex is not None ): if req.sampling_params.json_schema is not None: - req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( - ("json", req.sampling_params.json_schema) + req.regex_bnf = self.regex_bnf_cache.query( + "json", + req.sampling_params.json_schema, + self.model_config.vocab_size, ) elif req.sampling_params.regex is not None: - req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( - ("regex", req.sampling_params.regex) + req.regex_bnf = self.regex_bnf_cache.query( + "regex", req.sampling_params.regex, self.model_config.vocab_size ) if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - computed_regex_string - ) + req.allow_jump_forward = True # Truncate prompts that are too long if len(req.origin_input_ids) >= self.max_req_input_len: @@ -746,10 +743,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): elif not batch.decoding_reqs or req not in batch.decoding_reqs: self.tree_cache.cache_unfinished_req(req) - if req.regex_fsm is not None: - req.regex_fsm_state = req.regex_fsm.get_next_state( - req.regex_fsm_state, next_token_ids[i] - ) + if req.regex_bnf is not None: + assert req.regex_bnf.accept_token(next_token_ids[i]) if req.return_logprob: logprob_pt += self.add_logprob_return_values( @@ -800,10 +795,8 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): req.output_ids.append(next_token_id) req.check_finished() - if req.regex_fsm is not None: - req.regex_fsm_state = req.regex_fsm.get_next_state( - req.regex_fsm_state, next_token_id - ) + if req.regex_bnf is not None: + assert req.regex_bnf.accept_token(next_token_id) if req.finished(): self.tree_cache.cache_finished_req(req) @@ -996,7 +989,8 @@ def flush_cache(self): ): self.tree_cache.reset() self.tree_cache_metrics = {"total": 0, "hit": 0} - self.regex_fsm_cache.reset() + # TODO(dark): How to reset bnf cache? + # self.regex_bnf_cache.reset() self.req_to_token_pool.clear() self.token_to_kv_pool.clear() torch.cuda.empty_cache() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d583dcd34f7..b9c97f4fc76 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -582,7 +582,7 @@ def sample( ) -> torch.Tensor: # Put CPU-heavy tasks here. They will be overlapped with the forward pass. sampling_info = forward_batch.sampling_info - sampling_info.update_regex_vocab_mask() + sampling_info.update_regex_vocab_mask_bnf() sampling_info.update_penalties() logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info) @@ -610,6 +610,7 @@ def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchIn # Apply regex vocab_mask if sampling_info.vocab_mask is not None: logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf")) + # TODO(dark): add some custom cuda kernel for bnf bitmask here return logits diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 779af5101c0..6fb0a920a5c 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING, List, Optional import torch +from xgrammar import GrammarMatcher import sglang.srt.sampling.penaltylib as penaltylib -from sglang.srt.constrained import RegexGuide if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -26,11 +26,10 @@ class SamplingBatchInfo: # Bias Tensors vocab_size: int logit_bias: torch.Tensor = None - vocab_mask: torch.Tensor = None + vocab_mask: Optional[torch.Tensor] = None # FSM states - regex_fsms: List[RegexGuide] = None - regex_fsm_states: List[int] = None + regex_bnfs: Optional[List[Optional[GrammarMatcher]]] = None # Penalizer penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None @@ -126,24 +125,28 @@ def update_penalties(self): ) self.linear_penalties = penalizer.apply(self.linear_penalties) - def update_regex_vocab_mask(self): - has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms) - if not has_regex: - self.vocab_mask = None - return - - self.vocab_mask = torch.zeros( - len(self.temperatures), - self.vocab_size, - dtype=torch.bool, - device=self.device, - ) - for i, regex_fsm in enumerate(self.regex_fsms): - if regex_fsm is not None: - self.vocab_mask[i].fill_(1) - self.vocab_mask[i][ - regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens - ] = 0 + def update_regex_vocab_mask_bnf(self): + # Reset the vocab mask + self.vocab_mask = None + + if self.regex_bnfs and any(regex_bnf for regex_bnf in self.regex_bnfs): + # If has regex, then we need to update the vocab mask + self.vocab_mask = torch.zeros( + len(self.temperatures), + self.vocab_size, + dtype=torch.bool, + device=self.device, + ) + for i, regex_bnf in enumerate(self.regex_bnfs): + if regex_bnf is not None: + # Note that this bitmask is a bitset, not bool + bitmask = regex_bnf.find_next_token_bitmask() + # Mask the tokens that are not allowed + self.vocab_mask[i][ + regex_bnf.get_rejected_tokens_from_bitmask( + bitmask, self.vocab_size + ) + ] = 1 def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): if self.penalizer_orchestrator: diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index c054d72346f..30f3cddd1a7 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -22,38 +22,26 @@ class TestJSONConstrained(unittest.TestCase): def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - cls.json_schema = json.dumps( - { - "type": "object", - "properties": { - "name": {"type": "string", "pattern": "^[\\w]+$"}, - "population": {"type": "integer"}, - }, - "required": ["name", "population"], - } - ) - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=300, - other_args=["--max-running-requests", "10"], - ) + cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) @classmethod def tearDownClass(cls): kill_child_process(cls.process.pid) - def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1): - response = requests.post( + def run_decode( + self, prompt, json_schema, return_logprob=False, top_logprobs_num=0, n=1 + ): + return requests.post( self.base_url + "/generate", json={ - "text": "The capital of France is", + "text": prompt, "sampling_params": { "temperature": 0 if n == 1 else 0.5, "max_new_tokens": 128, "n": n, "stop_token_ids": [119690], "json_schema": json_schema, + "json_schema": json_schema, }, "stream": False, "return_logprob": return_logprob, @@ -61,53 +49,80 @@ def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1) "logprob_start_len": 0, }, ) + + def test_json_generate_simple(self): + prompt = "The capital of France is" + json_schema = json.dumps( + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + } + ) + response = self.run_decode(prompt, json_schema) print(json.dumps(response.json())) print("=" * 100) - - if not json_schema: - return - - try: - js_obj = json.loads(response.json()["text"]) - except (TypeError, json.decoder.JSONDecodeError): - raise + js_obj = json.loads(response.json()["text"]) assert isinstance(js_obj["name"], str) assert isinstance(js_obj["population"], int) - def test_json_generate(self): - self.run_decode(json_schema=self.json_schema) - - def test_json_openai(self): - client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1") - - response = client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "Introduce the capital of France."}, - ], - temperature=0, - max_tokens=128, - response_format={ - "type": "json_schema", - "json_schema": {"name": "foo", "schema": json.loads(self.json_schema)}, - }, - ) - text = response.choices[0].message.content - - try: - js_obj = json.loads(text) - except (TypeError, json.decoder.JSONDecodeError): - print("JSONDecodeError", text) - raise + def test_json_generate_complex(self): + prompt = "Please create a character named Komeiji Satori:" + json_schema = """{ + "title": "Character", + "type": "object", + "properties": { + "name": { + "title": "Name", + "type": "string" + }, + "age": { + "title": "Age", + "type": "integer" + }, + "armor": {"$ref": "#/$defs/Armor"}, + "weapon": {"$ref": "#/$defs/Weapon"}, + "strength": { + "title": "Strength", + "type": "integer" + } + }, + "required": ["name", "age", "armor", "weapon", "strength"], + "$defs": { + "Armor": { + "title": "Armor", + "description": "An enumeration.", + "enum": ["leather", "chainmail", "plate"], + "type": "string" + }, + "Weapon": { + "title": "Weapon", + "description": "An enumeration.", + "enum": ["third eye", "sword", "axe", "mace", "spear", "bow", "crossbow"], + "type": "string" + } + } +}""" + response = self.run_decode(prompt, json_schema) + print(json.dumps(response.json())) + print("=" * 100) + js_obj = json.loads(response.json()["text"]) assert isinstance(js_obj["name"], str) - assert isinstance(js_obj["population"], int) - - def test_mix_json_and_other(self): - json_schemas = [None, None, self.json_schema, self.json_schema] * 10 - - with ThreadPoolExecutor(len(json_schemas)) as executor: - list(executor.map(self.run_decode, json_schemas)) + assert isinstance(js_obj["age"], int) + assert js_obj["armor"] in ["leather", "chainmail", "plate"] + assert js_obj["weapon"] in [ + "third eye", + "sword", + "axe", + "mace", + "spear", + "bow", + "crossbow", + ] + assert isinstance(js_obj["strength"], int) if __name__ == "__main__":