diff --git a/.github/workflows/matrix-checks.yml b/.github/workflows/matrix-checks.yml index 74ede508..374ede25 100644 --- a/.github/workflows/matrix-checks.yml +++ b/.github/workflows/matrix-checks.yml @@ -5,9 +5,6 @@ on: python_version: required: false type: string - secrets: - CI_SSH_KEY: - required: true jobs: matrix-checks: diff --git a/.github/workflows/tests-macos.yml b/.github/workflows/tests-macos.yml index 281f343c..d142fc90 100644 --- a/.github/workflows/tests-macos.yml +++ b/.github/workflows/tests-macos.yml @@ -5,9 +5,6 @@ on: python_version: required: false type: string - secrets: - CI_SSH_KEY: - required: true jobs: tests-macos: diff --git a/pyproject.toml b/pyproject.toml index 33ba5f4e..c3a3c8ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "dnet-p2p @ file://${PROJECT_ROOT}/lib/dnet-p2p/bindings/py", "rich>=13.0.0", "psutil>=5.9.0", + "outlines>=1.2.0", ] [project.optional-dependencies] diff --git a/src/dnet/api/inference.py b/src/dnet/api/inference.py index d84cd5a9..8d5f2194 100644 --- a/src/dnet/api/inference.py +++ b/src/dnet/api/inference.py @@ -1,9 +1,11 @@ import asyncio import time import uuid +import json import mlx.core as mx import numpy as np from typing import Optional, Any, List +from builtins import aiter, anext from dnet.core.tensor import to_bytes from .models import ( @@ -14,11 +16,13 @@ ChatUsage, ChatCompletionReason, ChatLogProbs, + StructuredOutputsParams, ) from .cluster import ClusterManager from .model_manager import ModelManager from .strategies.base import ApiAdapterBase from dnet.core.decoding.config import DecodingConfig +from dnet.utils.logger import logger async def arange(count: int): @@ -64,9 +68,9 @@ async def connect_to_ring( self._api_callback_addr = api_callback_addr async def generate_stream(self, req: ChatRequestModel): - """ - Generator for chat completion chunks. - """ + """Generator for chat completion chunks.""" + logger.debug(f"generate_stream called: model={req.model}") + if not self.model_manager.tokenizer: raise RuntimeError( "Inference manager not ready (ring not connected or tokenizer not loaded)" @@ -79,9 +83,12 @@ async def generate_stream(self, req: ChatRequestModel): hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None ): - message_dicts = [ - {"role": m.role, "content": m.content} for m in req.messages - ] + # Convert messages to dict format + message_dicts = [] + for m in req.messages: + msg_dict = {"role": m.role, "content": m.content or ""} + message_dicts.append(msg_dict) + prompt_text = tokenizer.apply_chat_template( message_dicts, add_generation_prompt=True, @@ -89,10 +96,13 @@ async def generate_stream(self, req: ChatRequestModel): ) else: prompt_text = ( - "\n".join(m.content for m in req.messages) + "\nAssistant:" + "\n".join(m.content or "" for m in req.messages) + "\nAssistant:" ) - except Exception: - prompt_text = "\n".join(m.content for m in req.messages) + "\nAssistant:" + except Exception as e: + logger.warning(f"Failed to apply chat template: {e}, using fallback") + prompt_text = ( + "\n".join(m.content or "" for m in req.messages) + "\nAssistant:" + ) prompt_tokens = tokenizer.encode(prompt_text) prompt_array = mx.array(prompt_tokens) @@ -104,6 +114,16 @@ async def generate_stream(self, req: ChatRequestModel): tokenizer.encode(stop_word, add_special_tokens=False) ) + # Convert OpenAI response_format to internal structured_outputs format + if req.response_format and req.response_format.get("type") == "json_schema": + json_schema = req.response_format["json_schema"]["schema"] + req.structured_outputs = StructuredOutputsParams(json_schema=json_schema) + + # Get grammar JSON schema for structured output + grammar_json_schema = None + if req.structured_outputs and req.structured_outputs.json_schema: + grammar_json_schema = json.dumps(req.structured_outputs.json_schema) + nonce = f"chatcmpl-{uuid.uuid4()}" t_start = time.perf_counter() t_first_token: Optional[float] = None @@ -152,6 +172,7 @@ async def generate_stream(self, req: ChatRequestModel): min_tokens_to_keep=req.min_tokens_to_keep if hasattr(req, "min_tokens_to_keep") else 1, + grammar_json_schema=grammar_json_schema, ) # Send tokens to first shard @@ -209,9 +230,29 @@ async def generate_stream(self, req: ChatRequestModel): if token == tokenizer.eos_token_id: completion_reason = ChatCompletionReason.STOP break + y = mx.array([token], dtype=mx.int32) detokenizer.finalize() + final_text = detokenizer.text + + # Strip special tokens from output + # mlx-lm's NaiveStreamingDetokenizer calls tokenizer.decode() without skip_special_tokens=True + # (see: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tokenizer_utils.py) + # So we strip them manually as a post-processing step + SPECIAL_TOKENS_TO_STRIP = [ + "<|im_end|>", # Qwen, ChatML format + "<|im_start|>", # Qwen, ChatML format + "<|endoftext|>", # GPT/generic + "", # Llama, Mistral + "<|eot_id|>", # Llama 3 + "<|end|>", # Phi + "<|assistant|>", # Some chat templates + "<|user|>", # Some chat templates + ] + for token in SPECIAL_TOKENS_TO_STRIP: + final_text = final_text.replace(token, "") + final_text = final_text.strip() metrics_dict = None t_end = time.perf_counter() @@ -232,13 +273,19 @@ async def generate_stream(self, req: ChatRequestModel): ), } - # Final chunk with finish reason + final_message = ChatMessage( + role="assistant", + content=final_text, + ) + + # Final chunk yield ChatResponseModel( id=nonce, choices=[ ChatChoice( index=0, - delta=ChatMessage(role="assistant", content=""), + delta=None, + message=final_message, finish_reason=completion_reason, ) ], @@ -288,6 +335,13 @@ async def chat_completions(self, req: ChatRequestModel) -> ChatResponseModel: if chunk.usage: usage = chunk.usage + # Clean up structured output responses - remove end tokens + if req.structured_outputs and req.structured_outputs.json_schema: + full_content = full_content.strip() + for token in ["<|im_end|>", "<|endoftext|>", ""]: + if token in full_content: + full_content = full_content.split(token)[0].strip() + return ChatResponseModel( id=nonce, choices=[ diff --git a/src/dnet/api/models.py b/src/dnet/api/models.py index e27681ed..9d3249a1 100644 --- a/src/dnet/api/models.py +++ b/src/dnet/api/models.py @@ -29,6 +29,29 @@ class ChatCompletionReason(str, Enum): STOP = "stop" +class StructuredOutputsParams(BaseModel): + """Parameters for structured output generation.""" + + json_schema: Optional[Dict[str, Any]] = Field(default=None) + + @field_validator("json_schema") + @classmethod + def validate_json_schema(cls, v): + if v is None: + return v + if not isinstance(v, dict): + raise ValueError("JSON schema must be a dictionary") + if "type" not in v: + raise ValueError("JSON schema must have a 'type' field") + try: + import json + + json.dumps(v) + except (TypeError, ValueError) as e: + raise ValueError(f"JSON schema must be JSON serializable: {e}") + return v + + class RingInferenceError(BaseModel): """Error response for ring inference.""" @@ -49,10 +72,13 @@ class RingInferenceError(BaseModel): class ChatMessage(BaseModel): - """A single message in a chat conversation.""" + """A single message in a chat conversation. + + Compatible with OpenAI format. + """ - role: str # "system" | "user" | "assistant" | "tool" | "developer" # TODO: use Literal? - content: str + role: str # "system" | "user" | "assistant" | "developer" # TODO: use Literal? + content: Optional[str] = None class ChatParams(BaseModel): @@ -78,7 +104,12 @@ class ChatParams(BaseModel): # prediction: NOT USED # presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) # NOTE: unused # prompt_cache_key: Optional[str] = Field(default=None) # NOTE: unused - # TODO: response_format: + structured_outputs: Optional[StructuredOutputsParams] = Field( + default=None + ) # Structured output parameters for grammar-constrained generation + response_format: Optional[Dict[str, Any]] = Field( + default=None + ) # OpenAI-compatible response format (json_schema, etc.) # safety_identifier: Optional[str] = Field(default=None) # NOTE: unused # service_tier: Optional[str] = Field(default=None) # NOTE: unused stop: Union[str, List[str]] = Field(default_factory=list) @@ -298,7 +329,7 @@ class ListModelsResponseModel(BaseModel): data: List[ModelObject] -type RetrieveModelResponseModel = ModelObject +RetrieveModelResponseModel = ModelObject # ------------------------ diff --git a/src/dnet/api/strategies/ring.py b/src/dnet/api/strategies/ring.py index e49357e4..f2c9cefb 100644 --- a/src/dnet/api/strategies/ring.py +++ b/src/dnet/api/strategies/ring.py @@ -175,6 +175,9 @@ async def send_tokens( min_tokens_to_keep=decoding_config.min_tokens_to_keep if decoding_config else 1, + grammar_json_schema=decoding_config.grammar_json_schema + if decoding_config and hasattr(decoding_config, "grammar_json_schema") + else None, ) req = msg.to_proto(tokens) diff --git a/src/dnet/core/decoding/config.py b/src/dnet/core/decoding/config.py index cc40eabf..129c2538 100644 --- a/src/dnet/core/decoding/config.py +++ b/src/dnet/core/decoding/config.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional @dataclass @@ -12,3 +13,5 @@ class DecodingConfig: logit_bias: dict[int, float] | None = None min_p: float = 0.0 min_tokens_to_keep: int = 1 + # Structured output support + grammar_json_schema: Optional[str] = None diff --git a/src/dnet/core/decoding/sampler.py b/src/dnet/core/decoding/sampler.py index d4ce9d89..8bb88f06 100644 --- a/src/dnet/core/decoding/sampler.py +++ b/src/dnet/core/decoding/sampler.py @@ -1,26 +1,213 @@ import mlx.core as mx import numpy as np +from typing import Optional, Any, Dict from mlx_lm.sample_utils import make_sampler from dnet.core.types.messages import TokenResult from dnet.core.decoding.config import DecodingConfig +from dnet.utils.logger import logger + + +class GrammarState: + def __init__( + self, + guide, + index, + bitmask_allocator, + vocab_size: int, + eos_token_id: Optional[int] = None, + ): + self.guide = guide + self.index = index + self.bitmask_allocator = bitmask_allocator + self.vocab_size = vocab_size + self._eos_token_id = eos_token_id + self._bitmask = None + self._terminated = False + + def get_bitmask(self): + if self._bitmask is None: + self._bitmask = self.bitmask_allocator(self.vocab_size) + return self._bitmask + + def fill_next_token_bitmask(self): + if self._terminated: + return None + + from outlines_core.kernels.mlx import fill_next_token_bitmask + + bitmask = self.get_bitmask() + fill_next_token_bitmask(self.guide, bitmask) + return bitmask + + def accept_token(self, token_id: int) -> None: + if self._terminated: + return + + if not self.guide.is_finished(): + self.guide.advance(token_id=token_id, return_tokens=False) + else: + self._terminated = True + + def is_terminated(self) -> bool: + if self._terminated: + return True + + if not self.guide.is_finished(): + return False + + try: + current_state = self.guide.get_state() + is_final = self.index.is_final_state(current_state) + + if is_final: + self._terminated = True + return True + + self._terminated = True + logger.debug( + f"Guide finished but not in final state - marking as terminated anyway. " + f"state={current_state}, is_final={is_final}" + ) + return True + except Exception as e: + self._terminated = True + logger.debug( + f"Could not verify final state: {e}, using is_finished()={self.guide.is_finished()}, " + f"marking as terminated" + ) + return True class Sampler: """ Handles the transformation of logits into tokens based on a DecodingConfig. Wraps mlx_lm's make_sampler for consistent sampling behavior. + Supports structured output via grammar-constrained generation using Outlines. """ + # Cache for compiled vocabulary to avoid recomputing per request + _vocabulary_cache: Dict[int, Any] = {} + + def __init__(self): + pass + + @staticmethod + def _get_or_create_vocabulary(tokenizer, vocab_size: int): + cache_key = id(tokenizer) + if cache_key in Sampler._vocabulary_cache: + return Sampler._vocabulary_cache[cache_key] + + try: + from outlines_core import Vocabulary + + vocab = tokenizer.get_vocab() + actual_vocab_size = len(vocab) + if vocab_size != actual_vocab_size: + logger.warning( + f"Vocab size mismatch: expected {vocab_size} (from model/logits) " + f"but tokenizer has {actual_vocab_size} tokens. " + f"Using model vocab_size {vocab_size} for bitmask allocation." + ) + + eos_token_id = tokenizer.eos_token_id + eos_token = tokenizer.eos_token or tokenizer.decode([eos_token_id]) + + # Build formatted vocabulary for Outlines + formatted_vocab = {} + for token, token_id in vocab.items(): + try: + token_as_str = tokenizer.convert_tokens_to_string([token]) + if token_as_str not in formatted_vocab: + formatted_vocab[token_as_str] = [token_id] + else: + formatted_vocab[token_as_str].append(token_id) + except Exception: + if token not in formatted_vocab: + formatted_vocab[token] = [token_id] + else: + formatted_vocab[token].append(token_id) + + # Remove EOS token from vocab (Outlines handles it separately) + formatted_vocab.pop(eos_token, None) + + vocabulary = Vocabulary(eos_token_id, formatted_vocab) + Sampler._vocabulary_cache[cache_key] = vocabulary + + logger.debug( + f"Created vocabulary: {len(formatted_vocab)} entries, vocab_size={vocab_size}, actual={actual_vocab_size}" + ) + return vocabulary + + except Exception as e: + logger.warning(f"Failed to create Outlines vocabulary: {e}") + import traceback + + logger.debug(traceback.format_exc()) + return None + + @staticmethod + def create_grammar_state( + json_schema: str, tokenizer, model_vocab_size: Optional[int] = None + ) -> Optional[GrammarState]: + if not json_schema: + return None + + try: + from outlines_core import Index, Guide + from outlines_core.outlines_core import json_schema as oc_json_schema + from outlines_core.kernels.mlx import allocate_token_bitmask + + # Get vocab_size: prefer model_vocab_size (from logits shape) over tokenizer.vocab_size + vocab_size = model_vocab_size or getattr(tokenizer, "vocab_size", None) + if vocab_size is None: + logger.warning("Could not determine vocab size for grammar state") + return None + + if model_vocab_size: + logger.debug(f"Using model_vocab_size={vocab_size} (from logits shape)") + else: + logger.debug(f"Using tokenizer.vocab_size={vocab_size} (fallback)") + + regex_pattern = oc_json_schema.build_regex_from_schema(json_schema) + logger.debug(f"Built regex from JSON schema (length: {len(regex_pattern)})") + + vocabulary = Sampler._get_or_create_vocabulary(tokenizer, vocab_size) + if vocabulary is None: + logger.warning("Failed to create vocabulary for grammar state") + return None + + index = Index(regex_pattern, vocabulary) + guide = Guide(index) + eos_token_id = getattr(tokenizer, "eos_token_id", None) + eos_token_id = getattr(tokenizer, "eos_token_id", None) + + logger.debug("Created grammar state") + return GrammarState( + guide=guide, + index=index, + bitmask_allocator=allocate_token_bitmask, + vocab_size=vocab_size, + eos_token_id=eos_token_id, + ) + + except ImportError as e: + logger.warning(f"Outlines not installed or import error: {e}") + return None + except Exception as e: + logger.warning(f"Failed to create grammar state: {e}") + import traceback + + logger.debug(traceback.format_exc()) + return None + @staticmethod def sample( logits: mx.array, config: DecodingConfig, req_logprobs: bool = False, req_top_logprobs: int = 0, + grammar_state: Optional[GrammarState] = None, ) -> TokenResult: - """ - Sample a token from logits using the provided configuration. - """ sampler_fn = make_sampler( temp=config.temperature, top_p=config.top_p, @@ -38,9 +225,38 @@ def sample( v = logits[-1] else: v = logits + + if grammar_state is not None: + try: + from outlines_core.kernels.mlx import apply_token_bitmask + + bitmask = grammar_state.fill_next_token_bitmask() + + if bitmask is not None: + v_2d = v[None, :] if v.ndim == 1 else v + v_masked = apply_token_bitmask(v_2d, bitmask) + v = v_masked[0] if v_masked.ndim == 2 else v_masked + else: + v = mx.full_like(v, float("-inf")) + + except Exception as e: + logger.warning(f"Failed to apply grammar mask: {e}") + import traceback + + logger.debug(traceback.format_exc()) + token_tensor = sampler_fn(v) token_id = int(token_tensor.item()) + if grammar_state is not None: + try: + grammar_state.accept_token(token_id) + except Exception as e: + logger.warning(f"Failed to accept token in grammar: {e}") + import traceback + + logger.debug(traceback.format_exc()) + logprob = 0.0 top_logprobs = {} diff --git a/src/dnet/core/types/messages.py b/src/dnet/core/types/messages.py index d8a54814..316cd8be 100644 --- a/src/dnet/core/types/messages.py +++ b/src/dnet/core/types/messages.py @@ -35,6 +35,7 @@ class ActivationMessage: token_id: int = -1 logprob: float = 0.0 top_logprobs: Optional[dict[int, float]] = None + grammar_terminated: bool = False # True when Outlines grammar is complete # Request control req_logprobs: bool = False @@ -46,6 +47,10 @@ class ActivationMessage: repetition_penalty: float = 1.0 min_p: float = 0.0 min_tokens_to_keep: int = 1 + # Structured output support + grammar_json_schema: Optional[str] = ( + None # JSON schema for grammar-constrained generation + ) @classmethod def from_proto(cls, proto_msg: ActivationRequest, pool_id: int = 0): @@ -74,11 +79,14 @@ def from_proto(cls, proto_msg: ActivationRequest, pool_id: int = 0): min_tokens_to_keep=proto_msg.min_tokens_to_keep if proto_msg.HasField("min_tokens_to_keep") else 1, + grammar_json_schema=proto_msg.grammar_json_schema + if proto_msg.HasField("grammar_json_schema") + else None, ) def to_proto(self, data: bytes) -> ActivationRequest: """Convert to protobuf request""" - return ActivationRequest( + req = ActivationRequest( nonce=self.nonce, activation=Activation( data=data, @@ -99,6 +107,10 @@ def to_proto(self, data: bytes) -> ActivationRequest: min_p=self.min_p, min_tokens_to_keep=self.min_tokens_to_keep, ) + # Add optional grammar schema if present + if self.grammar_json_schema: + req.grammar_json_schema = self.grammar_json_schema + return req @dataclass(slots=True) diff --git a/src/dnet/protos/dnet_ring.proto b/src/dnet/protos/dnet_ring.proto index bf5cf402..e4197eb4 100644 --- a/src/dnet/protos/dnet_ring.proto +++ b/src/dnet/protos/dnet_ring.proto @@ -44,6 +44,9 @@ message ActivationRequest { optional float repetition_penalty = 11; optional float min_p = 12; optional int32 min_tokens_to_keep = 13; + + // Structured output support + optional string grammar_json_schema = 14; } // Response message for activation sending diff --git a/src/dnet/protos/shard_api_comm.proto b/src/dnet/protos/shard_api_comm.proto index f12b3ffe..65e1326d 100644 --- a/src/dnet/protos/shard_api_comm.proto +++ b/src/dnet/protos/shard_api_comm.proto @@ -37,6 +37,7 @@ message TokenRequest { int64 timestamp = 3; float logprob = 4; map top_logprobs = 5; + bool grammar_terminated = 6; } // Response for token reception diff --git a/src/dnet/shard/policies/fit_in_memory.py b/src/dnet/shard/policies/fit_in_memory.py index 2801c566..3316e775 100644 --- a/src/dnet/shard/policies/fit_in_memory.py +++ b/src/dnet/shard/policies/fit_in_memory.py @@ -16,6 +16,11 @@ class FitInMemoryPolicy(ComputePolicy): """Everything fits - no offloading needed""" + # Cache grammar states by nonce to maintain state across token generations + # TODO: Add TTL-based cleanup for _grammar_states to prevent memory growth + # See: _kv_by_nonce pattern in runtime.py + _grammar_states: dict = {} + def configure_policy_for_model(self, req: ShardLoadModelRequest) -> None: self._mode = "fit" local_count = max(1, len(self.runtime.assigned_layers)) @@ -138,7 +143,10 @@ def process(self, msg: ActivationMessage) -> None: y = self.runtime.model.normalize(x_cast) y = self.runtime.model.lm_project(y) - # Sampling + grammar_schema = getattr( + msg, "grammar_json_schema", None + ) + decoding_config = DecodingConfig( temperature=msg.temperature, top_p=msg.top_p, @@ -146,14 +154,54 @@ def process(self, msg: ActivationMessage) -> None: repetition_penalty=msg.repetition_penalty, min_p=msg.min_p, min_tokens_to_keep=msg.min_tokens_to_keep, + grammar_json_schema=grammar_schema, ) - sampler = Sampler() - result = sampler.sample( + # Get or create grammar state (cached by nonce for multi-token generation) + grammar_state = None + if grammar_schema: + nonce = msg.nonce + if nonce in FitInMemoryPolicy._grammar_states: + grammar_state = FitInMemoryPolicy._grammar_states[ + nonce + ] + # Check if grammar state was already terminated - if so, don't reuse it + if grammar_state is not None and getattr( + grammar_state, "_terminated", False + ): + logger.info( + f"Grammar state for nonce {nonce} already terminated, removing from cache - this should not happen!" + ) + del FitInMemoryPolicy._grammar_states[nonce] + grammar_state = None + else: + logger.debug( + f"Reusing grammar state for nonce {nonce}, _terminated={getattr(grammar_state, '_terminated', False) if grammar_state else None}" + ) + + if grammar_state is None: + logger.debug( + f"Creating new grammar state for nonce {nonce}" + ) + tokenizer = getattr(self.runtime, "tokenizer", None) + model_vocab_size = ( + y.shape[-1] if hasattr(y, "shape") else None + ) + if tokenizer: + grammar_state = Sampler.create_grammar_state( + grammar_schema, tokenizer, model_vocab_size + ) + if grammar_state: + FitInMemoryPolicy._grammar_states[nonce] = ( + grammar_state + ) + + result = Sampler.sample( logits=y, config=decoding_config, req_logprobs=msg.req_logprobs, req_top_logprobs=msg.req_top_logprobs, + grammar_state=grammar_state, ) token_id = result.token_id diff --git a/src/dnet/shard/policies/offload.py b/src/dnet/shard/policies/offload.py index ad960617..9c42f088 100644 --- a/src/dnet/shard/policies/offload.py +++ b/src/dnet/shard/policies/offload.py @@ -24,6 +24,11 @@ class OffloadPolicy(ComputePolicy): Handles 'offload' and 'sliding_fit' modes. """ + # Cache grammar states by nonce to maintain state across token generations + # TODO: Add TTL-based cleanup for _grammar_states to prevent memory growth + # See: _kv_by_nonce pattern in runtime.py + _grammar_states: dict = {} + def configure_policy_for_model(self, req: ShardLoadModelRequest) -> None: local_count = max(1, len(self.runtime.assigned_layers)) requested_w = max(1, int(req.window_size)) @@ -331,7 +336,8 @@ def process(self, msg: ActivationMessage) -> None: y = self.runtime.model.normalize(x_cast) y = self.runtime.model.lm_project(y) - # Sampling + grammar_schema = getattr(msg, "grammar_json_schema", None) + decoding_config = DecodingConfig( temperature=msg.temperature, top_p=msg.top_p, @@ -339,14 +345,45 @@ def process(self, msg: ActivationMessage) -> None: repetition_penalty=msg.repetition_penalty, min_p=msg.min_p, min_tokens_to_keep=msg.min_tokens_to_keep, + grammar_json_schema=grammar_schema, ) - sampler = Sampler() - result = sampler.sample( + # Get or create grammar state (cached by nonce for multi-token generation) + grammar_state = None + if grammar_schema: + nonce = msg.nonce + if nonce in OffloadPolicy._grammar_states: + grammar_state = OffloadPolicy._grammar_states[nonce] + # Check if grammar state was already terminated - if so, don't reuse it + if grammar_state is not None and getattr( + grammar_state, "_terminated", False + ): + logger.debug( + f"Grammar state for nonce {nonce} already terminated, removing from cache" + ) + del OffloadPolicy._grammar_states[nonce] + grammar_state = None + + if grammar_state is None: + tokenizer = getattr(self.runtime, "tokenizer", None) + model_vocab_size = ( + y.shape[-1] if hasattr(y, "shape") else None + ) + if tokenizer: + grammar_state = Sampler.create_grammar_state( + grammar_schema, tokenizer, model_vocab_size + ) + if grammar_state: + OffloadPolicy._grammar_states[nonce] = ( + grammar_state + ) + + result = Sampler.sample( logits=y, config=decoding_config, req_logprobs=msg.req_logprobs, req_top_logprobs=msg.req_top_logprobs, + grammar_state=grammar_state, ) token_id = result.token_id diff --git a/src/dnet/shard/runtime.py b/src/dnet/shard/runtime.py index 890e72c9..4496486c 100644 --- a/src/dnet/shard/runtime.py +++ b/src/dnet/shard/runtime.py @@ -28,6 +28,7 @@ load_embeddings, load_final_norm, load_lm_head, + resolve_tokenizer_dir, ) @@ -106,6 +107,7 @@ def __init__( self.model: Optional[BaseShardModel] = None self.cache: Optional[Any] = None self.model_path: Optional[str] = None + self.tokenizer: Optional[Any] = None # Cached tokenizer for grammar support # Memory Pools self.input_pool: Optional[LayerAwareMemoryPool] = None @@ -280,6 +282,24 @@ def load_model_core(self, req: ShardLoadModelRequest) -> None: int(has_end), int(tied), ) + + # Load tokenizer for grammar-constrained generation (only on end shard) + if has_end: + try: + from transformers import AutoTokenizer + + tok_dir = resolve_tokenizer_dir(self.model_path) + self.tokenizer = AutoTokenizer.from_pretrained(tok_dir) + logger.info( + "Runtime %s: loaded HuggingFace tokenizer for grammar support", + self.shard_id, + ) + except Exception as e: + logger.warning( + "Runtime %s: failed to load tokenizer for grammar: %s", + self.shard_id, + e, + ) except Exception as e: logger.warning( "Runtime %s: failed to load API‑layer weights: %s", self.shard_id, e diff --git a/tests/fakes/policies.py b/tests/fakes/policies.py index f70096ca..a9408600 100644 --- a/tests/fakes/policies.py +++ b/tests/fakes/policies.py @@ -71,7 +71,8 @@ def unload_layers(self, layers): class FakeSampler: - def sample(self, logits, config, req_logprobs, req_top_logprobs): + @staticmethod + def sample(logits, config, req_logprobs, req_top_logprobs, grammar_state=None): from .api import FakeTokenResult return FakeTokenResult(7, -0.1, {7: -0.1}) diff --git a/tests/integration/test_model_catalog.py b/tests/integration/test_model_catalog.py index 0aa89c12..053e772a 100644 --- a/tests/integration/test_model_catalog.py +++ b/tests/integration/test_model_catalog.py @@ -6,8 +6,7 @@ Usage (with servers running): uv run pytest tests/integration/test_model_catalog.py -v -x -Usage (standalone - starts servers automatically): - uv run pytest tests/integration/test_model_catalog.py -v -x --start-servers +Usage (standalone - starts servers automatically):o Usage (in CI - expects servers started externally): # Servers started by workflow steps diff --git a/tests/integration/test_structured_outputs_e2e.py b/tests/integration/test_structured_outputs_e2e.py new file mode 100644 index 00000000..56d60ed9 --- /dev/null +++ b/tests/integration/test_structured_outputs_e2e.py @@ -0,0 +1,310 @@ +"""End-to-end tests for structured outputs functionality. + +Usage (with servers running): + uv run pytest tests/integration/test_structured_outputs_e2e.py -v + +""" + +import json +import logging +import time +from typing import Any, Generator + +import pytest +import requests + +from dnet.api.catalog import get_ci_test_models + +logger = logging.getLogger(__name__) + +# Server configuration +API_HTTP_PORT = 8080 +BASE_URL = f"http://localhost:{API_HTTP_PORT}" + +# Timeouts +HEALTH_CHECK_TIMEOUT = 30 # seconds to wait for server health +INFERENCE_TIMEOUT = 60 # seconds for inference requests +MODEL_LOAD_TIMEOUT = 300 # seconds to wait for model loading + +# Get CI-testable models from catalog +CI_TEST_MODELS = get_ci_test_models() + +# Find the Qwen 4B 4bit model for structured outputs testing +STRUCTURED_OUTPUTS_MODEL = None +for model in CI_TEST_MODELS: + if model["id"] == "Qwen/Qwen3-4B-MLX-4bit": + STRUCTURED_OUTPUTS_MODEL = model + break + +# Fallback if not found +if STRUCTURED_OUTPUTS_MODEL is None: + STRUCTURED_OUTPUTS_MODEL = {"id": "Qwen/Qwen3-4B-MLX-4bit", "alias": "qwen3-4b"} + + +def wait_for_health(url: str, timeout: float = HEALTH_CHECK_TIMEOUT) -> bool: + """Wait for server health endpoint to respond.""" + deadline = time.time() + timeout + while time.time() < deadline: + try: + response = requests.get(f"{url}/health", timeout=2) + if response.status_code == 200: + return True + except (requests.RequestException, requests.ConnectionError): + time.sleep(0.5) + return False + + +@pytest.fixture(scope="module") +def servers() -> Generator[None, None, None]: + """Check that API server is healthy (assumes servers are already running).""" + # Assume servers are already running (like CI does) + if not wait_for_health(BASE_URL): + pytest.skip( + f"Server not healthy at {BASE_URL}/health (start servers manually first)" + ) + + yield + + +def prepare_and_load_model(model_id: str) -> None: + """Prepare topology and load model.""" + # Prepare topology + resp = requests.post( + f"{BASE_URL}/v1/prepare_topology", + json={"model": model_id}, + timeout=MODEL_LOAD_TIMEOUT, + ) + resp.raise_for_status() + + # Load model + resp = requests.post( + f"{BASE_URL}/v1/load_model", + json={"model": model_id}, + timeout=MODEL_LOAD_TIMEOUT, + ) + resp.raise_for_status() + + +def unload_model() -> None: + """Unload the current model. + + Logs a warning if unloading fails, as this is a best-effort cleanup. + """ + try: + resp = requests.post(f"{BASE_URL}/v1/unload_model", timeout=30) + resp.raise_for_status() + except requests.RequestException as e: + logger.warning(f"Failed to unload model (best effort): {e}") + + +@pytest.mark.integration +@pytest.mark.parametrize( + "schema,prompt", + [ + ( + { + "type": "object", + "properties": { + "answer": {"type": "string"}, + "count": {"type": "integer"}, + }, + "required": ["answer"], + }, + "Give me a simple response with a count", + ), + ( + { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name"], + }, + "Create a profile for a person", + ), + ( + { + "type": "object", + "properties": {"items": {"type": "array", "items": {"type": "string"}}}, + "required": ["items"], + }, + "List three fruits", + ), + ], +) +def test_structured_outputs_end_to_end( + servers, schema: dict[str, Any], prompt: str +) -> None: + """Test that structured outputs produce valid JSON conforming to schema.""" + model_id = STRUCTURED_OUTPUTS_MODEL["id"] + + try: + # Prepare and load the model + prepare_and_load_model(model_id) + + # Run the actual test + _test_structured_outputs_core(schema, prompt, model_id) + + finally: + # Cleanup: unload model + unload_model() + + +def _test_structured_outputs_core( + schema: dict[str, Any], prompt: str, model_id: str +) -> None: + """Core test logic for structured outputs.""" + payload = { + "model": model_id, + "messages": [{"role": "user", "content": prompt}], + "structured_outputs": {"json_schema": schema}, + "max_tokens": 500, + "temperature": 0.1, + } + + response = requests.post( + f"{BASE_URL}/v1/chat/completions", json=payload, timeout=INFERENCE_TIMEOUT + ) + assert response.status_code == 200, f"Request failed: {response.text}" + + result = response.json() + assert "choices" in result, f"Response missing 'choices': {result}" + assert len(result["choices"]) > 0, f"No choices in response: {result}" + + choice = result["choices"][0] + assert "message" in choice, f"Choice missing 'message': {choice}" + content = choice["message"].get("content", "") + + # Parse JSON - structured outputs should produce clean JSON (no end tokens) + parsed = json.loads(content) + assert isinstance(parsed, dict), f"Response is not a JSON object: {content}" + + # Verify it matches the schema requirements + for required_field in schema.get("required", []): + assert required_field in parsed, ( + f"Required field '{required_field}' missing from response" + ) + + # Basic type checking for properties + properties = schema.get("properties", {}) + for field_name, field_schema in properties.items(): + if field_name in parsed: + field_type = field_schema.get("type") + if field_type == "string": + assert isinstance(parsed[field_name], str), ( + f"Field '{field_name}' should be string" + ) + elif field_type == "integer": + assert isinstance(parsed[field_name], int), ( + f"Field '{field_name}' should be integer" + ) + elif field_type == "array": + assert isinstance(parsed[field_name], list), ( + f"Field '{field_name}' should be array" + ) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "schema,prompt", + [ + ( + { + "type": "object", + "properties": { + "answer": {"type": "string"}, + "count": {"type": "integer"}, + }, + "required": ["answer"], + }, + "Give me a simple response with a count", + ), + ( + { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name"], + }, + "Create a profile for a person", + ), + ( + { + "type": "object", + "properties": {"items": {"type": "array", "items": {"type": "string"}}}, + "required": ["items"], + }, + "List three fruits", + ), + ], +) +def test_openai_response_format_end_to_end( + servers, schema: dict[str, Any], prompt: str +) -> None: + """Test that OpenAI response_format produces valid JSON conforming to schema.""" + model_id = STRUCTURED_OUTPUTS_MODEL["id"] + + try: + # Prepare and load the model + prepare_and_load_model(model_id) + + # Run the actual test + _test_openai_response_format_core(schema, prompt, model_id) + + finally: + # Cleanup: unload model + unload_model() + + +def _test_openai_response_format_core( + schema: dict[str, Any], prompt: str, model_id: str +) -> None: + """Core test logic for OpenAI response_format.""" + payload = { + "model": model_id, + "messages": [{"role": "user", "content": prompt}], + "response_format": { + "type": "json_schema", + "json_schema": {"name": "test_schema", "schema": schema}, + }, + "max_tokens": 500, + "temperature": 0.1, + } + + response = requests.post( + f"{BASE_URL}/v1/chat/completions", json=payload, timeout=INFERENCE_TIMEOUT + ) + assert response.status_code == 200, f"Request failed: {response.text}" + + result = response.json() + assert "choices" in result, f"Response missing 'choices': {result}" + assert len(result["choices"]) > 0, f"No choices in response: {result}" + + choice = result["choices"][0] + assert "message" in choice, f"Choice missing 'message': {choice}" + content = choice["message"].get("content", "") + + # Parse JSON - OpenAI response_format should produce clean JSON + parsed = json.loads(content) + assert isinstance(parsed, dict), f"Response is not a JSON object: {content}" + + # Verify it matches the schema requirements + for required_field in schema.get("required", []): + assert required_field in parsed, ( + f"Required field '{required_field}' missing from response" + ) + + # Basic type checking for properties + properties = schema.get("properties", {}) + for field_name, field_schema in properties.items(): + if field_name in parsed: + field_type = field_schema.get("type") + if field_type == "string": + assert isinstance(parsed[field_name], str), ( + f"Field '{field_name}' should be string" + ) + elif field_type == "integer": + assert isinstance(parsed[field_name], int), ( + f"Field '{field_name}' should be integer" + ) + elif field_type == "array": + assert isinstance(parsed[field_name], list), ( + f"Field '{field_name}' should be array" + ) diff --git a/tests/subsystems/test_inference_manager.py b/tests/subsystems/test_inference_manager.py index 8c770728..f1deb5fa 100644 --- a/tests/subsystems/test_inference_manager.py +++ b/tests/subsystems/test_inference_manager.py @@ -7,7 +7,7 @@ pytest.importorskip("mlx.core") from dnet.api.inference import InferenceManager # noqa: E402 -from dnet.api.models import ChatRequestModel, ChatMessage # noqa: E402 +from dnet.api.models import ChatRequestModel, ChatMessage, StructuredOutputsParams # noqa: E402 from tests.fakes import ( # noqa: E402 FakeTokenizer, FakeTokenizerWithTemplate, @@ -320,3 +320,73 @@ def test_invalid_request_params_min_tokens_to_keep(): min_tokens_to_keep=0, logprobs=True, ) + + +def test_structured_outputs_valid_json_schema(): + """Test valid JSON schema in structured outputs.""" + valid_schema = {"type": "object", "properties": {"name": {"type": "string"}}} + + # Test direct structured_outputs usage + req = ChatRequestModel( + model="m", + messages=[ChatMessage(role="user", content="test")], + structured_outputs=StructuredOutputsParams(json_schema=valid_schema), + ) + assert req.structured_outputs.json_schema == valid_schema + + +def test_structured_outputs_invalid_json_schema(): + """Test invalid JSON schema validation.""" + + # Missing 'type' field + with pytest.raises(ValidationError): + StructuredOutputsParams( + json_schema={"properties": {"name": {"type": "string"}}} + ) + + # Not a dict + with pytest.raises(ValidationError): + StructuredOutputsParams(json_schema="not a dict") + + # Not JSON serializable + with pytest.raises(ValidationError): + StructuredOutputsParams( + json_schema={"key": set([1, 2, 3])} + ) # sets aren't JSON serializable + + +def test_structured_outputs_inference_integration(): + """Test that structured outputs work in the inference flow.""" + + async def main(): + tok = FakeTokenizer() + mm = FakeModelManager(tok) + ad = FakeStrategyAdapter() + cm = FakeClusterManager() + mgr = InferenceManager(cm, mm, grpc_port=50500, adapter=ad) + await mgr.connect_to_ring("1.2.3.4", 9000, "9.9.9.9:50500") + + schema = {"type": "object", "properties": {"answer": {"type": "string"}}} + req = ChatRequestModel( + model="m", + messages=[ChatMessage(role="user", content="hi")], + max_tokens=5, + structured_outputs=StructuredOutputsParams(json_schema=schema), + ) + + agen = mgr.generate_stream(req) + c0 = await agen.__anext__() + assert c0.choices[0].delta.role == "assistant" + assert ad.reset is True + + nonce = c0.id + ad.queue_token(nonce, FakeTokenResult(1)) + ad.queue_token(nonce, FakeTokenResult(tok.eos_token_id)) + c1 = await agen.__anext__() + assert c1.choices[0].delta.content == "t1" + assert ( + ad.sent[0]["decoding_config"].grammar_json_schema + == '{"type": "object", "properties": {"answer": {"type": "string"}}}' + ) + + asyncio.run(main())