Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions conda/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,3 @@ channels:
dependencies:
- python=3.12 # note: at the time of writing, xformer (< vllm) has a broken wheel for 3.13. https://github.com/facebookresearch/xformers/issues/740#issuecomment-2753869337
- uv
variables:
VLLM_USE_V1: 0 # need this to make outlines work
3 changes: 1 addition & 2 deletions conda/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ CONDA=""
if which mamba > /dev/null
then
CONDA=$(which mamba)
fi
if which conda > /dev/null
elif which conda > /dev/null
then
CONDA=$(which conda)
fi
Expand Down
128 changes: 78 additions & 50 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
from typing import TYPE_CHECKING, Any, cast

import granite_common
import outlines
import outlines_core
import llguidance
import llguidance.hf
import llguidance.torch
import peft
import torch
from transformers import (
AsyncTextIteratorStreamer,
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
LogitsProcessorList,
PreTrainedModel,
PreTrainedTokenizer,
set_seed,
Expand Down Expand Up @@ -69,8 +71,6 @@
from mellea.stdlib.intrinsics.intrinsic import Intrinsic
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement

assert outlines, "outlines needs to be present to make outlines_core work"

"""A configuration type for the unhappy path: Tokenizer * Model * torch device string

Huggingface backends can initialize themselves from a model string if the transformers `Auto*` classes can be used. Therefore, a TransformersTorchConfig usually isn't required. However, sometimes a model needs special care to instantiate properly, or a custom device type needs to bse used. Instead of trying to do a lot of partial magic, we basically have two modaliites: either the constructor can figure out everything from the model_id, or the user has to provide an entire config.
Expand All @@ -90,6 +90,59 @@ class HFAloraCacheInfo:
q_end: int = -1


# modified from VLLM v0.9.2 code base
# https://github.com/vllm-project/vllm/blob/v0.9.2/vllm/model_executor/guided_decoding/guidance_logits_processors.py
class _GuidanceLogitsProcessor:
def __init__(self, grammar: str, ll_tokenizer: llguidance.LLTokenizer) -> None:
self.grammar = grammar
self.vocab_size: int = ll_tokenizer.vocab_size
self.ll_tokenizer: llguidance.LLTokenizer = ll_tokenizer
self.ll_matchers: list[llguidance.LLMatcher] = []
self.bitmasks: list[torch.Tensor] = []
self.new_sampling: bool = False
self.batch_size: int = -1

def __call__(
self, batch_input_ids: torch.Tensor, batch_scores: torch.Tensor
) -> torch.Tensor:
i_batch, i_seqlen = batch_input_ids.shape
s_batch, s_vocab = batch_scores.shape
assert i_batch == s_batch
assert s_vocab == self.vocab_size

if self.batch_size != i_batch:
self.batch_size = i_batch
self.bitmasks = [
llguidance.torch.allocate_token_bitmask(1, self.vocab_size) # type: ignore[attr-defined]
for _ in range(self.batch_size)
]

self.ll_matchers = [
llguidance.LLMatcher(self.ll_tokenizer, self.grammar)
for _ in range(self.batch_size)
]

for input_ids, scores, ll_matcher, bitmask in zip(
batch_input_ids, batch_scores, self.ll_matchers, self.bitmasks
):
if self.new_sampling and len(input_ids) > 0:
ll_matcher.consume_token( # type: ignore[attr-defined]
input_ids.tolist()[-1]
)
err = ll_matcher.get_error() # type: ignore[attr-defined]
if err:
FancyLogger.get_logger().warning("Error in LLMatcher: %s", err)

llguidance.torch.fill_next_token_bitmask(ll_matcher, bitmask, 0)
llguidance.torch.apply_token_bitmask_inplace(
scores, bitmask.to(scores.device)
) # type: ignore[attr-defined]

self.new_sampling = True

return scores


class LocalHFBackend(FormatterBackend, AdapterMixin):
"""The LocalHFBackend uses Huggingface's transformers library for inference, and uses a Formatter to convert `Component`s into prompts. This backend also supports Activated LoRAs (ALoras)](https://arxiv.org/pdf/2504.12397).

Expand Down Expand Up @@ -178,6 +231,10 @@ def __init__(
case _:
self._tokenizer, self._model, self._device = custom_config

self._llguidance_tokenizer: llguidance.LLTokenizer = (
llguidance.hf.from_tokenizer(self._tokenizer) # type:ignore
)

self._use_caches = use_caches
self._cache = cache if cache is not None else SimpleLRUCache(3)

Expand Down Expand Up @@ -596,24 +653,15 @@ async def _generate_from_context_with_kv_cache(

format_kwargs = {}
if _format:
# outlines.generate.json always parses the resulting json into a python dict.
# We however want to keep it as a json string for later storing it in ModelOutputThunk
schema: dict[str, Any] = _format.model_json_schema()
schema_json: str = json.dumps(schema)
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
schema_json
grammar: str = llguidance.LLMatcher.grammar_from_json_schema(
schema, defaults={"whitespace_flexible": False}
)
logits_processor = _GuidanceLogitsProcessor(
grammar, self._llguidance_tokenizer
)

from outlines.models.transformers import TransformerTokenizer
from outlines.processors.structured import RegexLogitsProcessor
from transformers import LogitsProcessorList

format_kwargs["logits_processor"] = LogitsProcessorList(
[
RegexLogitsProcessor(
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
)
]
[logits_processor]
)

streaming_kwargs = {}
Expand Down Expand Up @@ -763,24 +811,15 @@ async def _generate_from_context_standard(

format_kwargs = {}
if _format:
# outlines.generate.json always parses the resulting json into a python dict.
# We however want to keep it as a json string for later storing it in ModelOutputThunk
schema: dict[str, Any] = _format.model_json_schema()
schema_json: str = json.dumps(schema)
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
schema_json
grammar: str = llguidance.LLMatcher.grammar_from_json_schema(
schema, defaults={"whitespace_flexible": False}
)
logits_processor = _GuidanceLogitsProcessor(
grammar, self._llguidance_tokenizer
)

from outlines.models.transformers import TransformerTokenizer
from outlines.processors.structured import RegexLogitsProcessor
from transformers import LogitsProcessorList

format_kwargs["logits_processor"] = LogitsProcessorList(
[
RegexLogitsProcessor(
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
)
]
[logits_processor]
)

streaming_kwargs = {}
Expand Down Expand Up @@ -990,25 +1029,14 @@ async def generate_from_raw(

format_kwargs = {}
if format:
# outlines.generate.json always parses the resulting json into a python dict.
# We however want to keep it as a json string for later storing it in ModelOutputThunk
schema: dict[str, Any] = format.model_json_schema()
schema_json: str = json.dumps(schema)
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
schema_json
grammar: str = llguidance.LLMatcher.grammar_from_json_schema(
schema, defaults={"whitespace_flexible": False}
)

from outlines.models.transformers import TransformerTokenizer
from outlines.processors.structured import RegexLogitsProcessor
from transformers import LogitsProcessorList

format_kwargs["logits_processor"] = LogitsProcessorList(
[
RegexLogitsProcessor(
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
)
]
logits_processor = _GuidanceLogitsProcessor(
grammar, self._llguidance_tokenizer
)
format_kwargs["logits_processor"] = LogitsProcessorList([logits_processor])

outputs = await asyncio.to_thread(
self._generate_with_adapter_lock,
Expand Down
74 changes: 21 additions & 53 deletions mellea/backends/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from typing import TYPE_CHECKING, Any, Optional

import msgspec # type:ignore
import outlines
import outlines_core
import torch
import vllm # type:ignore
from transformers import AutoTokenizer, PreTrainedTokenizerBase
Expand Down Expand Up @@ -50,8 +48,6 @@
from mellea.stdlib.chat import Message
from mellea.stdlib.requirement import LLMaJRequirement, Requirement

assert outlines, "outlines needs to be present to make outlines_core work"

format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors


Expand Down Expand Up @@ -83,14 +79,6 @@ def __init__(
formatter (Formatter): A mechanism for turning `stdlib` stuff into strings. Experimental Span-based models should use `mellea.backends.span.*` backends.
model_options (Optional[dict]): Default model options.
"""
if os.environ.get("VLLM_USE_V1", -1) != "0":
FancyLogger.get_logger().error(
"Mellea LocalVLLMBackend doesn't support VLLM V1. Must `export VLLM_USE_V1=0`."
)
raise ValueError(
"Mellea LocalVLLMBackend doesn't support VLLM V1. Must `export VLLM_USE_V1=0`."
)

formatter = (
formatter if formatter is not None else TemplateFormatter(model_id=model_id)
)
Expand Down Expand Up @@ -205,23 +193,20 @@ def __init__(

# Keep track of the event loop the engine was instantiated in.
self._event_loop = get_current_event_loop()
# we store the engine args because we have to reset the engine with a different event loop. See _model .
self.engine_args = engine_args

self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
self._hf_model_id
) # type:ignore

# See the notes in outlines.models.vllm.adapt_tokenizer for why this is needed.
# Note: there is a module named outlines.models.vllm and a function named outlines.models.vllm.vllm .
# However, outlines.models import outlines.models.vllm.vllm as vllm,
# thus the module outlines.models.vllm becomes inaccessible,
# hence the use of importlib to get the module.
self._tokenizer_for_outlines: PreTrainedTokenizerBase = importlib.import_module(
"outlines.models.vllm"
).adapt_tokenizer(self._tokenizer)

@property
def _model(self) -> vllm.AsyncLLMEngine:
"""Use model when making generation requests."""
# 2026/01/06 Masa: Temporarily canceling the mechanism below.
# After vllm 0.11.0, start/shutdown_background_loop is gone.
# 2026/01/07 Masa: Rewrote it to reinstantiate the engine.

el = get_current_event_loop()

# vLLM attaches itself to the event loop that is running when instantiated /
Expand All @@ -231,8 +216,13 @@ def _model(self) -> vllm.AsyncLLMEngine:
# Most of the time, this should be a no-op. The event loop will only change
# if switching between async and sync calls.
if el != self._event_loop:
self._underlying_model.shutdown_background_loop()
self._underlying_model.start_background_loop()
FancyLogger.get_logger().warning("restarting the vllm event loop")
# self._underlying_model.shutdown_background_loop()
# self._underlying_model.start_background_loop()
self._underlying_model.shutdown()
self._underlying_model = vllm.AsyncLLMEngine.from_engine_args(
vllm.AsyncEngineArgs(model=self._hf_model_id, **self.engine_args)
)
self._event_loop = el

return self._underlying_model
Expand Down Expand Up @@ -320,22 +310,10 @@ async def _generate_from_context_standard(
)

if _format is not None:
# outlines.generate.json always parses the resulting json into a python dict.
# We however want to keep it as a json string for later storing it in ModelOutputThunk
schema: dict[str, Any] = _format.model_json_schema()
schema_json: str = json.dumps(schema)
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
schema_json # type: ignore
) # type: ignore

from outlines.processors import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(
regex_str,
tokenizer=self._tokenizer_for_outlines, # type: ignore
)
sampling_params.logits_processors = (
[logits_processor] if logits_processor is not None else []
sampling_params.structured_outputs = (
vllm.sampling_params.StructuredOutputsParams(
json=_format.model_json_schema()
)
)

# stream = model_options.get(ModelOption.STREAM, False)
Expand Down Expand Up @@ -458,20 +436,10 @@ async def generate_from_raw(
)

if format is not None:
schema: dict[str, Any] = format.model_json_schema()
schema_json: str = json.dumps(schema)
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
schema_json # type: ignore
) # type: ignore

from outlines.processors import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(
regex_str,
tokenizer=self._tokenizer_for_outlines, # type: ignore
)
sampling_params.logits_processors = (
[logits_processor] if logits_processor is not None else []
sampling_params.structured_outputs = (
vllm.sampling_params.StructuredOutputsParams(
json=format.model_json_schema()
)
Comment on lines +461 to +463
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, when I was trying this locally (uv run pytest test/backends/test_vllm.py on Fedora 43, Python 3.12.8, CUDA 13.0), the test_generate_from_raw_with_format test failed. Here's a snippet of the output:

session = <mellea.stdlib.session.MelleaSession object at 0x7ff08f129a90>

    @pytest.mark.qualitative
    async def test_generate_from_raw_with_format(session):
        prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]

        class Answer(pydantic.BaseModel):
            name: str
            value: int

        results = await session.backend.generate_from_raw(
            actions=[CBlock(value=prompt) for prompt in prompts],
            ctx=session.ctx,
            format=Answer,
        )

        assert len(results) == len(prompts)

        random_result = results[0]
        try:
            answer = Answer.model_validate_json(random_result.value)
        except pydantic.ValidationError as e:
>           assert False, (
                f"formatting directive failed for {random_result.value}: {e.json()}"
            )
E           AssertionError: formatting directive failed for {
E
E
E                 "name": "binary",
E                 "value": 1
E              : [{"type":"json_invalid","loc":[],"msg":"Invalid JSON: EOF while parsing an object at line 6 column 1","input":"{\n\n\n    \"name\": \"binary\",\n    \"value\": 1\n ","ctx":{"error":"EOF while parsing an object at line 6 column 1"},"url":"https://errors.pydantic.dev/2.12/v/json_invalid"}]
E           assert False

test/backends/test_vllm.py:133: AssertionError

Seems like what's happening is that the model is following the grammar but not valid JSON (which seems like may just be a fact of life with a tiny model, I see the vllm test is using qwen3 0.6b).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick hack I tried locally that got the test to pass was to add an additional prompt to the list around using proper json content:

diff --git a/mellea/backends/vllm.py b/mellea/backends/vllm.py
index a59ced7..6e3ee5a 100644
--- a/mellea/backends/vllm.py
+++ b/mellea/backends/vllm.py
@@ -447,8 +447,21 @@ class LocalVLLMBackend(FormatterBackend):

         model_options = self._simplify_and_merge(model_options)

+        # When structured output is requested, ensure there's a reasonable max_tokens limit
+        # to prevent excessive whitespace generation and ensure output completion.
+        if format is not None and ModelOption.MAX_NEW_TOKENS not in model_options:
+            model_options[ModelOption.MAX_NEW_TOKENS] = 512
+
         prompts = [self.formatter.print(action) for action in actions]

+        # When structured output is requested, prepend format instructions to help the model
+        # understand what JSON content to generate. Without this, models may produce valid JSON
+        # structure (due to constrained decoding) but with meaningless content like whitespace.
+        if format is not None:
+            schema_str = json.dumps(format.model_json_schema(), indent=2)
+            format_prefix = f"Output a JSON object matching this schema:\n{schema_str}\n\nQuery: "
+            prompts = [format_prefix + p for p in prompts]
+
         sampling_params = vllm.SamplingParams(
             **self._make_backend_specific_and_remove(
                 model_options, vllm.SamplingParams

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this error message better, but I think what's happening is that the model runs out of tokens before the json completes (at least I see this error with hf sometimes):

"{\n\n\n    \"name\": \"binary\",\n    \"value\": 1\n " <-- missing closing bracket

Copy link
Member

@psschwei psschwei Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that seems to be it. just re-ran the test and I see:

E             Invalid JSON: EOF while parsing an object at line 130 column 0 [type=json_invalid, input_value='{  \n\n\n\n\n\n\n\n\n\n\...n\n\n\n\n\n\n\n\n\n\n\n', input_type=str]

and also

E             : [{"type":"json_invalid","loc":[],"msg":"Invalid JSON: EOF while parsing an object at line 130 column 0","input":"{  \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n","ctx":{"error":"EOF while parsing an object at line 130 column 0"},"url":"https://errors.pydantic.dev/2.12/v/json_invalid"}]

though it's interesting, on my run looks like the string is an opening bracket and then a LOT of newlines...

)

async def generate(prompt, request_id):
Expand Down
20 changes: 2 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,34 +57,18 @@ m = "cli.m:cli"
# uv pip install -e .[hf, watsonx]
# if you want to install all dependencies, use uv sync --all-extras


# note on outlines versions:
# outlines>=1.2.0 requires outlines-core==0.2.11
# outlines<=1.1.* requires outlines-core==0.1.26
# vllm==0.10.0 requires outlines-core==0.2.10
# vllm==0.9.* requires outlines-core==0.1.26
#
# thus the following version combination allows installing vllm and outlines
# (main library) at the same time.

hf = [
"accelerate>=1.9.0",
"alora==0.2.0",
"datasets>=4.0.0",
"outlines-core==0.1.26",
"outlines", # intentionally un-versioned, expecting a minor update. coutlines-core version should be enough to specify it
"llguidance",
"peft>=0.18.0", # aLoRA support was added in Peft 0.18.0
"transformers>=4.53.2",
"trl==0.19.1",
]

vllm = [
"transformers<4.54.0",
# see https://github.com/vllm-project/vllm-ascend/issues/2046
"numpy<2.0.0", # patching incorrect dependencies in vllm and outlines.
# see https://github.com/vllm-project/vllm/issues/5587
"outlines-core==0.1.26",
"vllm>=0.9.1",
"vllm>=0.13.0",
]

litellm = [
Expand Down
2 changes: 1 addition & 1 deletion test/backends/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
def backend():
"""Shared HuggingFace backend for all tests in this module."""
backend = LocalHFBackend(
model_id="ibm-granite/granite-3.3-8b-instruct",
model_id="ibm-granite/granite-3.3-2b-instruct",
formatter=TemplateFormatter(model_id="ibm-granite/granite-4.0-tiny-preview"),
cache=SimpleLRUCache(5),
)
Expand Down
2 changes: 1 addition & 1 deletion test/backends/test_huggingface_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def backend():
"""Shared HuggingFace backend for all tests in this module."""
backend = LocalHFBackend(
model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B, cache=SimpleLRUCache(5)
model_id="ibm-granite/granite-3.3-2b-instruct", cache=SimpleLRUCache(5)
)
# add_granite_aloras(backend)
return backend
Expand Down
Loading