[Feat][Attention] Add GQA paged prefill FP8 KV cache op#1268
[Feat][Attention] Add GQA paged prefill FP8 KV cache op#1268superAngGao wants to merge 3 commits intotile-ai:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new operation and kernel for Grouped Query Attention (GQA) prefill with paged FP8 KV cache. The implementation includes the GQAPrefillPagedWithFP8KVCacheFwdKernel, the GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp operator, and associated unit tests. Review feedback highlights opportunities to reduce code duplication in the kernel's quantization and dequantization logic by utilizing T.macro for better maintainability.
Gabbering
left a comment
There was a problem hiding this comment.
goose review — 570a2242
honk. A clean structural copy-paste from the non-FP8 paged kernel with FP8 quant/dequant bolted on. Mostly correct. Mostly. One real bug, one perf concern, one test gap.
Bugs
-
tileops/kernels/attention/gqa_fwd.py:2199— Thepage_size % block_n == 0fast-path for reading old cache pages assumes the entireblock_ntile sits inside a single logical page. That's only true whentile_startis page-aligned orpage_offset + block_n <= page_size. The code only checkspage_size % block_n == 0, then computesphysical_startfromtile_start's page. Iftile_starthappens to be mid-page andblock_n == page_size, it reads past the physical page boundary into an adjacent (wrong) physical page. The non-FP8 sibling kernel at line ~1858 has the exact same latent bug, so this was faithfully copied. However: for the default config (block_n=64, page_size=64) the conditionpage_size % block_n == 0is true andtile_end <= old_lenguaranteestile_startis a multiple ofblock_n(==page_size), so it happens to be safe forpage_size == block_n. It breaks if someone autotunes toblock_n=32withpage_size=64and the old_len isn't block_n-aligned —tile_startcan be 32, page_idx would be 0, page_offset 32, and you'd read indices 32..63 from that physical page, which is correct. Actually wait —tile_start = k_idx * block_n, so it's always a multiple ofblock_n, and sincepage_size % block_n == 0,page_offset + block_n <= page_sizealways holds. The goose retracts its fangs on this one. Carry on. -
tileops/kernels/attention/gqa_fwd.py:2159(FP8 append,page_size % block_nfast path) — There is no synchronisation barrier between the FP8 cache append (guarded byby < heads_kv) and the attention read loop that reads fromk_pages/v_pages. AllheadsCTAs in the same(bx, bz)group share the samek_pages/v_pagesglobal memory. CTAs withby >= heads_kvskip the append and fall straight into the attention loop, where the boundary tile (tile_start < old_len < tile_endortile_start >= old_len) readsk_new/v_newdirectly — so those CTAs don't actually read the freshly-written FP8 data for the current chunk. The old-cache tiles (tile_end <= old_len) only read pages written by previous calls, so there's no WAR hazard there either. The non-FP8 sibling has exactly the same structure and ships; so this is architecturally consistent and not a new bug introduced here. Fine.Actually — let me look harder. For the old cache dequant path the kernel reads
k_pages[..., cur_kv_head, ...]which is data written by a previous kernel launch. No hazard. For the new chunk path (tile_start >= old_len) it reads fromk_new/v_new(read-only inputs). For the boundary path it conditionally reads either old cache ork_new/v_new. So the freshly-appended FP8 data ink_pagesis never consumed within the same launch. The goose is satisfied. No sync bug.However, there IS a subtle data-race in the append itself: multiple CTAs with the same
(bz, by)but differentbxvalues all perform the append for their ownblock_mtile of new tokens, which is fine — they write disjoint positions. But within a single CTA, ifby < heads_kvandby != cur_kv_head(which isby // groups), the append writes to headbywhile the attention reads from headcur_kv_head. These are different head indices whengroups > 1, so different memory locations. No race. Good.The goose found no actual bugs in the kernel logic. Disappointing but correct.
Performance
tileops/kernels/attention/gqa_fwd.py:2148-2210— Every element in the old-cache dequant path does two casts and a multiply per element, per KV:fp8 → float32 → multiply → cast to dtype. For thepage_size % block_n == 0fast-path, this is2 * block_n * dimcasts +block_n * dimmultiplies, repeated for both K and V. The scale valuesk_scale[0]andv_scale[0]are loaded from global memory inside the innermost parallel loop — relying on the compiler to hoist them to registers. If tilelang doesn't hoist scalar loads out ofT.Parallel(it usually does, but worth verifying), every thread re-reads from global. Consider loadingk_scale[0]andv_scale[0]into a local variable before the pipelined loop. This is a nit on a brand-new kernel path, but for long-prefix scenarios (e.g. the 32k-prefix workload in the manifest) the dequant overhead on old pages dominates, so it matters.
Test gaps
-
tests/ops/attention/test_gqa_prefill_paged.py:213-312— The FP8 test only exercisesdtype=float16. The manifest declaresq: {dtype: "float16 | bfloat16"}, and the Op's_validate_attention_dtypeaccepts bfloat16, but there is zero test coverage forbfloat16 + float8_e4m3fncache. The existing non-FP8 tests parametrize over both dtypes. Add at least onebfloat16case — FP8 dequant → bf16 cast has different rounding behaviour than → fp16, and the tolerance may need adjusting. -
tests/ops/attention/test_gqa_prefill_paged.py:213-312— All three FP8 test cases usepage_size=64, block_n=64(default config). The non-FP8 tests have a dedicatedtest_gqa_prefill_paged_with_kv_cache_page_sizesthat exercisespage_size ∈ {16, 32, 128}to hit theblock_n % page_size == 0and generic fallback code paths in the kernel. The FP8 kernel has identical branching for those paths (with added dequant), but none of them are tested. Apage_size=16orpage_size=128case would exercise theblock_n % page_sizeand per-element fallback dequant paths respectively.
|
Updated in 79dd506. Changes:
Validation:
|
Ibuki-wind
left a comment
There was a problem hiding this comment.
Overall
Two blocking issues remain in the changed API surface; address the inline comments before the next round.
| if tensor.dtype != self.cache_dtype: | ||
| raise ValueError( | ||
| f"Expected {name}.dtype {self.cache_dtype}, got {tensor.dtype}") | ||
| for name, tensor in [("k_scale", k_scale), ("v_scale", v_scale)]: |
There was a problem hiding this comment.
GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp._validate_forward_inputs only checks k_scale/v_scale shape and dtype while the kernel divides by those values in quantize_fp8; zero or non-finite scales will silently produce saturated/NaN FP8 cache and output -> reject non-finite or <= 0 k_scale/v_scale values before launching the kernel, with a focused validation test.
Summary
Add a storage-only FP8 KV cache variant for paged GQA prefill:
GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp.q,k_new, andv_newin the attention dtype for the current chunk.float8_e4m3fn.k_scale/v_scaletensors of shape(1,).This is the storage-only path. Fused RoPE + FP8 append and FP8 tensor-core attention compute are intentionally left as follow-up kernel variants.
Tests
Ran with the nightly test image
tileops-runner:nightly-tl019-fullstack-no-tileops-ldfixon H200: