diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 9195aa30d95..1bcc51c6468 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -17,21 +17,14 @@ from typing import List, Tuple import torch - -try: - from xgrammar import ( - CachedGrammarCompiler, - CompiledGrammar, - GrammarMatcher, - TokenizerInfo, - ) - - import_error = None -except ImportError as e: - CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = ( - ImportError - ) - import_error = e +from xgrammar import ( + CompiledGrammar, + GrammarCompiler, + GrammarMatcher, + TokenizerInfo, + allocate_token_bitmask, + apply_token_bitmask_inplace, +) from sglang.srt.constrained.base_grammar_backend import ( BaseGrammarBackend, @@ -41,7 +34,7 @@ logger = logging.getLogger(__name__) -MAX_ROLLBACK_TOKENS = 10 +MAX_ROLLBACK_TOKENS = 200 class XGrammarGrammar(BaseGrammarObject): @@ -86,21 +79,22 @@ def jump_and_retokenize( def allocate_vocab_mask( self, vocab_size: int, batch_size: int, device ) -> torch.Tensor: - return self.matcher.allocate_token_bitmask(vocab_size, batch_size) + return allocate_token_bitmask(batch_size, vocab_size) def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: self.matcher.fill_next_token_bitmask(vocab_mask, idx) @staticmethod def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask) + if vocab_mask.device.type != logits.device.type: + # vocab_mask must then be on the same device as logits + # when applying the token bitmask, so we check and move if needed + vocab_mask = vocab_mask.to(logits.device) + + apply_token_bitmask_inplace(logits, vocab_mask) def copy(self): - matcher = GrammarMatcher( - self.ctx, - max_rollback_tokens=MAX_ROLLBACK_TOKENS, - vocab_size=self.vocab_size, - ) + matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) return XGrammarGrammar(matcher, self.vocab_size, self.ctx) @@ -112,25 +106,18 @@ def __init__( ): super().__init__() - if import_error: - logger.warning( - f"Ignore import error for the grammar backend: {import_error}" - ) - self.grammar_cache = None - return - - tokenizer_info = TokenizerInfo.from_huggingface(tokenizer) - self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info) + tokenizer_info = TokenizerInfo.from_huggingface( + tokenizer, vocab_size=vocab_size + ) + self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) self.vocab_size = vocab_size def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: - if import_error: - raise import_error key_type, key_string = key if key_type == "json": try: - ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string) + ctx = self.grammar_compiler.compile_json_schema(schema=key_string) except RuntimeError as e: logging.warning( f"Skip invalid json_schema: json_schema={key_string}, {e=}" @@ -144,13 +131,9 @@ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: else: raise ValueError(f"Invalid key_type: {key_type}") - matcher = GrammarMatcher( - ctx, - max_rollback_tokens=MAX_ROLLBACK_TOKENS, - vocab_size=self.vocab_size, - ) + matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) return XGrammarGrammar(matcher, self.vocab_size, ctx) def reset(self): - if self.grammar_cache: - self.grammar_cache.clear() + if self.grammar_compiler: + self.grammar_compiler.clear_cache() diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 2d08d66848e..b4a3cd2ac44 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -121,5 +121,114 @@ def test_mix_json_and_other(self): list(executor.map(self.run_decode, json_schemas)) +class TestJSONConstrainedXGrammarBackend(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.json_schema = json.dumps( + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + } + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=300, + other_args=[ + "--max-running-requests", + "10", + "--grammar-backend", + "xgrammar", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid, include_self=True) + + def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 128, + "n": n, + "stop_token_ids": [119690], + "json_schema": json_schema, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + ) + ret = response.json() + print(json.dumps(ret)) + print("=" * 100) + + if not json_schema: + return + + # Make sure the json output is valid + try: + js_obj = json.loads(ret["text"]) + except (TypeError, json.decoder.JSONDecodeError): + raise + + self.assertIsInstance(js_obj["name"], str) + self.assertIsInstance(js_obj["population"], int) + + # Make sure jump forward is triggered + # NOTE: This is skipped because overlap scheduler does not support jump forward + # self.assertGreater( + # ret["meta_info"]["completion_tokens"], + # ret["meta_info"]["completion_tokens_wo_jump_forward"], + # ) + + def test_json_generate(self): + self.run_decode(json_schema=self.json_schema) + + def test_json_openai(self): + client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1") + + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "Introduce the capital of France."}, + ], + temperature=0, + max_tokens=128, + response_format={ + "type": "json_schema", + "json_schema": {"name": "foo", "schema": json.loads(self.json_schema)}, + }, + ) + text = response.choices[0].message.content + + try: + js_obj = json.loads(text) + except (TypeError, json.decoder.JSONDecodeError): + print("JSONDecodeError", text) + raise + + self.assertIsInstance(js_obj["name"], str) + self.assertIsInstance(js_obj["population"], int) + + def test_mix_json_and_other(self): + json_schemas = [None, None, self.json_schema, self.json_schema] * 10 + + with ThreadPoolExecutor(len(json_schemas)) as executor: + list(executor.map(self.run_decode, json_schemas)) + + if __name__ == "__main__": unittest.main()