Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Resolving integration differences after XGrammar lauch refactoring #2145

Closed
wants to merge 3 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 19 additions & 32 deletions python/sglang/srt/constrained/xgrammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,18 @@

import torch

try:
from xgrammar import (
CachedGrammarCompiler,
CompiledGrammar,
GrammarMatcher,
TokenizerInfo,
)

import_error = None
except ImportError as e:
CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
ImportError
)
import_error = e

from xgrammar import (
GrammarCompiler,
CompiledGrammar,
GrammarMatcher,
TokenizerInfo,
)

from xgrammar.matcher import (
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
allocate_token_bitmask,
apply_token_bitmask_inplace
)

from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend,
Expand Down Expand Up @@ -86,20 +84,19 @@ def jump_and_retokenize(
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
return self.matcher.allocate_token_bitmask(vocab_size, batch_size)
return allocate_token_bitmask(vocab_size, batch_size)
merrymercy marked this conversation as resolved.
Show resolved Hide resolved

def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(vocab_mask, idx)

@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask)
apply_token_bitmask_inplace(logits, vocab_mask)

def copy(self):
matcher = GrammarMatcher(
self.ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
vocab_size=self.vocab_size,
max_rollback_tokens=MAX_ROLLBACK_TOKENS
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
)
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)

Expand All @@ -112,25 +109,16 @@ def __init__(
):
super().__init__()

if import_error:
logger.warning(
f"Ignore import error for the grammar backend: {import_error}"
)
self.grammar_cache = None
return

tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assign vocab_size and stop_token_ids (from the chat_template, optionally but will make it more robust) when constructing tokenizer_info

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.. I had an in issue initially where the vocab_size provided by the tokenizer differed from vocab_size param provided when initialing the grammar backend. In my most recent commit I use vocab_size when creating tokenizer_info. My blocker now is that relating to GrammarMatcher accepting the stop token, but still trying to find the next token mask. I believe this could be related to the backend not having the correct stop tokens.

What is best way to access the chat templates's stop_token_ids information when initializing the grammar backend?

self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info)
self.grammar_cache = GrammarCompiler(tokenizer_info=tokenizer_info)
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
self.vocab_size = vocab_size

def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
if import_error:
raise import_error

key_type, key_string = key
if key_type == "json":
try:
ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string)
ctx = self.grammar_cache.compile_json_schema(schema=key_string)
except RuntimeError as e:
logging.warning(
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
Expand All @@ -146,11 +134,10 @@ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:

matcher = GrammarMatcher(
ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
vocab_size=self.vocab_size,
max_rollback_tokens=MAX_ROLLBACK_TOKENS
)
return XGrammarGrammar(matcher, self.vocab_size, ctx)

def reset(self):
if self.grammar_cache:
self.grammar_cache.clear()
self.grammar_cache.clear_cache()
Loading