Skip to content

Make FlashInfer cache dtype hardware-aware (fix bf16 on SM<80)#56

Open
os-gabe wants to merge 1 commit into
Robbyant:mainfrom
os-gabe:flashinfer-fp16-sm75
Open

Make FlashInfer cache dtype hardware-aware (fix bf16 on SM<80)#56
os-gabe wants to merge 1 commit into
Robbyant:mainfrom
os-gabe:flashinfer-fp16-sm75

Conversation

@os-gabe
Copy link
Copy Markdown

@os-gabe os-gabe commented Apr 29, 2026

FlashInferKVCacheManager unconditionally mapped fp32 -> bf16, which fails at runtime on SM<80 (Turing/Volta, e.g. Titan RTX) where bf16 kernels aren't available. AggregatorStream also passed tokens.dtype, which autocast-exempt ops (LayerNorm) leak as fp32, so the bug fires even when demo.py selects fp16.

  • flashinfer_cache: hardware-aware fp32/None fallback (bf16 only on SM>=8).
  • aggregator/stream: prefer aggregator parameter dtype before falling through to the cache's default.

SM>=80 behavior is unchanged.

FlashInferKVCacheManager unconditionally mapped fp32 -> bf16, which fails
at runtime on SM<80 (Turing/Volta, e.g. Titan RTX) where bf16 kernels
aren't available. AggregatorStream also passed tokens.dtype, which
autocast-exempt ops (LayerNorm) leak as fp32, so the bug fires even when
demo.py selects fp16.

- flashinfer_cache: hardware-aware fp32/None fallback (bf16 only on SM>=8).
- aggregator/stream: prefer aggregator parameter dtype before falling
  through to the cache's default.

SM>=80 behavior is unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant