Skip to content

Route prefix-cache-hit prefill through sink ASM MHA kernel#1345

Merged
valarLip merged 1 commit into
mainfrom
yihonglie/asm-prefix-prefill
Jun 25, 2026
Merged

Route prefix-cache-hit prefill through sink ASM MHA kernel#1345
valarLip merged 1 commit into
mainfrom
yihonglie/asm-prefix-prefill

Conversation

@yhl-amd

@yhl-amd yhl-amd commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

On gfx1250 with ATOM_USE_UNIFIED_ATTN, a prefix-cache hit during prefill fell back to the Triton unified_attention path instead of the sink ASM varlen kernel, because _can_attempt_prefill_sink_asm bailed on has_cached and on max_seqlen_q != max_seqlen_k.

The gfx1250 sink varlen ASM kernel (fmha_fwd_with_sink_varlen_asm) actually handles bottom-right causal for sq < sk (chunked-prefill); rope_cache already gathers cached+new KV into a dense packed [total_kv, ...] tensor that the kernel consumes, and cu_seqlens_q/cu_seqlens_k carry the per-request new-token vs cached+new lengths. Verified on gfx1250 against a bottom-right causal + per-head sink reference (single/multi-batch, GQA, sq=1) within bf16 tolerance.

Changes:

  • _can_attempt_prefill_sink_asm: drop the has_cached and max_seqlen_q == max_seqlen_k gates.
  • rope_cache: decide the prefill backend before gathering (q/k/v still have matching token counts) and stash it; gather only when the ASM path will consume it. The Triton path reads the paged KV cache directly via block_table, so skip the now-wasted prefix gather there.
  • dispatch_backend: reuse the stashed decision instead of re-deriving it from the gathered (sq != sk) tensors, which would fail the q.shape[0] == k.shape[0] check.

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

On gfx1250 with ATOM_USE_UNIFIED_ATTN, a prefix-cache hit during prefill
fell back to the Triton unified_attention path instead of the sink ASM
varlen kernel, because _can_attempt_prefill_sink_asm bailed on has_cached
and on max_seqlen_q != max_seqlen_k.

The gfx1250 sink varlen ASM kernel (fmha_fwd_with_sink_varlen_asm) actually
handles bottom-right causal for sq < sk (chunked-prefill), and cu_seqlens_q/
cu_seqlens_k already carry the per-request new-token vs cached+new lengths.
Verified on gfx1250 against a bottom-right causal + per-head sink reference
(single/multi-batch, GQA, sq=1) within bf16 tolerance, and end-to-end on
gpt-oss-120b (full-attention layers take the ASM path on a cache hit; the
forced-Triton path never gathers).

Changes:
- _can_attempt_prefill_sink_asm: drop the has_cached and
  max_seqlen_q == max_seqlen_k gates.
- prefill_attention: gather the cached+new KV into a dense packed tensor here,
  where the ASM varlen kernel consumes it. Each prefill backend now prepares
  its own KV: the ASM path gathers; the Triton path reads the paged cache
  directly via block_table and never gathers.
- rope_cache: no longer gathers, so dispatch_backend sees q/k with matching
  token counts (sq == sk) and _can_use_prefill_sink_asm's shape check stays
  valid.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@HaonanWang98 HaonanWang98 force-pushed the yihonglie/asm-prefix-prefill branch from bb24447 to be3520a Compare June 25, 2026 09:18
@valarLip valarLip merged commit 57215ec into main Jun 25, 2026
24 of 32 checks passed
@valarLip valarLip deleted the yihonglie/asm-prefix-prefill branch June 25, 2026 12:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants