Skip to content

fix(prefix-cache): bypass prefix caching for multimodal sequences#1358

Open
carlushuang wants to merge 1 commit into
mainfrom
carhuang/mm_prefix_cache_bypass
Open

fix(prefix-cache): bypass prefix caching for multimodal sequences#1358
carlushuang wants to merge 1 commit into
mainfrom
carhuang/mm_prefix_cache_bypass

Conversation

@carlushuang

Copy link
Copy Markdown
Collaborator

fix(prefix-cache): bypass prefix caching for multimodal sequences

Prefix-cache block hashes (BlockManager.compute_hash) are computed from token ids only. All image-placeholder tokens share a single id (e.g. 248056), so blocks of placeholders from different images hash identically — and the token_ids != token_ids equality guard also passes, since the ids really are identical. A later image request therefore matches a cached prefix built from an earlier, different image and reuses its KV.

The vision tower then encodes N embeddings for the new image while only the uncached placeholders remain in the current forward, so the merge

inputs_embeds[mask] = vision_embeds      # e.g. 64 embeds, 4 placeholder slots
RuntimeError: shape mismatch [64, 2048] vs [4, 2048]

kills the model-runner process. The HTTP layer survives, so /health keeps returning 200 while every subsequent request hangs (silent stall). A single/cold image works; the crash only appears on the second image that shares a cached prefix.

Fix

ATOM has no multimodal-aware cache hashing yet, so don't prefix-cache sequences that carry multimodal data:

  • can_allocate returns no cache hits for them, and
  • hash_blocks registers no hashes for them.

Text sequences are unaffected and keep full prefix caching.

Validation (gfx1151, prefix caching enabled)

Three different images in a row — previously the 2nd crashed the runner — now all succeed and are identified correctly (Red/Blue/Green); plain text completion still works (text caching intact).

Follow-up

A future enhancement could fold an mm content hash into the block hash to also cache identical images across turns (multi-turn VL chat), rather than skipping caching entirely.

Prefix-cache block hashes are computed from token ids only. All image
placeholder tokens share a single id (e.g. 248056), so blocks of image
placeholders from *different* images hash identically and the token-id
equality guard also passes -- a later image request reuses the wrong
image's cached KV. The vision tower then produces N embeddings for the new
image while only the uncached placeholders remain in the current forward,
giving `inputs_embeds[mask] = vision_embeds` shape mismatches
(e.g. [64,2048] vs [4,2048]) that kill the model-runner process; the HTTP
layer survives so /health still returns 200 while requests hang.

ATOM has no multimodal-aware cache hashing yet, so don't prefix-cache
sequences that carry multimodal data: can_allocate returns no cache hits
and hash_blocks registers no hashes for them. Text sequences are
unaffected and keep full prefix caching.

Validated on gfx1151: three different images in a row (previously the 2nd
crashed) now all succeed with prefix caching enabled; text caching intact.

A future enhancement could fold an mm content hash into the block hash to
also cache identical images across turns.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant