Skip to content

transformers-CFG incompatible with gemma-3: causes tokenizer and model vocab size mismatch #127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
GeorgeDeac opened this issue Apr 9, 2025 · 4 comments

Comments

@GeorgeDeac
Copy link

GeorgeDeac commented Apr 9, 2025

I've encountered an issue when trying to use transformers-cfg for constrained generation with the recently released gemma-3 models.

Adding GrammarConstrainedLogitsProcessor from transformers-cfg to the logits_processor causes the model.generate() to fail with:
ValueError: impossible for tokenizer vocab to be less than model vocab

The error occurs regardless of whether the model is loaded using AutoModelForCausalLM or the specific Gemma3ForCausalLM class (same for the tokenizer class). The example below uses Gemma3ForCausalLM:

import torch
import transformers
from transformers import (
    AutoTokenizer,
    Gemma3ForCausalLM, # Same goes for the auto class
    GenerationConfig
    # BitsAndBytesConfig
)
import transformers_cfg
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
import time
import gc
import platform
import traceback

MODEL_ID = "google/gemma-3-1b-it"
SIMPLE_GRAMMAR = 'root ::= ("A" | "B")' # Simple grammar
START_RULE = "root"

# Quantization
USE_4BIT = False
USE_8BIT = False

# Environment Info
print("-" * 60)
print("ENVIRONMENT:")
print(f"- Python:       {platform.python_version()}")
print(f"- transformers: {transformers.__version__}")
print(f"- torch:        {torch.__version__}")
try:
    import accelerate
    print(f"- accelerate:   {accelerate.__version__}")
except ImportError: print("- accelerate:   Not Installed")
try:
    import bitsandbytes
    print(f"- bitsandbytes: {bitsandbytes.__version__}")
except ImportError: print("- bitsandbytes: Not Installed (Needed for quantization)")
try:
    print(f"- trans-cfg:    {transformers_cfg.__version__}")
except AttributeError: print("- trans-cfg:    Installed (version attribute missing)")
print("-" * 60)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Test
try:
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    if tokenizer.pad_token is None: 
        tokenizer.pad_token = tokenizer.eos_token
    if not tokenizer.chat_template: 
        raise AttributeError("Tokenizer missing chat template")

    print("Loading model (Gemma3ForCausalLM)...")
    model = Gemma3ForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
        device_map="auto",
        quantization_config=None, # No quantization
        trust_remote_code=True
    ).eval()

    model_vocab = getattr(model.config, 'vocab_size', 'N/A')
    tokenizer_vocab = getattr(tokenizer, 'vocab_size', 'N/A')
    print(f"(Post-load check) Model vocab: {model_vocab}, Tokenizer vocab: {tokenizer_vocab}")
    if model_vocab != tokenizer_vocab: 
        print("Post-load vocab size mismatch detected!")

    # Grammar Processor
    print("Creating grammar processor...")
    grammar_constraint = IncrementalGrammarConstraint(SIMPLE_GRAMMAR, START_RULE, tokenizer)
    grammar_processor = GrammarConstrainedLogitsProcessor(grammar_constraint)
    logits_processor_list = [grammar_processor]

    messages = [{"role": "user", "content": [{"type": "text", "text": "Output A or B."}]}]
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt"
    ).to(model.device)
    input_ids_len = input_ids.shape[1]

    gen_config = GenerationConfig(
        max_new_tokens=3,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=False
    )

    # Expected to Fail
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            generation_config=gen_config,
            logits_processor=logits_processor_list
        )

    output_token_ids = outputs[:, input_ids_len:]
    output_text = tokenizer.decode(output_token_ids[0], skip_special_tokens=True)

except ValueError as e:
    print(f"ERROR Type: {type(e).__name__}")
    print(f"ERROR Message: {e}")
    traceback.print_exc()
    print("-" * 35 + "\n")
except Exception as e:
    print(f"ERROR Type: {type(e).__name__}")
    print(f"ERROR Message: {e}")
    traceback.print_exc()
    print("-" * 35 + "\n")
finally:
    del model
    del tokenizer
    del grammar_processor
    if torch.cuda.is_available(): 
        torch.cuda.empty_cache()
        gc.collect()

Which for me throws:

------------------------------------------------------------
ENVIRONMENT:
- Python:       3.12.3
- transformers: 4.51.1
- torch:        2.4.0
- accelerate:   1.0.1
- bitsandbytes: 0.45.5
- trans-cfg:    0.2.7
------------------------------------------------------------
Using device: cuda
Loading tokenizer...
Loading model (Gemma3ForCausalLM)...
(Post-load check) Model vocab: 262144, Tokenizer vocab: 262144
Creating grammar processor...
`generation_config` default values have been modified to match model-specific defaults: {'do_sample': True, 'cache_implementation': 'hybrid', 'top_k': 64, 'top_p': 0.95, 'bos_token_id': 2}. If this is not desired, please set these values explicitly.
ERROR Type: AssertionError
ERROR Message: impossible for tokenizer vocab to be less than model vocab
-----------------------------------

Traceback (most recent call last):
  File "C:\Users\georg\AppData\Local\Temp\ipykernel_45580\3883546750.py", line 95, in <module>
    outputs = model.generate(
              ^^^^^^^^^^^^^^^
  File "c:\Users\georg\anaconda3\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\georg\anaconda3\Lib\site-packages\transformers\generation\utils.py", line 2463, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "c:\Users\georg\anaconda3\Lib\site-packages\transformers\generation\utils.py", line 3448, in _sample
    next_token_scores = logits_processor(input_ids, next_token_logits)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\georg\anaconda3\Lib\site-packages\transformers\generation\logits_process.py", line 88, in __call__
    scores = processor(input_ids, scores)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\georg\anaconda3\Lib\site-packages\transformers_cfg\generation\logits_process.py", line 164, in __call__
    return self.process_logits(input_ids, scores)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\georg\anaconda3\Lib\site-packages\transformers_cfg\generation\logits_process.py", line 157, in process_logits
    masked_scores = self.mask_logits(scores, device)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\georg\anaconda3\Lib\site-packages\transformers_cfg\generation\logits_process.py", line 77, in mask_logits
    acceptance_vocab_size < masked_logits_vocab_size
AssertionError: impossible for tokenizer vocab to be less than model vocab
@urroxyz
Copy link
Contributor

urroxyz commented Apr 9, 2025

I put in a PR (#126) that should fix this once it's accepted.

@Saibo-creator
Copy link
Collaborator

Saibo-creator commented Apr 9, 2025

Hello @GeorgeDeac
It does seem to be caused by the mismatch of model embedding size and tokenizer's vocab. Could you have a try with the PR #126 ? I still need to double check it before merging but it looks good to me

@Saibo-creator
Copy link
Collaborator

Hey @GeorgeDeac, thanks a lot for the detailed setup information and the reproducibility script — really appreciate it!
It turns out the issue was caused by an inconsistency in the tokenizer implementation for Gemma-3. I’ve added a fix here: #128

Thanks for the contribution of PR #126 , I will review it this weekend @urroxyz

@GeorgeDeac
Copy link
Author

Thanks a lot, I was about to investigate too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants