From 00a70c4081921b91e0f401f1e87b7e845607a34b Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 25 Jun 2026 14:04:28 +0000 Subject: [PATCH] fix(prefix-cache): bypass prefix caching for multimodal sequences Prefix-cache block hashes are computed from token ids only. All image placeholder tokens share a single id (e.g. 248056), so blocks of image placeholders from *different* images hash identically and the token-id equality guard also passes -- a later image request reuses the wrong image's cached KV. The vision tower then produces N embeddings for the new image while only the uncached placeholders remain in the current forward, giving `inputs_embeds[mask] = vision_embeds` shape mismatches (e.g. [64,2048] vs [4,2048]) that kill the model-runner process; the HTTP layer survives so /health still returns 200 while requests hang. ATOM has no multimodal-aware cache hashing yet, so don't prefix-cache sequences that carry multimodal data: can_allocate returns no cache hits and hash_blocks registers no hashes for them. Text sequences are unaffected and keep full prefix caching. Validated on gfx1151: three different images in a row (previously the 2nd crashed) now all succeed with prefix caching enabled; text caching intact. A future enhancement could fold an mm content hash into the block hash to also cache identical images across turns. --- atom/model_engine/block_manager.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/atom/model_engine/block_manager.py b/atom/model_engine/block_manager.py index 4ea5c14948..d6f4298658 100644 --- a/atom/model_engine/block_manager.py +++ b/atom/model_engine/block_manager.py @@ -146,7 +146,12 @@ def can_allocate(self, seq: Sequence) -> int: # blocks. See `allocate()` for the budget reasoning. if seq.has_per_req_cache and not self.free_per_req_cache_groups: return -1 - if not self.enable_prefix_caching: + # Multimodal sequences bypass prefix caching: image-placeholder tokens + # all share a single id, so token-id hashing collides across different + # images and would reuse the wrong image's KV (vision-embeds vs + # placeholder count mismatch -> runner crash). ATOM has no mm-aware + # cache hashing yet, so simply don't prefix-cache multimodal seqs. + if not self.enable_prefix_caching or seq.multimodal_data is not None: if len(self.free_block_ids_set) < seq.num_blocks: return -1 return 0 @@ -224,7 +229,7 @@ def hash_blocks(self, seq: Sequence, num_new_tokens: int) -> None: single-shot prefill that's `seq.num_tokens - seq.num_cached_tokens`; chunked prefill will pass the per-chunk count. """ - if not self.enable_prefix_caching: + if not self.enable_prefix_caching or seq.multimodal_data is not None: return start = seq.num_cached_tokens // self.block_size end = (seq.num_cached_tokens + num_new_tokens) // self.block_size