diff --git a/omlx/cache/paged_cache.py b/omlx/cache/paged_cache.py index c3a3569e..fad723d0 100644 --- a/omlx/cache/paged_cache.py +++ b/omlx/cache/paged_cache.py @@ -41,6 +41,32 @@ BlockHash = NewType("BlockHash", bytes) +def resolve_block_extra_keys( + block_end: int, + extra_keys: Optional[Tuple[Any, ...]] = None, + extra_key_token_start: Optional[int] = None, + extra_key_ranges: Optional[List[Tuple[int, Tuple[Any, ...]]]] = None, +) -> Optional[Tuple[Any, ...]]: + """Resolve which cache key salt applies to a block ending at ``block_end``. + + ``extra_key_ranges`` takes precedence over ``extra_keys`` and is intended + for segmented VLM cache keying. + """ + if extra_key_ranges: + selected = None + for start, keys in extra_key_ranges: + if block_end > start: + selected = keys + else: + break + return selected + if extra_keys is not None and ( + extra_key_token_start is None or block_end > extra_key_token_start + ): + return extra_keys + return None + + def compute_block_hash( parent_hash: Optional[BlockHash], token_ids: List[int], @@ -923,6 +949,8 @@ def get_computed_blocks( self, token_ids: List[int], extra_keys: Optional[Tuple[Any, ...]] = None, + extra_key_token_start: Optional[int] = None, + extra_key_ranges: Optional[List[Tuple[int, Tuple[Any, ...]]]] = None, ) -> Tuple[List[CacheBlock], int]: """ Find cached blocks for a token prefix (vLLM style). @@ -948,11 +976,17 @@ def get_computed_blocks( start = i * self.block_size end = start + self.block_size block_tokens = token_ids[start:end] + block_extra_keys = resolve_block_extra_keys( + end, + extra_keys=extra_keys, + extra_key_token_start=extra_key_token_start, + extra_key_ranges=extra_key_ranges, + ) # Compute expected hash block_hash = compute_block_hash( parent_hash, block_tokens, - extra_keys=extra_keys, model_name=self.model_name, + extra_keys=block_extra_keys, model_name=self.model_name, ) # Look up in cache @@ -1103,6 +1137,8 @@ def find_shared_prefix( self, tokens: List[int], extra_keys: Optional[Tuple[Any, ...]] = None, + extra_key_token_start: Optional[int] = None, + extra_key_ranges: Optional[List[Tuple[int, Tuple[Any, ...]]]] = None, ) -> Tuple[List[int], List[int]]: """ Find shared prefix blocks for a token sequence. @@ -1110,7 +1146,10 @@ def find_shared_prefix( Uses get_computed_blocks for consistent chain-hash lookup. """ cached_blocks, num_cached_tokens = self.get_computed_blocks( - tokens, extra_keys=extra_keys + tokens, + extra_keys=extra_keys, + extra_key_token_start=extra_key_token_start, + extra_key_ranges=extra_key_ranges, ) shared_block_ids = [b.block_id for b in cached_blocks] diff --git a/omlx/cache/prefix_cache.py b/omlx/cache/prefix_cache.py index 5bda8213..0d9e0b6a 100644 --- a/omlx/cache/prefix_cache.py +++ b/omlx/cache/prefix_cache.py @@ -20,7 +20,13 @@ from .interface import CacheManager from .paged_ssd_cache import PagedSSDCacheManager -from .paged_cache import BlockTable, CacheBlock, PagedCacheManager, compute_block_hash +from .paged_cache import ( + BlockTable, + CacheBlock, + PagedCacheManager, + compute_block_hash, + resolve_block_extra_keys, +) from .stats import BaseCacheStats, PrefixCacheStats from .type_handlers import CacheType, CacheTypeHandler from .type_registry import CacheTypeRegistry @@ -228,6 +234,8 @@ def fetch_cache( request_id: str, tokens: List[int], extra_keys: Optional[Tuple[Any, ...]] = None, + extra_key_token_start: Optional[int] = None, + extra_key_ranges: Optional[List[Tuple[int, Tuple[Any, ...]]]] = None, ) -> Tuple[Optional[BlockTable], List[int]]: """ Find cached prefix blocks for the given tokens. @@ -247,7 +255,10 @@ def fetch_cache( # Try to find shared prefix blocks shared_block_ids, remaining = self.paged_cache.find_shared_prefix( - tokens, extra_keys=extra_keys + tokens, + extra_keys=extra_keys, + extra_key_token_start=extra_key_token_start, + extra_key_ranges=extra_key_ranges, ) if shared_block_ids: @@ -311,6 +322,8 @@ def store_cache( model_cache_config: Optional[ModelCacheConfig] = None, boundary_snapshots: Optional[Dict[int, List[Any]]] = None, extra_keys: Optional[Tuple[Any, ...]] = None, + extra_key_token_start: Optional[int] = None, + extra_key_ranges: Optional[List[Tuple[int, Tuple[Any, ...]]]] = None, ) -> Optional[BlockTable]: """ Store computed cache for future reuse. @@ -433,9 +446,20 @@ def store_cache( if prev_block and prev_block.block_hash: parent_hash = prev_block.block_hash + block_extra_keys = resolve_block_extra_keys( + global_end, + extra_keys=extra_keys, + extra_key_token_start=extra_key_token_start, + extra_key_ranges=extra_key_ranges, + ) + # Check if this block already exists (deduplication) if len(block_tokens) == self.block_size: - existing_block = self.paged_cache.find_cached_block(block_tokens, parent_hash) + existing_block = self.paged_cache.find_cached_block( + block_tokens, + parent_hash, + extra_keys=block_extra_keys, + ) if existing_block: # Reuse existing block self.paged_cache.increment_ref(existing_block.block_id) @@ -462,13 +486,13 @@ def store_cache( # Compute chain hash for this block block.block_hash = compute_block_hash( parent_hash, block_tokens, - extra_keys=extra_keys, model_name=self.paged_cache.model_name, + extra_keys=block_extra_keys, model_name=self.paged_cache.model_name, ) # Register hash for full blocks (for deduplication) if len(block_tokens) == self.block_size: self.paged_cache.register_block_hash( - block, block_tokens, parent_hash, extra_keys=extra_keys + block, block_tokens, parent_hash, extra_keys=block_extra_keys ) # Extract tensor slice and save to paged SSD diff --git a/omlx/engine/vlm.py b/omlx/engine/vlm.py index b5d843df..38ccbec5 100644 --- a/omlx/engine/vlm.py +++ b/omlx/engine/vlm.py @@ -787,7 +787,7 @@ def _format_messages_for_vlm_template( self, messages: list[dict[str, Any]], num_images: int, - ) -> list[dict[str, Any]]: + ) -> tuple[list[dict[str, Any]], list[tuple[int, int]]]: """Format VLM messages with image tokens on image-bearing user turns.""" from mlx_vlm.prompt_utils import extract_text_from_content, get_message_json @@ -805,8 +805,9 @@ def _format_messages_for_vlm_template( remaining_images = num_images assigned_fallback_images = False formatted_messages: list[dict[str, Any]] = [] + image_message_ranges: list[tuple[int, int]] = [] - for msg in messages: + for idx, msg in enumerate(messages): if not isinstance(msg, dict): msg = {"role": "user", "content": str(msg)} @@ -829,6 +830,9 @@ def _format_messages_for_vlm_template( remaining_images = 0 assigned_fallback_images = True + if msg_num_images > 0: + image_message_ranges.append((idx, msg_num_images)) + formatted_messages.append( get_message_json( model_type, @@ -841,7 +845,7 @@ def _format_messages_for_vlm_template( ) ) - return formatted_messages + return formatted_messages, image_message_ranges def _compute_vision_features( self, pixel_values: Any, extra_model_inputs: dict @@ -912,7 +916,14 @@ def _prepare_vision_inputs( images: list[Any], chat_template_kwargs: dict[str, Any] | None = None, tools: list[dict] | None = None, - ) -> Tuple[List[int], Optional[mx.array], Optional[Dict[str, Any]], Optional[str]]: + ) -> Tuple[ + List[int], + Optional[mx.array], + Optional[Dict[str, Any]], + Optional[str], + int, + List[Tuple[int, str]], + ]: """ Run the full VLM preprocessing pipeline: 1. Apply chat template with image placeholders @@ -925,11 +936,21 @@ def _prepare_vision_inputs( images: List of PIL Image objects Returns: - Tuple of (token_ids, inputs_embeds, extra_kwargs, image_hash): + Tuple of ( + token_ids, + inputs_embeds, + extra_kwargs, + image_hash, + image_cache_key_start, + image_cache_key_ranges, + ): - token_ids: List of token IDs for BatchGenerator - inputs_embeds: Merged vision+text embeddings (or None if text-only) - extra_kwargs: Model-specific kwargs for language model - image_hash: SHA256 hash of images for prefix cache + - image_cache_key_start: Token index where image-aware cache keying begins + - image_cache_key_ranges: Per-image-turn cache key boundaries with + cumulative image hashes """ from mlx_vlm.prompt_utils import apply_chat_template from mlx_vlm.utils import prepare_inputs @@ -948,7 +969,7 @@ def _prepare_vision_inputs( # Build per-message placeholders in oMLX so image-bearing turns always # receive image tokens, regardless of conversation history shape. try: - formatted_messages = self._format_messages_for_vlm_template( + formatted_messages, image_message_ranges = self._format_messages_for_vlm_template( messages, num_images=num_images ) except Exception as e: @@ -964,6 +985,15 @@ def _prepare_vision_inputs( num_images=num_images, return_messages=True, ) + image_message_ranges = [] + for idx, msg in enumerate(messages): + if not isinstance(msg, dict): + continue + image_count = self._count_content_parts( + msg.get("content"), {"image", "image_url", "input_image"} + ) + if image_count > 0: + image_message_ranges.append((idx, image_count)) # Strip partial field from messages (VLM always uses add_generation_prompt=True) detect_and_strip_partial(formatted_messages) @@ -1029,6 +1059,56 @@ def _prepare_vision_inputs( pixel_values = inputs.get("pixel_values") attention_mask = inputs.get("attention_mask") + image_cache_key_start = 0 + image_cache_key_ranges: list[Tuple[int, str]] = [] + if image_message_ranges: + prefix_template_kwargs = { + "tokenize": False, + "add_generation_prompt": False, + } + if self._enable_thinking is not None: + prefix_template_kwargs["enable_thinking"] = self._enable_thinking + if tools: + prefix_template_kwargs["tools"] = tools + if chat_template_kwargs: + prefix_template_kwargs.update(chat_template_kwargs) + + images_consumed = 0 + for msg_idx, msg_num_images in image_message_ranges: + prefix_messages = formatted_messages[:msg_idx] + boundary_tokens = 0 + if prefix_messages: + try: + prefix_prompt = template_target.apply_chat_template( + prefix_messages, **prefix_template_kwargs + ) + except TypeError: + local_kwargs = dict(prefix_template_kwargs) + if chat_template_kwargs: + for key in chat_template_kwargs: + local_kwargs.pop(key, None) + local_kwargs.pop("enable_thinking", None) + prefix_prompt = template_target.apply_chat_template( + prefix_messages, **local_kwargs + ) + prefix_inputs = prepare_inputs( + self._processor, + images=images[:images_consumed] if images_consumed > 0 else None, + prompts=[prefix_prompt] if isinstance(prefix_prompt, str) else prefix_prompt, + ) + prefix_ids = prefix_inputs["input_ids"] + boundary_tokens = ( + len(prefix_ids[0].tolist()) + if prefix_ids.ndim > 1 + else len(prefix_ids.tolist()) + ) + + images_consumed += msg_num_images + cumulative_hash = compute_image_hash(images[:images_consumed]) + image_cache_key_ranges.append((boundary_tokens, cumulative_hash)) + + image_cache_key_start = image_cache_key_ranges[0][0] + # Extract additional model-specific inputs (filter None values # since prepare_inputs may include them after mlx-vlm 348466f) extra_model_inputs = { @@ -1121,11 +1201,18 @@ def _prepare_vision_inputs( # Extract token IDs as list token_ids = input_ids[0].tolist() if input_ids.ndim > 1 else input_ids.tolist() - return token_ids, embed_features.inputs_embeds, extra_kwargs, image_hash + return ( + token_ids, + embed_features.inputs_embeds, + extra_kwargs, + image_hash, + image_cache_key_start, + image_cache_key_ranges, + ) else: # Text-only (no images in this message) token_ids = input_ids[0].tolist() if input_ids.ndim > 1 else input_ids.tolist() - return token_ids, None, None, None + return token_ids, None, None, None, 0, [] def _apply_chat_template( self, @@ -1175,6 +1262,8 @@ async def generate( vlm_inputs_embeds: Any = None, vlm_extra_kwargs: dict[str, Any] | None = None, vlm_image_hash: str | None = None, + vlm_cache_key_start: int = 0, + vlm_cache_key_ranges: Optional[List[Tuple[int, str]]] = None, **kwargs, ) -> GenerationOutput: """Generate a complete response (non-streaming).""" @@ -1214,6 +1303,8 @@ async def generate( vlm_inputs_embeds=vlm_inputs_embeds, vlm_extra_kwargs=vlm_extra_kwargs, vlm_image_hash=vlm_image_hash, + vlm_cache_key_start=vlm_cache_key_start, + vlm_cache_key_ranges=vlm_cache_key_ranges, ) text = clean_special_tokens(output.output_text) @@ -1241,6 +1332,8 @@ async def stream_generate( vlm_inputs_embeds: Any = None, vlm_extra_kwargs: dict[str, Any] | None = None, vlm_image_hash: str | None = None, + vlm_cache_key_start: int = 0, + vlm_cache_key_ranges: Optional[List[Tuple[int, str]]] = None, **kwargs, ) -> AsyncIterator[GenerationOutput]: """Stream generation token by token.""" @@ -1291,6 +1384,8 @@ async def stream_generate( vlm_inputs_embeds=vlm_inputs_embeds, vlm_extra_kwargs=vlm_extra_kwargs, vlm_image_hash=vlm_image_hash, + vlm_cache_key_start=vlm_cache_key_start, + vlm_cache_key_ranges=vlm_cache_key_ranges, **specprefill_kwargs, ) @@ -1337,7 +1432,7 @@ async def chat( await self.start() loop = asyncio.get_running_loop() - prompt, vlm_embeds, vlm_kwargs, image_hash = await loop.run_in_executor( + prompt, vlm_embeds, vlm_kwargs, image_hash, image_cache_key_start, image_cache_key_ranges = await loop.run_in_executor( self._engine._mlx_executor, self._process_chat_messages, messages, tools, kwargs, ) @@ -1354,6 +1449,8 @@ async def chat( vlm_inputs_embeds=vlm_embeds, vlm_extra_kwargs=vlm_kwargs, vlm_image_hash=image_hash, + vlm_cache_key_start=image_cache_key_start, + vlm_cache_key_ranges=image_cache_key_ranges, **kwargs, ) @@ -1379,7 +1476,7 @@ async def stream_chat( # uvicorn from managing HTTP keep-alive connections, causing # TransferEncodingError on the next request (issue #80). loop = asyncio.get_running_loop() - prompt, vlm_embeds, vlm_kwargs, image_hash = await loop.run_in_executor( + prompt, vlm_embeds, vlm_kwargs, image_hash, image_cache_key_start, image_cache_key_ranges = await loop.run_in_executor( self._engine._mlx_executor, self._process_chat_messages, messages, tools, kwargs, ) @@ -1414,6 +1511,8 @@ async def stream_chat( vlm_inputs_embeds=vlm_embeds, vlm_extra_kwargs=vlm_kwargs, vlm_image_hash=image_hash, + vlm_cache_key_start=image_cache_key_start, + vlm_cache_key_ranges=image_cache_key_ranges, **kwargs, ): yield output @@ -1478,7 +1577,7 @@ def _process_chat_messages( messages: list[dict[str, Any]], tools: list[dict] | None, kwargs: dict, - ) -> Tuple[str | list[int], Any, dict | None, str | None]: + ) -> Tuple[str | list[int], Any, dict | None, str | None, int, List[Tuple[int, str]]]: """ Process chat messages, extracting images and preparing VLM inputs. @@ -1490,19 +1589,19 @@ def _process_chat_messages( ct_kwargs = kwargs.pop("chat_template_kwargs", None) - if images: - # Apply OCR-specific prompt if applicable - ocr_messages = self._apply_ocr_prompt(messages) - - # Convert tools for template format (same as text-only path) - template_tools = convert_tools_for_template(tools) if tools else None + # Keep VLM-capable models on one prompt-rendering path, even before the + # first image arrives. Otherwise the conversation switches prompt families + # on the first image-bearing turn and invalidates early prefix blocks. + vlm_messages = self._apply_ocr_prompt(messages) if images else text_messages + template_tools = convert_tools_for_template(tools) if tools else None + token_ids, vlm_embeds, vlm_kwargs, image_hash, image_cache_key_start, image_cache_key_ranges = self._prepare_vision_inputs( + vlm_messages, + images, + chat_template_kwargs=ct_kwargs, + tools=template_tools, + ) - # VLM path: prepare vision inputs - token_ids, vlm_embeds, vlm_kwargs, image_hash = self._prepare_vision_inputs( - ocr_messages, images, - chat_template_kwargs=ct_kwargs, - tools=template_tools, - ) + if images: # Free Metal intermediates from vision encoding. # Vision tower + projector produce large intermediate buffers # that stay in the Metal cache pool until explicitly cleared. @@ -1510,14 +1609,15 @@ def _process_chat_messages( # eventually trigger ProcessMemoryEnforcer aborts (see #667). mx.synchronize() mx.clear_cache() - return token_ids, vlm_embeds, vlm_kwargs, image_hash - else: - # Text-only path: standard chat template - template_tools = convert_tools_for_template(tools) if tools else None - prompt = self._apply_chat_template( - text_messages, template_tools, chat_template_kwargs=ct_kwargs - ) - return prompt, None, None, None + + return ( + token_ids, + vlm_embeds, + vlm_kwargs, + image_hash, + image_cache_key_start, + image_cache_key_ranges, + ) def count_chat_tokens( self, diff --git a/omlx/engine_core.py b/omlx/engine_core.py index 4fb9d363..f2e082df 100644 --- a/omlx/engine_core.py +++ b/omlx/engine_core.py @@ -18,7 +18,7 @@ import time import uuid from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Dict, List, Optional, Set, Union +from typing import Any, AsyncIterator, Dict, List, Optional, Set, Tuple, Union import mlx.core as mx @@ -253,6 +253,8 @@ async def add_request( vlm_inputs_embeds: Optional[Any] = None, vlm_extra_kwargs: Optional[Dict[str, Any]] = None, vlm_image_hash: Optional[str] = None, + vlm_cache_key_start: int = 0, + vlm_cache_key_ranges: Optional[List[Tuple[int, str]]] = None, specprefill: Optional[bool] = None, specprefill_keep_pct: Optional[float] = None, specprefill_threshold: Optional[int] = None, @@ -292,6 +294,8 @@ async def add_request( vlm_inputs_embeds=vlm_inputs_embeds, vlm_extra_kwargs=vlm_extra_kwargs, vlm_image_hash=vlm_image_hash, + vlm_cache_key_start=vlm_cache_key_start, + vlm_cache_key_ranges=vlm_cache_key_ranges, ) # SpecPrefill: resolve per-request settings. diff --git a/omlx/request.py b/omlx/request.py index fe84ba5b..25382a29 100644 --- a/omlx/request.py +++ b/omlx/request.py @@ -10,7 +10,7 @@ import enum import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union if TYPE_CHECKING: from .cache.paged_cache import BlockTable @@ -143,6 +143,8 @@ class Request: vlm_inputs_embeds: Optional[Any] = None # Pre-computed vision+text embeddings (mx.array) vlm_extra_kwargs: Optional[Dict[str, Any]] = None # Model-specific kwargs (e.g., position_ids) vlm_image_hash: Optional[str] = None # SHA256 hash of images for prefix cache + vlm_cache_key_start: int = 0 # Token index where image-specific cache keying starts + vlm_cache_key_ranges: Optional[List[Tuple[int, str]]] = None # [(token_start, cumulative_image_hash)] rope_deltas: float = 0.0 # Per-request mRoPE position delta (set after VLM prefill) # Metadata diff --git a/omlx/scheduler.py b/omlx/scheduler.py index d948011e..facda241 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -2200,11 +2200,26 @@ def add_request(self, request: Request) -> None: extra_keys = None if request.vlm_image_hash: extra_keys = (request.vlm_image_hash,) + extra_key_token_start = ( + request.vlm_cache_key_start if request.vlm_image_hash else None + ) + extra_key_ranges = ( + [ + (start, (image_hash,)) + for start, image_hash in request.vlm_cache_key_ranges + ] + if request.vlm_cache_key_ranges + else None + ) + # Segmented VLM ranges take precedence when present; extra_keys is + # retained as the legacy whole-prompt fallback for non-segmented cases. block_table, remaining = self.block_aware_cache.fetch_cache( request.request_id, request.prompt_token_ids, extra_keys=extra_keys, + extra_key_token_start=extra_key_token_start, + extra_key_ranges=extra_key_ranges, ) if block_table and block_table.num_tokens > 0: # Reconstruct actual KVCache objects from stored tensor data @@ -3369,6 +3384,21 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: store_extra_keys = None if request.vlm_image_hash: store_extra_keys = (request.vlm_image_hash,) + store_extra_key_token_start = ( + request.vlm_cache_key_start + if request.vlm_image_hash + else None + ) + store_extra_key_ranges = ( + [ + (start, (image_hash,)) + for start, image_hash in request.vlm_cache_key_ranges + ] + if request.vlm_cache_key_ranges + else None + ) + # Segmented VLM ranges take precedence when present; + # extra_keys remains the legacy fallback. block_table = self.block_aware_cache.store_cache( request_id, @@ -3377,6 +3407,8 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: model_cache_config=model_cache_config, boundary_snapshots=intermediate_snapshots, extra_keys=store_extra_keys, + extra_key_token_start=store_extra_key_token_start, + extra_key_ranges=store_extra_key_ranges, ) logger.debug( f"Stored paged cache for request {request_id} " diff --git a/tests/test_paged_cache.py b/tests/test_paged_cache.py index bd319d41..55f15e23 100644 --- a/tests/test_paged_cache.py +++ b/tests/test_paged_cache.py @@ -20,6 +20,7 @@ FreeKVCacheBlockQueue, PagedCacheManager, compute_block_hash, + resolve_block_extra_keys, ) @@ -85,6 +86,48 @@ def test_chain_hash(self): assert len({hash1, hash2, hash3}) == 3 +class TestResolveBlockExtraKeys: + """Tests for segmented cache-key resolution.""" + + def test_returns_none_before_first_range(self): + """Blocks before the first multimodal boundary should be unsalted.""" + ranges = [ + (5, ("image-1",)), + (9, ("image-1", "image-2")), + ] + + assert resolve_block_extra_keys(4, extra_key_ranges=ranges) is None + + def test_selects_latest_matching_range(self): + """Blocks should use the latest applicable segmented cache key.""" + ranges = [ + (5, ("image-1",)), + (9, ("image-1", "image-2")), + ] + + assert resolve_block_extra_keys(8, extra_key_ranges=ranges) == ("image-1",) + assert resolve_block_extra_keys(12, extra_key_ranges=ranges) == ( + "image-1", + "image-2", + ) + + def test_ranges_take_precedence_over_legacy_extra_keys(self): + """Segmented cache keys should override the legacy whole-request salt.""" + ranges = [(5, ("image-1",))] + + assert resolve_block_extra_keys( + 8, + extra_keys=("legacy-image",), + extra_key_ranges=ranges, + ) == ("image-1",) + + def test_range_start_at_zero_applies_from_first_block(self): + """A first-image boundary at token 0 should salt the entire sequence.""" + ranges = [(0, ("image-1",))] + + assert resolve_block_extra_keys(4, extra_key_ranges=ranges) == ("image-1",) + + class TestCacheBlock: """Tests for CacheBlock dataclass.""" diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index c58c385a..ded2395f 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -1015,6 +1015,69 @@ def test_store_cache_exact_multiple_creates_all_blocks(self, mx): assert len(result.block_ids) == 2 assert result.num_tokens == 8 + def test_fetch_cache_with_segmented_extra_key_ranges(self): + """Later image changes should preserve reuse before their boundary.""" + block_size = 4 + paged_cache = PagedCacheManager( + block_size=block_size, + max_blocks=100, + model_name="test-model", + initial_blocks=100, + ) + model = MockModel(num_layers=1) + cache = BlockAwarePrefixCache( + model=model, + paged_cache_manager=paged_cache, + ) + + tokens = list(range(12)) + original_ranges = [ + (5, ("image-1",)), + (9, ("image-1", "image-2")), + ] + + stored = cache.store_cache( + "req-store", + tokens, + [], + extra_key_ranges=original_ranges, + ) + assert stored is not None + assert stored.num_tokens == 12 + + exact_table, exact_remaining = cache.fetch_cache( + "req-exact", + tokens, + extra_key_ranges=original_ranges, + ) + assert exact_table is not None + assert exact_table.num_tokens == 12 + assert exact_remaining == [] + + changed_later_image_table, changed_later_image_remaining = cache.fetch_cache( + "req-later-image", + tokens, + extra_key_ranges=[ + (5, ("image-1",)), + (9, ("image-1", "image-3")), + ], + ) + assert changed_later_image_table is not None + assert changed_later_image_table.num_tokens == 8 + assert changed_later_image_remaining == tokens[8:] + + changed_earlier_image_table, changed_earlier_image_remaining = cache.fetch_cache( + "req-earlier-image", + tokens, + extra_key_ranges=[ + (5, ("image-x",)), + (9, ("image-x", "image-2")), + ], + ) + assert changed_earlier_image_table is not None + assert changed_earlier_image_table.num_tokens == 4 + assert changed_earlier_image_remaining == tokens[4:] + def test_store_cache_with_existing_prefix_uses_global_cache_indices(self, mx): """Store new blocks from full-sequence cache slices after cache hit. diff --git a/tests/test_vlm_engine.py b/tests/test_vlm_engine.py index 6aaa3bf3..162a7eb1 100644 --- a/tests/test_vlm_engine.py +++ b/tests/test_vlm_engine.py @@ -160,15 +160,21 @@ def test_skips_when_tokens_not_in_vocab(self): assert getattr(tokenizer, "has_tool_calling", False) is False def test_skips_when_mlx_lm_not_available(self): - """ImportError from mlx_lm → silently skipped.""" + """When neither parser backend is available, injection is skipped.""" engine = _make_engine() tokenizer = MockVLMTokenizer( chat_template=" tool_call.name", vocab={"": 100, "": 101}, ) - with patch.dict("sys.modules", {"mlx_lm": None, "mlx_lm.tokenizer_utils": None}): - # Import will fail + with patch.dict( + "sys.modules", + { + "mlx_vlm.tool_parsers": None, + "mlx_lm": None, + "mlx_lm.tokenizer_utils": None, + }, + ): engine._inject_tool_calling(tokenizer) # Should not crash, attributes not set @@ -426,25 +432,44 @@ def test_deepcopy_no_mutation(self): class TestProcessChatMessages: """Tests for VLMBatchedEngine._process_chat_messages().""" - def test_text_only_uses_chat_template(self): - """Text-only messages → _apply_chat_template() called.""" + @patch("omlx.engine.vlm.extract_images_from_messages") + def test_text_only_uses_vlm_prepare_path(self, mock_extract): + """Text-only turns on a VLM model still use _prepare_vision_inputs().""" + text_msgs = [{"role": "user", "content": "Hello"}] + mock_extract.return_value = (text_msgs, []) + engine = _make_loaded_engine() - engine._apply_chat_template = MagicMock(return_value="") + engine._prepare_vision_inputs = MagicMock( + return_value=([1, 2, 3], None, None, None, 0, []) + ) messages = [{"role": "user", "content": "Hello"}] result = engine._process_chat_messages(messages, tools=None, kwargs={}) - prompt, vlm_embeds, vlm_kwargs, image_hash = result - assert prompt == "" + token_ids, vlm_embeds, vlm_kwargs, image_hash, image_cache_key_start, image_cache_key_ranges = result + assert token_ids == [1, 2, 3] assert vlm_embeds is None assert vlm_kwargs is None assert image_hash is None - engine._apply_chat_template.assert_called_once() + assert image_cache_key_start == 0 + assert image_cache_key_ranges == [] + engine._prepare_vision_inputs.assert_called_once_with( + text_msgs, + [], + chat_template_kwargs=None, + tools=None, + ) + + @patch("omlx.engine.vlm.extract_images_from_messages") + def test_text_only_passes_tools_to_prepare_vision(self, mock_extract): + """Text-only + tools still convert and pass tools through VLM path.""" + text_msgs = [{"role": "user", "content": "Hello"}] + mock_extract.return_value = (text_msgs, []) - def test_text_only_passes_tools(self): - """Text-only + tools → convert_tools_for_template() called.""" engine = _make_loaded_engine() - engine._apply_chat_template = MagicMock(return_value="") + engine._prepare_vision_inputs = MagicMock( + return_value=([1, 2, 3], None, None, None, 0, []) + ) tools = [{"type": "function", "function": {"name": "test", "parameters": {}}}] messages = [{"role": "user", "content": "Hello"}] @@ -454,6 +479,8 @@ def test_text_only_passes_tools(self): engine._process_chat_messages(messages, tools=tools, kwargs={}) mock_convert.assert_called_once_with(tools) + call_kwargs = engine._prepare_vision_inputs.call_args[1] + assert call_kwargs["tools"] == [{"converted": True}] @patch("omlx.engine.vlm.extract_images_from_messages") def test_image_path_calls_prepare_vision(self, mock_extract): @@ -467,7 +494,7 @@ def test_image_path_calls_prepare_vision(self, mock_extract): engine = _make_loaded_engine() engine._apply_ocr_prompt = MagicMock(return_value=text_msgs) engine._prepare_vision_inputs = MagicMock( - return_value=([1, 2, 3], MagicMock(), {}, "hash123") + return_value=([1, 2, 3], MagicMock(), {}, "hash123", 12, [(12, "hash123")]) ) messages = [{"role": "user", "content": [ @@ -478,9 +505,11 @@ def test_image_path_calls_prepare_vision(self, mock_extract): result = engine._process_chat_messages(messages, tools=None, kwargs={}) engine._prepare_vision_inputs.assert_called_once() - token_ids, vlm_embeds, vlm_kwargs, image_hash = result + token_ids, vlm_embeds, vlm_kwargs, image_hash, image_cache_key_start, image_cache_key_ranges = result assert token_ids == [1, 2, 3] assert image_hash == "hash123" + assert image_cache_key_start == 12 + assert image_cache_key_ranges == [(12, "hash123")] @patch("omlx.engine.vlm.extract_images_from_messages") def test_image_path_passes_tools(self, mock_extract): @@ -494,7 +523,7 @@ def test_image_path_passes_tools(self, mock_extract): engine = _make_loaded_engine() engine._apply_ocr_prompt = MagicMock(return_value=text_msgs) engine._prepare_vision_inputs = MagicMock( - return_value=([1, 2, 3], None, None, None) + return_value=([1, 2, 3], None, None, None, 0, []) ) tools = [{"type": "function", "function": {"name": "analyze", "parameters": {}}}] @@ -521,7 +550,7 @@ def test_image_path_without_tools(self, mock_extract): engine = _make_loaded_engine() engine._apply_ocr_prompt = MagicMock(return_value=text_msgs) engine._prepare_vision_inputs = MagicMock( - return_value=([1, 2, 3], None, None, None) + return_value=([1, 2, 3], None, None, None, 0, []) ) messages = [{"role": "user", "content": "Describe"}] @@ -651,10 +680,13 @@ def test_assigns_placeholder_to_late_user_image_turn(self): }, ] - formatted = engine._format_messages_for_vlm_template(messages, num_images=1) + formatted, image_ranges = engine._format_messages_for_vlm_template( + messages, num_images=1 + ) assert self._count_image_placeholders(formatted) == 1 assert self._count_image_placeholders([formatted[-1]]) == 1 + assert image_ranges == [(2, 1)] def test_caps_placeholders_by_loaded_image_count(self): """Do not add more placeholders than successfully loaded images.""" @@ -670,18 +702,24 @@ def test_caps_placeholders_by_loaded_image_count(self): }, ] - formatted = engine._format_messages_for_vlm_template(messages, num_images=1) + formatted, image_ranges = engine._format_messages_for_vlm_template( + messages, num_images=1 + ) assert self._count_image_placeholders(formatted) == 1 + assert image_ranges == [(0, 1)] def test_fallback_inserts_first_user_when_no_explicit_parts(self): """Legacy path: num_images without explicit image parts still injects once.""" engine = _make_loaded_engine(model_type="qwen3_5") messages = [{"role": "user", "content": "Describe this"}] - formatted = engine._format_messages_for_vlm_template(messages, num_images=1) + formatted, image_ranges = engine._format_messages_for_vlm_template( + messages, num_images=1 + ) assert self._count_image_placeholders(formatted) == 1 + assert image_ranges == [(0, 1)] # ---------------------------------------------------------------------------